Add handshakes.max_rate to limit new handshakes per second

Nebula is vulnerable to DoS via handshake flooding since each incoming
  handshake performs expensive DH operations. This adds a token bucket
  rate limiter to the handshake manager that caps both inbound and
  outbound new handshakes per second. When the limit is reached, new
  handshakes are silently dropped and counted via the
  handshake_manager.rate_limited metric.

  Configured via handshakes.max_rate (default 0 = unlimited).

  Co-Authored-By: Claude <svc-devxp-claude@slack-corp.com>
This commit is contained in:
Jay Wren
2026-03-30 14:48:04 -04:00
parent 91d1f4675a
commit 3df60ae195
5 changed files with 160 additions and 12 deletions

View File

@@ -23,22 +23,25 @@ const (
DefaultHandshakeRetries = 10
DefaultHandshakeTriggerBuffer = 64
DefaultUseRelays = true
DefaultMaxHandshakeRate = 0 // 0 means unlimited
)
var (
defaultHandshakeConfig = HandshakeConfig{
tryInterval: DefaultHandshakeTryInterval,
retries: DefaultHandshakeRetries,
triggerBuffer: DefaultHandshakeTriggerBuffer,
useRelays: DefaultUseRelays,
tryInterval: DefaultHandshakeTryInterval,
retries: DefaultHandshakeRetries,
triggerBuffer: DefaultHandshakeTriggerBuffer,
useRelays: DefaultUseRelays,
maxHandshakeRate: DefaultMaxHandshakeRate,
}
)
type HandshakeConfig struct {
tryInterval time.Duration
retries int64
triggerBuffer int
useRelays bool
tryInterval time.Duration
retries int64
triggerBuffer int
useRelays bool
maxHandshakeRate int
messageMetrics *MessageMetrics
}
@@ -58,9 +61,15 @@ type HandshakeManager struct {
messageMetrics *MessageMetrics
metricInitiated metrics.Counter
metricTimedOut metrics.Counter
metricRateLimited metrics.Counter
f *Interface
l *logrus.Logger
// Rate limiting for new handshakes (token bucket)
rateBucket int // tokens currently available
rateMax int // max tokens (== max handshakes per second), 0 means unlimited
rateLastTick time.Time
// can be used to trigger outbound handshake for the given vpnIp
trigger chan netip.Addr
}
@@ -116,10 +125,41 @@ func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *Lig
messageMetrics: config.messageMetrics,
metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil),
metricRateLimited: metrics.GetOrRegisterCounter("handshake_manager.rate_limited", nil),
rateBucket: config.maxHandshakeRate,
rateMax: config.maxHandshakeRate,
rateLastTick: time.Now(),
l: l,
}
}
// handshakeRateAllow checks the token bucket rate limiter and returns true if a
// new handshake is allowed. Must be called with hm.Lock held.
func (hm *HandshakeManager) handshakeRateAllow(now time.Time) bool {
if hm.rateMax == 0 {
return true
}
// Refill tokens based on elapsed time
elapsed := now.Sub(hm.rateLastTick)
if elapsed >= time.Second {
// Add tokens for full seconds elapsed
tokens := int(elapsed/time.Second) * hm.rateMax
hm.rateBucket += tokens
if hm.rateBucket > hm.rateMax {
hm.rateBucket = hm.rateMax
}
hm.rateLastTick = now
}
if hm.rateBucket > 0 {
hm.rateBucket--
return true
}
return false
}
func (hm *HandshakeManager) Run(ctx context.Context) {
clockSource := time.NewTicker(hm.config.tryInterval)
defer clockSource.Stop()
@@ -149,6 +189,15 @@ func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *head
case header.HandshakeIXPSK0:
switch h.MessageCounter {
case 1:
// Check rate limit for new incoming handshakes
hm.Lock()
allowed := hm.handshakeRateAllow(time.Now())
hm.Unlock()
if !allowed {
hm.metricRateLimited.Inc(1)
hm.l.WithField("from", via).Debug("Handshake rate limit reached, dropping incoming handshake")
return
}
ixHandshakeStage1(hm.f, via, packet, h)
case 2:
@@ -447,6 +496,14 @@ func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*Han
return hh.hostinfo
}
// Check rate limit for new outbound handshakes
if !hm.handshakeRateAllow(time.Now()) {
hm.metricRateLimited.Inc(1)
hm.l.WithField("vpnAddr", vpnAddr).Debug("Handshake rate limit reached, dropping outbound handshake")
hm.Unlock()
return nil
}
hostinfo := &HostInfo{
vpnAddrs: []netip.Addr{vpnAddr},
HandshakePacket: make(map[uint8][]byte, 0),