Files
nebula/handshake_manager_test.go
Jay Wren 7794e93762 Address PR feedback: remove outbound rate limit, improve config docs
Remove rate limiting from StartHandshake (outbound) since DoS
protection only needs to limit inbound handshakes. This also avoids
returning nil from StartHandshake which historically always returned
non-nil. Update config comment to note openssl speed is single-core
and suggest scaling by routines.

Co-Authored-By: Claude <svc-devxp-claude@slack-corp.com>
2026-04-10 14:36:32 -04:00

165 lines
4.6 KiB
Go

package nebula
import (
"net/netip"
"testing"
"time"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert"
)
func Test_NewHandshakeManagerVpnIp(t *testing.T) {
l := test.NewLogger()
localrange := netip.MustParsePrefix("10.1.1.1/24")
ip := netip.MustParseAddr("172.1.1.2")
preferredRanges := []netip.Prefix{localrange}
mainHM := newHostMap(l)
mainHM.preferredRanges.Store(&preferredRanges)
lh := newTestLighthouse()
cs := &CertState{
initiatingVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
}
blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
blah.f = &Interface{handshakeManager: blah, pki: &PKI{}, l: l}
blah.f.pki.cs.Store(cs)
now := time.Now()
blah.NextOutboundHandshakeTimerTick(now)
i := blah.StartHandshake(ip, nil)
i2 := blah.StartHandshake(ip, nil)
assert.Same(t, i, i2)
i.remotes = NewRemoteList([]netip.Addr{}, nil)
// Adding something to pending should not affect the main hostmap
assert.Empty(t, mainHM.Hosts)
// Confirm they are in the pending index list
assert.Contains(t, blah.vpnIps, ip)
// Jump ahead `HandshakeRetries` ticks, offset by one to get the sleep logic right
for i := 1; i <= DefaultHandshakeRetries+1; i++ {
now = now.Add(time.Duration(i) * DefaultHandshakeTryInterval)
blah.NextOutboundHandshakeTimerTick(now)
}
// Confirm they are still in the pending index list
assert.Contains(t, blah.vpnIps, ip)
// Tick 1 more time, a minute will certainly flush it out
blah.NextOutboundHandshakeTimerTick(now.Add(time.Minute))
// Confirm they have been removed
assert.NotContains(t, blah.vpnIps, ip)
}
func Test_HandshakeManagerRateLimit(t *testing.T) {
l := test.NewLogger()
localrange := netip.MustParsePrefix("10.1.1.1/24")
preferredRanges := []netip.Prefix{localrange}
mainHM := newHostMap(l)
mainHM.preferredRanges.Store(&preferredRanges)
lh := newTestLighthouse()
config := defaultHandshakeConfig
config.maxHandshakeRate = 2
hm := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, config)
hm.f = &Interface{handshakeManager: hm, pki: &PKI{}, l: l}
now := time.Now()
// Should allow up to maxHandshakeRate handshakes
hm.Lock()
assert.True(t, hm.handshakeRateAllow(now), "first handshake should be allowed")
assert.True(t, hm.handshakeRateAllow(now), "second handshake should be allowed")
assert.False(t, hm.handshakeRateAllow(now), "third handshake should be rate limited")
hm.Unlock()
// After advancing time by 1 second, tokens should refill
hm.Lock()
assert.True(t, hm.handshakeRateAllow(now.Add(time.Second)), "handshake should be allowed after token refill")
hm.Unlock()
}
func Test_HandshakeManagerRateLimitUnlimited(t *testing.T) {
l := test.NewLogger()
localrange := netip.MustParsePrefix("10.1.1.1/24")
preferredRanges := []netip.Prefix{localrange}
mainHM := newHostMap(l)
mainHM.preferredRanges.Store(&preferredRanges)
lh := newTestLighthouse()
cs := &CertState{
initiatingVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
}
// Default config has maxHandshakeRate=0 (unlimited)
hm := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
hm.f = &Interface{handshakeManager: hm, pki: &PKI{}, l: l}
hm.f.pki.cs.Store(cs)
// Should allow many handshakes with no limit
// Limited to 10 due to test lighthouse query channel buffer
for i := 0; i < 10; i++ {
ip := netip.MustParseAddr("172.1.1.1").As16()
ip[15] = byte(i + 1)
addr := netip.AddrFrom16(ip)
h := hm.StartHandshake(addr, nil)
assert.NotNil(t, h, "handshake %d should be allowed with unlimited rate", i)
}
}
func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
for _, i := range tw.t.wheel {
n := i.Head
for n != nil {
c++
n = n.Next
}
}
return c
}
type mockEncWriter struct {
}
func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) {
return
}
func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) {
return
}
func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) {
return
}
func (mw *mockEncWriter) Handshake(_ netip.Addr) {}
func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo {
return nil
}
func (mw *mockEncWriter) GetCertState() *CertState {
return &CertState{initiatingVersion: cert.Version2}
}