From 3df60ae195fff53fcc7f7bb2ca5492508890ce50 Mon Sep 17 00:00:00 2001 From: Jay Wren Date: Mon, 30 Mar 2026 14:48:04 -0400 Subject: [PATCH] Add handshakes.max_rate to limit new handshakes per second Nebula is vulnerable to DoS via handshake flooding since each incoming handshake performs expensive DH operations. This adds a token bucket rate limiter to the handshake manager that caps both inbound and outbound new handshakes per second. When the limit is reached, new handshakes are silently dropped and counted via the handshake_manager.rate_limited metric. Configured via handshakes.max_rate (default 0 = unlimited). Co-Authored-By: Claude --- examples/config.yml | 8 ++++ handshake_manager.go | 73 ++++++++++++++++++++++++++++++++---- handshake_manager_test.go | 79 +++++++++++++++++++++++++++++++++++++++ main.go | 9 +++-- ssh.go | 3 ++ 5 files changed, 160 insertions(+), 12 deletions(-) diff --git a/examples/config.yml b/examples/config.yml index 1f9dc2a4..34ed7d02 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -342,6 +342,14 @@ logging: # after receiving the response for lighthouse queries #trigger_buffer: 64 + # max_rate limits the number of new handshakes per second. Both incoming and outgoing new + # handshakes count against this limit. Once the limit is reached, new handshakes are dropped + # until the next second. A value of 0 means unlimited (default). + # This is useful for preventing DoS attacks that attempt to exhaust CPU with handshake crypto. + # Running `openssl speed ecdhp256` on your hardware can be a good rule of thumb for choosing + # a max, as each handshake performs similar DH operations. + #max_rate: 0 + # Tunnel manager settings #tunnels: # drop_inactive controls whether inactive tunnels are maintained or dropped after the inactive_timeout period has diff --git a/handshake_manager.go b/handshake_manager.go index 25a59b6d..f1ddc41f 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -23,22 +23,25 @@ const ( DefaultHandshakeRetries = 10 DefaultHandshakeTriggerBuffer = 64 DefaultUseRelays = true + DefaultMaxHandshakeRate = 0 // 0 means unlimited ) var ( defaultHandshakeConfig = HandshakeConfig{ - tryInterval: DefaultHandshakeTryInterval, - retries: DefaultHandshakeRetries, - triggerBuffer: DefaultHandshakeTriggerBuffer, - useRelays: DefaultUseRelays, + tryInterval: DefaultHandshakeTryInterval, + retries: DefaultHandshakeRetries, + triggerBuffer: DefaultHandshakeTriggerBuffer, + useRelays: DefaultUseRelays, + maxHandshakeRate: DefaultMaxHandshakeRate, } ) type HandshakeConfig struct { - tryInterval time.Duration - retries int64 - triggerBuffer int - useRelays bool + tryInterval time.Duration + retries int64 + triggerBuffer int + useRelays bool + maxHandshakeRate int messageMetrics *MessageMetrics } @@ -58,9 +61,15 @@ type HandshakeManager struct { messageMetrics *MessageMetrics metricInitiated metrics.Counter metricTimedOut metrics.Counter + metricRateLimited metrics.Counter f *Interface l *logrus.Logger + // Rate limiting for new handshakes (token bucket) + rateBucket int // tokens currently available + rateMax int // max tokens (== max handshakes per second), 0 means unlimited + rateLastTick time.Time + // can be used to trigger outbound handshake for the given vpnIp trigger chan netip.Addr } @@ -116,10 +125,41 @@ func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *Lig messageMetrics: config.messageMetrics, metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil), metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil), + metricRateLimited: metrics.GetOrRegisterCounter("handshake_manager.rate_limited", nil), + rateBucket: config.maxHandshakeRate, + rateMax: config.maxHandshakeRate, + rateLastTick: time.Now(), l: l, } } +// handshakeRateAllow checks the token bucket rate limiter and returns true if a +// new handshake is allowed. Must be called with hm.Lock held. +func (hm *HandshakeManager) handshakeRateAllow(now time.Time) bool { + if hm.rateMax == 0 { + return true + } + + // Refill tokens based on elapsed time + elapsed := now.Sub(hm.rateLastTick) + if elapsed >= time.Second { + // Add tokens for full seconds elapsed + tokens := int(elapsed/time.Second) * hm.rateMax + hm.rateBucket += tokens + if hm.rateBucket > hm.rateMax { + hm.rateBucket = hm.rateMax + } + hm.rateLastTick = now + } + + if hm.rateBucket > 0 { + hm.rateBucket-- + return true + } + + return false +} + func (hm *HandshakeManager) Run(ctx context.Context) { clockSource := time.NewTicker(hm.config.tryInterval) defer clockSource.Stop() @@ -149,6 +189,15 @@ func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *head case header.HandshakeIXPSK0: switch h.MessageCounter { case 1: + // Check rate limit for new incoming handshakes + hm.Lock() + allowed := hm.handshakeRateAllow(time.Now()) + hm.Unlock() + if !allowed { + hm.metricRateLimited.Inc(1) + hm.l.WithField("from", via).Debug("Handshake rate limit reached, dropping incoming handshake") + return + } ixHandshakeStage1(hm.f, via, packet, h) case 2: @@ -447,6 +496,14 @@ func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*Han return hh.hostinfo } + // Check rate limit for new outbound handshakes + if !hm.handshakeRateAllow(time.Now()) { + hm.metricRateLimited.Inc(1) + hm.l.WithField("vpnAddr", vpnAddr).Debug("Handshake rate limit reached, dropping outbound handshake") + hm.Unlock() + return nil + } + hostinfo := &HostInfo{ vpnAddrs: []netip.Addr{vpnAddr}, HandshakePacket: make(map[uint8][]byte, 0), diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 2e6d34b5..a735b582 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -65,6 +65,85 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { assert.NotContains(t, blah.vpnIps, ip) } +func Test_HandshakeManagerRateLimit(t *testing.T) { + l := test.NewLogger() + localrange := netip.MustParsePrefix("10.1.1.1/24") + preferredRanges := []netip.Prefix{localrange} + mainHM := newHostMap(l) + mainHM.preferredRanges.Store(&preferredRanges) + + lh := newTestLighthouse() + + cs := &CertState{ + initiatingVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1}, + v1HandshakeBytes: []byte{}, + } + + config := defaultHandshakeConfig + config.maxHandshakeRate = 2 + + hm := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, config) + hm.f = &Interface{handshakeManager: hm, pki: &PKI{}, l: l} + hm.f.pki.cs.Store(cs) + + // Should allow up to maxHandshakeRate handshakes + ip1 := netip.MustParseAddr("172.1.1.1") + ip2 := netip.MustParseAddr("172.1.1.2") + ip3 := netip.MustParseAddr("172.1.1.3") + + h1 := hm.StartHandshake(ip1, nil) + assert.NotNil(t, h1, "first handshake should be allowed") + + h2 := hm.StartHandshake(ip2, nil) + assert.NotNil(t, h2, "second handshake should be allowed") + + // Third should be rate limited + h3 := hm.StartHandshake(ip3, nil) + assert.Nil(t, h3, "third handshake should be rate limited") + + // After advancing time by 1 second, tokens should refill + hm.Lock() + hm.rateLastTick = hm.rateLastTick.Add(-time.Second) + hm.Unlock() + + h3 = hm.StartHandshake(ip3, nil) + assert.NotNil(t, h3, "handshake should be allowed after token refill") +} + +func Test_HandshakeManagerRateLimitUnlimited(t *testing.T) { + l := test.NewLogger() + localrange := netip.MustParsePrefix("10.1.1.1/24") + preferredRanges := []netip.Prefix{localrange} + mainHM := newHostMap(l) + mainHM.preferredRanges.Store(&preferredRanges) + + lh := newTestLighthouse() + + cs := &CertState{ + initiatingVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1}, + v1HandshakeBytes: []byte{}, + } + + // Default config has maxHandshakeRate=0 (unlimited) + hm := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig) + hm.f = &Interface{handshakeManager: hm, pki: &PKI{}, l: l} + hm.f.pki.cs.Store(cs) + + // Should allow many handshakes with no limit + // Limited to 10 due to test lighthouse query channel buffer + for i := 0; i < 10; i++ { + ip := netip.MustParseAddr("172.1.1.1").As16() + ip[15] = byte(i + 1) + addr := netip.AddrFrom16(ip) + h := hm.StartHandshake(addr, nil) + assert.NotNil(t, h, "handshake %d should be allowed with unlimited rate", i) + } +} + func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) { for _, i := range tw.t.wheel { n := i.Head diff --git a/main.go b/main.go index 74979417..f39593b1 100644 --- a/main.go +++ b/main.go @@ -204,10 +204,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg useRelays := c.GetBool("relay.use_relays", DefaultUseRelays) && !c.GetBool("relay.am_relay", false) handshakeConfig := HandshakeConfig{ - tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval), - retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)), - triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer), - useRelays: useRelays, + tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval), + retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)), + triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer), + useRelays: useRelays, + maxHandshakeRate: c.GetInt("handshakes.max_rate", DefaultMaxHandshakeRate), messageMetrics: messageMetrics, } diff --git a/ssh.go b/ssh.go index 0a9adb51..9b7b052c 100644 --- a/ssh.go +++ b/ssh.go @@ -632,6 +632,9 @@ func sshCreateTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) e } hostInfo = ifce.handshakeManager.StartHandshake(vpnAddr, nil) + if hostInfo == nil { + return w.WriteLine("Handshake rate limit reached") + } if addr.IsValid() { hostInfo.SetRemote(addr) }