mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-16 04:47:38 +02:00
Add handshakes.max_rate to limit new handshakes per second
Nebula is vulnerable to DoS via handshake flooding since each incoming handshake performs expensive DH operations. This adds a token bucket rate limiter to the handshake manager that caps both inbound and outbound new handshakes per second. When the limit is reached, new handshakes are silently dropped and counted via the handshake_manager.rate_limited metric. Configured via handshakes.max_rate (default 0 = unlimited). Co-Authored-By: Claude <svc-devxp-claude@slack-corp.com>
This commit is contained in:
@@ -342,6 +342,14 @@ logging:
|
||||
# after receiving the response for lighthouse queries
|
||||
#trigger_buffer: 64
|
||||
|
||||
# max_rate limits the number of new handshakes per second. Both incoming and outgoing new
|
||||
# handshakes count against this limit. Once the limit is reached, new handshakes are dropped
|
||||
# until the next second. A value of 0 means unlimited (default).
|
||||
# This is useful for preventing DoS attacks that attempt to exhaust CPU with handshake crypto.
|
||||
# Running `openssl speed ecdhp256` on your hardware can be a good rule of thumb for choosing
|
||||
# a max, as each handshake performs similar DH operations.
|
||||
#max_rate: 0
|
||||
|
||||
# Tunnel manager settings
|
||||
#tunnels:
|
||||
# drop_inactive controls whether inactive tunnels are maintained or dropped after the inactive_timeout period has
|
||||
|
||||
@@ -23,6 +23,7 @@ const (
|
||||
DefaultHandshakeRetries = 10
|
||||
DefaultHandshakeTriggerBuffer = 64
|
||||
DefaultUseRelays = true
|
||||
DefaultMaxHandshakeRate = 0 // 0 means unlimited
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -31,6 +32,7 @@ var (
|
||||
retries: DefaultHandshakeRetries,
|
||||
triggerBuffer: DefaultHandshakeTriggerBuffer,
|
||||
useRelays: DefaultUseRelays,
|
||||
maxHandshakeRate: DefaultMaxHandshakeRate,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -39,6 +41,7 @@ type HandshakeConfig struct {
|
||||
retries int64
|
||||
triggerBuffer int
|
||||
useRelays bool
|
||||
maxHandshakeRate int
|
||||
|
||||
messageMetrics *MessageMetrics
|
||||
}
|
||||
@@ -58,9 +61,15 @@ type HandshakeManager struct {
|
||||
messageMetrics *MessageMetrics
|
||||
metricInitiated metrics.Counter
|
||||
metricTimedOut metrics.Counter
|
||||
metricRateLimited metrics.Counter
|
||||
f *Interface
|
||||
l *logrus.Logger
|
||||
|
||||
// Rate limiting for new handshakes (token bucket)
|
||||
rateBucket int // tokens currently available
|
||||
rateMax int // max tokens (== max handshakes per second), 0 means unlimited
|
||||
rateLastTick time.Time
|
||||
|
||||
// can be used to trigger outbound handshake for the given vpnIp
|
||||
trigger chan netip.Addr
|
||||
}
|
||||
@@ -116,10 +125,41 @@ func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *Lig
|
||||
messageMetrics: config.messageMetrics,
|
||||
metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
|
||||
metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil),
|
||||
metricRateLimited: metrics.GetOrRegisterCounter("handshake_manager.rate_limited", nil),
|
||||
rateBucket: config.maxHandshakeRate,
|
||||
rateMax: config.maxHandshakeRate,
|
||||
rateLastTick: time.Now(),
|
||||
l: l,
|
||||
}
|
||||
}
|
||||
|
||||
// handshakeRateAllow checks the token bucket rate limiter and returns true if a
|
||||
// new handshake is allowed. Must be called with hm.Lock held.
|
||||
func (hm *HandshakeManager) handshakeRateAllow(now time.Time) bool {
|
||||
if hm.rateMax == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Refill tokens based on elapsed time
|
||||
elapsed := now.Sub(hm.rateLastTick)
|
||||
if elapsed >= time.Second {
|
||||
// Add tokens for full seconds elapsed
|
||||
tokens := int(elapsed/time.Second) * hm.rateMax
|
||||
hm.rateBucket += tokens
|
||||
if hm.rateBucket > hm.rateMax {
|
||||
hm.rateBucket = hm.rateMax
|
||||
}
|
||||
hm.rateLastTick = now
|
||||
}
|
||||
|
||||
if hm.rateBucket > 0 {
|
||||
hm.rateBucket--
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (hm *HandshakeManager) Run(ctx context.Context) {
|
||||
clockSource := time.NewTicker(hm.config.tryInterval)
|
||||
defer clockSource.Stop()
|
||||
@@ -149,6 +189,15 @@ func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *head
|
||||
case header.HandshakeIXPSK0:
|
||||
switch h.MessageCounter {
|
||||
case 1:
|
||||
// Check rate limit for new incoming handshakes
|
||||
hm.Lock()
|
||||
allowed := hm.handshakeRateAllow(time.Now())
|
||||
hm.Unlock()
|
||||
if !allowed {
|
||||
hm.metricRateLimited.Inc(1)
|
||||
hm.l.WithField("from", via).Debug("Handshake rate limit reached, dropping incoming handshake")
|
||||
return
|
||||
}
|
||||
ixHandshakeStage1(hm.f, via, packet, h)
|
||||
|
||||
case 2:
|
||||
@@ -447,6 +496,14 @@ func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*Han
|
||||
return hh.hostinfo
|
||||
}
|
||||
|
||||
// Check rate limit for new outbound handshakes
|
||||
if !hm.handshakeRateAllow(time.Now()) {
|
||||
hm.metricRateLimited.Inc(1)
|
||||
hm.l.WithField("vpnAddr", vpnAddr).Debug("Handshake rate limit reached, dropping outbound handshake")
|
||||
hm.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
hostinfo := &HostInfo{
|
||||
vpnAddrs: []netip.Addr{vpnAddr},
|
||||
HandshakePacket: make(map[uint8][]byte, 0),
|
||||
|
||||
@@ -65,6 +65,85 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
||||
assert.NotContains(t, blah.vpnIps, ip)
|
||||
}
|
||||
|
||||
func Test_HandshakeManagerRateLimit(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
localrange := netip.MustParsePrefix("10.1.1.1/24")
|
||||
preferredRanges := []netip.Prefix{localrange}
|
||||
mainHM := newHostMap(l)
|
||||
mainHM.preferredRanges.Store(&preferredRanges)
|
||||
|
||||
lh := newTestLighthouse()
|
||||
|
||||
cs := &CertState{
|
||||
initiatingVersion: cert.Version1,
|
||||
privateKey: []byte{},
|
||||
v1Cert: &dummyCert{version: cert.Version1},
|
||||
v1HandshakeBytes: []byte{},
|
||||
}
|
||||
|
||||
config := defaultHandshakeConfig
|
||||
config.maxHandshakeRate = 2
|
||||
|
||||
hm := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, config)
|
||||
hm.f = &Interface{handshakeManager: hm, pki: &PKI{}, l: l}
|
||||
hm.f.pki.cs.Store(cs)
|
||||
|
||||
// Should allow up to maxHandshakeRate handshakes
|
||||
ip1 := netip.MustParseAddr("172.1.1.1")
|
||||
ip2 := netip.MustParseAddr("172.1.1.2")
|
||||
ip3 := netip.MustParseAddr("172.1.1.3")
|
||||
|
||||
h1 := hm.StartHandshake(ip1, nil)
|
||||
assert.NotNil(t, h1, "first handshake should be allowed")
|
||||
|
||||
h2 := hm.StartHandshake(ip2, nil)
|
||||
assert.NotNil(t, h2, "second handshake should be allowed")
|
||||
|
||||
// Third should be rate limited
|
||||
h3 := hm.StartHandshake(ip3, nil)
|
||||
assert.Nil(t, h3, "third handshake should be rate limited")
|
||||
|
||||
// After advancing time by 1 second, tokens should refill
|
||||
hm.Lock()
|
||||
hm.rateLastTick = hm.rateLastTick.Add(-time.Second)
|
||||
hm.Unlock()
|
||||
|
||||
h3 = hm.StartHandshake(ip3, nil)
|
||||
assert.NotNil(t, h3, "handshake should be allowed after token refill")
|
||||
}
|
||||
|
||||
func Test_HandshakeManagerRateLimitUnlimited(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
localrange := netip.MustParsePrefix("10.1.1.1/24")
|
||||
preferredRanges := []netip.Prefix{localrange}
|
||||
mainHM := newHostMap(l)
|
||||
mainHM.preferredRanges.Store(&preferredRanges)
|
||||
|
||||
lh := newTestLighthouse()
|
||||
|
||||
cs := &CertState{
|
||||
initiatingVersion: cert.Version1,
|
||||
privateKey: []byte{},
|
||||
v1Cert: &dummyCert{version: cert.Version1},
|
||||
v1HandshakeBytes: []byte{},
|
||||
}
|
||||
|
||||
// Default config has maxHandshakeRate=0 (unlimited)
|
||||
hm := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
|
||||
hm.f = &Interface{handshakeManager: hm, pki: &PKI{}, l: l}
|
||||
hm.f.pki.cs.Store(cs)
|
||||
|
||||
// Should allow many handshakes with no limit
|
||||
// Limited to 10 due to test lighthouse query channel buffer
|
||||
for i := 0; i < 10; i++ {
|
||||
ip := netip.MustParseAddr("172.1.1.1").As16()
|
||||
ip[15] = byte(i + 1)
|
||||
addr := netip.AddrFrom16(ip)
|
||||
h := hm.StartHandshake(addr, nil)
|
||||
assert.NotNil(t, h, "handshake %d should be allowed with unlimited rate", i)
|
||||
}
|
||||
}
|
||||
|
||||
func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) {
|
||||
for _, i := range tw.t.wheel {
|
||||
n := i.Head
|
||||
|
||||
1
main.go
1
main.go
@@ -208,6 +208,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)),
|
||||
triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
|
||||
useRelays: useRelays,
|
||||
maxHandshakeRate: c.GetInt("handshakes.max_rate", DefaultMaxHandshakeRate),
|
||||
|
||||
messageMetrics: messageMetrics,
|
||||
}
|
||||
|
||||
3
ssh.go
3
ssh.go
@@ -632,6 +632,9 @@ func sshCreateTunnel(ifce *Interface, fs any, a []string, w sshd.StringWriter) e
|
||||
}
|
||||
|
||||
hostInfo = ifce.handshakeManager.StartHandshake(vpnAddr, nil)
|
||||
if hostInfo == nil {
|
||||
return w.WriteLine("Handshake rate limit reached")
|
||||
}
|
||||
if addr.IsValid() {
|
||||
hostInfo.SetRemote(addr)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user