mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-16 04:47:38 +02:00
checkpt, try to parse packets only once pt2
This commit is contained in:
69
firewall.go
69
firewall.go
@@ -80,8 +80,8 @@ type firewallMetrics struct {
|
|||||||
type FirewallConntrack struct {
|
type FirewallConntrack struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
Conns map[firewall.Packet]*conn
|
Conns map[firewall.PacketKey]*conn
|
||||||
TimerWheel *TimerWheel[firewall.Packet]
|
TimerWheel *TimerWheel[firewall.PacketKey]
|
||||||
}
|
}
|
||||||
|
|
||||||
// FirewallTable is the entry point for a rule, the evaluation order is:
|
// FirewallTable is the entry point for a rule, the evaluation order is:
|
||||||
@@ -166,8 +166,8 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur
|
|||||||
|
|
||||||
return &Firewall{
|
return &Firewall{
|
||||||
Conntrack: &FirewallConntrack{
|
Conntrack: &FirewallConntrack{
|
||||||
Conns: make(map[firewall.Packet]*conn),
|
Conns: make(map[firewall.PacketKey]*conn),
|
||||||
TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax),
|
TimerWheel: NewTimerWheel[firewall.PacketKey](tmin, tmax),
|
||||||
},
|
},
|
||||||
InRules: newFirewallTable(),
|
InRules: newFirewallTable(),
|
||||||
OutRules: newFirewallTable(),
|
OutRules: newFirewallTable(),
|
||||||
@@ -422,12 +422,27 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
|||||||
|
|
||||||
// Drop returns an error if the packet should be dropped, explaining why. It
|
// Drop returns an error if the packet should be dropped, explaining why. It
|
||||||
// returns nil if the packet should not be dropped.
|
// returns nil if the packet should not be dropped.
|
||||||
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
|
//
|
||||||
// Check if we spoke to this tuple, if we did then allow this packet
|
// key is the dense conntrack key — used as-is for the inConns fast path
|
||||||
if f.inConns(fp, h, caPool, localCache) {
|
// without touching fp at all. fp is the rich Packet form rule matching
|
||||||
|
// needs (CIDR lookups, family checks); on the conntrack-miss slow path
|
||||||
|
// Drop ensures fp is hydrated from key (idempotent if the caller already
|
||||||
|
// filled fp). On accept-via-conntrack the caller's fp is left untouched.
|
||||||
|
func (f *Firewall) Drop(key firewall.PacketKey, fp *firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
|
||||||
|
// Check if we spoke to this tuple, if we did then allow this packet.
|
||||||
|
// Hot path: only the dense key is touched.
|
||||||
|
if f.inConns(key, h, caPool, localCache) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Conntrack miss → rule matching needs the rich Packet form. Hydrate
|
||||||
|
// from the key if the caller passed a zero-valued fp (the inbound path
|
||||||
|
// after ParseInbound). Outbound callers fill fp via newPacket and skip
|
||||||
|
// this hop.
|
||||||
|
if !fp.LocalAddr.IsValid() {
|
||||||
|
key.Hydrate(fp)
|
||||||
|
}
|
||||||
|
|
||||||
// Make sure remote address matches nebula certificate, and determine how to treat it
|
// Make sure remote address matches nebula certificate, and determine how to treat it
|
||||||
if h.networks == nil {
|
if h.networks == nil {
|
||||||
// Simple case: Certificate has one address and no unsafe networks
|
// Simple case: Certificate has one address and no unsafe networks
|
||||||
@@ -467,13 +482,13 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// We now know which firewall table to check against
|
// We now know which firewall table to check against
|
||||||
if !table.match(fp, incoming, h.ConnectionState.peerCert, caPool) {
|
if !table.match(*fp, incoming, h.ConnectionState.peerCert, caPool) {
|
||||||
f.metrics(incoming).droppedNoRule.Inc(1)
|
f.metrics(incoming).droppedNoRule.Inc(1)
|
||||||
return ErrNoMatchingRule
|
return ErrNoMatchingRule
|
||||||
}
|
}
|
||||||
|
|
||||||
// We always want to conntrack since it is a faster operation
|
// We always want to conntrack since it is a faster operation
|
||||||
f.addConn(fp, incoming)
|
f.addConn(key, fp.Protocol, incoming)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -502,9 +517,9 @@ func (f *Firewall) EmitStats() {
|
|||||||
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
|
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
|
func (f *Firewall) inConns(key firewall.PacketKey, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
|
||||||
if localCache != nil {
|
if localCache != nil {
|
||||||
if _, ok := localCache[fp]; ok {
|
if _, ok := localCache[key]; ok {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -517,7 +532,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
|||||||
f.evict(ep)
|
f.evict(ep)
|
||||||
}
|
}
|
||||||
|
|
||||||
c, ok := conntrack.Conns[fp]
|
c, ok := conntrack.Conns[key]
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
conntrack.Unlock()
|
conntrack.Unlock()
|
||||||
@@ -526,7 +541,11 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
|||||||
|
|
||||||
if c.rulesVersion != f.rulesVersion {
|
if c.rulesVersion != f.rulesVersion {
|
||||||
// This conntrack entry was for an older rule set, validate
|
// This conntrack entry was for an older rule set, validate
|
||||||
// it still passes with the current rule set
|
// it still passes with the current rule set. Rule matching needs
|
||||||
|
// the rich Packet form, so hydrate from key.
|
||||||
|
var fp firewall.Packet
|
||||||
|
key.Hydrate(&fp)
|
||||||
|
|
||||||
table := f.OutRules
|
table := f.OutRules
|
||||||
if c.incoming {
|
if c.incoming {
|
||||||
table = f.InRules
|
table = f.InRules
|
||||||
@@ -542,7 +561,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
|||||||
"oldRulesVersion", c.rulesVersion,
|
"oldRulesVersion", c.rulesVersion,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
delete(conntrack.Conns, fp)
|
delete(conntrack.Conns, key)
|
||||||
conntrack.Unlock()
|
conntrack.Unlock()
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -559,7 +578,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
|||||||
c.rulesVersion = f.rulesVersion
|
c.rulesVersion = f.rulesVersion
|
||||||
}
|
}
|
||||||
|
|
||||||
switch fp.Protocol {
|
switch key.Protocol {
|
||||||
case firewall.ProtoTCP:
|
case firewall.ProtoTCP:
|
||||||
c.Expires = time.Now().Add(f.TCPTimeout)
|
c.Expires = time.Now().Add(f.TCPTimeout)
|
||||||
case firewall.ProtoUDP:
|
case firewall.ProtoUDP:
|
||||||
@@ -571,17 +590,17 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
|||||||
conntrack.Unlock()
|
conntrack.Unlock()
|
||||||
|
|
||||||
if localCache != nil {
|
if localCache != nil {
|
||||||
localCache[fp] = struct{}{}
|
localCache[key] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
|
func (f *Firewall) addConn(key firewall.PacketKey, protocol uint8, incoming bool) {
|
||||||
var timeout time.Duration
|
var timeout time.Duration
|
||||||
c := &conn{}
|
c := &conn{}
|
||||||
|
|
||||||
switch fp.Protocol {
|
switch protocol {
|
||||||
case firewall.ProtoTCP:
|
case firewall.ProtoTCP:
|
||||||
timeout = f.TCPTimeout
|
timeout = f.TCPTimeout
|
||||||
case firewall.ProtoUDP:
|
case firewall.ProtoUDP:
|
||||||
@@ -592,9 +611,9 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
|
|||||||
|
|
||||||
conntrack := f.Conntrack
|
conntrack := f.Conntrack
|
||||||
conntrack.Lock()
|
conntrack.Lock()
|
||||||
if _, ok := conntrack.Conns[fp]; !ok {
|
if _, ok := conntrack.Conns[key]; !ok {
|
||||||
conntrack.TimerWheel.Advance(time.Now())
|
conntrack.TimerWheel.Advance(time.Now())
|
||||||
conntrack.TimerWheel.Add(fp, timeout)
|
conntrack.TimerWheel.Add(key, timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Record which rulesVersion allowed this connection, so we can retest after
|
// Record which rulesVersion allowed this connection, so we can retest after
|
||||||
@@ -602,16 +621,16 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
|
|||||||
c.incoming = incoming
|
c.incoming = incoming
|
||||||
c.rulesVersion = f.rulesVersion
|
c.rulesVersion = f.rulesVersion
|
||||||
c.Expires = time.Now().Add(timeout)
|
c.Expires = time.Now().Add(timeout)
|
||||||
conntrack.Conns[fp] = c
|
conntrack.Conns[key] = c
|
||||||
conntrack.Unlock()
|
conntrack.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
|
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
|
||||||
// Caller must own the connMutex lock!
|
// Caller must own the connMutex lock!
|
||||||
func (f *Firewall) evict(p firewall.Packet) {
|
func (f *Firewall) evict(key firewall.PacketKey) {
|
||||||
// Are we still tracking this conn?
|
// Are we still tracking this conn?
|
||||||
conntrack := f.Conntrack
|
conntrack := f.Conntrack
|
||||||
t, ok := conntrack.Conns[p]
|
t, ok := conntrack.Conns[key]
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -621,12 +640,12 @@ func (f *Firewall) evict(p firewall.Packet) {
|
|||||||
// Timeout is in the future, re-add the timer
|
// Timeout is in the future, re-add the timer
|
||||||
if newT > 0 {
|
if newT > 0 {
|
||||||
conntrack.TimerWheel.Advance(time.Now())
|
conntrack.TimerWheel.Advance(time.Now())
|
||||||
conntrack.TimerWheel.Add(p, newT)
|
conntrack.TimerWheel.Add(key, newT)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// This conn is done
|
// This conn is done
|
||||||
delete(conntrack.Conns, p)
|
delete(conntrack.Conns, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedCertificate, caPool *cert.CAPool) bool {
|
func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedCertificate, caPool *cert.CAPool) bool {
|
||||||
|
|||||||
@@ -10,8 +10,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// ConntrackCache is used as a local routine cache to know if a given flow
|
// ConntrackCache is used as a local routine cache to know if a given flow
|
||||||
// has been seen in the conntrack table.
|
// has been seen in the conntrack table. Keyed on PacketKey (dense form)
|
||||||
type ConntrackCache map[Packet]struct{}
|
// rather than Packet so the lookup hashes raw bytes instead of the
|
||||||
|
// unique.Handle each netip.Addr in Packet carries.
|
||||||
|
type ConntrackCache map[PacketKey]struct{}
|
||||||
|
|
||||||
type ConntrackCacheTicker struct {
|
type ConntrackCacheTicker struct {
|
||||||
cacheV uint64
|
cacheV uint64
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ func newFixedTicker(t *testing.T, l *slog.Logger, cacheLen int) *ConntrackCacheT
|
|||||||
cache: make(ConntrackCache, cacheLen),
|
cache: make(ConntrackCache, cacheLen),
|
||||||
}
|
}
|
||||||
for i := 0; i < cacheLen; i++ {
|
for i := 0; i < cacheLen; i++ {
|
||||||
c.cache[Packet{LocalPort: uint16(i) + 1}] = struct{}{}
|
c.cache[PacketKey{TransportTuple: TransportTuple{LocalPort: uint16(i) + 1}}] = struct{}{}
|
||||||
}
|
}
|
||||||
c.cacheTick.Store(1) // cacheV starts at 0, so Get() takes the reset path
|
c.cacheTick.Store(1) // cacheV starts at 0, so Get() takes the reset path
|
||||||
return c
|
return c
|
||||||
|
|||||||
@@ -19,14 +19,34 @@ const (
|
|||||||
PortFragment = -1 // Special value for matching `port: fragment`
|
PortFragment = -1 // Special value for matching `port: fragment`
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TransportTuple is the dense 5-tuple shape shared between the coalescer's
|
||||||
|
// flowKey-equivalent and the firewall's PacketKey. Stored in Local/Remote
|
||||||
|
// orientation so a flow's incoming and outgoing packets share the same
|
||||||
|
// tuple identity. v4 addresses occupy the low 4 bytes of LocalAddr/
|
||||||
|
// RemoteAddr (NOT v4-mapped form) so v4 vs v6 tuples never collide.
|
||||||
type TransportTuple struct {
|
type TransportTuple struct {
|
||||||
FirstAddr [16]byte
|
LocalAddr [16]byte
|
||||||
SecondAddr [16]byte
|
RemoteAddr [16]byte
|
||||||
FirstPort uint16
|
LocalPort uint16
|
||||||
SecondPort uint16
|
RemotePort uint16
|
||||||
IsV6 bool
|
IsV6 bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PacketKey is the firewall's conntrack and ConntrackCache map key — the
|
||||||
|
// dense form of the 5-tuple plus the protocol and fragment flag the
|
||||||
|
// firewall actually discriminates flows on. Kept separate from Packet so
|
||||||
|
// the conntrack-hit fast path doesn't pay for hashing the unique.Handle
|
||||||
|
// each netip.Addr carries, and so the inbound parser can skip the
|
||||||
|
// AddrFrom4/AddrFrom16 calls until rule matching actually needs them.
|
||||||
|
//
|
||||||
|
// Superset of the coalescer's flowKey shape (same 5-tuple, just in
|
||||||
|
// Local/Remote orientation rather than wire src/dst).
|
||||||
|
type PacketKey struct {
|
||||||
|
TransportTuple
|
||||||
|
Protocol uint8
|
||||||
|
Fragment bool
|
||||||
|
}
|
||||||
|
|
||||||
type Packet struct {
|
type Packet struct {
|
||||||
LocalAddr netip.Addr
|
LocalAddr netip.Addr
|
||||||
RemoteAddr netip.Addr
|
RemoteAddr netip.Addr
|
||||||
@@ -39,6 +59,51 @@ type Packet struct {
|
|||||||
Fragment bool
|
Fragment bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Key derives a PacketKey from a populated Packet. Used by the outgoing
|
||||||
|
// path (inside.go) which still parses into a full Packet via newPacket
|
||||||
|
// before the firewall check; the inbound path skips this hop entirely by
|
||||||
|
// having its parser write straight into the PacketKey.
|
||||||
|
func (fp *Packet) Key() PacketKey {
|
||||||
|
k := PacketKey{
|
||||||
|
Protocol: fp.Protocol,
|
||||||
|
Fragment: fp.Fragment,
|
||||||
|
}
|
||||||
|
k.LocalPort = fp.LocalPort
|
||||||
|
k.RemotePort = fp.RemotePort
|
||||||
|
k.IsV6 = !fp.LocalAddr.Is4()
|
||||||
|
if k.IsV6 {
|
||||||
|
k.LocalAddr = fp.LocalAddr.As16()
|
||||||
|
k.RemoteAddr = fp.RemoteAddr.As16()
|
||||||
|
} else {
|
||||||
|
v4 := fp.LocalAddr.As4()
|
||||||
|
copy(k.LocalAddr[:4], v4[:])
|
||||||
|
v4 = fp.RemoteAddr.As4()
|
||||||
|
copy(k.RemoteAddr[:4], v4[:])
|
||||||
|
}
|
||||||
|
return k
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hydrate fills fp's netip.Addr fields and copies the rest from k. Called
|
||||||
|
// by the firewall slow path when conntrack misses and rule matching needs
|
||||||
|
// the rich Packet form (CIDR lookups, family checks). The fast path skips
|
||||||
|
// this entirely.
|
||||||
|
func (k *PacketKey) Hydrate(fp *Packet) {
|
||||||
|
fp.LocalPort = k.LocalPort
|
||||||
|
fp.RemotePort = k.RemotePort
|
||||||
|
fp.Protocol = k.Protocol
|
||||||
|
fp.Fragment = k.Fragment
|
||||||
|
if k.IsV6 {
|
||||||
|
fp.LocalAddr = netip.AddrFrom16(k.LocalAddr)
|
||||||
|
fp.RemoteAddr = netip.AddrFrom16(k.RemoteAddr)
|
||||||
|
} else {
|
||||||
|
var v4 [4]byte
|
||||||
|
copy(v4[:], k.LocalAddr[:4])
|
||||||
|
fp.LocalAddr = netip.AddrFrom4(v4)
|
||||||
|
copy(v4[:], k.RemoteAddr[:4])
|
||||||
|
fp.RemoteAddr = netip.AddrFrom4(v4)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (fp *Packet) Copy() *Packet {
|
func (fp *Packet) Copy() *Packet {
|
||||||
return &Packet{
|
return &Packet{
|
||||||
LocalAddr: fp.LocalAddr,
|
LocalAddr: fp.LocalAddr,
|
||||||
|
|||||||
106
firewall_test.go
106
firewall_test.go
@@ -211,44 +211,44 @@ func TestFirewall_Drop(t *testing.T) {
|
|||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// Drop outbound
|
// Drop outbound
|
||||||
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
|
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p.Key(), &p, false, &h, cp, nil))
|
||||||
// Allow inbound
|
// Allow inbound
|
||||||
resetConntrack(fw)
|
resetConntrack(fw)
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil))
|
||||||
// Allow outbound because conntrack
|
// Allow outbound because conntrack
|
||||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
require.NoError(t, fw.Drop(p.Key(), &p, false, &h, cp, nil))
|
||||||
|
|
||||||
// test remote mismatch
|
// test remote mismatch
|
||||||
oldRemote := p.RemoteAddr
|
oldRemote := p.RemoteAddr
|
||||||
p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
|
p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
|
||||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
|
assert.Equal(t, fw.Drop(p.Key(), &p, false, &h, cp, nil), ErrInvalidRemoteIP)
|
||||||
p.RemoteAddr = oldRemote
|
p.RemoteAddr = oldRemote
|
||||||
|
|
||||||
// ensure signer doesn't get in the way of group checks
|
// ensure signer doesn't get in the way of group checks
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
|
||||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p.Key(), &p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
|
||||||
// test caSha doesn't drop on match
|
// test caSha doesn't drop on match
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil))
|
||||||
|
|
||||||
// ensure ca name doesn't get in the way of group checks
|
// ensure ca name doesn't get in the way of group checks
|
||||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
|
||||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p.Key(), &p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
|
||||||
// test caName doesn't drop on match
|
// test caName doesn't drop on match
|
||||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_DropV6(t *testing.T) {
|
func TestFirewall_DropV6(t *testing.T) {
|
||||||
@@ -289,44 +289,44 @@ func TestFirewall_DropV6(t *testing.T) {
|
|||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// Drop outbound
|
// Drop outbound
|
||||||
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
|
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p.Key(), &p, false, &h, cp, nil))
|
||||||
// Allow inbound
|
// Allow inbound
|
||||||
resetConntrack(fw)
|
resetConntrack(fw)
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil))
|
||||||
// Allow outbound because conntrack
|
// Allow outbound because conntrack
|
||||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
require.NoError(t, fw.Drop(p.Key(), &p, false, &h, cp, nil))
|
||||||
|
|
||||||
// test remote mismatch
|
// test remote mismatch
|
||||||
oldRemote := p.RemoteAddr
|
oldRemote := p.RemoteAddr
|
||||||
p.RemoteAddr = netip.MustParseAddr("fd12::56")
|
p.RemoteAddr = netip.MustParseAddr("fd12::56")
|
||||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
|
assert.Equal(t, fw.Drop(p.Key(), &p, false, &h, cp, nil), ErrInvalidRemoteIP)
|
||||||
p.RemoteAddr = oldRemote
|
p.RemoteAddr = oldRemote
|
||||||
|
|
||||||
// ensure signer doesn't get in the way of group checks
|
// ensure signer doesn't get in the way of group checks
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
|
||||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p.Key(), &p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
|
||||||
// test caSha doesn't drop on match
|
// test caSha doesn't drop on match
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil))
|
||||||
|
|
||||||
// ensure ca name doesn't get in the way of group checks
|
// ensure ca name doesn't get in the way of group checks
|
||||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
|
||||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p.Key(), &p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
|
|
||||||
// test caName doesn't drop on match
|
// test caName doesn't drop on match
|
||||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkFirewallTable_match(b *testing.B) {
|
func BenchmarkFirewallTable_match(b *testing.B) {
|
||||||
@@ -533,10 +533,10 @@ func TestFirewall_Drop2(t *testing.T) {
|
|||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// h1/c1 lacks the proper groups
|
// h1/c1 lacks the proper groups
|
||||||
require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
|
require.ErrorIs(t, fw.Drop(p.Key(), &p, true, &h1, cp, nil), ErrNoMatchingRule)
|
||||||
// c has the proper groups
|
// c has the proper groups
|
||||||
resetConntrack(fw)
|
resetConntrack(fw)
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop3(t *testing.T) {
|
func TestFirewall_Drop3(t *testing.T) {
|
||||||
@@ -613,18 +613,18 @@ func TestFirewall_Drop3(t *testing.T) {
|
|||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// c1 should pass because host match
|
// c1 should pass because host match
|
||||||
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
|
require.NoError(t, fw.Drop(p.Key(), &p, true, &h1, cp, nil))
|
||||||
// c2 should pass because ca sha match
|
// c2 should pass because ca sha match
|
||||||
resetConntrack(fw)
|
resetConntrack(fw)
|
||||||
require.NoError(t, fw.Drop(p, true, &h2, cp, nil))
|
require.NoError(t, fw.Drop(p.Key(), &p, true, &h2, cp, nil))
|
||||||
// c3 should fail because no match
|
// c3 should fail because no match
|
||||||
resetConntrack(fw)
|
resetConntrack(fw)
|
||||||
assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p.Key(), &p, true, &h3, cp, nil), ErrNoMatchingRule)
|
||||||
|
|
||||||
// Test a remote address match
|
// Test a remote address match
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", ""))
|
||||||
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
|
require.NoError(t, fw.Drop(p.Key(), &p, true, &h1, cp, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_Drop3V6(t *testing.T) {
|
func TestFirewall_Drop3V6(t *testing.T) {
|
||||||
@@ -661,7 +661,7 @@ func TestFirewall_Drop3V6(t *testing.T) {
|
|||||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", ""))
|
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", ""))
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_DropConntrackReload(t *testing.T) {
|
func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||||
@@ -702,12 +702,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
|
|
||||||
// Drop outbound
|
// Drop outbound
|
||||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p.Key(), &p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
// Allow inbound
|
// Allow inbound
|
||||||
resetConntrack(fw)
|
resetConntrack(fw)
|
||||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
require.NoError(t, fw.Drop(p.Key(), &p, true, &h, cp, nil))
|
||||||
// Allow outbound because conntrack
|
// Allow outbound because conntrack
|
||||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
require.NoError(t, fw.Drop(p.Key(), &p, false, &h, cp, nil))
|
||||||
|
|
||||||
oldFw := fw
|
oldFw := fw
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
@@ -716,7 +716,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||||
|
|
||||||
// Allow outbound because conntrack and new rules allow port 10
|
// Allow outbound because conntrack and new rules allow port 10
|
||||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
require.NoError(t, fw.Drop(p.Key(), &p, false, &h, cp, nil))
|
||||||
|
|
||||||
oldFw = fw
|
oldFw = fw
|
||||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||||
@@ -725,7 +725,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
|||||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||||
|
|
||||||
// Drop outbound because conntrack doesn't match new ruleset
|
// Drop outbound because conntrack doesn't match new ruleset
|
||||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p.Key(), &p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
||||||
@@ -770,12 +770,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
|||||||
p.LocalPort = 0
|
p.LocalPort = 0
|
||||||
p.RemotePort = 0
|
p.RemotePort = 0
|
||||||
// Drop outbound
|
// Drop outbound
|
||||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
// Allow inbound
|
// Allow inbound
|
||||||
resetConntrack(fw)
|
resetConntrack(fw)
|
||||||
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
|
require.NoError(t, fw.Drop(p.Key(), p, true, &h, cp, nil))
|
||||||
//now also allow outbound
|
//now also allow outbound
|
||||||
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
|
require.NoError(t, fw.Drop(p.Key(), p, false, &h, cp, nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("nonzero ports", func(t *testing.T) {
|
t.Run("nonzero ports", func(t *testing.T) {
|
||||||
@@ -783,12 +783,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
|||||||
p.LocalPort = 0xabcd
|
p.LocalPort = 0xabcd
|
||||||
p.RemotePort = 0x1234
|
p.RemotePort = 0x1234
|
||||||
// Drop outbound
|
// Drop outbound
|
||||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
// Allow inbound
|
// Allow inbound
|
||||||
resetConntrack(fw)
|
resetConntrack(fw)
|
||||||
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
|
require.NoError(t, fw.Drop(p.Key(), p, true, &h, cp, nil))
|
||||||
//now also allow outbound
|
//now also allow outbound
|
||||||
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
|
require.NoError(t, fw.Drop(p.Key(), p, false, &h, cp, nil))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -800,12 +800,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
|||||||
p.LocalPort = 0
|
p.LocalPort = 0
|
||||||
p.RemotePort = 0
|
p.RemotePort = 0
|
||||||
// Drop outbound
|
// Drop outbound
|
||||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
// Allow inbound
|
// Allow inbound
|
||||||
resetConntrack(fw)
|
resetConntrack(fw)
|
||||||
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p.Key(), p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
//now also allow outbound
|
//now also allow outbound
|
||||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("nonzero ports, still blocked", func(t *testing.T) {
|
t.Run("nonzero ports, still blocked", func(t *testing.T) {
|
||||||
@@ -813,12 +813,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
|||||||
p.LocalPort = 0xabcd
|
p.LocalPort = 0xabcd
|
||||||
p.RemotePort = 0x1234
|
p.RemotePort = 0x1234
|
||||||
// Drop outbound
|
// Drop outbound
|
||||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
// Allow inbound
|
// Allow inbound
|
||||||
resetConntrack(fw)
|
resetConntrack(fw)
|
||||||
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p.Key(), p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
//now also allow outbound
|
//now also allow outbound
|
||||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("nonzero, matching ports, still blocked", func(t *testing.T) {
|
t.Run("nonzero, matching ports, still blocked", func(t *testing.T) {
|
||||||
@@ -826,12 +826,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
|||||||
p.LocalPort = 80
|
p.LocalPort = 80
|
||||||
p.RemotePort = 80
|
p.RemotePort = 80
|
||||||
// Drop outbound
|
// Drop outbound
|
||||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
// Allow inbound
|
// Allow inbound
|
||||||
resetConntrack(fw)
|
resetConntrack(fw)
|
||||||
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p.Key(), p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||||
//now also allow outbound
|
//now also allow outbound
|
||||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
t.Run("Any proto, any port", func(t *testing.T) {
|
t.Run("Any proto, any port", func(t *testing.T) {
|
||||||
@@ -843,12 +843,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
|||||||
p.LocalPort = 0
|
p.LocalPort = 0
|
||||||
p.RemotePort = 0
|
p.RemotePort = 0
|
||||||
// Drop outbound
|
// Drop outbound
|
||||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
// Allow inbound
|
// Allow inbound
|
||||||
resetConntrack(fw)
|
resetConntrack(fw)
|
||||||
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
|
require.NoError(t, fw.Drop(p.Key(), p, true, &h, cp, nil))
|
||||||
//now also allow outbound
|
//now also allow outbound
|
||||||
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
|
require.NoError(t, fw.Drop(p.Key(), p, false, &h, cp, nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("nonzero ports, allowed", func(t *testing.T) {
|
t.Run("nonzero ports, allowed", func(t *testing.T) {
|
||||||
@@ -857,15 +857,15 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
|||||||
p.LocalPort = 0xabcd
|
p.LocalPort = 0xabcd
|
||||||
p.RemotePort = 0x1234
|
p.RemotePort = 0x1234
|
||||||
// Drop outbound
|
// Drop outbound
|
||||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
assert.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
// Allow inbound
|
// Allow inbound
|
||||||
resetConntrack(fw)
|
resetConntrack(fw)
|
||||||
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
|
require.NoError(t, fw.Drop(p.Key(), p, true, &h, cp, nil))
|
||||||
//now also allow outbound
|
//now also allow outbound
|
||||||
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
|
require.NoError(t, fw.Drop(p.Key(), p, false, &h, cp, nil))
|
||||||
//different ID is blocked
|
//different ID is blocked
|
||||||
p.RemotePort++
|
p.RemotePort++
|
||||||
require.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
require.Equal(t, fw.Drop(p.Key(), p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -913,7 +913,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
|
|||||||
Protocol: firewall.ProtoUDP,
|
Protocol: firewall.ProtoUDP,
|
||||||
Fragment: false,
|
Fragment: false,
|
||||||
}
|
}
|
||||||
assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
|
assert.Equal(t, fw.Drop(p.Key(), &p, true, &h1, cp, nil), ErrInvalidRemoteIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkLookup(b *testing.B) {
|
func BenchmarkLookup(b *testing.B) {
|
||||||
@@ -1327,7 +1327,7 @@ func (c *testcase) Test(t *testing.T, fw *Firewall) {
|
|||||||
t.Helper()
|
t.Helper()
|
||||||
cp := cert.NewCAPool()
|
cp := cert.NewCAPool()
|
||||||
resetConntrack(fw)
|
resetConntrack(fw)
|
||||||
err := fw.Drop(c.p, true, c.h, cp, nil)
|
err := fw.Drop(c.p.Key(), &c.p, true, c.h, cp, nil)
|
||||||
if c.err == nil {
|
if c.err == nil {
|
||||||
require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
|
require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
|
||||||
} else {
|
} else {
|
||||||
@@ -1519,6 +1519,6 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
|
|||||||
|
|
||||||
func resetConntrack(fw *Firewall) {
|
func resetConntrack(fw *Firewall) {
|
||||||
fw.Conntrack.Lock()
|
fw.Conntrack.Lock()
|
||||||
fw.Conntrack.Conns = map[firewall.Packet]*conn{}
|
fw.Conntrack.Conns = map[firewall.PacketKey]*conn{}
|
||||||
fw.Conntrack.Unlock()
|
fw.Conntrack.Unlock()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ func (f *Interface) consumeInsidePacket(pkt tio.Packet, fwPacket *firewall.Packe
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
|
dropReason := f.firewall.Drop(fwPacket.Key(), fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
|
||||||
if dropReason == nil {
|
if dropReason == nil {
|
||||||
f.sendInsideMessage(hostinfo, pkt, nb, sendBatch, rejectBuf, q)
|
f.sendInsideMessage(hostinfo, pkt, nb, sendBatch, rejectBuf, q)
|
||||||
} else {
|
} else {
|
||||||
@@ -400,7 +400,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check if packet is in outbound fw rules
|
// check if packet is in outbound fw rules
|
||||||
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
|
dropReason := f.firewall.Drop(fp.Key(), fp, false, hostinfo, f.pki.GetCAPool(), nil)
|
||||||
if dropReason != nil {
|
if dropReason != nil {
|
||||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
f.l.Debug("dropping cached packet",
|
f.l.Debug("dropping cached packet",
|
||||||
|
|||||||
13
outside.go
13
outside.go
@@ -570,10 +570,13 @@ func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, p
|
|||||||
applyOuterECN(out, meta.OuterECN, hostinfo, f.l)
|
applyOuterECN(out, meta.OuterECN, hostinfo, f.l)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Single IP+L4 walk feeds both the firewall (via fwPacket) and the
|
// Single IP+L4 walk feeds the firewall conntrack key (parsedRx.Key)
|
||||||
// batcher (via parsedRx). Replaces newPacket — the batcher's CommitInbound
|
// and the batcher hint (parsedRx.tcp/udp). Replaces newPacket — and
|
||||||
// uses parsedRx instead of re-walking the headers.
|
// pointedly does NOT fill fwPacket.LocalAddr/RemoteAddr, since
|
||||||
err := batch.ParseInbound(out, fwPacket, parsedRx)
|
// firewall.Drop's fast path uses Key alone and only hydrates fwPacket
|
||||||
|
// from Key on the slow path.
|
||||||
|
*fwPacket = firewall.Packet{}
|
||||||
|
err := batch.ParseInbound(out, parsedRx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).Warn("Error while validating inbound packet",
|
hostinfo.logger(f.l).Warn("Error while validating inbound packet",
|
||||||
"error", err,
|
"error", err,
|
||||||
@@ -582,7 +585,7 @@ func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, p
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
|
dropReason := f.firewall.Drop(parsedRx.Key, fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
|
||||||
if dropReason != nil {
|
if dropReason != nil {
|
||||||
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
||||||
// This gives us a buffer to build the reject packet in
|
// This gives us a buffer to build the reject packet in
|
||||||
|
|||||||
401
overlay/batch/inbound.go
Normal file
401
overlay/batch/inbound.go
Normal file
@@ -0,0 +1,401 @@
|
|||||||
|
package batch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/firewall"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IANA protocol numbers we recognise during the inbound parse. Kept local
|
||||||
|
// (rather than reaching for the firewall constants for every one of these)
|
||||||
|
// so the byte-comparison hot path doesn't depend on cross-package values.
|
||||||
|
const (
|
||||||
|
ipProtoICMP = 1
|
||||||
|
ipProtoIPv6Fragment = 44
|
||||||
|
ipProtoESP = 50
|
||||||
|
ipProtoAH = 51
|
||||||
|
ipProtoICMPv6 = 58
|
||||||
|
ipProtoNoNextHdr = 59
|
||||||
|
|
||||||
|
icmpv6TypeEchoRequest = 128
|
||||||
|
icmpv6TypeEchoReply = 129
|
||||||
|
)
|
||||||
|
|
||||||
|
// Inbound parse errors. Match outside.go's sentinel set so the unified
|
||||||
|
// parser can drop in as a replacement for newPacket without callers
|
||||||
|
// noticing a behavior change.
|
||||||
|
var (
|
||||||
|
ErrInboundPacketTooShort = errors.New("packet is too short")
|
||||||
|
ErrInboundUnknownIPVersion = errors.New("packet is an unknown ip version")
|
||||||
|
ErrInboundIPv4InvalidHdrLen = errors.New("invalid ipv4 header length")
|
||||||
|
ErrInboundIPv4TooShort = errors.New("ipv4 packet is too short")
|
||||||
|
ErrInboundIPv6TooShort = errors.New("ipv6 packet is too short")
|
||||||
|
ErrInboundIPv6NoPayload = errors.New("could not find payload in ipv6 packet")
|
||||||
|
)
|
||||||
|
|
||||||
|
// RxKind discriminates how an inbound plaintext packet should be committed
|
||||||
|
// after its firewall.Packet has been built. RxKindPassthrough means the
|
||||||
|
// IP shape is valid (firewall could match on it) but the coalescer's
|
||||||
|
// strict checks reject it — caller should still write it via the
|
||||||
|
// passthrough lane.
|
||||||
|
type RxKind uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
RxKindPassthrough RxKind = iota
|
||||||
|
RxKindTCP
|
||||||
|
RxKindUDP
|
||||||
|
)
|
||||||
|
|
||||||
|
// RxParsed is the unified result of one IP+L4 walk:
|
||||||
|
// - Key: the firewall's conntrack/cache lookup key. The dense form lets
|
||||||
|
// firewall.Drop hit conntrack without ever filling the rich Packet's
|
||||||
|
// netip.Addr fields. On a conntrack miss, Drop hydrates the caller's
|
||||||
|
// Packet from Key.
|
||||||
|
// - tcp/udp: the coalescer hint so commitParsed doesn't re-walk the
|
||||||
|
// headers. Meaningful only when Kind is RxKindTCP / RxKindUDP.
|
||||||
|
type RxParsed struct {
|
||||||
|
Kind RxKind
|
||||||
|
Key firewall.PacketKey
|
||||||
|
tcp parsedTCP
|
||||||
|
udp parsedUDP
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseInbound walks an inbound plaintext packet once and fills:
|
||||||
|
// - parsed.Key with the dense, Local/Remote-oriented conntrack key the
|
||||||
|
// firewall uses (replaces the netip.Addr-rich path through newPacket).
|
||||||
|
// - parsed.{tcp,udp} with the coalescer hint, when the shape is
|
||||||
|
// coalesce-eligible.
|
||||||
|
//
|
||||||
|
// Eligibility rules match the coalescer's own parseTCPBase/parseUDP:
|
||||||
|
// - IPv4 strict: IHL == 20, no fragmentation (MF or offset), proto TCP/UDP.
|
||||||
|
// - IPv6 strict: NextHeader is directly TCP or UDP (no extension headers).
|
||||||
|
//
|
||||||
|
// Returns the same set of errors newPacket returns for malformed input —
|
||||||
|
// callers can treat those as drop.
|
||||||
|
func ParseInbound(pkt []byte, parsed *RxParsed) error {
|
||||||
|
parsed.Kind = RxKindPassthrough
|
||||||
|
// Reset Key in full: v4 only writes the low 4 bytes of each address
|
||||||
|
// field, so without this a v6 call followed by a v4 reusing the same
|
||||||
|
// RxParsed would inherit the high 12 bytes — breaking the conntrack
|
||||||
|
// map equality for v4 flows.
|
||||||
|
parsed.Key = firewall.PacketKey{}
|
||||||
|
if len(pkt) < 1 {
|
||||||
|
return ErrInboundPacketTooShort
|
||||||
|
}
|
||||||
|
switch pkt[0] >> 4 {
|
||||||
|
case 4:
|
||||||
|
return parseInboundV4(pkt, parsed)
|
||||||
|
case 6:
|
||||||
|
return parseInboundV6(pkt, parsed)
|
||||||
|
}
|
||||||
|
return ErrInboundUnknownIPVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseInboundV4 mirrors parseV4(incoming=true) for the firewall side and
|
||||||
|
// also fills the coalescer hint when the shape is strict.
|
||||||
|
func parseInboundV4(pkt []byte, parsed *RxParsed) error {
|
||||||
|
if len(pkt) < 20 {
|
||||||
|
return ErrInboundIPv4TooShort
|
||||||
|
}
|
||||||
|
ihl := int(pkt[0]&0x0f) << 2
|
||||||
|
if ihl < 20 {
|
||||||
|
return ErrInboundIPv4InvalidHdrLen
|
||||||
|
}
|
||||||
|
flagsfrags := binary.BigEndian.Uint16(pkt[6:8])
|
||||||
|
parsed.Key.Fragment = (flagsfrags & 0x1FFF) != 0
|
||||||
|
parsed.Key.Protocol = pkt[9]
|
||||||
|
parsed.Key.IsV6 = false
|
||||||
|
|
||||||
|
// minFwPacketLen (4) is the L4-header prefix the firewall needs to pull
|
||||||
|
// ports; ICMP needs two extra bytes for the identifier.
|
||||||
|
minLen := ihl
|
||||||
|
if !parsed.Key.Fragment {
|
||||||
|
if parsed.Key.Protocol == firewall.ProtoICMP {
|
||||||
|
minLen += 4 + 2
|
||||||
|
} else {
|
||||||
|
minLen += 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(pkt) < minLen {
|
||||||
|
return ErrInboundIPv4InvalidHdrLen
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inbound orientation: wire src → Remote, wire dst → Local.
|
||||||
|
copy(parsed.Key.RemoteAddr[:4], pkt[12:16])
|
||||||
|
copy(parsed.Key.LocalAddr[:4], pkt[16:20])
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case parsed.Key.Fragment:
|
||||||
|
parsed.Key.RemotePort = 0
|
||||||
|
parsed.Key.LocalPort = 0
|
||||||
|
case parsed.Key.Protocol == firewall.ProtoICMP:
|
||||||
|
parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[ihl+4 : ihl+6])
|
||||||
|
parsed.Key.LocalPort = 0
|
||||||
|
default:
|
||||||
|
parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[ihl : ihl+2])
|
||||||
|
parsed.Key.LocalPort = binary.BigEndian.Uint16(pkt[ihl+2 : ihl+4])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Coalescer-eligible? Strict shape: IHL==20, no MF/offset, TCP or UDP.
|
||||||
|
if ihl != 20 || (flagsfrags&0x3FFF) != 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if parsed.Key.Protocol != ipProtoTCP && parsed.Key.Protocol != ipProtoUDP {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
totalLen := int(binary.BigEndian.Uint16(pkt[2:4]))
|
||||||
|
if totalLen > len(pkt) || totalLen < 20 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
pktTrim := pkt[:totalLen]
|
||||||
|
|
||||||
|
switch parsed.Key.Protocol {
|
||||||
|
case ipProtoTCP:
|
||||||
|
fillParsedTCPv4(pktTrim, parsed)
|
||||||
|
case ipProtoUDP:
|
||||||
|
fillParsedUDPv4(pktTrim, parsed)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// fillParsedTCPv4 fills parsed.tcp from a strict-shape IPv4+TCP packet
|
||||||
|
// already validated to have IHL==20 and to be totalLen-trimmed.
|
||||||
|
func fillParsedTCPv4(pkt []byte, parsed *RxParsed) {
|
||||||
|
if len(pkt) < 40 { // IPv4(20) + min TCP(20)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tcpOff := int(pkt[32]>>4) * 4
|
||||||
|
if tcpOff < 20 || tcpOff > 60 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(pkt) < 20+tcpOff {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p := &parsed.tcp
|
||||||
|
p.ipHdrLen = 20
|
||||||
|
p.tcpHdrLen = tcpOff
|
||||||
|
p.hdrLen = 20 + tcpOff
|
||||||
|
p.payLen = len(pkt) - p.hdrLen
|
||||||
|
p.seq = binary.BigEndian.Uint32(pkt[24:28])
|
||||||
|
p.flags = pkt[33]
|
||||||
|
p.fk.isV6 = false
|
||||||
|
p.fk.sport = parsed.Key.RemotePort
|
||||||
|
p.fk.dport = parsed.Key.LocalPort
|
||||||
|
copy(p.fk.src[:4], pkt[12:16])
|
||||||
|
copy(p.fk.dst[:4], pkt[16:20])
|
||||||
|
parsed.Kind = RxKindTCP
|
||||||
|
}
|
||||||
|
|
||||||
|
// fillParsedUDPv4 fills parsed.udp from a strict-shape IPv4+UDP packet.
|
||||||
|
func fillParsedUDPv4(pkt []byte, parsed *RxParsed) {
|
||||||
|
if len(pkt) < 28 { // IPv4(20) + UDP(8)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
udpLen := int(binary.BigEndian.Uint16(pkt[24:26]))
|
||||||
|
if udpLen < 8 || udpLen > len(pkt)-20 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p := &parsed.udp
|
||||||
|
p.ipHdrLen = 20
|
||||||
|
p.hdrLen = 28
|
||||||
|
p.payLen = udpLen - 8
|
||||||
|
p.fk.isV6 = false
|
||||||
|
p.fk.sport = parsed.Key.RemotePort
|
||||||
|
p.fk.dport = parsed.Key.LocalPort
|
||||||
|
copy(p.fk.src[:4], pkt[12:16])
|
||||||
|
copy(p.fk.dst[:4], pkt[16:20])
|
||||||
|
parsed.Kind = RxKindUDP
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseInboundV6 mirrors parseV6(incoming=true). The coalescer-eligible
|
||||||
|
// fast path triggers only when NextHeader is directly TCP or UDP — any
|
||||||
|
// extension header chain falls into the lenient walk below.
|
||||||
|
func parseInboundV6(pkt []byte, parsed *RxParsed) error {
|
||||||
|
if len(pkt) < 40 {
|
||||||
|
return ErrInboundIPv6TooShort
|
||||||
|
}
|
||||||
|
parsed.Key.IsV6 = true
|
||||||
|
copy(parsed.Key.RemoteAddr[:], pkt[8:24])
|
||||||
|
copy(parsed.Key.LocalAddr[:], pkt[24:40])
|
||||||
|
|
||||||
|
if proto := pkt[6]; proto == ipProtoTCP || proto == ipProtoUDP {
|
||||||
|
// Strict v6: ports are at the IP header end. Always fill key; only
|
||||||
|
// fill the coalescer hint if the L4 shape passes.
|
||||||
|
if len(pkt) < 44 {
|
||||||
|
return ErrInboundIPv6TooShort
|
||||||
|
}
|
||||||
|
parsed.Key.Protocol = proto
|
||||||
|
parsed.Key.Fragment = false
|
||||||
|
parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[40:42])
|
||||||
|
parsed.Key.LocalPort = binary.BigEndian.Uint16(pkt[42:44])
|
||||||
|
|
||||||
|
payloadLen := int(binary.BigEndian.Uint16(pkt[4:6]))
|
||||||
|
if 40+payloadLen > len(pkt) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
pktTrim := pkt[:40+payloadLen]
|
||||||
|
|
||||||
|
switch proto {
|
||||||
|
case ipProtoTCP:
|
||||||
|
fillParsedTCPv6(pktTrim, parsed)
|
||||||
|
case ipProtoUDP:
|
||||||
|
fillParsedUDPv6(pktTrim, parsed)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Slow path: walk extension header chain just like parseV6 does.
|
||||||
|
return walkInboundV6Headers(pkt, parsed)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fillParsedTCPv6(pkt []byte, parsed *RxParsed) {
|
||||||
|
if len(pkt) < 60 { // IPv6(40) + min TCP(20)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tcpOff := int(pkt[52]>>4) * 4
|
||||||
|
if tcpOff < 20 || tcpOff > 60 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(pkt) < 40+tcpOff {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p := &parsed.tcp
|
||||||
|
p.ipHdrLen = 40
|
||||||
|
p.tcpHdrLen = tcpOff
|
||||||
|
p.hdrLen = 40 + tcpOff
|
||||||
|
p.payLen = len(pkt) - p.hdrLen
|
||||||
|
p.seq = binary.BigEndian.Uint32(pkt[44:48])
|
||||||
|
p.flags = pkt[53]
|
||||||
|
p.fk.isV6 = true
|
||||||
|
p.fk.sport = parsed.Key.RemotePort
|
||||||
|
p.fk.dport = parsed.Key.LocalPort
|
||||||
|
copy(p.fk.src[:], pkt[8:24])
|
||||||
|
copy(p.fk.dst[:], pkt[24:40])
|
||||||
|
parsed.Kind = RxKindTCP
|
||||||
|
}
|
||||||
|
|
||||||
|
func fillParsedUDPv6(pkt []byte, parsed *RxParsed) {
|
||||||
|
if len(pkt) < 48 { // IPv6(40) + UDP(8)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
udpLen := int(binary.BigEndian.Uint16(pkt[44:46]))
|
||||||
|
if udpLen < 8 || udpLen > len(pkt)-40 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p := &parsed.udp
|
||||||
|
p.ipHdrLen = 40
|
||||||
|
p.hdrLen = 48
|
||||||
|
p.payLen = udpLen - 8
|
||||||
|
p.fk.isV6 = true
|
||||||
|
p.fk.sport = parsed.Key.RemotePort
|
||||||
|
p.fk.dport = parsed.Key.LocalPort
|
||||||
|
copy(p.fk.src[:], pkt[8:24])
|
||||||
|
copy(p.fk.dst[:], pkt[24:40])
|
||||||
|
parsed.Kind = RxKindUDP
|
||||||
|
}
|
||||||
|
|
||||||
|
// walkInboundV6Headers handles every IPv6 case parseV6 handles that isn't
|
||||||
|
// the strict "NextHeader == TCP/UDP" fast path: ESP, NoNextHeader, ICMPv6,
|
||||||
|
// fragment headers (first vs later), AH, generic extension headers.
|
||||||
|
// Coalescer eligibility is always RxKindPassthrough on this path (parsed
|
||||||
|
// already initialised that way).
|
||||||
|
func walkInboundV6Headers(pkt []byte, parsed *RxParsed) error {
|
||||||
|
dataLen := len(pkt)
|
||||||
|
protoAt := 6
|
||||||
|
offset := 40
|
||||||
|
next := 0
|
||||||
|
for {
|
||||||
|
if protoAt >= dataLen {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
proto := pkt[protoAt]
|
||||||
|
switch proto {
|
||||||
|
case ipProtoESP, ipProtoNoNextHdr:
|
||||||
|
parsed.Key.Protocol = proto
|
||||||
|
parsed.Key.RemotePort = 0
|
||||||
|
parsed.Key.LocalPort = 0
|
||||||
|
parsed.Key.Fragment = false
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case ipProtoICMPv6:
|
||||||
|
if dataLen < offset+6 {
|
||||||
|
return ErrInboundIPv6TooShort
|
||||||
|
}
|
||||||
|
parsed.Key.Protocol = proto
|
||||||
|
parsed.Key.LocalPort = 0
|
||||||
|
switch pkt[offset+1] {
|
||||||
|
case icmpv6TypeEchoRequest, icmpv6TypeEchoReply:
|
||||||
|
parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[offset+4 : offset+6])
|
||||||
|
default:
|
||||||
|
parsed.Key.RemotePort = 0
|
||||||
|
}
|
||||||
|
parsed.Key.Fragment = false
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case ipProtoTCP, ipProtoUDP:
|
||||||
|
// Reachable when an extension-header chain ends at TCP/UDP. The
|
||||||
|
// strict-eligible fast path above already handled the no-extension
|
||||||
|
// case; here we only fill firewall ports and stay passthrough.
|
||||||
|
if dataLen < offset+4 {
|
||||||
|
return ErrInboundIPv6TooShort
|
||||||
|
}
|
||||||
|
parsed.Key.Protocol = proto
|
||||||
|
parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[offset : offset+2])
|
||||||
|
parsed.Key.LocalPort = binary.BigEndian.Uint16(pkt[offset+2 : offset+4])
|
||||||
|
parsed.Key.Fragment = false
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case ipProtoIPv6Fragment:
|
||||||
|
if dataLen < offset+8 {
|
||||||
|
return ErrInboundIPv6TooShort
|
||||||
|
}
|
||||||
|
fragmentOffset := binary.BigEndian.Uint16(pkt[offset+2:offset+4]) &^ uint16(0x7)
|
||||||
|
if fragmentOffset != 0 {
|
||||||
|
// Non-first fragment: report the fragment flag and stop.
|
||||||
|
parsed.Key.Protocol = pkt[offset]
|
||||||
|
parsed.Key.Fragment = true
|
||||||
|
parsed.Key.RemotePort = 0
|
||||||
|
parsed.Key.LocalPort = 0
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
next = 8
|
||||||
|
|
||||||
|
case ipProtoAH:
|
||||||
|
if dataLen <= offset+1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
next = int(pkt[offset+1]+2) << 2
|
||||||
|
|
||||||
|
default:
|
||||||
|
if dataLen <= offset+1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
next = int(pkt[offset+1]+1) << 3
|
||||||
|
}
|
||||||
|
|
||||||
|
if next <= 0 {
|
||||||
|
next = 8
|
||||||
|
}
|
||||||
|
protoAt = offset
|
||||||
|
offset = offset + next
|
||||||
|
}
|
||||||
|
return ErrInboundIPv6NoPayload
|
||||||
|
}
|
||||||
|
|
||||||
|
// CommitInbound dispatches pkt to the appropriate lane using parsed.Kind,
|
||||||
|
// skipping the IP+L4 re-parse that MultiCoalescer.Commit would otherwise
|
||||||
|
// do. Borrowed slice contract is identical to MultiCoalescer.Commit.
|
||||||
|
func (m *MultiCoalescer) CommitInbound(pkt []byte, parsed *RxParsed) error {
|
||||||
|
switch parsed.Kind {
|
||||||
|
case RxKindTCP:
|
||||||
|
if m.tcp != nil {
|
||||||
|
return m.tcp.commitParsed(pkt, parsed.tcp)
|
||||||
|
}
|
||||||
|
case RxKindUDP:
|
||||||
|
if m.udp != nil {
|
||||||
|
return m.udp.commitParsed(pkt, parsed.udp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m.pt.Commit(pkt)
|
||||||
|
}
|
||||||
394
overlay/batch/inbound_bench_test.go
Normal file
394
overlay/batch/inbound_bench_test.go
Normal file
@@ -0,0 +1,394 @@
|
|||||||
|
package batch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/firewall"
|
||||||
|
)
|
||||||
|
|
||||||
|
// parseV4InboundBaseline mirrors what outside.go's parseV4(incoming=true)
|
||||||
|
// does, so the "split" bench measures the *current* state: firewall-side
|
||||||
|
// parse, then m.Commit re-parses inside the coalescer. Two walks per
|
||||||
|
// packet. Kept faithful in shape (one read per field, AddrFromSlice for
|
||||||
|
// the addrs) so the CPU profile matches the production parseV4.
|
||||||
|
func parseV4InboundBaseline(pkt []byte, fp *firewall.Packet) bool {
|
||||||
|
if len(pkt) < 20 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
ihl := int(pkt[0]&0x0f) << 2
|
||||||
|
if ihl < 20 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
flagsfrags := binary.BigEndian.Uint16(pkt[6:8])
|
||||||
|
fp.Fragment = (flagsfrags & 0x1FFF) != 0
|
||||||
|
fp.Protocol = pkt[9]
|
||||||
|
minLen := ihl
|
||||||
|
if !fp.Fragment {
|
||||||
|
if fp.Protocol == firewall.ProtoICMP {
|
||||||
|
minLen += 4 + 2
|
||||||
|
} else {
|
||||||
|
minLen += 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(pkt) < minLen {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
fp.RemoteAddr, _ = netip.AddrFromSlice(pkt[12:16])
|
||||||
|
fp.LocalAddr, _ = netip.AddrFromSlice(pkt[16:20])
|
||||||
|
switch {
|
||||||
|
case fp.Fragment:
|
||||||
|
fp.RemotePort = 0
|
||||||
|
fp.LocalPort = 0
|
||||||
|
case fp.Protocol == firewall.ProtoICMP:
|
||||||
|
fp.RemotePort = binary.BigEndian.Uint16(pkt[ihl+4 : ihl+6])
|
||||||
|
fp.LocalPort = 0
|
||||||
|
default:
|
||||||
|
fp.RemotePort = binary.BigEndian.Uint16(pkt[ihl : ihl+2])
|
||||||
|
fp.LocalPort = binary.BigEndian.Uint16(pkt[ihl+2 : ihl+4])
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseV6InboundBaseline is the v6 analogue: replicates parseV6's
|
||||||
|
// extension-header walk so the split bench captures its true cost.
|
||||||
|
func parseV6InboundBaseline(pkt []byte, fp *firewall.Packet) bool {
|
||||||
|
dataLen := len(pkt)
|
||||||
|
if dataLen < 40 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
fp.RemoteAddr, _ = netip.AddrFromSlice(pkt[8:24])
|
||||||
|
fp.LocalAddr, _ = netip.AddrFromSlice(pkt[24:40])
|
||||||
|
|
||||||
|
protoAt := 6
|
||||||
|
offset := 40
|
||||||
|
next := 0
|
||||||
|
for {
|
||||||
|
if protoAt >= dataLen {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
proto := pkt[protoAt]
|
||||||
|
switch proto {
|
||||||
|
case ipProtoESP, ipProtoNoNextHdr:
|
||||||
|
fp.Protocol = proto
|
||||||
|
fp.RemotePort = 0
|
||||||
|
fp.LocalPort = 0
|
||||||
|
fp.Fragment = false
|
||||||
|
return true
|
||||||
|
case ipProtoICMPv6:
|
||||||
|
if dataLen < offset+6 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
fp.Protocol = proto
|
||||||
|
fp.LocalPort = 0
|
||||||
|
switch pkt[offset+1] {
|
||||||
|
case icmpv6TypeEchoRequest, icmpv6TypeEchoReply:
|
||||||
|
fp.RemotePort = binary.BigEndian.Uint16(pkt[offset+4 : offset+6])
|
||||||
|
default:
|
||||||
|
fp.RemotePort = 0
|
||||||
|
}
|
||||||
|
fp.Fragment = false
|
||||||
|
return true
|
||||||
|
case ipProtoTCP, ipProtoUDP:
|
||||||
|
if dataLen < offset+4 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
fp.Protocol = proto
|
||||||
|
fp.RemotePort = binary.BigEndian.Uint16(pkt[offset : offset+2])
|
||||||
|
fp.LocalPort = binary.BigEndian.Uint16(pkt[offset+2 : offset+4])
|
||||||
|
fp.Fragment = false
|
||||||
|
return true
|
||||||
|
case ipProtoIPv6Fragment:
|
||||||
|
if dataLen < offset+8 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
fragmentOffset := binary.BigEndian.Uint16(pkt[offset+2:offset+4]) &^ uint16(0x7)
|
||||||
|
if fragmentOffset != 0 {
|
||||||
|
fp.Protocol = pkt[offset]
|
||||||
|
fp.Fragment = true
|
||||||
|
fp.RemotePort = 0
|
||||||
|
fp.LocalPort = 0
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
next = 8
|
||||||
|
case ipProtoAH:
|
||||||
|
if dataLen <= offset+1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
next = int(pkt[offset+1]+2) << 2
|
||||||
|
default:
|
||||||
|
if dataLen <= offset+1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
next = int(pkt[offset+1]+1) << 3
|
||||||
|
}
|
||||||
|
if next <= 0 {
|
||||||
|
next = 8
|
||||||
|
}
|
||||||
|
protoAt = offset
|
||||||
|
offset = offset + next
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// runRxSplit drives the split path: faithful inbound parse for the firewall
|
||||||
|
// side, then m.Commit re-parses to coalesce. v6 controls which baseline
|
||||||
|
// parser we run.
|
||||||
|
func runRxSplit(b *testing.B, pkts [][]byte, batchSize int, v6 bool) {
|
||||||
|
b.Helper()
|
||||||
|
m := NewMultiCoalescer(nopTunWriter{}, true, true)
|
||||||
|
var fp firewall.Packet
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.SetBytes(int64(len(pkts[0])))
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
pkt := pkts[i%len(pkts)]
|
||||||
|
var ok bool
|
||||||
|
if v6 {
|
||||||
|
ok = parseV6InboundBaseline(pkt, &fp)
|
||||||
|
} else {
|
||||||
|
ok = parseV4InboundBaseline(pkt, &fp)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
b.Fatal("baseline parse failed")
|
||||||
|
}
|
||||||
|
if err := m.Commit(pkt); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if (i+1)%batchSize == 0 {
|
||||||
|
if err := m.Flush(); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = m.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// runRxUnified drives the unified path: ParseInbound walks once, filling
|
||||||
|
// the conntrack key + coalescer hint in parsed; CommitInbound dispatches
|
||||||
|
// without re-parsing.
|
||||||
|
func runRxUnified(b *testing.B, pkts [][]byte, batchSize int) {
|
||||||
|
b.Helper()
|
||||||
|
m := NewMultiCoalescer(nopTunWriter{}, true, true)
|
||||||
|
var parsed RxParsed
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.SetBytes(int64(len(pkts[0])))
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
pkt := pkts[i%len(pkts)]
|
||||||
|
if err := ParseInbound(pkt, &parsed); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := m.CommitInbound(pkt, &parsed); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if (i+1)%batchSize == 0 {
|
||||||
|
if err := m.Flush(); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = m.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildUDPv4Bulk returns N UDP packets on a single 5-tuple suitable for the
|
||||||
|
// UDP coalescer's append path.
|
||||||
|
func buildUDPv4Bulk(n, payloadLen int) [][]byte {
|
||||||
|
pkts := make([][]byte, n)
|
||||||
|
pay := make([]byte, payloadLen)
|
||||||
|
for i := range n {
|
||||||
|
pkts[i] = buildUDPv4(1000, 53, pay)
|
||||||
|
}
|
||||||
|
return pkts
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildTCPv6Bulk(n, payloadLen int) [][]byte {
|
||||||
|
pkts := make([][]byte, n)
|
||||||
|
pay := make([]byte, payloadLen)
|
||||||
|
seq := uint32(1000)
|
||||||
|
for i := range n {
|
||||||
|
pkts[i] = buildTCPv6(0, seq, tcpAck, pay)
|
||||||
|
seq += uint32(payloadLen)
|
||||||
|
}
|
||||||
|
return pkts
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildICMPv4Bulk(n int) [][]byte {
|
||||||
|
pkts := make([][]byte, n)
|
||||||
|
for i := range pkts {
|
||||||
|
pkts[i] = buildICMPv4()
|
||||||
|
}
|
||||||
|
return pkts
|
||||||
|
}
|
||||||
|
|
||||||
|
// === TCPv4 ===
|
||||||
|
|
||||||
|
func BenchmarkRxSplitTCPv4(b *testing.B) {
|
||||||
|
pkts := buildTCPv4BulkFlow(tcpCoalesceMaxSegs, 1200)
|
||||||
|
runRxSplit(b, pkts, tcpCoalesceMaxSegs, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRxUnifiedTCPv4(b *testing.B) {
|
||||||
|
pkts := buildTCPv4BulkFlow(tcpCoalesceMaxSegs, 1200)
|
||||||
|
runRxUnified(b, pkts, tcpCoalesceMaxSegs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// === TCPv4 interleaved (4 flows) ===
|
||||||
|
|
||||||
|
func BenchmarkRxSplitTCPv4Interleaved4(b *testing.B) {
|
||||||
|
pkts := buildTCPv4Interleaved(4, tcpCoalesceMaxSegs, 1200)
|
||||||
|
runRxSplit(b, pkts, len(pkts), false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRxUnifiedTCPv4Interleaved4(b *testing.B) {
|
||||||
|
pkts := buildTCPv4Interleaved(4, tcpCoalesceMaxSegs, 1200)
|
||||||
|
runRxUnified(b, pkts, len(pkts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// === UDPv4 ===
|
||||||
|
|
||||||
|
func BenchmarkRxSplitUDPv4(b *testing.B) {
|
||||||
|
pkts := buildUDPv4Bulk(udpCoalesceMaxSegs, 1200)
|
||||||
|
runRxSplit(b, pkts, udpCoalesceMaxSegs, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRxUnifiedUDPv4(b *testing.B) {
|
||||||
|
pkts := buildUDPv4Bulk(udpCoalesceMaxSegs, 1200)
|
||||||
|
runRxUnified(b, pkts, udpCoalesceMaxSegs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// === TCPv6 ===
|
||||||
|
|
||||||
|
func BenchmarkRxSplitTCPv6(b *testing.B) {
|
||||||
|
pkts := buildTCPv6Bulk(tcpCoalesceMaxSegs, 1200)
|
||||||
|
runRxSplit(b, pkts, tcpCoalesceMaxSegs, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRxUnifiedTCPv6(b *testing.B) {
|
||||||
|
pkts := buildTCPv6Bulk(tcpCoalesceMaxSegs, 1200)
|
||||||
|
runRxUnified(b, pkts, tcpCoalesceMaxSegs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// === ICMPv4 (passthrough) — measures the unified parser on the coalescer-
|
||||||
|
// rejected path, where both lenient and unified must still fill fp. ===
|
||||||
|
|
||||||
|
func BenchmarkRxSplitICMPv4(b *testing.B) {
|
||||||
|
pkts := buildICMPv4Bulk(64)
|
||||||
|
runRxSplit(b, pkts, 64, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRxUnifiedICMPv4(b *testing.B) {
|
||||||
|
pkts := buildICMPv4Bulk(64)
|
||||||
|
runRxUnified(b, pkts, 64)
|
||||||
|
}
|
||||||
|
|
||||||
|
// === Firewall fast-path (conntrack-hit) — exercises the savings from the
|
||||||
|
// dense PacketKey: smaller hash key for the per-routine ConntrackCache,
|
||||||
|
// and skipping the AddrFrom4 calls that the old path needed to fill the
|
||||||
|
// netip.Addr-rich firewall.Packet up-front. ===
|
||||||
|
//
|
||||||
|
// The "split" baseline simulates the legacy path: parseV4InboundBaseline
|
||||||
|
// fills a netip.Addr-rich Packet, then we probe a localCache keyed on
|
||||||
|
// Packet. The "unified" path: ParseInbound fills only the dense PacketKey,
|
||||||
|
// and we probe a localCache keyed on PacketKey. Both paths follow with
|
||||||
|
// the coalescer Commit so the bench captures end-to-end RX-side cost.
|
||||||
|
|
||||||
|
// runRxSplitWithCache mirrors runRxSplit but runs the legacy-style
|
||||||
|
// firewall fast path (localCache keyed on firewall.Packet) on every
|
||||||
|
// packet so we can compare against the unified path.
|
||||||
|
func runRxSplitWithCache(b *testing.B, pkts [][]byte, batchSize int) {
|
||||||
|
b.Helper()
|
||||||
|
m := NewMultiCoalescer(nopTunWriter{}, true, true)
|
||||||
|
var fp firewall.Packet
|
||||||
|
|
||||||
|
// Pre-warm a per-packet cache keyed on the netip.Addr-rich Packet form.
|
||||||
|
cache := make(map[firewall.Packet]struct{}, len(pkts))
|
||||||
|
for _, pkt := range pkts {
|
||||||
|
var seedFp firewall.Packet
|
||||||
|
if !parseV4InboundBaseline(pkt, &seedFp) {
|
||||||
|
b.Fatal("seed parse failed")
|
||||||
|
}
|
||||||
|
cache[seedFp] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.SetBytes(int64(len(pkts[0])))
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
pkt := pkts[i%len(pkts)]
|
||||||
|
if !parseV4InboundBaseline(pkt, &fp) {
|
||||||
|
b.Fatal("baseline parse failed")
|
||||||
|
}
|
||||||
|
if _, ok := cache[fp]; !ok {
|
||||||
|
b.Fatal("cache miss")
|
||||||
|
}
|
||||||
|
if err := m.Commit(pkt); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if (i+1)%batchSize == 0 {
|
||||||
|
if err := m.Flush(); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = m.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// runRxUnifiedWithCache: unified path with a PacketKey-keyed localCache.
|
||||||
|
// Each iteration: ParseInbound → conntrack-cache hit → CommitInbound.
|
||||||
|
func runRxUnifiedWithCache(b *testing.B, pkts [][]byte, batchSize int) {
|
||||||
|
b.Helper()
|
||||||
|
m := NewMultiCoalescer(nopTunWriter{}, true, true)
|
||||||
|
var parsed RxParsed
|
||||||
|
|
||||||
|
cache := make(firewall.ConntrackCache, len(pkts))
|
||||||
|
for _, pkt := range pkts {
|
||||||
|
var seed RxParsed
|
||||||
|
if err := ParseInbound(pkt, &seed); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
cache[seed.Key] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.SetBytes(int64(len(pkts[0])))
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
pkt := pkts[i%len(pkts)]
|
||||||
|
if err := ParseInbound(pkt, &parsed); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, ok := cache[parsed.Key]; !ok {
|
||||||
|
b.Fatal("cache miss")
|
||||||
|
}
|
||||||
|
if err := m.CommitInbound(pkt, &parsed); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if (i+1)%batchSize == 0 {
|
||||||
|
if err := m.Flush(); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = m.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRxSplitTCPv4WithCache(b *testing.B) {
|
||||||
|
pkts := buildTCPv4BulkFlow(tcpCoalesceMaxSegs, 1200)
|
||||||
|
runRxSplitWithCache(b, pkts, tcpCoalesceMaxSegs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRxUnifiedTCPv4WithCache(b *testing.B) {
|
||||||
|
pkts := buildTCPv4BulkFlow(tcpCoalesceMaxSegs, 1200)
|
||||||
|
runRxUnifiedWithCache(b, pkts, tcpCoalesceMaxSegs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRxSplitInterleaved4WithCache(b *testing.B) {
|
||||||
|
pkts := buildTCPv4Interleaved(4, tcpCoalesceMaxSegs, 1200)
|
||||||
|
runRxSplitWithCache(b, pkts, len(pkts))
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRxUnifiedInterleaved4WithCache(b *testing.B) {
|
||||||
|
pkts := buildTCPv4Interleaved(4, tcpCoalesceMaxSegs, 1200)
|
||||||
|
runRxUnifiedWithCache(b, pkts, len(pkts))
|
||||||
|
}
|
||||||
174
overlay/batch/inbound_test.go
Normal file
174
overlay/batch/inbound_test.go
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
package batch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/firewall"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestParseInboundParity asserts that ParseInbound + Key.Hydrate produces
|
||||||
|
// the same firewall.Packet that the lenient baseline parsers (which
|
||||||
|
// mirror outside.go's parseV4/parseV6 with incoming=true) produce for
|
||||||
|
// every shape we care about. Catches drift between the unified
|
||||||
|
// parse-then-hydrate flow and the production newPacket behavior so
|
||||||
|
// swapping one for the other is observably safe.
|
||||||
|
func TestParseInboundParity(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
pkt []byte
|
||||||
|
v6 bool
|
||||||
|
}{
|
||||||
|
{"tcp_v4", buildTCPv4Ports(1234, 443, 1000, tcpAck, []byte("payload")), false},
|
||||||
|
{"tcp_v4_psh", buildTCPv4Ports(1234, 443, 2000, tcpAckPsh, make([]byte, 1200)), false},
|
||||||
|
{"udp_v4", buildUDPv4(40000, 53, []byte("dnsquery")), false},
|
||||||
|
{"icmp_v4", buildICMPv4(), false},
|
||||||
|
{"tcp_v6", buildTCPv6(0, 5000, tcpAck, make([]byte, 800)), true},
|
||||||
|
{"udp_v6", buildUDPv6(40001, 53, []byte("v6dns")), true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
var fpUnified, fpBaseline firewall.Packet
|
||||||
|
var parsed RxParsed
|
||||||
|
|
||||||
|
if err := ParseInbound(tc.pkt, &parsed); err != nil {
|
||||||
|
t.Fatalf("ParseInbound: %v", err)
|
||||||
|
}
|
||||||
|
parsed.Key.Hydrate(&fpUnified)
|
||||||
|
var ok bool
|
||||||
|
if tc.v6 {
|
||||||
|
ok = parseV6InboundBaseline(tc.pkt, &fpBaseline)
|
||||||
|
} else {
|
||||||
|
ok = parseV4InboundBaseline(tc.pkt, &fpBaseline)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("baseline parse failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if fpUnified != fpBaseline {
|
||||||
|
t.Errorf("firewall.Packet mismatch:\n unified: %+v\n baseline: %+v", fpUnified, fpBaseline)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestParseInboundFlowKey checks that the coalescer hint the unified parser
|
||||||
|
// produces matches what parseTCPBase/parseUDP would produce on the same
|
||||||
|
// packet — same flowKey, ipHdrLen, payLen, etc. The hint is only valid
|
||||||
|
// when Kind is RxKindTCP/RxKindUDP.
|
||||||
|
func TestParseInboundFlowKey(t *testing.T) {
|
||||||
|
t.Run("tcp_v4", func(t *testing.T) {
|
||||||
|
pkt := buildTCPv4Ports(1234, 443, 5000, tcpAck, make([]byte, 800))
|
||||||
|
var parsed RxParsed
|
||||||
|
if err := ParseInbound(pkt, &parsed); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if parsed.Kind != RxKindTCP {
|
||||||
|
t.Fatalf("kind=%v want TCP", parsed.Kind)
|
||||||
|
}
|
||||||
|
ref, ok := parseTCPBase(pkt)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("parseTCPBase failed")
|
||||||
|
}
|
||||||
|
if parsed.tcp != ref {
|
||||||
|
t.Errorf("parsedTCP mismatch:\n unified: %+v\n ref: %+v", parsed.tcp, ref)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("udp_v4", func(t *testing.T) {
|
||||||
|
pkt := buildUDPv4(40000, 53, []byte("dnsquery"))
|
||||||
|
var parsed RxParsed
|
||||||
|
if err := ParseInbound(pkt, &parsed); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if parsed.Kind != RxKindUDP {
|
||||||
|
t.Fatalf("kind=%v want UDP", parsed.Kind)
|
||||||
|
}
|
||||||
|
ref, ok := parseUDP(pkt)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("parseUDP failed")
|
||||||
|
}
|
||||||
|
if parsed.udp != ref {
|
||||||
|
t.Errorf("parsedUDP mismatch:\n unified: %+v\n ref: %+v", parsed.udp, ref)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("tcp_v6", func(t *testing.T) {
|
||||||
|
pkt := buildTCPv6(0, 9000, tcpAck, make([]byte, 800))
|
||||||
|
var parsed RxParsed
|
||||||
|
if err := ParseInbound(pkt, &parsed); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if parsed.Kind != RxKindTCP {
|
||||||
|
t.Fatalf("kind=%v want TCP", parsed.Kind)
|
||||||
|
}
|
||||||
|
ref, ok := parseTCPBase(pkt)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("parseTCPBase failed")
|
||||||
|
}
|
||||||
|
if parsed.tcp != ref {
|
||||||
|
t.Errorf("parsedTCP mismatch:\n unified: %+v\n ref: %+v", parsed.tcp, ref)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestParseInboundICMPPassthrough confirms ICMP packets populate the
|
||||||
|
// conntrack key (including the ICMP identifier in RemotePort) but stay
|
||||||
|
// RxKindPassthrough so the batcher writes them verbatim. After Hydrate
|
||||||
|
// the firewall.Packet form should match what the legacy parseV4 produced.
|
||||||
|
func TestParseInboundICMPPassthrough(t *testing.T) {
|
||||||
|
pkt := buildICMPv4()
|
||||||
|
// Stamp a non-zero identifier into the ICMP header so we can check
|
||||||
|
// RemotePort gets it.
|
||||||
|
pkt[20] = 8 // type=echo
|
||||||
|
pkt[24] = 0xab
|
||||||
|
pkt[25] = 0xcd
|
||||||
|
|
||||||
|
var parsed RxParsed
|
||||||
|
if err := ParseInbound(pkt, &parsed); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if parsed.Kind != RxKindPassthrough {
|
||||||
|
t.Errorf("kind=%v want Passthrough", parsed.Kind)
|
||||||
|
}
|
||||||
|
var fp firewall.Packet
|
||||||
|
parsed.Key.Hydrate(&fp)
|
||||||
|
if fp.Protocol != firewall.ProtoICMP {
|
||||||
|
t.Errorf("Protocol=%d want %d", fp.Protocol, firewall.ProtoICMP)
|
||||||
|
}
|
||||||
|
if fp.RemotePort != 0xabcd {
|
||||||
|
t.Errorf("RemotePort=0x%x want 0xabcd", fp.RemotePort)
|
||||||
|
}
|
||||||
|
if fp.LocalPort != 0 {
|
||||||
|
t.Errorf("LocalPort=%d want 0", fp.LocalPort)
|
||||||
|
}
|
||||||
|
wantRemote := netip.MustParseAddr("10.0.0.1")
|
||||||
|
wantLocal := netip.MustParseAddr("10.0.0.2")
|
||||||
|
if fp.RemoteAddr != wantRemote || fp.LocalAddr != wantLocal {
|
||||||
|
t.Errorf("addrs: remote=%v local=%v want %v/%v", fp.RemoteAddr, fp.LocalAddr, wantRemote, wantLocal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestParseInboundV4Fragment confirms a fragmented v4 packet fills the
|
||||||
|
// conntrack key with Fragment=true and falls into Passthrough on the
|
||||||
|
// coalescer side.
|
||||||
|
func TestParseInboundV4Fragment(t *testing.T) {
|
||||||
|
// Build a TCP packet then twiddle the IP flags to make it look like a
|
||||||
|
// non-first fragment (offset != 0).
|
||||||
|
pkt := buildTCPv4Ports(1234, 443, 1000, tcpAck, []byte("payload"))
|
||||||
|
// Set a non-zero fragment offset (bytes 6-7, low 13 bits).
|
||||||
|
pkt[6] = 0x00
|
||||||
|
pkt[7] = 0x10 // offset = 16 (in 8-byte units)
|
||||||
|
|
||||||
|
var parsed RxParsed
|
||||||
|
if err := ParseInbound(pkt, &parsed); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !parsed.Key.Fragment {
|
||||||
|
t.Error("Fragment=false, want true")
|
||||||
|
}
|
||||||
|
if parsed.Kind != RxKindPassthrough {
|
||||||
|
t.Errorf("kind=%v want Passthrough", parsed.Kind)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user