This commit is contained in:
JackDoan
2026-05-07 12:00:41 -05:00
parent 5bdf645b0b
commit 01b31360df
5 changed files with 112 additions and 113 deletions

View File

@@ -23,7 +23,7 @@ func newFixedTicker(t *testing.T, l *slog.Logger, cacheLen int) *ConntrackCacheT
cache: make(ConntrackCache, cacheLen),
}
for i := 0; i < cacheLen; i++ {
c.cache[PacketKey{TransportTuple: TransportTuple{LocalPort: uint16(i) + 1}}] = struct{}{}
c.cache[PacketKey{LocalPort: uint16(i) + 1}] = struct{}{}
}
c.cacheTick.Store(1) // cacheV starts at 0, so Get() takes the reset path
return c

View File

@@ -19,19 +19,6 @@ const (
PortFragment = -1 // Special value for matching `port: fragment`
)
// TransportTuple is the dense 5-tuple shape shared between the coalescer's
// flowKey-equivalent and the firewall's PacketKey. Stored in Local/Remote
// orientation so a flow's incoming and outgoing packets share the same
// tuple identity. v4 addresses occupy the low 4 bytes of LocalAddr/
// RemoteAddr (NOT v4-mapped form) so v4 vs v6 tuples never collide.
type TransportTuple struct {
LocalAddr [16]byte
RemoteAddr [16]byte
LocalPort uint16
RemotePort uint16
IsV6 bool
}
// PacketKey is the firewall's conntrack and ConntrackCache map key — the
// dense form of the 5-tuple plus the protocol and fragment flag the
// firewall actually discriminates flows on. Kept separate from Packet so
@@ -42,7 +29,11 @@ type TransportTuple struct {
// Superset of the coalescer's flowKey shape (same 5-tuple, just in
// Local/Remote orientation rather than wire src/dst).
type PacketKey struct {
TransportTuple
LocalAddr [16]byte
RemoteAddr [16]byte
LocalPort uint16
RemotePort uint16
IsV6 bool
Protocol uint8
Fragment bool
}
@@ -104,6 +95,16 @@ func (k *PacketKey) Hydrate(fp *Packet) {
}
}
func (k *PacketKey) GetRemoteAddr() netip.Addr {
if k.IsV6 {
return netip.AddrFrom16(k.RemoteAddr)
} else {
var v4 [4]byte
copy(v4[:], k.RemoteAddr[:4])
return netip.AddrFrom4(v4)
}
}
func (fp *Packet) Copy() *Packet {
return &Packet{
LocalAddr: fp.LocalAddr,

View File

@@ -28,7 +28,7 @@ func (f *Interface) consumeInsidePacket(pkt tio.Packet, fwPacket *firewall.Packe
// the same 5-tuple every segment will share, so a single newPacket /
// firewall check covers the whole superpacket.
packet := pkt.Bytes
err := newPacket(packet, false, fwPacket)
key, err := newPacketKey(packet, false)
if err != nil {
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Error while validating outbound packet",
@@ -39,6 +39,8 @@ func (f *Interface) consumeInsidePacket(pkt tio.Packet, fwPacket *firewall.Packe
return
}
key.Hydrate(fwPacket)
// Ignore local broadcast packets
if f.dropLocalBroadcast {
if f.myBroadcastAddrsTable.Contains(fwPacket.RemoteAddr) {
@@ -105,7 +107,7 @@ func (f *Interface) consumeInsidePacket(pkt tio.Packet, fwPacket *firewall.Packe
return
}
dropReason := f.firewall.Drop(fwPacket.Key(), fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
dropReason := f.firewall.Drop(key, fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason == nil {
f.sendInsideMessage(hostinfo, pkt, nb, sendBatch, rejectBuf, q)
} else {
@@ -392,15 +394,16 @@ func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cac
}
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
fp := &firewall.Packet{}
err := newPacket(p, false, fp)
key, err := newPacketKey(p, false)
if err != nil {
f.l.Warn("error while parsing outgoing packet for firewall check", "error", err)
return
}
fp := &firewall.Packet{}
key.Hydrate(fp)
// check if packet is in outbound fw rules
dropReason := f.firewall.Drop(fp.Key(), fp, false, hostinfo, f.pki.GetCAPool(), nil)
dropReason := f.firewall.Drop(key, fp, false, hostinfo, f.pki.GetCAPool(), nil)
if dropReason != nil {
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("dropping cached packet",

View File

@@ -313,37 +313,54 @@ var (
)
// newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
// newPacket parses data into a fully-hydrated firewall.Packet — kept as a
// thin wrapper around newPacketKey + Hydrate so there's one source of
// parse logic. Callers that don't need the netip.Addr-rich form (e.g.
// conntrack-only paths) should use newPacketKey directly.
func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
key, err := newPacketKey(data, incoming)
if err != nil {
return err
}
key.Hydrate(fp)
return nil
}
// newPacketKey parses data into a dense firewall.PacketKey. Hot path: no
// netip.Addr construction, no unique.Handle interning. Caller decides
// whether to also Hydrate to a Packet (for rule matching) or pass the key
// straight to conntrack.
func newPacketKey(data []byte, incoming bool) (firewall.PacketKey, error) {
var k firewall.PacketKey
if len(data) < 1 {
return ErrPacketTooShort
return k, ErrPacketTooShort
}
version := int((data[0] >> 4) & 0x0f)
switch version {
switch int((data[0] >> 4) & 0x0f) {
case ipv4.Version:
return parseV4(data, incoming, fp)
return k, parseV4Key(data, incoming, &k)
case ipv6.Version:
return parseV6(data, incoming, fp)
k.IsV6 = true
return k, parseV6Key(data, incoming, &k)
}
return ErrUnknownIPVersion
return k, ErrUnknownIPVersion
}
func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
func parseV6Key(data []byte, incoming bool, k *firewall.PacketKey) error {
dataLen := len(data)
if dataLen < ipv6.HeaderLen {
return ErrIPv6PacketTooShort
}
if incoming {
fp.RemoteAddr, _ = netip.AddrFromSlice(data[8:24])
fp.LocalAddr, _ = netip.AddrFromSlice(data[24:40])
copy(k.RemoteAddr[:], data[8:24])
copy(k.LocalAddr[:], data[24:40])
} else {
fp.LocalAddr, _ = netip.AddrFromSlice(data[8:24])
fp.RemoteAddr, _ = netip.AddrFromSlice(data[24:40])
copy(k.LocalAddr[:], data[8:24])
copy(k.RemoteAddr[:], data[24:40])
}
protoAt := 6 // NextHeader is at 6 bytes into the ipv6 header
offset := ipv6.HeaderLen // Start at the end of the ipv6 header
protoAt := 6
offset := ipv6.HeaderLen
next := 0
for {
if protoAt >= dataLen {
@@ -353,87 +370,72 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
switch proto {
case layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
fp.Protocol = uint8(proto)
fp.RemotePort = 0
fp.LocalPort = 0
fp.Fragment = false
k.Protocol = uint8(proto)
k.RemotePort = 0
k.LocalPort = 0
k.Fragment = false
return nil
case layers.IPProtocolICMPv6:
if dataLen < offset+6 {
return ErrIPv6PacketTooShort
}
fp.Protocol = uint8(proto)
fp.LocalPort = 0 //incoming vs outgoing doesn't matter for icmpv6
icmptype := data[offset+1]
switch icmptype {
k.Protocol = uint8(proto)
k.LocalPort = 0
switch data[offset+1] {
case layers.ICMPv6TypeEchoRequest, layers.ICMPv6TypeEchoReply:
fp.RemotePort = binary.BigEndian.Uint16(data[offset+4 : offset+6]) //identifier
k.RemotePort = binary.BigEndian.Uint16(data[offset+4 : offset+6])
default:
fp.RemotePort = 0
k.RemotePort = 0
}
fp.Fragment = false
k.Fragment = false
return nil
case layers.IPProtocolTCP, layers.IPProtocolUDP:
if dataLen < offset+4 {
return ErrIPv6PacketTooShort
}
fp.Protocol = uint8(proto)
k.Protocol = uint8(proto)
if incoming {
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
k.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
k.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
} else {
fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
k.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
k.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
}
fp.Fragment = false
k.Fragment = false
return nil
case layers.IPProtocolIPv6Fragment:
// Fragment header is 8 bytes, need at least offset+4 to read the offset field
if dataLen < offset+8 {
return ErrIPv6PacketTooShort
}
// Check if this is the first fragment
fragmentOffset := binary.BigEndian.Uint16(data[offset+2:offset+4]) &^ uint16(0x7) // Remove the reserved and M flag bits
fragmentOffset := binary.BigEndian.Uint16(data[offset+2:offset+4]) &^ uint16(0x7)
if fragmentOffset != 0 {
// Non-first fragment, use what we have now and stop processing
fp.Protocol = data[offset]
fp.Fragment = true
fp.RemotePort = 0
fp.LocalPort = 0
k.Protocol = data[offset]
k.Fragment = true
k.RemotePort = 0
k.LocalPort = 0
return nil
}
// The next loop should be the transport layer since we are the first fragment
next = 8 // Fragment headers are always 8 bytes
next = 8
case layers.IPProtocolAH:
// Auth headers, used by IPSec, have a different meaning for header length
if dataLen <= offset+1 {
break
}
next = int(data[offset+1]+2) << 2
default:
// Normal ipv6 header length processing
if dataLen <= offset+1 {
break
}
next = int(data[offset+1]+1) << 3
}
if next <= 0 {
// Safety check, each ipv6 header has to be at least 8 bytes
next = 8
}
protoAt = offset
offset = offset + next
}
@@ -441,61 +443,54 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
return ErrIPv6CouldNotFindPayload
}
func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
// Do we at least have an ipv4 header worth of data?
func parseV4Key(data []byte, incoming bool, k *firewall.PacketKey) error {
if len(data) < ipv4.HeaderLen {
return ErrIPv4PacketTooShort
}
// Adjust our start position based on the advertised ip header length
ihl := int(data[0]&0x0f) << 2
// Well-formed ip header length?
if ihl < ipv4.HeaderLen {
return ErrIPv4InvalidHeaderLength
}
// Check if this is the second or further fragment of a fragmented packet.
flagsfrags := binary.BigEndian.Uint16(data[6:8])
fp.Fragment = (flagsfrags & 0x1FFF) != 0
k.Fragment = (flagsfrags & 0x1FFF) != 0
k.Protocol = data[9]
// Firewall handles protocol checks
fp.Protocol = data[9]
// Accounting for a variable header length, do we have enough data for our src/dst tuples?
minLen := ihl
if !fp.Fragment {
if fp.Protocol == firewall.ProtoICMP {
if !k.Fragment {
if k.Protocol == firewall.ProtoICMP {
minLen += minFwPacketLen + 2
} else {
minLen += minFwPacketLen
}
}
if len(data) < minLen {
return ErrIPv4InvalidHeaderLength
}
if incoming { // Firewall packets are locally oriented
fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16])
fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20])
// Dense form: v4 in low 4 bytes, rest zero. Matches the coalescer's
// flowKey convention so the two stay byte-identical for the same flow.
if incoming {
copy(k.RemoteAddr[:4], data[12:16])
copy(k.LocalAddr[:4], data[16:20])
} else {
fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16])
fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20])
copy(k.LocalAddr[:4], data[12:16])
copy(k.RemoteAddr[:4], data[16:20])
}
if fp.Fragment {
fp.RemotePort = 0
fp.LocalPort = 0
} else if fp.Protocol == firewall.ProtoICMP { //note that orientation doesn't matter on ICMP
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier
fp.LocalPort = 0 //code would be uint16(data[ihl+1])
} else if incoming {
fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port
} else {
fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port
switch {
case k.Fragment:
k.RemotePort = 0
k.LocalPort = 0
case k.Protocol == firewall.ProtoICMP:
k.RemotePort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6])
k.LocalPort = 0
case incoming:
k.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2])
k.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
default:
k.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2])
k.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
}
return nil

View File

@@ -529,7 +529,7 @@ func BenchmarkParseV6(b *testing.B) {
b.Run("Normal", func(b *testing.B) {
for i := 0; i < b.N; i++ {
if err = parseV6(normalPacket, true, fp); err != nil {
if err = newPacket(normalPacket, true, fp); err != nil {
b.Fatal(err)
}
}
@@ -537,7 +537,7 @@ func BenchmarkParseV6(b *testing.B) {
b.Run("FirstFragment", func(b *testing.B) {
for i := 0; i < b.N; i++ {
if err = parseV6(firstFrag, true, fp); err != nil {
if err = newPacket(firstFrag, true, fp); err != nil {
b.Fatal(err)
}
}
@@ -545,7 +545,7 @@ func BenchmarkParseV6(b *testing.B) {
b.Run("SecondFragment", func(b *testing.B) {
for i := 0; i < b.N; i++ {
if err = parseV6(secondFrag, true, fp); err != nil {
if err = newPacket(secondFrag, true, fp); err != nil {
b.Fatal(err)
}
}
@@ -590,7 +590,7 @@ func BenchmarkParseV6(b *testing.B) {
b.Run("200 HopByHop headers", func(b *testing.B) {
for i := 0; i < b.N; i++ {
if err = parseV6(evilBytes, false, fp); err != nil {
if err = newPacket(evilBytes, false, fp); err != nil {
b.Fatal(err)
}
}