diff --git a/firewall.go b/firewall.go index 948dadcc..45eb499d 100644 --- a/firewall.go +++ b/firewall.go @@ -437,8 +437,8 @@ func (f *Firewall) Drop(key firewall.PacketKey, fp *firewall.Packet, incoming bo // Conntrack miss → rule matching needs the rich Packet form. Hydrate // from the key if the caller passed a zero-valued fp (the inbound path - // after ParseInbound). Outbound callers fill fp via newPacket and skip - // this hop. + // after batch.ParsePacket). Outbound callers Hydrate themselves and + // skip this hop. if !fp.LocalAddr.IsValid() { key.Hydrate(fp) } diff --git a/firewall/packet.go b/firewall/packet.go index 02f0952d..c2e27128 100644 --- a/firewall/packet.go +++ b/firewall/packet.go @@ -50,10 +50,10 @@ type Packet struct { Fragment bool } -// Key derives a PacketKey from a populated Packet. Used by the outgoing -// path (inside.go) which still parses into a full Packet via newPacket -// before the firewall check; the inbound path skips this hop entirely by -// having its parser write straight into the PacketKey. +// Key derives a PacketKey from a populated Packet. Used by the few code +// paths that have a Packet but no Key in hand (e.g. tests). Both inbound +// and outbound production parsers write straight into a PacketKey via +// batch.ParsePacket, so this function is rarely on the hot path. func (fp *Packet) Key() PacketKey { k := PacketKey{ Protocol: fp.Protocol, diff --git a/inside.go b/inside.go index 3c537fc6..b56570e0 100644 --- a/inside.go +++ b/inside.go @@ -25,11 +25,11 @@ func (f *Interface) consumeInsidePacket(pkt tio.Packet, fwPacket *firewall.Packe // // pkt.Bytes is either one IP datagram (GSO zero) or a TSO/USO // superpacket. In both cases the L3+L4 headers at the start describe - // the same 5-tuple every segment will share, so a single newPacket / + // the same 5-tuple every segment will share, so a single parse + // firewall check covers the whole superpacket. packet := pkt.Bytes - key, err := newPacketKey(packet, false) - if err != nil { + var parsed batch.RxParsed + if err := batch.ParsePacket(packet, false, &parsed); err != nil { if f.l.Enabled(context.Background(), slog.LevelDebug) { f.l.Debug("Error while validating outbound packet", "packet", packet, @@ -39,7 +39,7 @@ func (f *Interface) consumeInsidePacket(pkt tio.Packet, fwPacket *firewall.Packe return } - key.Hydrate(fwPacket) + parsed.Key.Hydrate(fwPacket) // Ignore local broadcast packets if f.dropLocalBroadcast { @@ -107,7 +107,7 @@ func (f *Interface) consumeInsidePacket(pkt tio.Packet, fwPacket *firewall.Packe return } - dropReason := f.firewall.Drop(key, fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) + dropReason := f.firewall.Drop(parsed.Key, fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache) if dropReason == nil { f.sendInsideMessage(hostinfo, pkt, nb, sendBatch, rejectBuf, q) } else { @@ -394,16 +394,16 @@ func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cac } func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { - key, err := newPacketKey(p, false) - if err != nil { + var parsed batch.RxParsed + if err := batch.ParsePacket(p, false, &parsed); err != nil { f.l.Warn("error while parsing outgoing packet for firewall check", "error", err) return } fp := &firewall.Packet{} - key.Hydrate(fp) + parsed.Key.Hydrate(fp) // check if packet is in outbound fw rules - dropReason := f.firewall.Drop(key, fp, false, hostinfo, f.pki.GetCAPool(), nil) + dropReason := f.firewall.Drop(parsed.Key, fp, false, hostinfo, f.pki.GetCAPool(), nil) if dropReason != nil { if f.l.Enabled(context.Background(), slog.LevelDebug) { f.l.Debug("dropping cached packet", diff --git a/outside.go b/outside.go index 58ddb9bd..725d7e64 100644 --- a/outside.go +++ b/outside.go @@ -2,24 +2,15 @@ package nebula import ( "context" - "encoding/binary" "errors" "log/slog" "net/netip" "time" - "github.com/google/gopacket/layers" - "golang.org/x/net/ipv6" - "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/overlay/batch" "github.com/slackhq/nebula/udp" - "golang.org/x/net/ipv4" -) - -const ( - minFwPacketLen = 4 ) var ErrOutOfWindow = errors.New("out of window packet") @@ -318,181 +309,11 @@ var ( // parse logic. Callers that don't need the netip.Addr-rich form (e.g. // conntrack-only paths) should use newPacketKey directly. func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { - key, err := newPacketKey(data, incoming) - if err != nil { + var parsed batch.RxParsed + if err := batch.ParsePacket(data, incoming, &parsed); err != nil { return err } - key.Hydrate(fp) - return nil -} - -// newPacketKey parses data into a dense firewall.PacketKey. Hot path: no -// netip.Addr construction, no unique.Handle interning. Caller decides -// whether to also Hydrate to a Packet (for rule matching) or pass the key -// straight to conntrack. -func newPacketKey(data []byte, incoming bool) (firewall.PacketKey, error) { - var k firewall.PacketKey - if len(data) < 1 { - return k, ErrPacketTooShort - } - switch int((data[0] >> 4) & 0x0f) { - case ipv4.Version: - return k, parseV4Key(data, incoming, &k) - case ipv6.Version: - k.IsV6 = true - return k, parseV6Key(data, incoming, &k) - } - return k, ErrUnknownIPVersion -} - -func parseV6Key(data []byte, incoming bool, k *firewall.PacketKey) error { - dataLen := len(data) - if dataLen < ipv6.HeaderLen { - return ErrIPv6PacketTooShort - } - - if incoming { - copy(k.RemoteAddr[:], data[8:24]) - copy(k.LocalAddr[:], data[24:40]) - } else { - copy(k.LocalAddr[:], data[8:24]) - copy(k.RemoteAddr[:], data[24:40]) - } - - protoAt := 6 - offset := ipv6.HeaderLen - next := 0 - for { - if protoAt >= dataLen { - break - } - proto := layers.IPProtocol(data[protoAt]) - - switch proto { - case layers.IPProtocolESP, layers.IPProtocolNoNextHeader: - k.Protocol = uint8(proto) - k.RemotePort = 0 - k.LocalPort = 0 - k.Fragment = false - return nil - - case layers.IPProtocolICMPv6: - if dataLen < offset+6 { - return ErrIPv6PacketTooShort - } - k.Protocol = uint8(proto) - k.LocalPort = 0 - switch data[offset+1] { - case layers.ICMPv6TypeEchoRequest, layers.ICMPv6TypeEchoReply: - k.RemotePort = binary.BigEndian.Uint16(data[offset+4 : offset+6]) - default: - k.RemotePort = 0 - } - k.Fragment = false - return nil - - case layers.IPProtocolTCP, layers.IPProtocolUDP: - if dataLen < offset+4 { - return ErrIPv6PacketTooShort - } - k.Protocol = uint8(proto) - if incoming { - k.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2]) - k.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) - } else { - k.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2]) - k.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) - } - k.Fragment = false - return nil - - case layers.IPProtocolIPv6Fragment: - if dataLen < offset+8 { - return ErrIPv6PacketTooShort - } - fragmentOffset := binary.BigEndian.Uint16(data[offset+2:offset+4]) &^ uint16(0x7) - if fragmentOffset != 0 { - k.Protocol = data[offset] - k.Fragment = true - k.RemotePort = 0 - k.LocalPort = 0 - return nil - } - next = 8 - - case layers.IPProtocolAH: - if dataLen <= offset+1 { - break - } - next = int(data[offset+1]+2) << 2 - - default: - if dataLen <= offset+1 { - break - } - next = int(data[offset+1]+1) << 3 - } - - if next <= 0 { - next = 8 - } - protoAt = offset - offset = offset + next - } - - return ErrIPv6CouldNotFindPayload -} - -func parseV4Key(data []byte, incoming bool, k *firewall.PacketKey) error { - if len(data) < ipv4.HeaderLen { - return ErrIPv4PacketTooShort - } - ihl := int(data[0]&0x0f) << 2 - if ihl < ipv4.HeaderLen { - return ErrIPv4InvalidHeaderLength - } - - flagsfrags := binary.BigEndian.Uint16(data[6:8]) - k.Fragment = (flagsfrags & 0x1FFF) != 0 - k.Protocol = data[9] - - minLen := ihl - if !k.Fragment { - if k.Protocol == firewall.ProtoICMP { - minLen += minFwPacketLen + 2 - } else { - minLen += minFwPacketLen - } - } - if len(data) < minLen { - return ErrIPv4InvalidHeaderLength - } - - // Dense form: v4 in low 4 bytes, rest zero. Matches the coalescer's - // flowKey convention so the two stay byte-identical for the same flow. - if incoming { - copy(k.RemoteAddr[:4], data[12:16]) - copy(k.LocalAddr[:4], data[16:20]) - } else { - copy(k.LocalAddr[:4], data[12:16]) - copy(k.RemoteAddr[:4], data[16:20]) - } - - switch { - case k.Fragment: - k.RemotePort = 0 - k.LocalPort = 0 - case k.Protocol == firewall.ProtoICMP: - k.RemotePort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) - k.LocalPort = 0 - case incoming: - k.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2]) - k.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) - default: - k.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) - k.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) - } - + parsed.Key.Hydrate(fp) return nil } @@ -571,7 +392,7 @@ func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, p // firewall.Drop's fast path uses Key alone and only hydrates fwPacket // from Key on the slow path. *fwPacket = firewall.Packet{} - err := batch.ParseInbound(out, parsedRx) + err := batch.ParsePacket(out, true, parsedRx) if err != nil { hostinfo.logger(f.l).Warn("Error while validating inbound packet", "error", err, diff --git a/outside_test.go b/outside_test.go index 49eeec39..22ccf0c0 100644 --- a/outside_test.go +++ b/outside_test.go @@ -11,6 +11,7 @@ import ( "github.com/google/gopacket/layers" "github.com/slackhq/nebula/firewall" + "github.com/slackhq/nebula/overlay/batch" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/ipv4" @@ -21,13 +22,13 @@ func Test_newPacket(t *testing.T) { // length fails err := newPacket([]byte{}, true, p) - require.ErrorIs(t, err, ErrPacketTooShort) + require.ErrorIs(t, err, batch.ErrPacketTooShort) err = newPacket([]byte{0x40}, true, p) - require.ErrorIs(t, err, ErrIPv4PacketTooShort) + require.ErrorIs(t, err, batch.ErrIPv4PacketTooShort) err = newPacket([]byte{0x60}, true, p) - require.ErrorIs(t, err, ErrIPv6PacketTooShort) + require.ErrorIs(t, err, batch.ErrIPv6PacketTooShort) // length fail with ip options h := ipv4.Header{ @@ -40,15 +41,15 @@ func Test_newPacket(t *testing.T) { b, _ := h.Marshal() err = newPacket(b, true, p) - require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) + require.ErrorIs(t, err, batch.ErrIPv4InvalidHeaderLength) // not an ipv4 packet err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) - require.ErrorIs(t, err, ErrUnknownIPVersion) + require.ErrorIs(t, err, batch.ErrUnknownIPVersion) // invalid ihl err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) - require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) + require.ErrorIs(t, err, batch.ErrIPv4InvalidHeaderLength) // account for variable ip header length - incoming h = ipv4.Header{ @@ -115,7 +116,7 @@ func Test_newPacket_v6(t *testing.T) { require.NoError(t, err) err = newPacket(buffer.Bytes(), true, p) - require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + require.ErrorIs(t, err, batch.ErrIPv6CouldNotFindPayload) // A v6 packet with a hop-by-hop extension // ICMPv6 Payload (Echo Request) @@ -149,12 +150,12 @@ func Test_newPacket_v6(t *testing.T) { // A full IPv6 header and 1 byte in the first extension, but missing // the length byte. err = newPacket(buffer.Bytes()[:41], true, p) - require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + require.ErrorIs(t, err, batch.ErrIPv6CouldNotFindPayload) // A full IPv6 header plus 1 full extension, but only 1 byte of the // next layer, missing length byte err = newPacket(buffer.Bytes()[:49], true, p) - require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + require.ErrorIs(t, err, batch.ErrIPv6CouldNotFindPayload) err = nil // A good ICMP packet @@ -217,7 +218,7 @@ func Test_newPacket_v6(t *testing.T) { b = buffer.Bytes() b[6] = 255 // 255 is a reserved protocol number err = newPacket(b, true, p) - require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + require.ErrorIs(t, err, batch.ErrIPv6CouldNotFindPayload) // A good UDP packet ip = layers.IPv6{ @@ -264,7 +265,7 @@ func Test_newPacket_v6(t *testing.T) { // Too short UDP packet err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes - require.ErrorIs(t, err, ErrIPv6PacketTooShort) + require.ErrorIs(t, err, batch.ErrIPv6PacketTooShort) // A good TCP packet b[6] = byte(layers.IPProtocolTCP) @@ -291,7 +292,7 @@ func Test_newPacket_v6(t *testing.T) { // Too short TCP packet err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes - require.ErrorIs(t, err, ErrIPv6PacketTooShort) + require.ErrorIs(t, err, batch.ErrIPv6PacketTooShort) // A good UDP packet with an AH header ip = layers.IPv6{ @@ -336,12 +337,12 @@ func Test_newPacket_v6(t *testing.T) { // Ensure buffer bounds checking during processing err = newPacket(b[:41], true, p) - require.ErrorIs(t, err, ErrIPv6PacketTooShort) + require.ErrorIs(t, err, batch.ErrIPv6PacketTooShort) // Invalid AH header b = buffer.Bytes() err = newPacket(b, true, p) - require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + require.ErrorIs(t, err, batch.ErrIPv6CouldNotFindPayload) } func Test_newPacket_ipv6Fragment(t *testing.T) { @@ -448,7 +449,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) { // Too short of a fragment packet err = newPacket(secondFrag[:len(secondFrag)-10], false, p) - require.ErrorIs(t, err, ErrIPv6PacketTooShort) + require.ErrorIs(t, err, batch.ErrIPv6PacketTooShort) } func BenchmarkParseV6(b *testing.B) { diff --git a/overlay/batch/batch.go b/overlay/batch/batch.go index ac7102c0..57d94b25 100644 --- a/overlay/batch/batch.go +++ b/overlay/batch/batch.go @@ -7,9 +7,9 @@ type RxBatcher interface { Reserve(sz int) []byte // Commit borrows pkt. The caller must keep pkt valid until the next Flush. // Walks IP+L4 headers itself; prefer CommitInbound when the caller already - // has an RxParsed in hand from ParseInbound. + // has an RxParsed in hand from ParsePacket. Commit(pkt []byte) error - // CommitInbound is Commit with a hint produced by ParseInbound, so the + // CommitInbound is Commit with a hint produced by ParsePacket, so the // batcher can skip the IP+L4 re-parse. Borrowed slice contract is the // same as Commit. Implementations that don't coalesce may delegate to // Commit. diff --git a/overlay/batch/inbound.go b/overlay/batch/inbound.go index 67a528fa..6771f478 100644 --- a/overlay/batch/inbound.go +++ b/overlay/batch/inbound.go @@ -22,16 +22,16 @@ const ( icmpv6TypeEchoReply = 129 ) -// Inbound parse errors. Match outside.go's sentinel set so the unified -// parser can drop in as a replacement for newPacket without callers -// noticing a behavior change. +// Packet parse errors — the canonical sentinel set for IP+L4 parsing. +// Both inbound and outbound callers share this surface, so any code path +// that ends up at firewall.PacketKey reports drops with the same errors. var ( - ErrInboundPacketTooShort = errors.New("packet is too short") - ErrInboundUnknownIPVersion = errors.New("packet is an unknown ip version") - ErrInboundIPv4InvalidHdrLen = errors.New("invalid ipv4 header length") - ErrInboundIPv4TooShort = errors.New("ipv4 packet is too short") - ErrInboundIPv6TooShort = errors.New("ipv6 packet is too short") - ErrInboundIPv6NoPayload = errors.New("could not find payload in ipv6 packet") + ErrPacketTooShort = errors.New("packet is too short") + ErrUnknownIPVersion = errors.New("packet is an unknown ip version") + ErrIPv4InvalidHeaderLength = errors.New("invalid ipv4 header length") + ErrIPv4PacketTooShort = errors.New("ipv4 packet is too short") + ErrIPv6PacketTooShort = errors.New("ipv6 packet is too short") + ErrIPv6CouldNotFindPayload = errors.New("could not find payload in ipv6 packet") ) // RxKind discriminates how an inbound plaintext packet should be committed @@ -61,19 +61,27 @@ type RxParsed struct { udp parsedUDP } -// ParseInbound walks an inbound plaintext packet once and fills: -// - parsed.Key with the dense, Local/Remote-oriented conntrack key the -// firewall uses (replaces the netip.Addr-rich path through newPacket). -// - parsed.{tcp,udp} with the coalescer hint, when the shape is -// coalesce-eligible. +// ParsePacket walks an IP packet once and fills parsed.Key. When incoming +// is true and the L4 shape is coalesce-eligible, also fills parsed.tcp / +// parsed.udp so CommitInbound can dispatch into the coalescer without +// re-walking the headers. // -// Eligibility rules match the coalescer's own parseTCPBase/parseUDP: +// Direction selects the Key orientation: +// +// incoming=true → wire src → Key.RemoteAddr/Port, wire dst → Key.LocalAddr/Port +// incoming=false → wire src → Key.LocalAddr/Port, wire dst → Key.RemoteAddr/Port +// +// ICMP always lands the identifier in Key.RemotePort, regardless of direction. +// +// Eligibility rules for the coalescer hint match the coalescer's own +// parseTCPBase/parseUDP: // - IPv4 strict: IHL == 20, no fragmentation (MF or offset), proto TCP/UDP. // - IPv6 strict: NextHeader is directly TCP or UDP (no extension headers). // -// Returns the same set of errors newPacket returns for malformed input — -// callers can treat those as drop. -func ParseInbound(pkt []byte, parsed *RxParsed) error { +// The hint is only filled for incoming packets, since the outbound path +// does not feed an inbound coalescer. Outbound callers see Kind stay at +// RxKindPassthrough and parsed.tcp/udp stay zero. +func ParsePacket(pkt []byte, incoming bool, parsed *RxParsed) error { parsed.Kind = RxKindPassthrough // Reset Key in full: v4 only writes the low 4 bytes of each address // field, so without this a v6 call followed by a v4 reusing the same @@ -81,26 +89,27 @@ func ParseInbound(pkt []byte, parsed *RxParsed) error { // map equality for v4 flows. parsed.Key = firewall.PacketKey{} if len(pkt) < 1 { - return ErrInboundPacketTooShort + return ErrPacketTooShort } switch pkt[0] >> 4 { case 4: - return parseInboundV4(pkt, parsed) + return parsePacketV4(pkt, incoming, parsed) case 6: - return parseInboundV6(pkt, parsed) + return parsePacketV6(pkt, incoming, parsed) } - return ErrInboundUnknownIPVersion + return ErrUnknownIPVersion } -// parseInboundV4 mirrors parseV4(incoming=true) for the firewall side and -// also fills the coalescer hint when the shape is strict. -func parseInboundV4(pkt []byte, parsed *RxParsed) error { +// parsePacketV4 fills parsed.Key from an IPv4 packet. Direction selects +// Local/Remote orientation. When incoming and the shape is strict, also +// fills the coalescer hint. +func parsePacketV4(pkt []byte, incoming bool, parsed *RxParsed) error { if len(pkt) < 20 { - return ErrInboundIPv4TooShort + return ErrIPv4PacketTooShort } ihl := int(pkt[0]&0x0f) << 2 if ihl < 20 { - return ErrInboundIPv4InvalidHdrLen + return ErrIPv4InvalidHeaderLength } flagsfrags := binary.BigEndian.Uint16(pkt[6:8]) parsed.Key.Fragment = (flagsfrags & 0x1FFF) != 0 @@ -118,12 +127,16 @@ func parseInboundV4(pkt []byte, parsed *RxParsed) error { } } if len(pkt) < minLen { - return ErrInboundIPv4InvalidHdrLen + return ErrIPv4InvalidHeaderLength } - // Inbound orientation: wire src → Remote, wire dst → Local. - copy(parsed.Key.RemoteAddr[:4], pkt[12:16]) - copy(parsed.Key.LocalAddr[:4], pkt[16:20]) + if incoming { + copy(parsed.Key.RemoteAddr[:4], pkt[12:16]) + copy(parsed.Key.LocalAddr[:4], pkt[16:20]) + } else { + copy(parsed.Key.LocalAddr[:4], pkt[12:16]) + copy(parsed.Key.RemoteAddr[:4], pkt[16:20]) + } switch { case parsed.Key.Fragment: @@ -132,11 +145,18 @@ func parseInboundV4(pkt []byte, parsed *RxParsed) error { case parsed.Key.Protocol == firewall.ProtoICMP: parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[ihl+4 : ihl+6]) parsed.Key.LocalPort = 0 - default: + case incoming: parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[ihl : ihl+2]) parsed.Key.LocalPort = binary.BigEndian.Uint16(pkt[ihl+2 : ihl+4]) + default: + parsed.Key.LocalPort = binary.BigEndian.Uint16(pkt[ihl : ihl+2]) + parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[ihl+2 : ihl+4]) } + // Coalescer hint is inbound-only: no inbound coalescer fires on outgoing. + if !incoming { + return nil + } // Coalescer-eligible? Strict shape: IHL==20, no MF/offset, TCP or UDP. if ihl != 20 || (flagsfrags&0x3FFF) != 0 { return nil @@ -208,28 +228,43 @@ func fillParsedUDPv4(pkt []byte, parsed *RxParsed) { parsed.Kind = RxKindUDP } -// parseInboundV6 mirrors parseV6(incoming=true). The coalescer-eligible -// fast path triggers only when NextHeader is directly TCP or UDP — any -// extension header chain falls into the lenient walk below. -func parseInboundV6(pkt []byte, parsed *RxParsed) error { +// parsePacketV6 fills parsed.Key from an IPv6 packet. Direction selects +// Local/Remote orientation. The coalescer hint fast path only triggers +// when NextHeader is directly TCP or UDP — any extension header chain +// falls into the lenient walk below, and the hint stays unfilled. +func parsePacketV6(pkt []byte, incoming bool, parsed *RxParsed) error { if len(pkt) < 40 { - return ErrInboundIPv6TooShort + return ErrIPv6PacketTooShort } parsed.Key.IsV6 = true - copy(parsed.Key.RemoteAddr[:], pkt[8:24]) - copy(parsed.Key.LocalAddr[:], pkt[24:40]) + if incoming { + copy(parsed.Key.RemoteAddr[:], pkt[8:24]) + copy(parsed.Key.LocalAddr[:], pkt[24:40]) + } else { + copy(parsed.Key.LocalAddr[:], pkt[8:24]) + copy(parsed.Key.RemoteAddr[:], pkt[24:40]) + } if proto := pkt[6]; proto == ipProtoTCP || proto == ipProtoUDP { // Strict v6: ports are at the IP header end. Always fill key; only // fill the coalescer hint if the L4 shape passes. if len(pkt) < 44 { - return ErrInboundIPv6TooShort + return ErrIPv6PacketTooShort } parsed.Key.Protocol = proto parsed.Key.Fragment = false - parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[40:42]) - parsed.Key.LocalPort = binary.BigEndian.Uint16(pkt[42:44]) + if incoming { + parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[40:42]) + parsed.Key.LocalPort = binary.BigEndian.Uint16(pkt[42:44]) + } else { + parsed.Key.LocalPort = binary.BigEndian.Uint16(pkt[40:42]) + parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[42:44]) + } + // Coalescer hint is inbound-only. + if !incoming { + return nil + } payloadLen := int(binary.BigEndian.Uint16(pkt[4:6])) if 40+payloadLen > len(pkt) { return nil @@ -245,8 +280,9 @@ func parseInboundV6(pkt []byte, parsed *RxParsed) error { return nil } - // Slow path: walk extension header chain just like parseV6 does. - return walkInboundV6Headers(pkt, parsed) + // Slow path: walk extension header chain. Coalescer hint never fires + // here, so direction only matters for L4 port orientation. + return walkV6Headers(pkt, incoming, parsed) } func fillParsedTCPv6(pkt []byte, parsed *RxParsed) { @@ -295,12 +331,13 @@ func fillParsedUDPv6(pkt []byte, parsed *RxParsed) { parsed.Kind = RxKindUDP } -// walkInboundV6Headers handles every IPv6 case parseV6 handles that isn't -// the strict "NextHeader == TCP/UDP" fast path: ESP, NoNextHeader, ICMPv6, -// fragment headers (first vs later), AH, generic extension headers. -// Coalescer eligibility is always RxKindPassthrough on this path (parsed -// already initialised that way). -func walkInboundV6Headers(pkt []byte, parsed *RxParsed) error { +// walkV6Headers handles every IPv6 case the strict "NextHeader == TCP/UDP" +// fast path doesn't: ESP, NoNextHeader, ICMPv6, fragment headers (first vs +// later), AH, generic extension headers. Coalescer eligibility is always +// RxKindPassthrough on this path (parsed already initialised that way). +// Direction matters only for the L4 port orientation when the chain +// terminates at TCP/UDP. +func walkV6Headers(pkt []byte, incoming bool, parsed *RxParsed) error { dataLen := len(pkt) protoAt := 6 offset := 40 @@ -320,7 +357,7 @@ func walkInboundV6Headers(pkt []byte, parsed *RxParsed) error { case ipProtoICMPv6: if dataLen < offset+6 { - return ErrInboundIPv6TooShort + return ErrIPv6PacketTooShort } parsed.Key.Protocol = proto parsed.Key.LocalPort = 0 @@ -338,17 +375,22 @@ func walkInboundV6Headers(pkt []byte, parsed *RxParsed) error { // strict-eligible fast path above already handled the no-extension // case; here we only fill firewall ports and stay passthrough. if dataLen < offset+4 { - return ErrInboundIPv6TooShort + return ErrIPv6PacketTooShort } parsed.Key.Protocol = proto - parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[offset : offset+2]) - parsed.Key.LocalPort = binary.BigEndian.Uint16(pkt[offset+2 : offset+4]) + if incoming { + parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[offset : offset+2]) + parsed.Key.LocalPort = binary.BigEndian.Uint16(pkt[offset+2 : offset+4]) + } else { + parsed.Key.LocalPort = binary.BigEndian.Uint16(pkt[offset : offset+2]) + parsed.Key.RemotePort = binary.BigEndian.Uint16(pkt[offset+2 : offset+4]) + } parsed.Key.Fragment = false return nil case ipProtoIPv6Fragment: if dataLen < offset+8 { - return ErrInboundIPv6TooShort + return ErrIPv6PacketTooShort } fragmentOffset := binary.BigEndian.Uint16(pkt[offset+2:offset+4]) &^ uint16(0x7) if fragmentOffset != 0 { @@ -380,7 +422,7 @@ func walkInboundV6Headers(pkt []byte, parsed *RxParsed) error { protoAt = offset offset = offset + next } - return ErrInboundIPv6NoPayload + return ErrIPv6CouldNotFindPayload } // CommitInbound dispatches pkt to the appropriate lane using parsed.Kind, diff --git a/overlay/batch/inbound_bench_test.go b/overlay/batch/inbound_bench_test.go index dbbd479d..eaf263d8 100644 --- a/overlay/batch/inbound_bench_test.go +++ b/overlay/batch/inbound_bench_test.go @@ -176,7 +176,7 @@ func runRxUnified(b *testing.B, pkts [][]byte, batchSize int) { b.ResetTimer() for i := 0; i < b.N; i++ { pkt := pkts[i%len(pkts)] - if err := ParseInbound(pkt, &parsed); err != nil { + if err := ParsePacket(pkt, true, &parsed); err != nil { b.Fatal(err) } if err := m.CommitInbound(pkt, &parsed); err != nil { @@ -344,7 +344,7 @@ func runRxUnifiedWithCache(b *testing.B, pkts [][]byte, batchSize int) { cache := make(firewall.ConntrackCache, len(pkts)) for _, pkt := range pkts { var seed RxParsed - if err := ParseInbound(pkt, &seed); err != nil { + if err := ParsePacket(pkt, true, &seed); err != nil { b.Fatal(err) } cache[seed.Key] = struct{}{} @@ -355,7 +355,7 @@ func runRxUnifiedWithCache(b *testing.B, pkts [][]byte, batchSize int) { b.ResetTimer() for i := 0; i < b.N; i++ { pkt := pkts[i%len(pkts)] - if err := ParseInbound(pkt, &parsed); err != nil { + if err := ParsePacket(pkt, true, &parsed); err != nil { b.Fatal(err) } if _, ok := cache[parsed.Key]; !ok { diff --git a/overlay/batch/inbound_test.go b/overlay/batch/inbound_test.go index a918a4e4..3ecb2e60 100644 --- a/overlay/batch/inbound_test.go +++ b/overlay/batch/inbound_test.go @@ -32,8 +32,8 @@ func TestParseInboundParity(t *testing.T) { var fpUnified, fpBaseline firewall.Packet var parsed RxParsed - if err := ParseInbound(tc.pkt, &parsed); err != nil { - t.Fatalf("ParseInbound: %v", err) + if err := ParsePacket(tc.pkt, true, &parsed); err != nil { + t.Fatalf("ParsePacket: %v", err) } parsed.Key.Hydrate(&fpUnified) var ok bool @@ -61,7 +61,7 @@ func TestParseInboundFlowKey(t *testing.T) { t.Run("tcp_v4", func(t *testing.T) { pkt := buildTCPv4Ports(1234, 443, 5000, tcpAck, make([]byte, 800)) var parsed RxParsed - if err := ParseInbound(pkt, &parsed); err != nil { + if err := ParsePacket(pkt, true, &parsed); err != nil { t.Fatal(err) } if parsed.Kind != RxKindTCP { @@ -79,7 +79,7 @@ func TestParseInboundFlowKey(t *testing.T) { t.Run("udp_v4", func(t *testing.T) { pkt := buildUDPv4(40000, 53, []byte("dnsquery")) var parsed RxParsed - if err := ParseInbound(pkt, &parsed); err != nil { + if err := ParsePacket(pkt, true, &parsed); err != nil { t.Fatal(err) } if parsed.Kind != RxKindUDP { @@ -97,7 +97,7 @@ func TestParseInboundFlowKey(t *testing.T) { t.Run("tcp_v6", func(t *testing.T) { pkt := buildTCPv6(0, 9000, tcpAck, make([]byte, 800)) var parsed RxParsed - if err := ParseInbound(pkt, &parsed); err != nil { + if err := ParsePacket(pkt, true, &parsed); err != nil { t.Fatal(err) } if parsed.Kind != RxKindTCP { @@ -126,7 +126,7 @@ func TestParseInboundICMPPassthrough(t *testing.T) { pkt[25] = 0xcd var parsed RxParsed - if err := ParseInbound(pkt, &parsed); err != nil { + if err := ParsePacket(pkt, true, &parsed); err != nil { t.Fatal(err) } if parsed.Kind != RxKindPassthrough { @@ -162,7 +162,7 @@ func TestParseInboundV4Fragment(t *testing.T) { pkt[7] = 0x10 // offset = 16 (in 8-byte units) var parsed RxParsed - if err := ParseInbound(pkt, &parsed); err != nil { + if err := ParsePacket(pkt, true, &parsed); err != nil { t.Fatal(err) } if !parsed.Key.Fragment {