This commit is contained in:
JackDoan
2026-05-07 13:18:41 -05:00
parent 01b31360df
commit 400cbc26a1
9 changed files with 145 additions and 281 deletions

View File

@@ -437,8 +437,8 @@ func (f *Firewall) Drop(key firewall.PacketKey, fp *firewall.Packet, incoming bo
// Conntrack miss → rule matching needs the rich Packet form. Hydrate // Conntrack miss → rule matching needs the rich Packet form. Hydrate
// from the key if the caller passed a zero-valued fp (the inbound path // from the key if the caller passed a zero-valued fp (the inbound path
// after ParseInbound). Outbound callers fill fp via newPacket and skip // after batch.ParsePacket). Outbound callers Hydrate themselves and
// this hop. // skip this hop.
if !fp.LocalAddr.IsValid() { if !fp.LocalAddr.IsValid() {
key.Hydrate(fp) key.Hydrate(fp)
} }

View File

@@ -50,10 +50,10 @@ type Packet struct {
Fragment bool Fragment bool
} }
// Key derives a PacketKey from a populated Packet. Used by the outgoing // Key derives a PacketKey from a populated Packet. Used by the few code
// path (inside.go) which still parses into a full Packet via newPacket // paths that have a Packet but no Key in hand (e.g. tests). Both inbound
// before the firewall check; the inbound path skips this hop entirely by // and outbound production parsers write straight into a PacketKey via
// having its parser write straight into the PacketKey. // batch.ParsePacket, so this function is rarely on the hot path.
func (fp *Packet) Key() PacketKey { func (fp *Packet) Key() PacketKey {
k := PacketKey{ k := PacketKey{
Protocol: fp.Protocol, Protocol: fp.Protocol,

View File

@@ -25,11 +25,11 @@ func (f *Interface) consumeInsidePacket(pkt tio.Packet, fwPacket *firewall.Packe
// //
// pkt.Bytes is either one IP datagram (GSO zero) or a TSO/USO // pkt.Bytes is either one IP datagram (GSO zero) or a TSO/USO
// superpacket. In both cases the L3+L4 headers at the start describe // superpacket. In both cases the L3+L4 headers at the start describe
// the same 5-tuple every segment will share, so a single newPacket / // the same 5-tuple every segment will share, so a single parse +
// firewall check covers the whole superpacket. // firewall check covers the whole superpacket.
packet := pkt.Bytes packet := pkt.Bytes
key, err := newPacketKey(packet, false) var parsed batch.RxParsed
if err != nil { if err := batch.ParsePacket(packet, false, &parsed); err != nil {
if f.l.Enabled(context.Background(), slog.LevelDebug) { if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Error while validating outbound packet", f.l.Debug("Error while validating outbound packet",
"packet", packet, "packet", packet,
@@ -39,7 +39,7 @@ func (f *Interface) consumeInsidePacket(pkt tio.Packet, fwPacket *firewall.Packe
return return
} }
key.Hydrate(fwPacket) parsed.Key.Hydrate(fwPacket)
// Ignore local broadcast packets // Ignore local broadcast packets
if f.dropLocalBroadcast { if f.dropLocalBroadcast {
@@ -107,7 +107,7 @@ func (f *Interface) consumeInsidePacket(pkt tio.Packet, fwPacket *firewall.Packe
return return
} }
dropReason := f.firewall.Drop(key, fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) dropReason := f.firewall.Drop(parsed.Key, fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason == nil { if dropReason == nil {
f.sendInsideMessage(hostinfo, pkt, nb, sendBatch, rejectBuf, q) f.sendInsideMessage(hostinfo, pkt, nb, sendBatch, rejectBuf, q)
} else { } else {
@@ -394,16 +394,16 @@ func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cac
} }
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
key, err := newPacketKey(p, false) var parsed batch.RxParsed
if err != nil { if err := batch.ParsePacket(p, false, &parsed); err != nil {
f.l.Warn("error while parsing outgoing packet for firewall check", "error", err) f.l.Warn("error while parsing outgoing packet for firewall check", "error", err)
return return
} }
fp := &firewall.Packet{} fp := &firewall.Packet{}
key.Hydrate(fp) parsed.Key.Hydrate(fp)
// check if packet is in outbound fw rules // check if packet is in outbound fw rules
dropReason := f.firewall.Drop(key, fp, false, hostinfo, f.pki.GetCAPool(), nil) dropReason := f.firewall.Drop(parsed.Key, fp, false, hostinfo, f.pki.GetCAPool(), nil)
if dropReason != nil { if dropReason != nil {
if f.l.Enabled(context.Background(), slog.LevelDebug) { if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("dropping cached packet", f.l.Debug("dropping cached packet",

View File

@@ -2,24 +2,15 @@ package nebula
import ( import (
"context" "context"
"encoding/binary"
"errors" "errors"
"log/slog" "log/slog"
"net/netip" "net/netip"
"time" "time"
"github.com/google/gopacket/layers"
"golang.org/x/net/ipv6"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/overlay/batch" "github.com/slackhq/nebula/overlay/batch"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"golang.org/x/net/ipv4"
)
const (
minFwPacketLen = 4
) )
var ErrOutOfWindow = errors.New("out of window packet") var ErrOutOfWindow = errors.New("out of window packet")
@@ -318,181 +309,11 @@ var (
// parse logic. Callers that don't need the netip.Addr-rich form (e.g. // parse logic. Callers that don't need the netip.Addr-rich form (e.g.
// conntrack-only paths) should use newPacketKey directly. // conntrack-only paths) should use newPacketKey directly.
func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
key, err := newPacketKey(data, incoming) var parsed batch.RxParsed
if err != nil { if err := batch.ParsePacket(data, incoming, &parsed); err != nil {
return err return err
} }
key.Hydrate(fp) parsed.Key.Hydrate(fp)
return nil
}
// newPacketKey parses data into a dense firewall.PacketKey. Hot path: no
// netip.Addr construction, no unique.Handle interning. Caller decides
// whether to also Hydrate to a Packet (for rule matching) or pass the key
// straight to conntrack.
func newPacketKey(data []byte, incoming bool) (firewall.PacketKey, error) {
var k firewall.PacketKey
if len(data) < 1 {
return k, ErrPacketTooShort
}
switch int((data[0] >> 4) & 0x0f) {
case ipv4.Version:
return k, parseV4Key(data, incoming, &k)
case ipv6.Version:
k.IsV6 = true
return k, parseV6Key(data, incoming, &k)
}
return k, ErrUnknownIPVersion
}
func parseV6Key(data []byte, incoming bool, k *firewall.PacketKey) error {
dataLen := len(data)
if dataLen < ipv6.HeaderLen {
return ErrIPv6PacketTooShort
}
if incoming {
copy(k.RemoteAddr[:], data[8:24])
copy(k.LocalAddr[:], data[24:40])
} else {
copy(k.LocalAddr[:], data[8:24])
copy(k.RemoteAddr[:], data[24:40])
}
protoAt := 6
offset := ipv6.HeaderLen
next := 0
for {
if protoAt >= dataLen {
break
}
proto := layers.IPProtocol(data[protoAt])
switch proto {
case layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
k.Protocol = uint8(proto)
k.RemotePort = 0
k.LocalPort = 0
k.Fragment = false
return nil
case layers.IPProtocolICMPv6:
if dataLen < offset+6 {
return ErrIPv6PacketTooShort
}
k.Protocol = uint8(proto)
k.LocalPort = 0
switch data[offset+1] {
case layers.ICMPv6TypeEchoRequest, layers.ICMPv6TypeEchoReply:
k.RemotePort = binary.BigEndian.Uint16(data[offset+4 : offset+6])
default:
k.RemotePort = 0
}
k.Fragment = false
return nil
case layers.IPProtocolTCP, layers.IPProtocolUDP:
if dataLen < offset+4 {
return ErrIPv6PacketTooShort
}
k.Protocol = uint8(proto)
if incoming {
k.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
k.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
} else {
k.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
k.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
}
k.Fragment = false
return nil
case layers.IPProtocolIPv6Fragment:
if dataLen < offset+8 {
return ErrIPv6PacketTooShort
}
fragmentOffset := binary.BigEndian.Uint16(data[offset+2:offset+4]) &^ uint16(0x7)
if fragmentOffset != 0 {
k.Protocol = data[offset]
k.Fragment = true
k.RemotePort = 0
k.LocalPort = 0
return nil
}
next = 8
case layers.IPProtocolAH:
if dataLen <= offset+1 {
break
}
next = int(data[offset+1]+2) << 2
default:
if dataLen <= offset+1 {
break
}
next = int(data[offset+1]+1) << 3
}
if next <= 0 {
next = 8
}
protoAt = offset
offset = offset + next
}
return ErrIPv6CouldNotFindPayload
}
func parseV4Key(data []byte, incoming bool, k *firewall.PacketKey) error {
if len(data) < ipv4.HeaderLen {
return ErrIPv4PacketTooShort
}
ihl := int(data[0]&0x0f) << 2
if ihl < ipv4.HeaderLen {
return ErrIPv4InvalidHeaderLength
}
flagsfrags := binary.BigEndian.Uint16(data[6:8])
k.Fragment = (flagsfrags & 0x1FFF) != 0
k.Protocol = data[9]
minLen := ihl
if !k.Fragment {
if k.Protocol == firewall.ProtoICMP {
minLen += minFwPacketLen + 2
} else {
minLen += minFwPacketLen
}
}
if len(data) < minLen {
return ErrIPv4InvalidHeaderLength
}
// Dense form: v4 in low 4 bytes, rest zero. Matches the coalescer's
// flowKey convention so the two stay byte-identical for the same flow.
if incoming {
copy(k.RemoteAddr[:4], data[12:16])
copy(k.LocalAddr[:4], data[16:20])
} else {
copy(k.LocalAddr[:4], data[12:16])
copy(k.RemoteAddr[:4], data[16:20])
}
switch {
case k.Fragment:
k.RemotePort = 0
k.LocalPort = 0
case k.Protocol == firewall.ProtoICMP:
k.RemotePort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6])
k.LocalPort = 0
case incoming:
k.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2])
k.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
default:
k.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2])
k.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
}
return nil return nil
} }
@@ -571,7 +392,7 @@ func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, p
// firewall.Drop's fast path uses Key alone and only hydrates fwPacket // firewall.Drop's fast path uses Key alone and only hydrates fwPacket
// from Key on the slow path. // from Key on the slow path.
*fwPacket = firewall.Packet{} *fwPacket = firewall.Packet{}
err := batch.ParseInbound(out, parsedRx) err := batch.ParsePacket(out, true, parsedRx)
if err != nil { if err != nil {
hostinfo.logger(f.l).Warn("Error while validating inbound packet", hostinfo.logger(f.l).Warn("Error while validating inbound packet",
"error", err, "error", err,

View File

@@ -11,6 +11,7 @@ import (
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/overlay/batch"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
@@ -21,13 +22,13 @@ func Test_newPacket(t *testing.T) {
// length fails // length fails
err := newPacket([]byte{}, true, p) err := newPacket([]byte{}, true, p)
require.ErrorIs(t, err, ErrPacketTooShort) require.ErrorIs(t, err, batch.ErrPacketTooShort)
err = newPacket([]byte{0x40}, true, p) err = newPacket([]byte{0x40}, true, p)
require.ErrorIs(t, err, ErrIPv4PacketTooShort) require.ErrorIs(t, err, batch.ErrIPv4PacketTooShort)
err = newPacket([]byte{0x60}, true, p) err = newPacket([]byte{0x60}, true, p)
require.ErrorIs(t, err, ErrIPv6PacketTooShort) require.ErrorIs(t, err, batch.ErrIPv6PacketTooShort)
// length fail with ip options // length fail with ip options
h := ipv4.Header{ h := ipv4.Header{
@@ -40,15 +41,15 @@ func Test_newPacket(t *testing.T) {
b, _ := h.Marshal() b, _ := h.Marshal()
err = newPacket(b, true, p) err = newPacket(b, true, p)
require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) require.ErrorIs(t, err, batch.ErrIPv4InvalidHeaderLength)
// not an ipv4 packet // not an ipv4 packet
err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
require.ErrorIs(t, err, ErrUnknownIPVersion) require.ErrorIs(t, err, batch.ErrUnknownIPVersion)
// invalid ihl // invalid ihl
err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) require.ErrorIs(t, err, batch.ErrIPv4InvalidHeaderLength)
// account for variable ip header length - incoming // account for variable ip header length - incoming
h = ipv4.Header{ h = ipv4.Header{
@@ -115,7 +116,7 @@ func Test_newPacket_v6(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
err = newPacket(buffer.Bytes(), true, p) err = newPacket(buffer.Bytes(), true, p)
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) require.ErrorIs(t, err, batch.ErrIPv6CouldNotFindPayload)
// A v6 packet with a hop-by-hop extension // A v6 packet with a hop-by-hop extension
// ICMPv6 Payload (Echo Request) // ICMPv6 Payload (Echo Request)
@@ -149,12 +150,12 @@ func Test_newPacket_v6(t *testing.T) {
// A full IPv6 header and 1 byte in the first extension, but missing // A full IPv6 header and 1 byte in the first extension, but missing
// the length byte. // the length byte.
err = newPacket(buffer.Bytes()[:41], true, p) err = newPacket(buffer.Bytes()[:41], true, p)
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) require.ErrorIs(t, err, batch.ErrIPv6CouldNotFindPayload)
// A full IPv6 header plus 1 full extension, but only 1 byte of the // A full IPv6 header plus 1 full extension, but only 1 byte of the
// next layer, missing length byte // next layer, missing length byte
err = newPacket(buffer.Bytes()[:49], true, p) err = newPacket(buffer.Bytes()[:49], true, p)
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) require.ErrorIs(t, err, batch.ErrIPv6CouldNotFindPayload)
err = nil err = nil
// A good ICMP packet // A good ICMP packet
@@ -217,7 +218,7 @@ func Test_newPacket_v6(t *testing.T) {
b = buffer.Bytes() b = buffer.Bytes()
b[6] = 255 // 255 is a reserved protocol number b[6] = 255 // 255 is a reserved protocol number
err = newPacket(b, true, p) err = newPacket(b, true, p)
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) require.ErrorIs(t, err, batch.ErrIPv6CouldNotFindPayload)
// A good UDP packet // A good UDP packet
ip = layers.IPv6{ ip = layers.IPv6{
@@ -264,7 +265,7 @@ func Test_newPacket_v6(t *testing.T) {
// Too short UDP packet // Too short UDP packet
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
require.ErrorIs(t, err, ErrIPv6PacketTooShort) require.ErrorIs(t, err, batch.ErrIPv6PacketTooShort)
// A good TCP packet // A good TCP packet
b[6] = byte(layers.IPProtocolTCP) b[6] = byte(layers.IPProtocolTCP)
@@ -291,7 +292,7 @@ func Test_newPacket_v6(t *testing.T) {
// Too short TCP packet // Too short TCP packet
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
require.ErrorIs(t, err, ErrIPv6PacketTooShort) require.ErrorIs(t, err, batch.ErrIPv6PacketTooShort)
// A good UDP packet with an AH header // A good UDP packet with an AH header
ip = layers.IPv6{ ip = layers.IPv6{
@@ -336,12 +337,12 @@ func Test_newPacket_v6(t *testing.T) {
// Ensure buffer bounds checking during processing // Ensure buffer bounds checking during processing
err = newPacket(b[:41], true, p) err = newPacket(b[:41], true, p)
require.ErrorIs(t, err, ErrIPv6PacketTooShort) require.ErrorIs(t, err, batch.ErrIPv6PacketTooShort)
// Invalid AH header // Invalid AH header
b = buffer.Bytes() b = buffer.Bytes()
err = newPacket(b, true, p) err = newPacket(b, true, p)
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) require.ErrorIs(t, err, batch.ErrIPv6CouldNotFindPayload)
} }
func Test_newPacket_ipv6Fragment(t *testing.T) { func Test_newPacket_ipv6Fragment(t *testing.T) {
@@ -448,7 +449,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
// Too short of a fragment packet // Too short of a fragment packet
err = newPacket(secondFrag[:len(secondFrag)-10], false, p) err = newPacket(secondFrag[:len(secondFrag)-10], false, p)
require.ErrorIs(t, err, ErrIPv6PacketTooShort) require.ErrorIs(t, err, batch.ErrIPv6PacketTooShort)
} }
func BenchmarkParseV6(b *testing.B) { func BenchmarkParseV6(b *testing.B) {

View File

@@ -7,9 +7,9 @@ type RxBatcher interface {
Reserve(sz int) []byte Reserve(sz int) []byte
// Commit borrows pkt. The caller must keep pkt valid until the next Flush. // Commit borrows pkt. The caller must keep pkt valid until the next Flush.
// Walks IP+L4 headers itself; prefer CommitInbound when the caller already // Walks IP+L4 headers itself; prefer CommitInbound when the caller already
// has an RxParsed in hand from ParseInbound. // has an RxParsed in hand from ParsePacket.
Commit(pkt []byte) error Commit(pkt []byte) error
// CommitInbound is Commit with a hint produced by ParseInbound, so the // CommitInbound is Commit with a hint produced by ParsePacket, so the
// batcher can skip the IP+L4 re-parse. Borrowed slice contract is the // batcher can skip the IP+L4 re-parse. Borrowed slice contract is the
// same as Commit. Implementations that don't coalesce may delegate to // same as Commit. Implementations that don't coalesce may delegate to
// Commit. // Commit.

View File

@@ -22,16 +22,16 @@ const (
icmpv6TypeEchoReply = 129 icmpv6TypeEchoReply = 129
) )
// Inbound parse errors. Match outside.go's sentinel set so the unified // Packet parse errors — the canonical sentinel set for IP+L4 parsing.
// parser can drop in as a replacement for newPacket without callers // Both inbound and outbound callers share this surface, so any code path
// noticing a behavior change. // that ends up at firewall.PacketKey reports drops with the same errors.
var ( var (
ErrInboundPacketTooShort = errors.New("packet is too short") ErrPacketTooShort = errors.New("packet is too short")
ErrInboundUnknownIPVersion = errors.New("packet is an unknown ip version") ErrUnknownIPVersion = errors.New("packet is an unknown ip version")
ErrInboundIPv4InvalidHdrLen = errors.New("invalid ipv4 header length") ErrIPv4InvalidHeaderLength = errors.New("invalid ipv4 header length")
ErrInboundIPv4TooShort = errors.New("ipv4 packet is too short") ErrIPv4PacketTooShort = errors.New("ipv4 packet is too short")
ErrInboundIPv6TooShort = errors.New("ipv6 packet is too short") ErrIPv6PacketTooShort = errors.New("ipv6 packet is too short")
ErrInboundIPv6NoPayload = errors.New("could not find payload in ipv6 packet") ErrIPv6CouldNotFindPayload = errors.New("could not find payload in ipv6 packet")
) )
// RxKind discriminates how an inbound plaintext packet should be committed // RxKind discriminates how an inbound plaintext packet should be committed
@@ -61,19 +61,27 @@ type RxParsed struct {
udp parsedUDP udp parsedUDP
} }
// ParseInbound walks an inbound plaintext packet once and fills: // ParsePacket walks an IP packet once and fills parsed.Key. When incoming
// - parsed.Key with the dense, Local/Remote-oriented conntrack key the // is true and the L4 shape is coalesce-eligible, also fills parsed.tcp /
// firewall uses (replaces the netip.Addr-rich path through newPacket). // parsed.udp so CommitInbound can dispatch into the coalescer without
// - parsed.{tcp,udp} with the coalescer hint, when the shape is // re-walking the headers.
// coalesce-eligible.
// //
// Eligibility rules match the coalescer's own parseTCPBase/parseUDP: // Direction selects the Key orientation:
//
// incoming=true → wire src → Key.RemoteAddr/Port, wire dst → Key.LocalAddr/Port
// incoming=false → wire src → Key.LocalAddr/Port, wire dst → Key.RemoteAddr/Port
//
// ICMP always lands the identifier in Key.RemotePort, regardless of direction.
//
// Eligibility rules for the coalescer hint match the coalescer's own
// parseTCPBase/parseUDP:
// - IPv4 strict: IHL == 20, no fragmentation (MF or offset), proto TCP/UDP. // - IPv4 strict: IHL == 20, no fragmentation (MF or offset), proto TCP/UDP.
// - IPv6 strict: NextHeader is directly TCP or UDP (no extension headers). // - IPv6 strict: NextHeader is directly TCP or UDP (no extension headers).
// //
// Returns the same set of errors newPacket returns for malformed input — // The hint is only filled for incoming packets, since the outbound path
// callers can treat those as drop. // does not feed an inbound coalescer. Outbound callers see Kind stay at
func ParseInbound(pkt []byte, parsed *RxParsed) error { // RxKindPassthrough and parsed.tcp/udp stay zero.
func ParsePacket(pkt []byte, incoming bool, parsed *RxParsed) error {
parsed.Kind = RxKindPassthrough parsed.Kind = RxKindPassthrough
// Reset Key in full: v4 only writes the low 4 bytes of each address // Reset Key in full: v4 only writes the low 4 bytes of each address
// field, so without this a v6 call followed by a v4 reusing the same // field, so without this a v6 call followed by a v4 reusing the same
@@ -81,26 +89,27 @@ func ParseInbound(pkt []byte, parsed *RxParsed) error {
// map equality for v4 flows. // map equality for v4 flows.
parsed.Key = firewall.PacketKey{} parsed.Key = firewall.PacketKey{}
if len(pkt) < 1 { if len(pkt) < 1 {
return ErrInboundPacketTooShort return ErrPacketTooShort
} }
switch pkt[0] >> 4 { switch pkt[0] >> 4 {
case 4: case 4:
return parseInboundV4(pkt, parsed) return parsePacketV4(pkt, incoming, parsed)
case 6: case 6:
return parseInboundV6(pkt, parsed) return parsePacketV6(pkt, incoming, parsed)
} }
return ErrInboundUnknownIPVersion return ErrUnknownIPVersion
} }
// parseInboundV4 mirrors parseV4(incoming=true) for the firewall side and // parsePacketV4 fills parsed.Key from an IPv4 packet. Direction selects
// also fills the coalescer hint when the shape is strict. // Local/Remote orientation. When incoming and the shape is strict, also
func parseInboundV4(pkt []byte, parsed *RxParsed) error { // fills the coalescer hint.
func parsePacketV4(pkt []byte, incoming bool, parsed *RxParsed) error {
if len(pkt) < 20 { if len(pkt) < 20 {
return ErrInboundIPv4TooShort return ErrIPv4PacketTooShort
} }
ihl := int(pkt[0]&0x0f) << 2 ihl := int(pkt[0]&0x0f) << 2
if ihl < 20 { if ihl < 20 {
return ErrInboundIPv4InvalidHdrLen return ErrIPv4InvalidHeaderLength
} }
flagsfrags := binary.BigEndian.Uint16(pkt[6:8]) flagsfrags := binary.BigEndian.Uint16(pkt[6:8])
parsed.Key.Fragment = (flagsfrags & 0x1FFF) != 0 parsed.Key.Fragment = (flagsfrags & 0x1FFF) != 0
@@ -118,12 +127,16 @@ func parseInboundV4(pkt []byte, parsed *RxParsed) error {
} }
} }
if len(pkt) < minLen { if len(pkt) < minLen {
return ErrInboundIPv4InvalidHdrLen return ErrIPv4InvalidHeaderLength
} }
// Inbound orientation: wire src → Remote, wire dst → Local. if incoming {
copy(parsed.Key.RemoteAddr[:4], pkt[12:16]) copy(parsed.Key.RemoteAddr[:4], pkt[12:16])
copy(parsed.Key.LocalAddr[:4], pkt[16:20]) copy(parsed.Key.LocalAddr[:4], pkt[16:20])
} else {
copy(parsed.Key.LocalAddr[:4], pkt[12:16])
copy(parsed.Key.RemoteAddr[:4], pkt[16:20])
}
switch { switch {
case parsed.Key.Fragment: case parsed.Key.Fragment:
@@ -132,11 +145,18 @@ func parseInboundV4(pkt []byte, parsed *RxParsed) error {
case parsed.Key.Protocol == firewall.ProtoICMP: case parsed.Key.Protocol == firewall.ProtoICMP:
parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[ihl+4 : ihl+6]) parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[ihl+4 : ihl+6])
parsed.Key.LocalPort = 0 parsed.Key.LocalPort = 0
default: case incoming:
parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[ihl : ihl+2]) parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[ihl : ihl+2])
parsed.Key.LocalPort = binary.BigEndian.Uint16(pkt[ihl+2 : ihl+4]) parsed.Key.LocalPort = binary.BigEndian.Uint16(pkt[ihl+2 : ihl+4])
default:
parsed.Key.LocalPort = binary.BigEndian.Uint16(pkt[ihl : ihl+2])
parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[ihl+2 : ihl+4])
} }
// Coalescer hint is inbound-only: no inbound coalescer fires on outgoing.
if !incoming {
return nil
}
// Coalescer-eligible? Strict shape: IHL==20, no MF/offset, TCP or UDP. // Coalescer-eligible? Strict shape: IHL==20, no MF/offset, TCP or UDP.
if ihl != 20 || (flagsfrags&0x3FFF) != 0 { if ihl != 20 || (flagsfrags&0x3FFF) != 0 {
return nil return nil
@@ -208,28 +228,43 @@ func fillParsedUDPv4(pkt []byte, parsed *RxParsed) {
parsed.Kind = RxKindUDP parsed.Kind = RxKindUDP
} }
// parseInboundV6 mirrors parseV6(incoming=true). The coalescer-eligible // parsePacketV6 fills parsed.Key from an IPv6 packet. Direction selects
// fast path triggers only when NextHeader is directly TCP or UDP — any // Local/Remote orientation. The coalescer hint fast path only triggers
// extension header chain falls into the lenient walk below. // when NextHeader is directly TCP or UDP — any extension header chain
func parseInboundV6(pkt []byte, parsed *RxParsed) error { // falls into the lenient walk below, and the hint stays unfilled.
func parsePacketV6(pkt []byte, incoming bool, parsed *RxParsed) error {
if len(pkt) < 40 { if len(pkt) < 40 {
return ErrInboundIPv6TooShort return ErrIPv6PacketTooShort
} }
parsed.Key.IsV6 = true parsed.Key.IsV6 = true
if incoming {
copy(parsed.Key.RemoteAddr[:], pkt[8:24]) copy(parsed.Key.RemoteAddr[:], pkt[8:24])
copy(parsed.Key.LocalAddr[:], pkt[24:40]) copy(parsed.Key.LocalAddr[:], pkt[24:40])
} else {
copy(parsed.Key.LocalAddr[:], pkt[8:24])
copy(parsed.Key.RemoteAddr[:], pkt[24:40])
}
if proto := pkt[6]; proto == ipProtoTCP || proto == ipProtoUDP { if proto := pkt[6]; proto == ipProtoTCP || proto == ipProtoUDP {
// Strict v6: ports are at the IP header end. Always fill key; only // Strict v6: ports are at the IP header end. Always fill key; only
// fill the coalescer hint if the L4 shape passes. // fill the coalescer hint if the L4 shape passes.
if len(pkt) < 44 { if len(pkt) < 44 {
return ErrInboundIPv6TooShort return ErrIPv6PacketTooShort
} }
parsed.Key.Protocol = proto parsed.Key.Protocol = proto
parsed.Key.Fragment = false parsed.Key.Fragment = false
if incoming {
parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[40:42]) parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[40:42])
parsed.Key.LocalPort = binary.BigEndian.Uint16(pkt[42:44]) parsed.Key.LocalPort = binary.BigEndian.Uint16(pkt[42:44])
} else {
parsed.Key.LocalPort = binary.BigEndian.Uint16(pkt[40:42])
parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[42:44])
}
// Coalescer hint is inbound-only.
if !incoming {
return nil
}
payloadLen := int(binary.BigEndian.Uint16(pkt[4:6])) payloadLen := int(binary.BigEndian.Uint16(pkt[4:6]))
if 40+payloadLen > len(pkt) { if 40+payloadLen > len(pkt) {
return nil return nil
@@ -245,8 +280,9 @@ func parseInboundV6(pkt []byte, parsed *RxParsed) error {
return nil return nil
} }
// Slow path: walk extension header chain just like parseV6 does. // Slow path: walk extension header chain. Coalescer hint never fires
return walkInboundV6Headers(pkt, parsed) // here, so direction only matters for L4 port orientation.
return walkV6Headers(pkt, incoming, parsed)
} }
func fillParsedTCPv6(pkt []byte, parsed *RxParsed) { func fillParsedTCPv6(pkt []byte, parsed *RxParsed) {
@@ -295,12 +331,13 @@ func fillParsedUDPv6(pkt []byte, parsed *RxParsed) {
parsed.Kind = RxKindUDP parsed.Kind = RxKindUDP
} }
// walkInboundV6Headers handles every IPv6 case parseV6 handles that isn't // walkV6Headers handles every IPv6 case the strict "NextHeader == TCP/UDP"
// the strict "NextHeader == TCP/UDP" fast path: ESP, NoNextHeader, ICMPv6, // fast path doesn't: ESP, NoNextHeader, ICMPv6, fragment headers (first vs
// fragment headers (first vs later), AH, generic extension headers. // later), AH, generic extension headers. Coalescer eligibility is always
// Coalescer eligibility is always RxKindPassthrough on this path (parsed // RxKindPassthrough on this path (parsed already initialised that way).
// already initialised that way). // Direction matters only for the L4 port orientation when the chain
func walkInboundV6Headers(pkt []byte, parsed *RxParsed) error { // terminates at TCP/UDP.
func walkV6Headers(pkt []byte, incoming bool, parsed *RxParsed) error {
dataLen := len(pkt) dataLen := len(pkt)
protoAt := 6 protoAt := 6
offset := 40 offset := 40
@@ -320,7 +357,7 @@ func walkInboundV6Headers(pkt []byte, parsed *RxParsed) error {
case ipProtoICMPv6: case ipProtoICMPv6:
if dataLen < offset+6 { if dataLen < offset+6 {
return ErrInboundIPv6TooShort return ErrIPv6PacketTooShort
} }
parsed.Key.Protocol = proto parsed.Key.Protocol = proto
parsed.Key.LocalPort = 0 parsed.Key.LocalPort = 0
@@ -338,17 +375,22 @@ func walkInboundV6Headers(pkt []byte, parsed *RxParsed) error {
// strict-eligible fast path above already handled the no-extension // strict-eligible fast path above already handled the no-extension
// case; here we only fill firewall ports and stay passthrough. // case; here we only fill firewall ports and stay passthrough.
if dataLen < offset+4 { if dataLen < offset+4 {
return ErrInboundIPv6TooShort return ErrIPv6PacketTooShort
} }
parsed.Key.Protocol = proto parsed.Key.Protocol = proto
if incoming {
parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[offset : offset+2]) parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[offset : offset+2])
parsed.Key.LocalPort = binary.BigEndian.Uint16(pkt[offset+2 : offset+4]) parsed.Key.LocalPort = binary.BigEndian.Uint16(pkt[offset+2 : offset+4])
} else {
parsed.Key.LocalPort = binary.BigEndian.Uint16(pkt[offset : offset+2])
parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[offset+2 : offset+4])
}
parsed.Key.Fragment = false parsed.Key.Fragment = false
return nil return nil
case ipProtoIPv6Fragment: case ipProtoIPv6Fragment:
if dataLen < offset+8 { if dataLen < offset+8 {
return ErrInboundIPv6TooShort return ErrIPv6PacketTooShort
} }
fragmentOffset := binary.BigEndian.Uint16(pkt[offset+2:offset+4]) &^ uint16(0x7) fragmentOffset := binary.BigEndian.Uint16(pkt[offset+2:offset+4]) &^ uint16(0x7)
if fragmentOffset != 0 { if fragmentOffset != 0 {
@@ -380,7 +422,7 @@ func walkInboundV6Headers(pkt []byte, parsed *RxParsed) error {
protoAt = offset protoAt = offset
offset = offset + next offset = offset + next
} }
return ErrInboundIPv6NoPayload return ErrIPv6CouldNotFindPayload
} }
// CommitInbound dispatches pkt to the appropriate lane using parsed.Kind, // CommitInbound dispatches pkt to the appropriate lane using parsed.Kind,

View File

@@ -176,7 +176,7 @@ func runRxUnified(b *testing.B, pkts [][]byte, batchSize int) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
pkt := pkts[i%len(pkts)] pkt := pkts[i%len(pkts)]
if err := ParseInbound(pkt, &parsed); err != nil { if err := ParsePacket(pkt, true, &parsed); err != nil {
b.Fatal(err) b.Fatal(err)
} }
if err := m.CommitInbound(pkt, &parsed); err != nil { if err := m.CommitInbound(pkt, &parsed); err != nil {
@@ -344,7 +344,7 @@ func runRxUnifiedWithCache(b *testing.B, pkts [][]byte, batchSize int) {
cache := make(firewall.ConntrackCache, len(pkts)) cache := make(firewall.ConntrackCache, len(pkts))
for _, pkt := range pkts { for _, pkt := range pkts {
var seed RxParsed var seed RxParsed
if err := ParseInbound(pkt, &seed); err != nil { if err := ParsePacket(pkt, true, &seed); err != nil {
b.Fatal(err) b.Fatal(err)
} }
cache[seed.Key] = struct{}{} cache[seed.Key] = struct{}{}
@@ -355,7 +355,7 @@ func runRxUnifiedWithCache(b *testing.B, pkts [][]byte, batchSize int) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
pkt := pkts[i%len(pkts)] pkt := pkts[i%len(pkts)]
if err := ParseInbound(pkt, &parsed); err != nil { if err := ParsePacket(pkt, true, &parsed); err != nil {
b.Fatal(err) b.Fatal(err)
} }
if _, ok := cache[parsed.Key]; !ok { if _, ok := cache[parsed.Key]; !ok {

View File

@@ -32,8 +32,8 @@ func TestParseInboundParity(t *testing.T) {
var fpUnified, fpBaseline firewall.Packet var fpUnified, fpBaseline firewall.Packet
var parsed RxParsed var parsed RxParsed
if err := ParseInbound(tc.pkt, &parsed); err != nil { if err := ParsePacket(tc.pkt, true, &parsed); err != nil {
t.Fatalf("ParseInbound: %v", err) t.Fatalf("ParsePacket: %v", err)
} }
parsed.Key.Hydrate(&fpUnified) parsed.Key.Hydrate(&fpUnified)
var ok bool var ok bool
@@ -61,7 +61,7 @@ func TestParseInboundFlowKey(t *testing.T) {
t.Run("tcp_v4", func(t *testing.T) { t.Run("tcp_v4", func(t *testing.T) {
pkt := buildTCPv4Ports(1234, 443, 5000, tcpAck, make([]byte, 800)) pkt := buildTCPv4Ports(1234, 443, 5000, tcpAck, make([]byte, 800))
var parsed RxParsed var parsed RxParsed
if err := ParseInbound(pkt, &parsed); err != nil { if err := ParsePacket(pkt, true, &parsed); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if parsed.Kind != RxKindTCP { if parsed.Kind != RxKindTCP {
@@ -79,7 +79,7 @@ func TestParseInboundFlowKey(t *testing.T) {
t.Run("udp_v4", func(t *testing.T) { t.Run("udp_v4", func(t *testing.T) {
pkt := buildUDPv4(40000, 53, []byte("dnsquery")) pkt := buildUDPv4(40000, 53, []byte("dnsquery"))
var parsed RxParsed var parsed RxParsed
if err := ParseInbound(pkt, &parsed); err != nil { if err := ParsePacket(pkt, true, &parsed); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if parsed.Kind != RxKindUDP { if parsed.Kind != RxKindUDP {
@@ -97,7 +97,7 @@ func TestParseInboundFlowKey(t *testing.T) {
t.Run("tcp_v6", func(t *testing.T) { t.Run("tcp_v6", func(t *testing.T) {
pkt := buildTCPv6(0, 9000, tcpAck, make([]byte, 800)) pkt := buildTCPv6(0, 9000, tcpAck, make([]byte, 800))
var parsed RxParsed var parsed RxParsed
if err := ParseInbound(pkt, &parsed); err != nil { if err := ParsePacket(pkt, true, &parsed); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if parsed.Kind != RxKindTCP { if parsed.Kind != RxKindTCP {
@@ -126,7 +126,7 @@ func TestParseInboundICMPPassthrough(t *testing.T) {
pkt[25] = 0xcd pkt[25] = 0xcd
var parsed RxParsed var parsed RxParsed
if err := ParseInbound(pkt, &parsed); err != nil { if err := ParsePacket(pkt, true, &parsed); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if parsed.Kind != RxKindPassthrough { if parsed.Kind != RxKindPassthrough {
@@ -162,7 +162,7 @@ func TestParseInboundV4Fragment(t *testing.T) {
pkt[7] = 0x10 // offset = 16 (in 8-byte units) pkt[7] = 0x10 // offset = 16 (in 8-byte units)
var parsed RxParsed var parsed RxParsed
if err := ParseInbound(pkt, &parsed); err != nil { if err := ParsePacket(pkt, true, &parsed); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !parsed.Key.Fragment { if !parsed.Key.Fragment {