diff --git a/examples/config.yml b/examples/config.yml index 1f9dc2a4..34ed7d02 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -342,6 +342,14 @@ logging: # after receiving the response for lighthouse queries #trigger_buffer: 64 + # max_rate limits the number of new handshakes per second. Both incoming and outgoing new + # handshakes count against this limit. Once the limit is reached, new handshakes are dropped + # until the next second. A value of 0 means unlimited (default). + # This is useful for preventing DoS attacks that attempt to exhaust CPU with handshake crypto. + # Running `openssl speed ecdhp256` on your hardware can be a good rule of thumb for choosing + # a max, as each handshake performs similar DH operations. + #max_rate: 0 + # Tunnel manager settings #tunnels: # drop_inactive controls whether inactive tunnels are maintained or dropped after the inactive_timeout period has diff --git a/handshake_manager.go b/handshake_manager.go index 25a59b6d..f1ddc41f 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -23,22 +23,25 @@ const ( DefaultHandshakeRetries = 10 DefaultHandshakeTriggerBuffer = 64 DefaultUseRelays = true + DefaultMaxHandshakeRate = 0 // 0 means unlimited ) var ( defaultHandshakeConfig = HandshakeConfig{ - tryInterval: DefaultHandshakeTryInterval, - retries: DefaultHandshakeRetries, - triggerBuffer: DefaultHandshakeTriggerBuffer, - useRelays: DefaultUseRelays, + tryInterval: DefaultHandshakeTryInterval, + retries: DefaultHandshakeRetries, + triggerBuffer: DefaultHandshakeTriggerBuffer, + useRelays: DefaultUseRelays, + maxHandshakeRate: DefaultMaxHandshakeRate, } ) type HandshakeConfig struct { - tryInterval time.Duration - retries int64 - triggerBuffer int - useRelays bool + tryInterval time.Duration + retries int64 + triggerBuffer int + useRelays bool + maxHandshakeRate int messageMetrics *MessageMetrics } @@ -58,9 +61,15 @@ type HandshakeManager struct { messageMetrics *MessageMetrics metricInitiated metrics.Counter metricTimedOut metrics.Counter + metricRateLimited metrics.Counter f *Interface l *logrus.Logger + // Rate limiting for new handshakes (token bucket) + rateBucket int // tokens currently available + rateMax int // max tokens (== max handshakes per second), 0 means unlimited + rateLastTick time.Time + // can be used to trigger outbound handshake for the given vpnIp trigger chan netip.Addr } @@ -116,10 +125,41 @@ func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *Lig messageMetrics: config.messageMetrics, metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil), metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil), + metricRateLimited: metrics.GetOrRegisterCounter("handshake_manager.rate_limited", nil), + rateBucket: config.maxHandshakeRate, + rateMax: config.maxHandshakeRate, + rateLastTick: time.Now(), l: l, } } +// handshakeRateAllow checks the token bucket rate limiter and returns true if a +// new handshake is allowed. Must be called with hm.Lock held. +func (hm *HandshakeManager) handshakeRateAllow(now time.Time) bool { + if hm.rateMax == 0 { + return true + } + + // Refill tokens based on elapsed time + elapsed := now.Sub(hm.rateLastTick) + if elapsed >= time.Second { + // Add tokens for full seconds elapsed + tokens := int(elapsed/time.Second) * hm.rateMax + hm.rateBucket += tokens + if hm.rateBucket > hm.rateMax { + hm.rateBucket = hm.rateMax + } + hm.rateLastTick = now + } + + if hm.rateBucket > 0 { + hm.rateBucket-- + return true + } + + return false +} + func (hm *HandshakeManager) Run(ctx context.Context) { clockSource := time.NewTicker(hm.config.tryInterval) defer clockSource.Stop() @@ -149,6 +189,15 @@ func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *head case header.HandshakeIXPSK0: switch h.MessageCounter { case 1: + // Check rate limit for new incoming handshakes + hm.Lock() + allowed := hm.handshakeRateAllow(time.Now()) + hm.Unlock() + if !allowed { + hm.metricRateLimited.Inc(1) + hm.l.WithField("from", via).Debug("Handshake rate limit reached, dropping incoming handshake") + return + } ixHandshakeStage1(hm.f, via, packet, h) case 2: @@ -447,6 +496,14 @@ func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*Han return hh.hostinfo } + // Check rate limit for new outbound handshakes + if !hm.handshakeRateAllow(time.Now()) { + hm.metricRateLimited.Inc(1) + hm.l.WithField("vpnAddr", vpnAddr).Debug("Handshake rate limit reached, dropping outbound handshake") + hm.Unlock() + return nil + } + hostinfo := &HostInfo{ vpnAddrs: []netip.Addr{vpnAddr}, HandshakePacket: make(map[uint8][]byte, 0), diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 2e6d34b5..a735b582 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -65,6 +65,85 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { assert.NotContains(t, blah.vpnIps, ip) } +func Test_HandshakeManagerRateLimit(t *testing.T) { + l := test.NewLogger() + localrange := netip.MustParsePrefix("10.1.1.1/24") + preferredRanges := []netip.Prefix{localrange} + mainHM := newHostMap(l) + mainHM.preferredRanges.Store(&preferredRanges) + + lh := newTestLighthouse() + + cs := &CertState{ + initiatingVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1}, + v1HandshakeBytes: []byte{}, + } + + config := defaultHandshakeConfig + config.maxHandshakeRate = 2 + + hm := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, config) + hm.f = &Interface{handshakeManager: hm, pki: &PKI{}, l: l} + hm.f.pki.cs.Store(cs) + + // Should allow up to maxHandshakeRate handshakes + ip1 := netip.MustParseAddr("172.1.1.1") + ip2 := netip.MustParseAddr("172.1.1.2") + ip3 := netip.MustParseAddr("172.1.1.3") + + h1 := hm.StartHandshake(ip1, nil) + assert.NotNil(t, h1, "first handshake should be allowed") + + h2 := hm.StartHandshake(ip2, nil) + assert.NotNil(t, h2, "second handshake should be allowed") + + // Third should be rate limited + h3 := hm.StartHandshake(ip3, nil) + assert.Nil(t, h3, "third handshake should be rate limited") + + // After advancing time by 1 second, tokens should refill + hm.Lock() + hm.rateLastTick = hm.rateLastTick.Add(-time.Second) + hm.Unlock() + + h3 = hm.StartHandshake(ip3, nil) + assert.NotNil(t, h3, "handshake should be allowed after token refill") +} + +func Test_HandshakeManagerRateLimitUnlimited(t *testing.T) { + l := test.NewLogger() + localrange := netip.MustParsePrefix("10.1.1.1/24") + preferredRanges := []netip.Prefix{localrange} + mainHM := newHostMap(l) + mainHM.preferredRanges.Store(&preferredRanges) + + lh := newTestLighthouse() + + cs := &CertState{ + initiatingVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1}, + v1HandshakeBytes: []byte{}, + } + + // Default config has maxHandshakeRate=0 (unlimited) + hm := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig) + hm.f = &Interface{handshakeManager: hm, pki: &PKI{}, l: l} + hm.f.pki.cs.Store(cs) + + // Should allow many handshakes with no limit + // Limited to 10 due to test lighthouse query channel buffer + for i := 0; i < 10; i++ { + ip := netip.MustParseAddr("172.1.1.1").As16() + ip[15] = byte(i + 1) + addr := netip.AddrFrom16(ip) + h := hm.StartHandshake(addr, nil) + assert.NotNil(t, h, "handshake %d should be allowed with unlimited rate", i) + } +} + func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) { for _, i := range tw.t.wheel { n := i.Head diff --git a/main.go b/main.go index 74979417..f39593b1 100644 --- a/main.go +++ b/main.go @@ -204,10 +204,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg useRelays := c.GetBool("relay.use_relays", DefaultUseRelays) && !c.GetBool("relay.am_relay", false) handshakeConfig := HandshakeConfig{ - tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval), - retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)), - triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer), - useRelays: useRelays, + tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval), + retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)), + triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer), + useRelays: useRelays, + maxHandshakeRate: c.GetInt("handshakes.max_rate", DefaultMaxHandshakeRate), messageMetrics: messageMetrics, } diff --git a/ssh.go b/ssh.go index 0a9adb51..9b7b052c 100644 --- a/ssh.go +++ b/ssh.go @@ -632,6 +632,9 @@ func sshCreateTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) e } hostInfo = ifce.handshakeManager.StartHandshake(vpnAddr, nil) + if hostInfo == nil { + return w.WriteLine("Handshake rate limit reached") + } if addr.IsValid() { hostInfo.SetRemote(addr) }