diff --git a/firewall.go b/firewall.go index adecbe81..948dadcc 100644 --- a/firewall.go +++ b/firewall.go @@ -80,8 +80,8 @@ type firewallMetrics struct { type FirewallConntrack struct { sync.Mutex - Conns map[firewall.Packet]*conn - TimerWheel *TimerWheel[firewall.Packet] + Conns map[firewall.PacketKey]*conn + TimerWheel *TimerWheel[firewall.PacketKey] } // FirewallTable is the entry point for a rule, the evaluation order is: @@ -166,8 +166,8 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur return &Firewall{ Conntrack: &FirewallConntrack{ - Conns: make(map[firewall.Packet]*conn), - TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax), + Conns: make(map[firewall.PacketKey]*conn), + TimerWheel: NewTimerWheel[firewall.PacketKey](tmin, tmax), }, InRules: newFirewallTable(), OutRules: newFirewallTable(), @@ -422,12 +422,27 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table") // Drop returns an error if the packet should be dropped, explaining why. It // returns nil if the packet should not be dropped. -func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error { - // Check if we spoke to this tuple, if we did then allow this packet - if f.inConns(fp, h, caPool, localCache) { +// +// key is the dense conntrack key — used as-is for the inConns fast path +// without touching fp at all. fp is the rich Packet form rule matching +// needs (CIDR lookups, family checks); on the conntrack-miss slow path +// Drop ensures fp is hydrated from key (idempotent if the caller already +// filled fp). On accept-via-conntrack the caller's fp is left untouched. +func (f *Firewall) Drop(key firewall.PacketKey, fp *firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error { + // Check if we spoke to this tuple, if we did then allow this packet. + // Hot path: only the dense key is touched. + if f.inConns(key, h, caPool, localCache) { return nil } + // Conntrack miss → rule matching needs the rich Packet form. Hydrate + // from the key if the caller passed a zero-valued fp (the inbound path + // after ParseInbound). Outbound callers fill fp via newPacket and skip + // this hop. + if !fp.LocalAddr.IsValid() { + key.Hydrate(fp) + } + // Make sure remote address matches nebula certificate, and determine how to treat it if h.networks == nil { // Simple case: Certificate has one address and no unsafe networks @@ -467,13 +482,13 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * } // We now know which firewall table to check against - if !table.match(fp, incoming, h.ConnectionState.peerCert, caPool) { + if !table.match(*fp, incoming, h.ConnectionState.peerCert, caPool) { f.metrics(incoming).droppedNoRule.Inc(1) return ErrNoMatchingRule } // We always want to conntrack since it is a faster operation - f.addConn(fp, incoming) + f.addConn(key, fp.Protocol, incoming) return nil } @@ -502,9 +517,9 @@ func (f *Firewall) EmitStats() { metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV())) } -func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool { +func (f *Firewall) inConns(key firewall.PacketKey, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool { if localCache != nil { - if _, ok := localCache[fp]; ok { + if _, ok := localCache[key]; ok { return true } } @@ -517,7 +532,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, f.evict(ep) } - c, ok := conntrack.Conns[fp] + c, ok := conntrack.Conns[key] if !ok { conntrack.Unlock() @@ -526,7 +541,11 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, if c.rulesVersion != f.rulesVersion { // This conntrack entry was for an older rule set, validate - // it still passes with the current rule set + // it still passes with the current rule set. Rule matching needs + // the rich Packet form, so hydrate from key. + var fp firewall.Packet + key.Hydrate(&fp) + table := f.OutRules if c.incoming { table = f.InRules @@ -542,7 +561,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, "oldRulesVersion", c.rulesVersion, ) } - delete(conntrack.Conns, fp) + delete(conntrack.Conns, key) conntrack.Unlock() return false } @@ -559,7 +578,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, c.rulesVersion = f.rulesVersion } - switch fp.Protocol { + switch key.Protocol { case firewall.ProtoTCP: c.Expires = time.Now().Add(f.TCPTimeout) case firewall.ProtoUDP: @@ -571,17 +590,17 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, conntrack.Unlock() if localCache != nil { - localCache[fp] = struct{}{} + localCache[key] = struct{}{} } return true } -func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { +func (f *Firewall) addConn(key firewall.PacketKey, protocol uint8, incoming bool) { var timeout time.Duration c := &conn{} - switch fp.Protocol { + switch protocol { case firewall.ProtoTCP: timeout = f.TCPTimeout case firewall.ProtoUDP: @@ -592,9 +611,9 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { conntrack := f.Conntrack conntrack.Lock() - if _, ok := conntrack.Conns[fp]; !ok { + if _, ok := conntrack.Conns[key]; !ok { conntrack.TimerWheel.Advance(time.Now()) - conntrack.TimerWheel.Add(fp, timeout) + conntrack.TimerWheel.Add(key, timeout) } // Record which rulesVersion allowed this connection, so we can retest after @@ -602,16 +621,16 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) { c.incoming = incoming c.rulesVersion = f.rulesVersion c.Expires = time.Now().Add(timeout) - conntrack.Conns[fp] = c + conntrack.Conns[key] = c conntrack.Unlock() } // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel // Caller must own the connMutex lock! -func (f *Firewall) evict(p firewall.Packet) { +func (f *Firewall) evict(key firewall.PacketKey) { // Are we still tracking this conn? conntrack := f.Conntrack - t, ok := conntrack.Conns[p] + t, ok := conntrack.Conns[key] if !ok { return } @@ -621,12 +640,12 @@ func (f *Firewall) evict(p firewall.Packet) { // Timeout is in the future, re-add the timer if newT > 0 { conntrack.TimerWheel.Advance(time.Now()) - conntrack.TimerWheel.Add(p, newT) + conntrack.TimerWheel.Add(key, newT) return } // This conn is done - delete(conntrack.Conns, p) + delete(conntrack.Conns, key) } func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedCertificate, caPool *cert.CAPool) bool { diff --git a/firewall/cache.go b/firewall/cache.go index 3e34e6ea..87546df8 100644 --- a/firewall/cache.go +++ b/firewall/cache.go @@ -10,8 +10,10 @@ import ( ) // ConntrackCache is used as a local routine cache to know if a given flow -// has been seen in the conntrack table. -type ConntrackCache map[Packet]struct{} +// has been seen in the conntrack table. Keyed on PacketKey (dense form) +// rather than Packet so the lookup hashes raw bytes instead of the +// unique.Handle each netip.Addr in Packet carries. +type ConntrackCache map[PacketKey]struct{} type ConntrackCacheTicker struct { cacheV uint64 diff --git a/firewall/cache_test.go b/firewall/cache_test.go index 3baf2326..5522de29 100644 --- a/firewall/cache_test.go +++ b/firewall/cache_test.go @@ -23,7 +23,7 @@ func newFixedTicker(t *testing.T, l *slog.Logger, cacheLen int) *ConntrackCacheT cache: make(ConntrackCache, cacheLen), } for i := 0; i < cacheLen; i++ { - c.cache[Packet{LocalPort: uint16(i) + 1}] = struct{}{} + c.cache[PacketKey{TransportTuple: TransportTuple{LocalPort: uint16(i) + 1}}] = struct{}{} } c.cacheTick.Store(1) // cacheV starts at 0, so Get() takes the reset path return c diff --git a/firewall/packet.go b/firewall/packet.go index ea7162fe..355d0f05 100644 --- a/firewall/packet.go +++ b/firewall/packet.go @@ -19,14 +19,34 @@ const ( PortFragment = -1 // Special value for matching `port: fragment` ) +// TransportTuple is the dense 5-tuple shape shared between the coalescer's +// flowKey-equivalent and the firewall's PacketKey. Stored in Local/Remote +// orientation so a flow's incoming and outgoing packets share the same +// tuple identity. v4 addresses occupy the low 4 bytes of LocalAddr/ +// RemoteAddr (NOT v4-mapped form) so v4 vs v6 tuples never collide. type TransportTuple struct { - FirstAddr [16]byte - SecondAddr [16]byte - FirstPort uint16 - SecondPort uint16 + LocalAddr [16]byte + RemoteAddr [16]byte + LocalPort uint16 + RemotePort uint16 IsV6 bool } +// PacketKey is the firewall's conntrack and ConntrackCache map key — the +// dense form of the 5-tuple plus the protocol and fragment flag the +// firewall actually discriminates flows on. Kept separate from Packet so +// the conntrack-hit fast path doesn't pay for hashing the unique.Handle +// each netip.Addr carries, and so the inbound parser can skip the +// AddrFrom4/AddrFrom16 calls until rule matching actually needs them. +// +// Superset of the coalescer's flowKey shape (same 5-tuple, just in +// Local/Remote orientation rather than wire src/dst). +type PacketKey struct { + TransportTuple + Protocol uint8 + Fragment bool +} + type Packet struct { LocalAddr netip.Addr RemoteAddr netip.Addr @@ -39,6 +59,51 @@ type Packet struct { Fragment bool } +// Key derives a PacketKey from a populated Packet. Used by the outgoing +// path (inside.go) which still parses into a full Packet via newPacket +// before the firewall check; the inbound path skips this hop entirely by +// having its parser write straight into the PacketKey. +func (fp *Packet) Key() PacketKey { + k := PacketKey{ + Protocol: fp.Protocol, + Fragment: fp.Fragment, + } + k.LocalPort = fp.LocalPort + k.RemotePort = fp.RemotePort + k.IsV6 = !fp.LocalAddr.Is4() + if k.IsV6 { + k.LocalAddr = fp.LocalAddr.As16() + k.RemoteAddr = fp.RemoteAddr.As16() + } else { + v4 := fp.LocalAddr.As4() + copy(k.LocalAddr[:4], v4[:]) + v4 = fp.RemoteAddr.As4() + copy(k.RemoteAddr[:4], v4[:]) + } + return k +} + +// Hydrate fills fp's netip.Addr fields and copies the rest from k. Called +// by the firewall slow path when conntrack misses and rule matching needs +// the rich Packet form (CIDR lookups, family checks). The fast path skips +// this entirely. +func (k *PacketKey) Hydrate(fp *Packet) { + fp.LocalPort = k.LocalPort + fp.RemotePort = k.RemotePort + fp.Protocol = k.Protocol + fp.Fragment = k.Fragment + if k.IsV6 { + fp.LocalAddr = netip.AddrFrom16(k.LocalAddr) + fp.RemoteAddr = netip.AddrFrom16(k.RemoteAddr) + } else { + var v4 [4]byte + copy(v4[:], k.LocalAddr[:4]) + fp.LocalAddr = netip.AddrFrom4(v4) + copy(v4[:], k.RemoteAddr[:4]) + fp.RemoteAddr = netip.AddrFrom4(v4) + } +} + func (fp *Packet) Copy() *Packet { return &Packet{ LocalAddr: fp.LocalAddr, diff --git a/firewall_test.go b/firewall_test.go index 40b57477..1917dc7c 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -211,44 +211,44 @@ func TestFirewall_Drop(t *testing.T) { cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil)) + assert.Equal(t, ErrNoMatchingRule, fw.Drop(p.Key(), &p, false, &h, cp, nil)) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil)) // Allow outbound because conntrack - require.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p.Key(), &p, false, &h, cp, nil)) // test remote mismatch oldRemote := p.RemoteAddr p.RemoteAddr = netip.MustParseAddr("1.2.3.10") - assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) + assert.Equal(t, fw.Drop(p.Key(), &p, false, &h, cp, nil), ErrInvalidRemoteIP) p.RemoteAddr = oldRemote // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) - assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p.Key(), &p, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) - assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p.Key(), &p, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil)) } func TestFirewall_DropV6(t *testing.T) { @@ -289,44 +289,44 @@ func TestFirewall_DropV6(t *testing.T) { cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil)) + assert.Equal(t, ErrNoMatchingRule, fw.Drop(p.Key(), &p, false, &h, cp, nil)) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil)) // Allow outbound because conntrack - require.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p.Key(), &p, false, &h, cp, nil)) // test remote mismatch oldRemote := p.RemoteAddr p.RemoteAddr = netip.MustParseAddr("fd12::56") - assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) + assert.Equal(t, fw.Drop(p.Key(), &p, false, &h, cp, nil), ErrInvalidRemoteIP) p.RemoteAddr = oldRemote // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad")) - assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p.Key(), &p, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum")) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", "")) - assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p.Key(), &p, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", "")) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil)) } func BenchmarkFirewallTable_match(b *testing.B) { @@ -533,10 +533,10 @@ func TestFirewall_Drop2(t *testing.T) { cp := cert.NewCAPool() // h1/c1 lacks the proper groups - require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule) + require.ErrorIs(t, fw.Drop(p.Key(), &p, true, &h1, cp, nil), ErrNoMatchingRule) // c has the proper groups resetConntrack(fw) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil)) } func TestFirewall_Drop3(t *testing.T) { @@ -613,18 +613,18 @@ func TestFirewall_Drop3(t *testing.T) { cp := cert.NewCAPool() // c1 should pass because host match - require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) + require.NoError(t, fw.Drop(p.Key(), &p, true, &h1, cp, nil)) // c2 should pass because ca sha match resetConntrack(fw) - require.NoError(t, fw.Drop(p, true, &h2, cp, nil)) + require.NoError(t, fw.Drop(p.Key(), &p, true, &h2, cp, nil)) // c3 should fail because no match resetConntrack(fw) - assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p.Key(), &p, true, &h3, cp, nil), ErrNoMatchingRule) // Test a remote address match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", "")) - require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) + require.NoError(t, fw.Drop(p.Key(), &p, true, &h1, cp, nil)) } func TestFirewall_Drop3V6(t *testing.T) { @@ -661,7 +661,7 @@ func TestFirewall_Drop3V6(t *testing.T) { fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) cp := cert.NewCAPool() require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", "")) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil)) } func TestFirewall_DropConntrackReload(t *testing.T) { @@ -702,12 +702,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) { cp := cert.NewCAPool() // Drop outbound - assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p.Key(), &p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil)) // Allow outbound because conntrack - require.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p.Key(), &p, false, &h, cp, nil)) oldFw := fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) @@ -716,7 +716,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { fw.rulesVersion = oldFw.rulesVersion + 1 // Allow outbound because conntrack and new rules allow port 10 - require.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p.Key(), &p, false, &h, cp, nil)) oldFw = fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) @@ -725,7 +725,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { fw.rulesVersion = oldFw.rulesVersion + 1 // Drop outbound because conntrack doesn't match new ruleset - assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p.Key(), &p, false, &h, cp, nil), ErrNoMatchingRule) } func TestFirewall_ICMPPortBehavior(t *testing.T) { @@ -770,12 +770,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { p.LocalPort = 0 p.RemotePort = 0 // Drop outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p.Key(), p, true, &h, cp, nil)) //now also allow outbound - require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p.Key(), p, false, &h, cp, nil)) }) t.Run("nonzero ports", func(t *testing.T) { @@ -783,12 +783,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { p.LocalPort = 0xabcd p.RemotePort = 0x1234 // Drop outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p.Key(), p, true, &h, cp, nil)) //now also allow outbound - require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p.Key(), p, false, &h, cp, nil)) }) }) @@ -800,12 +800,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { p.LocalPort = 0 p.RemotePort = 0 // Drop outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p.Key(), p, true, &h, cp, nil), ErrNoMatchingRule) //now also allow outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule) }) t.Run("nonzero ports, still blocked", func(t *testing.T) { @@ -813,12 +813,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { p.LocalPort = 0xabcd p.RemotePort = 0x1234 // Drop outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p.Key(), p, true, &h, cp, nil), ErrNoMatchingRule) //now also allow outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule) }) t.Run("nonzero, matching ports, still blocked", func(t *testing.T) { @@ -826,12 +826,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { p.LocalPort = 80 p.RemotePort = 80 // Drop outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p.Key(), p, true, &h, cp, nil), ErrNoMatchingRule) //now also allow outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule) }) }) t.Run("Any proto, any port", func(t *testing.T) { @@ -843,12 +843,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { p.LocalPort = 0 p.RemotePort = 0 // Drop outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p.Key(), p, true, &h, cp, nil)) //now also allow outbound - require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p.Key(), p, false, &h, cp, nil)) }) t.Run("nonzero ports, allowed", func(t *testing.T) { @@ -857,15 +857,15 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) { p.LocalPort = 0xabcd p.RemotePort = 0x1234 // Drop outbound - assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - require.NoError(t, fw.Drop(*p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p.Key(), p, true, &h, cp, nil)) //now also allow outbound - require.NoError(t, fw.Drop(*p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p.Key(), p, false, &h, cp, nil)) //different ID is blocked p.RemotePort++ - require.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule) + require.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule) }) }) @@ -913,7 +913,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) { Protocol: firewall.ProtoUDP, Fragment: false, } - assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP) + assert.Equal(t, fw.Drop(p.Key(), &p, true, &h1, cp, nil), ErrInvalidRemoteIP) } func BenchmarkLookup(b *testing.B) { @@ -1327,7 +1327,7 @@ func (c *testcase) Test(t *testing.T, fw *Firewall) { t.Helper() cp := cert.NewCAPool() resetConntrack(fw) - err := fw.Drop(c.p, true, c.h, cp, nil) + err := fw.Drop(c.p.Key(), &c.p, true, c.h, cp, nil) if c.err == nil { require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr) } else { @@ -1519,6 +1519,6 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end func resetConntrack(fw *Firewall) { fw.Conntrack.Lock() - fw.Conntrack.Conns = map[firewall.Packet]*conn{} + fw.Conntrack.Conns = map[firewall.PacketKey]*conn{} fw.Conntrack.Unlock() } diff --git a/inside.go b/inside.go index 0fa841da..1d388340 100644 --- a/inside.go +++ b/inside.go @@ -105,7 +105,7 @@ func (f *Interface) consumeInsidePacket(pkt tio.Packet, fwPacket *firewall.Packe return } - dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) + dropReason := f.firewall.Drop(fwPacket.Key(), fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { f.sendInsideMessage(hostinfo, pkt, nb, sendBatch, rejectBuf, q) } else { @@ -400,7 +400,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp } // check if packet is in outbound fw rules - dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil) + dropReason := f.firewall.Drop(fp.Key(), fp, false, hostinfo, f.pki.GetCAPool(), nil) if dropReason != nil { if f.l.Enabled(context.Background(), slog.LevelDebug) { f.l.Debug("dropping cached packet", diff --git a/outside.go b/outside.go index 05dc7a13..e609f4d8 100644 --- a/outside.go +++ b/outside.go @@ -570,10 +570,13 @@ func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, p applyOuterECN(out, meta.OuterECN, hostinfo, f.l) } - // Single IP+L4 walk feeds both the firewall (via fwPacket) and the - // batcher (via parsedRx). Replaces newPacket — the batcher's CommitInbound - // uses parsedRx instead of re-walking the headers. - err := batch.ParseInbound(out, fwPacket, parsedRx) + // Single IP+L4 walk feeds the firewall conntrack key (parsedRx.Key) + // and the batcher hint (parsedRx.tcp/udp). Replaces newPacket — and + // pointedly does NOT fill fwPacket.LocalAddr/RemoteAddr, since + // firewall.Drop's fast path uses Key alone and only hydrates fwPacket + // from Key on the slow path. + *fwPacket = firewall.Packet{} + err := batch.ParseInbound(out, parsedRx) if err != nil { hostinfo.logger(f.l).Warn("Error while validating inbound packet", "error", err, @@ -582,7 +585,7 @@ func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, p return } - dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) + dropReason := f.firewall.Drop(parsedRx.Key, fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) if dropReason != nil { // NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore // This gives us a buffer to build the reject packet in diff --git a/overlay/batch/inbound.go b/overlay/batch/inbound.go new file mode 100644 index 00000000..67a528fa --- /dev/null +++ b/overlay/batch/inbound.go @@ -0,0 +1,401 @@ +package batch + +import ( + "encoding/binary" + "errors" + + "github.com/slackhq/nebula/firewall" +) + +// IANA protocol numbers we recognise during the inbound parse. Kept local +// (rather than reaching for the firewall constants for every one of these) +// so the byte-comparison hot path doesn't depend on cross-package values. +const ( + ipProtoICMP = 1 + ipProtoIPv6Fragment = 44 + ipProtoESP = 50 + ipProtoAH = 51 + ipProtoICMPv6 = 58 + ipProtoNoNextHdr = 59 + + icmpv6TypeEchoRequest = 128 + icmpv6TypeEchoReply = 129 +) + +// Inbound parse errors. Match outside.go's sentinel set so the unified +// parser can drop in as a replacement for newPacket without callers +// noticing a behavior change. +var ( + ErrInboundPacketTooShort = errors.New("packet is too short") + ErrInboundUnknownIPVersion = errors.New("packet is an unknown ip version") + ErrInboundIPv4InvalidHdrLen = errors.New("invalid ipv4 header length") + ErrInboundIPv4TooShort = errors.New("ipv4 packet is too short") + ErrInboundIPv6TooShort = errors.New("ipv6 packet is too short") + ErrInboundIPv6NoPayload = errors.New("could not find payload in ipv6 packet") +) + +// RxKind discriminates how an inbound plaintext packet should be committed +// after its firewall.Packet has been built. RxKindPassthrough means the +// IP shape is valid (firewall could match on it) but the coalescer's +// strict checks reject it — caller should still write it via the +// passthrough lane. +type RxKind uint8 + +const ( + RxKindPassthrough RxKind = iota + RxKindTCP + RxKindUDP +) + +// RxParsed is the unified result of one IP+L4 walk: +// - Key: the firewall's conntrack/cache lookup key. The dense form lets +// firewall.Drop hit conntrack without ever filling the rich Packet's +// netip.Addr fields. On a conntrack miss, Drop hydrates the caller's +// Packet from Key. +// - tcp/udp: the coalescer hint so commitParsed doesn't re-walk the +// headers. Meaningful only when Kind is RxKindTCP / RxKindUDP. +type RxParsed struct { + Kind RxKind + Key firewall.PacketKey + tcp parsedTCP + udp parsedUDP +} + +// ParseInbound walks an inbound plaintext packet once and fills: +// - parsed.Key with the dense, Local/Remote-oriented conntrack key the +// firewall uses (replaces the netip.Addr-rich path through newPacket). +// - parsed.{tcp,udp} with the coalescer hint, when the shape is +// coalesce-eligible. +// +// Eligibility rules match the coalescer's own parseTCPBase/parseUDP: +// - IPv4 strict: IHL == 20, no fragmentation (MF or offset), proto TCP/UDP. +// - IPv6 strict: NextHeader is directly TCP or UDP (no extension headers). +// +// Returns the same set of errors newPacket returns for malformed input — +// callers can treat those as drop. +func ParseInbound(pkt []byte, parsed *RxParsed) error { + parsed.Kind = RxKindPassthrough + // Reset Key in full: v4 only writes the low 4 bytes of each address + // field, so without this a v6 call followed by a v4 reusing the same + // RxParsed would inherit the high 12 bytes — breaking the conntrack + // map equality for v4 flows. + parsed.Key = firewall.PacketKey{} + if len(pkt) < 1 { + return ErrInboundPacketTooShort + } + switch pkt[0] >> 4 { + case 4: + return parseInboundV4(pkt, parsed) + case 6: + return parseInboundV6(pkt, parsed) + } + return ErrInboundUnknownIPVersion +} + +// parseInboundV4 mirrors parseV4(incoming=true) for the firewall side and +// also fills the coalescer hint when the shape is strict. +func parseInboundV4(pkt []byte, parsed *RxParsed) error { + if len(pkt) < 20 { + return ErrInboundIPv4TooShort + } + ihl := int(pkt[0]&0x0f) << 2 + if ihl < 20 { + return ErrInboundIPv4InvalidHdrLen + } + flagsfrags := binary.BigEndian.Uint16(pkt[6:8]) + parsed.Key.Fragment = (flagsfrags & 0x1FFF) != 0 + parsed.Key.Protocol = pkt[9] + parsed.Key.IsV6 = false + + // minFwPacketLen (4) is the L4-header prefix the firewall needs to pull + // ports; ICMP needs two extra bytes for the identifier. + minLen := ihl + if !parsed.Key.Fragment { + if parsed.Key.Protocol == firewall.ProtoICMP { + minLen += 4 + 2 + } else { + minLen += 4 + } + } + if len(pkt) < minLen { + return ErrInboundIPv4InvalidHdrLen + } + + // Inbound orientation: wire src → Remote, wire dst → Local. + copy(parsed.Key.RemoteAddr[:4], pkt[12:16]) + copy(parsed.Key.LocalAddr[:4], pkt[16:20]) + + switch { + case parsed.Key.Fragment: + parsed.Key.RemotePort = 0 + parsed.Key.LocalPort = 0 + case parsed.Key.Protocol == firewall.ProtoICMP: + parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[ihl+4 : ihl+6]) + parsed.Key.LocalPort = 0 + default: + parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[ihl : ihl+2]) + parsed.Key.LocalPort = binary.BigEndian.Uint16(pkt[ihl+2 : ihl+4]) + } + + // Coalescer-eligible? Strict shape: IHL==20, no MF/offset, TCP or UDP. + if ihl != 20 || (flagsfrags&0x3FFF) != 0 { + return nil + } + if parsed.Key.Protocol != ipProtoTCP && parsed.Key.Protocol != ipProtoUDP { + return nil + } + totalLen := int(binary.BigEndian.Uint16(pkt[2:4])) + if totalLen > len(pkt) || totalLen < 20 { + return nil + } + pktTrim := pkt[:totalLen] + + switch parsed.Key.Protocol { + case ipProtoTCP: + fillParsedTCPv4(pktTrim, parsed) + case ipProtoUDP: + fillParsedUDPv4(pktTrim, parsed) + } + return nil +} + +// fillParsedTCPv4 fills parsed.tcp from a strict-shape IPv4+TCP packet +// already validated to have IHL==20 and to be totalLen-trimmed. +func fillParsedTCPv4(pkt []byte, parsed *RxParsed) { + if len(pkt) < 40 { // IPv4(20) + min TCP(20) + return + } + tcpOff := int(pkt[32]>>4) * 4 + if tcpOff < 20 || tcpOff > 60 { + return + } + if len(pkt) < 20+tcpOff { + return + } + p := &parsed.tcp + p.ipHdrLen = 20 + p.tcpHdrLen = tcpOff + p.hdrLen = 20 + tcpOff + p.payLen = len(pkt) - p.hdrLen + p.seq = binary.BigEndian.Uint32(pkt[24:28]) + p.flags = pkt[33] + p.fk.isV6 = false + p.fk.sport = parsed.Key.RemotePort + p.fk.dport = parsed.Key.LocalPort + copy(p.fk.src[:4], pkt[12:16]) + copy(p.fk.dst[:4], pkt[16:20]) + parsed.Kind = RxKindTCP +} + +// fillParsedUDPv4 fills parsed.udp from a strict-shape IPv4+UDP packet. +func fillParsedUDPv4(pkt []byte, parsed *RxParsed) { + if len(pkt) < 28 { // IPv4(20) + UDP(8) + return + } + udpLen := int(binary.BigEndian.Uint16(pkt[24:26])) + if udpLen < 8 || udpLen > len(pkt)-20 { + return + } + p := &parsed.udp + p.ipHdrLen = 20 + p.hdrLen = 28 + p.payLen = udpLen - 8 + p.fk.isV6 = false + p.fk.sport = parsed.Key.RemotePort + p.fk.dport = parsed.Key.LocalPort + copy(p.fk.src[:4], pkt[12:16]) + copy(p.fk.dst[:4], pkt[16:20]) + parsed.Kind = RxKindUDP +} + +// parseInboundV6 mirrors parseV6(incoming=true). The coalescer-eligible +// fast path triggers only when NextHeader is directly TCP or UDP — any +// extension header chain falls into the lenient walk below. +func parseInboundV6(pkt []byte, parsed *RxParsed) error { + if len(pkt) < 40 { + return ErrInboundIPv6TooShort + } + parsed.Key.IsV6 = true + copy(parsed.Key.RemoteAddr[:], pkt[8:24]) + copy(parsed.Key.LocalAddr[:], pkt[24:40]) + + if proto := pkt[6]; proto == ipProtoTCP || proto == ipProtoUDP { + // Strict v6: ports are at the IP header end. Always fill key; only + // fill the coalescer hint if the L4 shape passes. + if len(pkt) < 44 { + return ErrInboundIPv6TooShort + } + parsed.Key.Protocol = proto + parsed.Key.Fragment = false + parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[40:42]) + parsed.Key.LocalPort = binary.BigEndian.Uint16(pkt[42:44]) + + payloadLen := int(binary.BigEndian.Uint16(pkt[4:6])) + if 40+payloadLen > len(pkt) { + return nil + } + pktTrim := pkt[:40+payloadLen] + + switch proto { + case ipProtoTCP: + fillParsedTCPv6(pktTrim, parsed) + case ipProtoUDP: + fillParsedUDPv6(pktTrim, parsed) + } + return nil + } + + // Slow path: walk extension header chain just like parseV6 does. + return walkInboundV6Headers(pkt, parsed) +} + +func fillParsedTCPv6(pkt []byte, parsed *RxParsed) { + if len(pkt) < 60 { // IPv6(40) + min TCP(20) + return + } + tcpOff := int(pkt[52]>>4) * 4 + if tcpOff < 20 || tcpOff > 60 { + return + } + if len(pkt) < 40+tcpOff { + return + } + p := &parsed.tcp + p.ipHdrLen = 40 + p.tcpHdrLen = tcpOff + p.hdrLen = 40 + tcpOff + p.payLen = len(pkt) - p.hdrLen + p.seq = binary.BigEndian.Uint32(pkt[44:48]) + p.flags = pkt[53] + p.fk.isV6 = true + p.fk.sport = parsed.Key.RemotePort + p.fk.dport = parsed.Key.LocalPort + copy(p.fk.src[:], pkt[8:24]) + copy(p.fk.dst[:], pkt[24:40]) + parsed.Kind = RxKindTCP +} + +func fillParsedUDPv6(pkt []byte, parsed *RxParsed) { + if len(pkt) < 48 { // IPv6(40) + UDP(8) + return + } + udpLen := int(binary.BigEndian.Uint16(pkt[44:46])) + if udpLen < 8 || udpLen > len(pkt)-40 { + return + } + p := &parsed.udp + p.ipHdrLen = 40 + p.hdrLen = 48 + p.payLen = udpLen - 8 + p.fk.isV6 = true + p.fk.sport = parsed.Key.RemotePort + p.fk.dport = parsed.Key.LocalPort + copy(p.fk.src[:], pkt[8:24]) + copy(p.fk.dst[:], pkt[24:40]) + parsed.Kind = RxKindUDP +} + +// walkInboundV6Headers handles every IPv6 case parseV6 handles that isn't +// the strict "NextHeader == TCP/UDP" fast path: ESP, NoNextHeader, ICMPv6, +// fragment headers (first vs later), AH, generic extension headers. +// Coalescer eligibility is always RxKindPassthrough on this path (parsed +// already initialised that way). +func walkInboundV6Headers(pkt []byte, parsed *RxParsed) error { + dataLen := len(pkt) + protoAt := 6 + offset := 40 + next := 0 + for { + if protoAt >= dataLen { + break + } + proto := pkt[protoAt] + switch proto { + case ipProtoESP, ipProtoNoNextHdr: + parsed.Key.Protocol = proto + parsed.Key.RemotePort = 0 + parsed.Key.LocalPort = 0 + parsed.Key.Fragment = false + return nil + + case ipProtoICMPv6: + if dataLen < offset+6 { + return ErrInboundIPv6TooShort + } + parsed.Key.Protocol = proto + parsed.Key.LocalPort = 0 + switch pkt[offset+1] { + case icmpv6TypeEchoRequest, icmpv6TypeEchoReply: + parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[offset+4 : offset+6]) + default: + parsed.Key.RemotePort = 0 + } + parsed.Key.Fragment = false + return nil + + case ipProtoTCP, ipProtoUDP: + // Reachable when an extension-header chain ends at TCP/UDP. The + // strict-eligible fast path above already handled the no-extension + // case; here we only fill firewall ports and stay passthrough. + if dataLen < offset+4 { + return ErrInboundIPv6TooShort + } + parsed.Key.Protocol = proto + parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[offset : offset+2]) + parsed.Key.LocalPort = binary.BigEndian.Uint16(pkt[offset+2 : offset+4]) + parsed.Key.Fragment = false + return nil + + case ipProtoIPv6Fragment: + if dataLen < offset+8 { + return ErrInboundIPv6TooShort + } + fragmentOffset := binary.BigEndian.Uint16(pkt[offset+2:offset+4]) &^ uint16(0x7) + if fragmentOffset != 0 { + // Non-first fragment: report the fragment flag and stop. + parsed.Key.Protocol = pkt[offset] + parsed.Key.Fragment = true + parsed.Key.RemotePort = 0 + parsed.Key.LocalPort = 0 + return nil + } + next = 8 + + case ipProtoAH: + if dataLen <= offset+1 { + break + } + next = int(pkt[offset+1]+2) << 2 + + default: + if dataLen <= offset+1 { + break + } + next = int(pkt[offset+1]+1) << 3 + } + + if next <= 0 { + next = 8 + } + protoAt = offset + offset = offset + next + } + return ErrInboundIPv6NoPayload +} + +// CommitInbound dispatches pkt to the appropriate lane using parsed.Kind, +// skipping the IP+L4 re-parse that MultiCoalescer.Commit would otherwise +// do. Borrowed slice contract is identical to MultiCoalescer.Commit. +func (m *MultiCoalescer) CommitInbound(pkt []byte, parsed *RxParsed) error { + switch parsed.Kind { + case RxKindTCP: + if m.tcp != nil { + return m.tcp.commitParsed(pkt, parsed.tcp) + } + case RxKindUDP: + if m.udp != nil { + return m.udp.commitParsed(pkt, parsed.udp) + } + } + return m.pt.Commit(pkt) +} diff --git a/overlay/batch/inbound_bench_test.go b/overlay/batch/inbound_bench_test.go new file mode 100644 index 00000000..dbbd479d --- /dev/null +++ b/overlay/batch/inbound_bench_test.go @@ -0,0 +1,394 @@ +package batch + +import ( + "encoding/binary" + "net/netip" + "testing" + + "github.com/slackhq/nebula/firewall" +) + +// parseV4InboundBaseline mirrors what outside.go's parseV4(incoming=true) +// does, so the "split" bench measures the *current* state: firewall-side +// parse, then m.Commit re-parses inside the coalescer. Two walks per +// packet. Kept faithful in shape (one read per field, AddrFromSlice for +// the addrs) so the CPU profile matches the production parseV4. +func parseV4InboundBaseline(pkt []byte, fp *firewall.Packet) bool { + if len(pkt) < 20 { + return false + } + ihl := int(pkt[0]&0x0f) << 2 + if ihl < 20 { + return false + } + flagsfrags := binary.BigEndian.Uint16(pkt[6:8]) + fp.Fragment = (flagsfrags & 0x1FFF) != 0 + fp.Protocol = pkt[9] + minLen := ihl + if !fp.Fragment { + if fp.Protocol == firewall.ProtoICMP { + minLen += 4 + 2 + } else { + minLen += 4 + } + } + if len(pkt) < minLen { + return false + } + fp.RemoteAddr, _ = netip.AddrFromSlice(pkt[12:16]) + fp.LocalAddr, _ = netip.AddrFromSlice(pkt[16:20]) + switch { + case fp.Fragment: + fp.RemotePort = 0 + fp.LocalPort = 0 + case fp.Protocol == firewall.ProtoICMP: + fp.RemotePort = binary.BigEndian.Uint16(pkt[ihl+4 : ihl+6]) + fp.LocalPort = 0 + default: + fp.RemotePort = binary.BigEndian.Uint16(pkt[ihl : ihl+2]) + fp.LocalPort = binary.BigEndian.Uint16(pkt[ihl+2 : ihl+4]) + } + return true +} + +// parseV6InboundBaseline is the v6 analogue: replicates parseV6's +// extension-header walk so the split bench captures its true cost. +func parseV6InboundBaseline(pkt []byte, fp *firewall.Packet) bool { + dataLen := len(pkt) + if dataLen < 40 { + return false + } + fp.RemoteAddr, _ = netip.AddrFromSlice(pkt[8:24]) + fp.LocalAddr, _ = netip.AddrFromSlice(pkt[24:40]) + + protoAt := 6 + offset := 40 + next := 0 + for { + if protoAt >= dataLen { + return false + } + proto := pkt[protoAt] + switch proto { + case ipProtoESP, ipProtoNoNextHdr: + fp.Protocol = proto + fp.RemotePort = 0 + fp.LocalPort = 0 + fp.Fragment = false + return true + case ipProtoICMPv6: + if dataLen < offset+6 { + return false + } + fp.Protocol = proto + fp.LocalPort = 0 + switch pkt[offset+1] { + case icmpv6TypeEchoRequest, icmpv6TypeEchoReply: + fp.RemotePort = binary.BigEndian.Uint16(pkt[offset+4 : offset+6]) + default: + fp.RemotePort = 0 + } + fp.Fragment = false + return true + case ipProtoTCP, ipProtoUDP: + if dataLen < offset+4 { + return false + } + fp.Protocol = proto + fp.RemotePort = binary.BigEndian.Uint16(pkt[offset : offset+2]) + fp.LocalPort = binary.BigEndian.Uint16(pkt[offset+2 : offset+4]) + fp.Fragment = false + return true + case ipProtoIPv6Fragment: + if dataLen < offset+8 { + return false + } + fragmentOffset := binary.BigEndian.Uint16(pkt[offset+2:offset+4]) &^ uint16(0x7) + if fragmentOffset != 0 { + fp.Protocol = pkt[offset] + fp.Fragment = true + fp.RemotePort = 0 + fp.LocalPort = 0 + return true + } + next = 8 + case ipProtoAH: + if dataLen <= offset+1 { + return false + } + next = int(pkt[offset+1]+2) << 2 + default: + if dataLen <= offset+1 { + return false + } + next = int(pkt[offset+1]+1) << 3 + } + if next <= 0 { + next = 8 + } + protoAt = offset + offset = offset + next + } +} + +// runRxSplit drives the split path: faithful inbound parse for the firewall +// side, then m.Commit re-parses to coalesce. v6 controls which baseline +// parser we run. +func runRxSplit(b *testing.B, pkts [][]byte, batchSize int, v6 bool) { + b.Helper() + m := NewMultiCoalescer(nopTunWriter{}, true, true) + var fp firewall.Packet + b.ReportAllocs() + b.SetBytes(int64(len(pkts[0]))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + pkt := pkts[i%len(pkts)] + var ok bool + if v6 { + ok = parseV6InboundBaseline(pkt, &fp) + } else { + ok = parseV4InboundBaseline(pkt, &fp) + } + if !ok { + b.Fatal("baseline parse failed") + } + if err := m.Commit(pkt); err != nil { + b.Fatal(err) + } + if (i+1)%batchSize == 0 { + if err := m.Flush(); err != nil { + b.Fatal(err) + } + } + } + _ = m.Flush() +} + +// runRxUnified drives the unified path: ParseInbound walks once, filling +// the conntrack key + coalescer hint in parsed; CommitInbound dispatches +// without re-parsing. +func runRxUnified(b *testing.B, pkts [][]byte, batchSize int) { + b.Helper() + m := NewMultiCoalescer(nopTunWriter{}, true, true) + var parsed RxParsed + b.ReportAllocs() + b.SetBytes(int64(len(pkts[0]))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + pkt := pkts[i%len(pkts)] + if err := ParseInbound(pkt, &parsed); err != nil { + b.Fatal(err) + } + if err := m.CommitInbound(pkt, &parsed); err != nil { + b.Fatal(err) + } + if (i+1)%batchSize == 0 { + if err := m.Flush(); err != nil { + b.Fatal(err) + } + } + } + _ = m.Flush() +} + +// buildUDPv4Bulk returns N UDP packets on a single 5-tuple suitable for the +// UDP coalescer's append path. +func buildUDPv4Bulk(n, payloadLen int) [][]byte { + pkts := make([][]byte, n) + pay := make([]byte, payloadLen) + for i := range n { + pkts[i] = buildUDPv4(1000, 53, pay) + } + return pkts +} + +func buildTCPv6Bulk(n, payloadLen int) [][]byte { + pkts := make([][]byte, n) + pay := make([]byte, payloadLen) + seq := uint32(1000) + for i := range n { + pkts[i] = buildTCPv6(0, seq, tcpAck, pay) + seq += uint32(payloadLen) + } + return pkts +} + +func buildICMPv4Bulk(n int) [][]byte { + pkts := make([][]byte, n) + for i := range pkts { + pkts[i] = buildICMPv4() + } + return pkts +} + +// === TCPv4 === + +func BenchmarkRxSplitTCPv4(b *testing.B) { + pkts := buildTCPv4BulkFlow(tcpCoalesceMaxSegs, 1200) + runRxSplit(b, pkts, tcpCoalesceMaxSegs, false) +} + +func BenchmarkRxUnifiedTCPv4(b *testing.B) { + pkts := buildTCPv4BulkFlow(tcpCoalesceMaxSegs, 1200) + runRxUnified(b, pkts, tcpCoalesceMaxSegs) +} + +// === TCPv4 interleaved (4 flows) === + +func BenchmarkRxSplitTCPv4Interleaved4(b *testing.B) { + pkts := buildTCPv4Interleaved(4, tcpCoalesceMaxSegs, 1200) + runRxSplit(b, pkts, len(pkts), false) +} + +func BenchmarkRxUnifiedTCPv4Interleaved4(b *testing.B) { + pkts := buildTCPv4Interleaved(4, tcpCoalesceMaxSegs, 1200) + runRxUnified(b, pkts, len(pkts)) +} + +// === UDPv4 === + +func BenchmarkRxSplitUDPv4(b *testing.B) { + pkts := buildUDPv4Bulk(udpCoalesceMaxSegs, 1200) + runRxSplit(b, pkts, udpCoalesceMaxSegs, false) +} + +func BenchmarkRxUnifiedUDPv4(b *testing.B) { + pkts := buildUDPv4Bulk(udpCoalesceMaxSegs, 1200) + runRxUnified(b, pkts, udpCoalesceMaxSegs) +} + +// === TCPv6 === + +func BenchmarkRxSplitTCPv6(b *testing.B) { + pkts := buildTCPv6Bulk(tcpCoalesceMaxSegs, 1200) + runRxSplit(b, pkts, tcpCoalesceMaxSegs, true) +} + +func BenchmarkRxUnifiedTCPv6(b *testing.B) { + pkts := buildTCPv6Bulk(tcpCoalesceMaxSegs, 1200) + runRxUnified(b, pkts, tcpCoalesceMaxSegs) +} + +// === ICMPv4 (passthrough) — measures the unified parser on the coalescer- +// rejected path, where both lenient and unified must still fill fp. === + +func BenchmarkRxSplitICMPv4(b *testing.B) { + pkts := buildICMPv4Bulk(64) + runRxSplit(b, pkts, 64, false) +} + +func BenchmarkRxUnifiedICMPv4(b *testing.B) { + pkts := buildICMPv4Bulk(64) + runRxUnified(b, pkts, 64) +} + +// === Firewall fast-path (conntrack-hit) — exercises the savings from the +// dense PacketKey: smaller hash key for the per-routine ConntrackCache, +// and skipping the AddrFrom4 calls that the old path needed to fill the +// netip.Addr-rich firewall.Packet up-front. === +// +// The "split" baseline simulates the legacy path: parseV4InboundBaseline +// fills a netip.Addr-rich Packet, then we probe a localCache keyed on +// Packet. The "unified" path: ParseInbound fills only the dense PacketKey, +// and we probe a localCache keyed on PacketKey. Both paths follow with +// the coalescer Commit so the bench captures end-to-end RX-side cost. + +// runRxSplitWithCache mirrors runRxSplit but runs the legacy-style +// firewall fast path (localCache keyed on firewall.Packet) on every +// packet so we can compare against the unified path. +func runRxSplitWithCache(b *testing.B, pkts [][]byte, batchSize int) { + b.Helper() + m := NewMultiCoalescer(nopTunWriter{}, true, true) + var fp firewall.Packet + + // Pre-warm a per-packet cache keyed on the netip.Addr-rich Packet form. + cache := make(map[firewall.Packet]struct{}, len(pkts)) + for _, pkt := range pkts { + var seedFp firewall.Packet + if !parseV4InboundBaseline(pkt, &seedFp) { + b.Fatal("seed parse failed") + } + cache[seedFp] = struct{}{} + } + + b.ReportAllocs() + b.SetBytes(int64(len(pkts[0]))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + pkt := pkts[i%len(pkts)] + if !parseV4InboundBaseline(pkt, &fp) { + b.Fatal("baseline parse failed") + } + if _, ok := cache[fp]; !ok { + b.Fatal("cache miss") + } + if err := m.Commit(pkt); err != nil { + b.Fatal(err) + } + if (i+1)%batchSize == 0 { + if err := m.Flush(); err != nil { + b.Fatal(err) + } + } + } + _ = m.Flush() +} + +// runRxUnifiedWithCache: unified path with a PacketKey-keyed localCache. +// Each iteration: ParseInbound → conntrack-cache hit → CommitInbound. +func runRxUnifiedWithCache(b *testing.B, pkts [][]byte, batchSize int) { + b.Helper() + m := NewMultiCoalescer(nopTunWriter{}, true, true) + var parsed RxParsed + + cache := make(firewall.ConntrackCache, len(pkts)) + for _, pkt := range pkts { + var seed RxParsed + if err := ParseInbound(pkt, &seed); err != nil { + b.Fatal(err) + } + cache[seed.Key] = struct{}{} + } + + b.ReportAllocs() + b.SetBytes(int64(len(pkts[0]))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + pkt := pkts[i%len(pkts)] + if err := ParseInbound(pkt, &parsed); err != nil { + b.Fatal(err) + } + if _, ok := cache[parsed.Key]; !ok { + b.Fatal("cache miss") + } + if err := m.CommitInbound(pkt, &parsed); err != nil { + b.Fatal(err) + } + if (i+1)%batchSize == 0 { + if err := m.Flush(); err != nil { + b.Fatal(err) + } + } + } + _ = m.Flush() +} + +func BenchmarkRxSplitTCPv4WithCache(b *testing.B) { + pkts := buildTCPv4BulkFlow(tcpCoalesceMaxSegs, 1200) + runRxSplitWithCache(b, pkts, tcpCoalesceMaxSegs) +} + +func BenchmarkRxUnifiedTCPv4WithCache(b *testing.B) { + pkts := buildTCPv4BulkFlow(tcpCoalesceMaxSegs, 1200) + runRxUnifiedWithCache(b, pkts, tcpCoalesceMaxSegs) +} + +func BenchmarkRxSplitInterleaved4WithCache(b *testing.B) { + pkts := buildTCPv4Interleaved(4, tcpCoalesceMaxSegs, 1200) + runRxSplitWithCache(b, pkts, len(pkts)) +} + +func BenchmarkRxUnifiedInterleaved4WithCache(b *testing.B) { + pkts := buildTCPv4Interleaved(4, tcpCoalesceMaxSegs, 1200) + runRxUnifiedWithCache(b, pkts, len(pkts)) +} diff --git a/overlay/batch/inbound_test.go b/overlay/batch/inbound_test.go new file mode 100644 index 00000000..a918a4e4 --- /dev/null +++ b/overlay/batch/inbound_test.go @@ -0,0 +1,174 @@ +package batch + +import ( + "net/netip" + "testing" + + "github.com/slackhq/nebula/firewall" +) + +// TestParseInboundParity asserts that ParseInbound + Key.Hydrate produces +// the same firewall.Packet that the lenient baseline parsers (which +// mirror outside.go's parseV4/parseV6 with incoming=true) produce for +// every shape we care about. Catches drift between the unified +// parse-then-hydrate flow and the production newPacket behavior so +// swapping one for the other is observably safe. +func TestParseInboundParity(t *testing.T) { + cases := []struct { + name string + pkt []byte + v6 bool + }{ + {"tcp_v4", buildTCPv4Ports(1234, 443, 1000, tcpAck, []byte("payload")), false}, + {"tcp_v4_psh", buildTCPv4Ports(1234, 443, 2000, tcpAckPsh, make([]byte, 1200)), false}, + {"udp_v4", buildUDPv4(40000, 53, []byte("dnsquery")), false}, + {"icmp_v4", buildICMPv4(), false}, + {"tcp_v6", buildTCPv6(0, 5000, tcpAck, make([]byte, 800)), true}, + {"udp_v6", buildUDPv6(40001, 53, []byte("v6dns")), true}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var fpUnified, fpBaseline firewall.Packet + var parsed RxParsed + + if err := ParseInbound(tc.pkt, &parsed); err != nil { + t.Fatalf("ParseInbound: %v", err) + } + parsed.Key.Hydrate(&fpUnified) + var ok bool + if tc.v6 { + ok = parseV6InboundBaseline(tc.pkt, &fpBaseline) + } else { + ok = parseV4InboundBaseline(tc.pkt, &fpBaseline) + } + if !ok { + t.Fatalf("baseline parse failed") + } + + if fpUnified != fpBaseline { + t.Errorf("firewall.Packet mismatch:\n unified: %+v\n baseline: %+v", fpUnified, fpBaseline) + } + }) + } +} + +// TestParseInboundFlowKey checks that the coalescer hint the unified parser +// produces matches what parseTCPBase/parseUDP would produce on the same +// packet — same flowKey, ipHdrLen, payLen, etc. The hint is only valid +// when Kind is RxKindTCP/RxKindUDP. +func TestParseInboundFlowKey(t *testing.T) { + t.Run("tcp_v4", func(t *testing.T) { + pkt := buildTCPv4Ports(1234, 443, 5000, tcpAck, make([]byte, 800)) + var parsed RxParsed + if err := ParseInbound(pkt, &parsed); err != nil { + t.Fatal(err) + } + if parsed.Kind != RxKindTCP { + t.Fatalf("kind=%v want TCP", parsed.Kind) + } + ref, ok := parseTCPBase(pkt) + if !ok { + t.Fatal("parseTCPBase failed") + } + if parsed.tcp != ref { + t.Errorf("parsedTCP mismatch:\n unified: %+v\n ref: %+v", parsed.tcp, ref) + } + }) + + t.Run("udp_v4", func(t *testing.T) { + pkt := buildUDPv4(40000, 53, []byte("dnsquery")) + var parsed RxParsed + if err := ParseInbound(pkt, &parsed); err != nil { + t.Fatal(err) + } + if parsed.Kind != RxKindUDP { + t.Fatalf("kind=%v want UDP", parsed.Kind) + } + ref, ok := parseUDP(pkt) + if !ok { + t.Fatal("parseUDP failed") + } + if parsed.udp != ref { + t.Errorf("parsedUDP mismatch:\n unified: %+v\n ref: %+v", parsed.udp, ref) + } + }) + + t.Run("tcp_v6", func(t *testing.T) { + pkt := buildTCPv6(0, 9000, tcpAck, make([]byte, 800)) + var parsed RxParsed + if err := ParseInbound(pkt, &parsed); err != nil { + t.Fatal(err) + } + if parsed.Kind != RxKindTCP { + t.Fatalf("kind=%v want TCP", parsed.Kind) + } + ref, ok := parseTCPBase(pkt) + if !ok { + t.Fatal("parseTCPBase failed") + } + if parsed.tcp != ref { + t.Errorf("parsedTCP mismatch:\n unified: %+v\n ref: %+v", parsed.tcp, ref) + } + }) +} + +// TestParseInboundICMPPassthrough confirms ICMP packets populate the +// conntrack key (including the ICMP identifier in RemotePort) but stay +// RxKindPassthrough so the batcher writes them verbatim. After Hydrate +// the firewall.Packet form should match what the legacy parseV4 produced. +func TestParseInboundICMPPassthrough(t *testing.T) { + pkt := buildICMPv4() + // Stamp a non-zero identifier into the ICMP header so we can check + // RemotePort gets it. + pkt[20] = 8 // type=echo + pkt[24] = 0xab + pkt[25] = 0xcd + + var parsed RxParsed + if err := ParseInbound(pkt, &parsed); err != nil { + t.Fatal(err) + } + if parsed.Kind != RxKindPassthrough { + t.Errorf("kind=%v want Passthrough", parsed.Kind) + } + var fp firewall.Packet + parsed.Key.Hydrate(&fp) + if fp.Protocol != firewall.ProtoICMP { + t.Errorf("Protocol=%d want %d", fp.Protocol, firewall.ProtoICMP) + } + if fp.RemotePort != 0xabcd { + t.Errorf("RemotePort=0x%x want 0xabcd", fp.RemotePort) + } + if fp.LocalPort != 0 { + t.Errorf("LocalPort=%d want 0", fp.LocalPort) + } + wantRemote := netip.MustParseAddr("10.0.0.1") + wantLocal := netip.MustParseAddr("10.0.0.2") + if fp.RemoteAddr != wantRemote || fp.LocalAddr != wantLocal { + t.Errorf("addrs: remote=%v local=%v want %v/%v", fp.RemoteAddr, fp.LocalAddr, wantRemote, wantLocal) + } +} + +// TestParseInboundV4Fragment confirms a fragmented v4 packet fills the +// conntrack key with Fragment=true and falls into Passthrough on the +// coalescer side. +func TestParseInboundV4Fragment(t *testing.T) { + // Build a TCP packet then twiddle the IP flags to make it look like a + // non-first fragment (offset != 0). + pkt := buildTCPv4Ports(1234, 443, 1000, tcpAck, []byte("payload")) + // Set a non-zero fragment offset (bytes 6-7, low 13 bits). + pkt[6] = 0x00 + pkt[7] = 0x10 // offset = 16 (in 8-byte units) + + var parsed RxParsed + if err := ParseInbound(pkt, &parsed); err != nil { + t.Fatal(err) + } + if !parsed.Key.Fragment { + t.Error("Fragment=false, want true") + } + if parsed.Kind != RxKindPassthrough { + t.Errorf("kind=%v want Passthrough", parsed.Kind) + } +}