Add handshakes.max_rate to limit new handshakes per second

Nebula is vulnerable to DoS via handshake flooding since each incoming
  handshake performs expensive DH operations. This adds a token bucket
  rate limiter to the handshake manager that caps both inbound and
  outbound new handshakes per second. When the limit is reached, new
  handshakes are silently dropped and counted via the
  handshake_manager.rate_limited metric.

  Configured via handshakes.max_rate (default 0 = unlimited).

  Co-Authored-By: Claude <svc-devxp-claude@slack-corp.com>
This commit is contained in:
Jay Wren
2026-03-30 14:48:04 -04:00
parent 91d1f4675a
commit 3df60ae195
5 changed files with 160 additions and 12 deletions

View File

@@ -342,6 +342,14 @@ logging:
# after receiving the response for lighthouse queries # after receiving the response for lighthouse queries
#trigger_buffer: 64 #trigger_buffer: 64
# max_rate limits the number of new handshakes per second. Both incoming and outgoing new
# handshakes count against this limit. Once the limit is reached, new handshakes are dropped
# until the next second. A value of 0 means unlimited (default).
# This is useful for preventing DoS attacks that attempt to exhaust CPU with handshake crypto.
# Running `openssl speed ecdhp256` on your hardware can be a good rule of thumb for choosing
# a max, as each handshake performs similar DH operations.
#max_rate: 0
# Tunnel manager settings # Tunnel manager settings
#tunnels: #tunnels:
# drop_inactive controls whether inactive tunnels are maintained or dropped after the inactive_timeout period has # drop_inactive controls whether inactive tunnels are maintained or dropped after the inactive_timeout period has

View File

@@ -23,22 +23,25 @@ const (
DefaultHandshakeRetries = 10 DefaultHandshakeRetries = 10
DefaultHandshakeTriggerBuffer = 64 DefaultHandshakeTriggerBuffer = 64
DefaultUseRelays = true DefaultUseRelays = true
DefaultMaxHandshakeRate = 0 // 0 means unlimited
) )
var ( var (
defaultHandshakeConfig = HandshakeConfig{ defaultHandshakeConfig = HandshakeConfig{
tryInterval: DefaultHandshakeTryInterval, tryInterval: DefaultHandshakeTryInterval,
retries: DefaultHandshakeRetries, retries: DefaultHandshakeRetries,
triggerBuffer: DefaultHandshakeTriggerBuffer, triggerBuffer: DefaultHandshakeTriggerBuffer,
useRelays: DefaultUseRelays, useRelays: DefaultUseRelays,
maxHandshakeRate: DefaultMaxHandshakeRate,
} }
) )
type HandshakeConfig struct { type HandshakeConfig struct {
tryInterval time.Duration tryInterval time.Duration
retries int64 retries int64
triggerBuffer int triggerBuffer int
useRelays bool useRelays bool
maxHandshakeRate int
messageMetrics *MessageMetrics messageMetrics *MessageMetrics
} }
@@ -58,9 +61,15 @@ type HandshakeManager struct {
messageMetrics *MessageMetrics messageMetrics *MessageMetrics
metricInitiated metrics.Counter metricInitiated metrics.Counter
metricTimedOut metrics.Counter metricTimedOut metrics.Counter
metricRateLimited metrics.Counter
f *Interface f *Interface
l *logrus.Logger l *logrus.Logger
// Rate limiting for new handshakes (token bucket)
rateBucket int // tokens currently available
rateMax int // max tokens (== max handshakes per second), 0 means unlimited
rateLastTick time.Time
// can be used to trigger outbound handshake for the given vpnIp // can be used to trigger outbound handshake for the given vpnIp
trigger chan netip.Addr trigger chan netip.Addr
} }
@@ -116,10 +125,41 @@ func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *Lig
messageMetrics: config.messageMetrics, messageMetrics: config.messageMetrics,
metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil), metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil), metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil),
metricRateLimited: metrics.GetOrRegisterCounter("handshake_manager.rate_limited", nil),
rateBucket: config.maxHandshakeRate,
rateMax: config.maxHandshakeRate,
rateLastTick: time.Now(),
l: l, l: l,
} }
} }
// handshakeRateAllow checks the token bucket rate limiter and returns true if a
// new handshake is allowed. Must be called with hm.Lock held.
func (hm *HandshakeManager) handshakeRateAllow(now time.Time) bool {
if hm.rateMax == 0 {
return true
}
// Refill tokens based on elapsed time
elapsed := now.Sub(hm.rateLastTick)
if elapsed >= time.Second {
// Add tokens for full seconds elapsed
tokens := int(elapsed/time.Second) * hm.rateMax
hm.rateBucket += tokens
if hm.rateBucket > hm.rateMax {
hm.rateBucket = hm.rateMax
}
hm.rateLastTick = now
}
if hm.rateBucket > 0 {
hm.rateBucket--
return true
}
return false
}
func (hm *HandshakeManager) Run(ctx context.Context) { func (hm *HandshakeManager) Run(ctx context.Context) {
clockSource := time.NewTicker(hm.config.tryInterval) clockSource := time.NewTicker(hm.config.tryInterval)
defer clockSource.Stop() defer clockSource.Stop()
@@ -149,6 +189,15 @@ func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *head
case header.HandshakeIXPSK0: case header.HandshakeIXPSK0:
switch h.MessageCounter { switch h.MessageCounter {
case 1: case 1:
// Check rate limit for new incoming handshakes
hm.Lock()
allowed := hm.handshakeRateAllow(time.Now())
hm.Unlock()
if !allowed {
hm.metricRateLimited.Inc(1)
hm.l.WithField("from", via).Debug("Handshake rate limit reached, dropping incoming handshake")
return
}
ixHandshakeStage1(hm.f, via, packet, h) ixHandshakeStage1(hm.f, via, packet, h)
case 2: case 2:
@@ -447,6 +496,14 @@ func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*Han
return hh.hostinfo return hh.hostinfo
} }
// Check rate limit for new outbound handshakes
if !hm.handshakeRateAllow(time.Now()) {
hm.metricRateLimited.Inc(1)
hm.l.WithField("vpnAddr", vpnAddr).Debug("Handshake rate limit reached, dropping outbound handshake")
hm.Unlock()
return nil
}
hostinfo := &HostInfo{ hostinfo := &HostInfo{
vpnAddrs: []netip.Addr{vpnAddr}, vpnAddrs: []netip.Addr{vpnAddr},
HandshakePacket: make(map[uint8][]byte, 0), HandshakePacket: make(map[uint8][]byte, 0),

View File

@@ -65,6 +65,85 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
assert.NotContains(t, blah.vpnIps, ip) assert.NotContains(t, blah.vpnIps, ip)
} }
func Test_HandshakeManagerRateLimit(t *testing.T) {
l := test.NewLogger()
localrange := netip.MustParsePrefix("10.1.1.1/24")
preferredRanges := []netip.Prefix{localrange}
mainHM := newHostMap(l)
mainHM.preferredRanges.Store(&preferredRanges)
lh := newTestLighthouse()
cs := &CertState{
initiatingVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
}
config := defaultHandshakeConfig
config.maxHandshakeRate = 2
hm := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, config)
hm.f = &Interface{handshakeManager: hm, pki: &PKI{}, l: l}
hm.f.pki.cs.Store(cs)
// Should allow up to maxHandshakeRate handshakes
ip1 := netip.MustParseAddr("172.1.1.1")
ip2 := netip.MustParseAddr("172.1.1.2")
ip3 := netip.MustParseAddr("172.1.1.3")
h1 := hm.StartHandshake(ip1, nil)
assert.NotNil(t, h1, "first handshake should be allowed")
h2 := hm.StartHandshake(ip2, nil)
assert.NotNil(t, h2, "second handshake should be allowed")
// Third should be rate limited
h3 := hm.StartHandshake(ip3, nil)
assert.Nil(t, h3, "third handshake should be rate limited")
// After advancing time by 1 second, tokens should refill
hm.Lock()
hm.rateLastTick = hm.rateLastTick.Add(-time.Second)
hm.Unlock()
h3 = hm.StartHandshake(ip3, nil)
assert.NotNil(t, h3, "handshake should be allowed after token refill")
}
func Test_HandshakeManagerRateLimitUnlimited(t *testing.T) {
l := test.NewLogger()
localrange := netip.MustParsePrefix("10.1.1.1/24")
preferredRanges := []netip.Prefix{localrange}
mainHM := newHostMap(l)
mainHM.preferredRanges.Store(&preferredRanges)
lh := newTestLighthouse()
cs := &CertState{
initiatingVersion: cert.Version1,
privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{},
}
// Default config has maxHandshakeRate=0 (unlimited)
hm := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
hm.f = &Interface{handshakeManager: hm, pki: &PKI{}, l: l}
hm.f.pki.cs.Store(cs)
// Should allow many handshakes with no limit
// Limited to 10 due to test lighthouse query channel buffer
for i := 0; i < 10; i++ {
ip := netip.MustParseAddr("172.1.1.1").As16()
ip[15] = byte(i + 1)
addr := netip.AddrFrom16(ip)
h := hm.StartHandshake(addr, nil)
assert.NotNil(t, h, "handshake %d should be allowed with unlimited rate", i)
}
}
func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) { func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
for _, i := range tw.t.wheel { for _, i := range tw.t.wheel {
n := i.Head n := i.Head

View File

@@ -204,10 +204,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
useRelays := c.GetBool("relay.use_relays", DefaultUseRelays) && !c.GetBool("relay.am_relay", false) useRelays := c.GetBool("relay.use_relays", DefaultUseRelays) && !c.GetBool("relay.am_relay", false)
handshakeConfig := HandshakeConfig{ handshakeConfig := HandshakeConfig{
tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval), tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)), retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)),
triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer), triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
useRelays: useRelays, useRelays: useRelays,
maxHandshakeRate: c.GetInt("handshakes.max_rate", DefaultMaxHandshakeRate),
messageMetrics: messageMetrics, messageMetrics: messageMetrics,
} }

3
ssh.go
View File

@@ -632,6 +632,9 @@ func sshCreateTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) e
} }
hostInfo = ifce.handshakeManager.StartHandshake(vpnAddr, nil) hostInfo = ifce.handshakeManager.StartHandshake(vpnAddr, nil)
if hostInfo == nil {
return w.WriteLine("Handshake rate limit reached")
}
if addr.IsValid() { if addr.IsValid() {
hostInfo.SetRemote(addr) hostInfo.SetRemote(addr)
} }