connection-track ICMP and ICMPv6 traffic

This commit is contained in:
JackDoan
2026-01-14 12:36:55 -06:00
parent e8bb874e14
commit 39452b5eec
6 changed files with 211 additions and 39 deletions

View File

@@ -37,17 +37,18 @@ docker run --name host4 --device /dev/net/tun:/dev/net/tun --cap-add NET_ADMIN -
sleep 1 sleep 1
# grab tcpdump pcaps for debugging # grab tcpdump pcaps for debugging
docker exec lighthouse1 tcpdump -i nebula1 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap & docker exec lighthouse1 tcpdump -i tun0 -q -w - -U 2>logs/lighthouse1.inside.log >logs/lighthouse1.inside.pcap &
docker exec lighthouse1 tcpdump -i eth0 -q -w - -U 2>logs/lighthouse1.outside.log >logs/lighthouse1.outside.pcap & docker exec lighthouse1 tcpdump -i eth0 -q -w - -U 2>logs/lighthouse1.outside.log >logs/lighthouse1.outside.pcap &
docker exec host2 tcpdump -i nebula1 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap & docker exec host2 tcpdump -i tun0 -q -w - -U 2>logs/host2.inside.log >logs/host2.inside.pcap &
docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host2.outside.pcap & docker exec host2 tcpdump -i eth0 -q -w - -U 2>logs/host2.outside.log >logs/host2.outside.pcap &
docker exec host3 tcpdump -i nebula1 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap & docker exec host3 tcpdump -i tun0 -q -w - -U 2>logs/host3.inside.log >logs/host3.inside.pcap &
docker exec host3 tcpdump -i eth0 -q -w - -U 2>logs/host3.outside.log >logs/host3.outside.pcap & docker exec host3 tcpdump -i eth0 -q -w - -U 2>logs/host3.outside.log >logs/host3.outside.pcap &
docker exec host4 tcpdump -i nebula1 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap & docker exec host4 tcpdump -i tun0 -q -w - -U 2>logs/host4.inside.log >logs/host4.inside.pcap &
docker exec host4 tcpdump -i eth0 -q -w - -U 2>logs/host4.outside.log >logs/host4.outside.pcap & docker exec host4 tcpdump -i eth0 -q -w - -U 2>logs/host4.outside.log >logs/host4.outside.pcap &
docker exec host2 ncat -nklv 0.0.0.0 2000 & docker exec host2 ncat -nklv 0.0.0.0 2000 &
docker exec host3 ncat -nklv 0.0.0.0 2000 & docker exec host3 ncat -nklv 0.0.0.0 2000 &
docker exec host4 ncat -nkluv 0.0.0.0 4000 &
docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 & docker exec host2 ncat -e '/usr/bin/echo host2' -nkluv 0.0.0.0 3000 &
docker exec host3 ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000 & docker exec host3 ncat -e '/usr/bin/echo host3' -nkluv 0.0.0.0 3000 &
@@ -119,11 +120,11 @@ echo
echo " *** Testing conntrack" echo " *** Testing conntrack"
echo echo
set -x set -x
# host2 can ping host3 now that host3 pinged it first
docker exec host2 ping -c1 192.168.100.3 # host2 speaking to host4 on UDP 4000 should allow it to reply, when firewall rules would normally not permit this
# host4 can ping host2 once conntrack established docker exec host2 sh -c "/usr/bin/echo host2 | ncat -nuv 192.168.100.4 4000"
docker exec host2 ping -c1 192.168.100.4 docker exec host2 ncat -e '/usr/bin/echo helloagainfromhost2' -nkluv 0.0.0.0 4000 &
docker exec host4 ping -c1 192.168.100.2 docker exec host4 sh -c "/usr/bin/echo host4 | ncat -nuv 192.168.100.2 4000"
docker exec host4 sh -c 'kill 1' docker exec host4 sh -c 'kill 1'
docker exec host3 sh -c 'kill 1' docker exec host3 sh -c 'kill 1'

View File

@@ -480,7 +480,7 @@ func (f *Firewall) metrics(incoming bool) firewallMetrics {
} }
} }
// Destroy cleans up any known cyclical references so the object can be free'd my GC. This should be called if a new // Destroy cleans up any known cyclical references so the object can be freed by GC. This should be called if a new
// firewall object is created // firewall object is created
func (f *Firewall) Destroy() { func (f *Firewall) Destroy() {
//TODO: clean references if/when needed //TODO: clean references if/when needed

View File

@@ -22,7 +22,10 @@ const (
type Packet struct { type Packet struct {
LocalAddr netip.Addr LocalAddr netip.Addr
RemoteAddr netip.Addr RemoteAddr netip.Addr
LocalPort uint16 // LocalPort is the destination port for incoming traffic, or the source port for outgoing. Zero for ICMP.
LocalPort uint16
// RemotePort is the source port for incoming traffic, or the destination port for outgoing.
// For ICMP, it's the "identifier". This is only used for connection tracking, actual firewall rules will not filter on ICMP identifier
RemotePort uint16 RemotePort uint16
Protocol uint8 Protocol uint8
Fragment bool Fragment bool
@@ -46,6 +49,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) {
proto = "tcp" proto = "tcp"
case ProtoICMP: case ProtoICMP:
proto = "icmp" proto = "icmp"
case ProtoICMPv6:
proto = "icmpv6"
case ProtoUDP: case ProtoUDP:
proto = "udp" proto = "udp"
default: default:

View File

@@ -735,6 +735,150 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
} }
func TestFirewall_ICMPPortBehavior(t *testing.T) {
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
myVpnNetworksTable := new(bart.Lite)
myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
network := netip.MustParsePrefix("1.2.3.4/24")
c := cert.CachedCertificate{
Certificate: &dummyCert{
name: "host1",
networks: []netip.Prefix{network},
groups: []string{"default-group"},
issuer: "signer-shasum",
},
InvertedGroups: map[string]struct{}{"default-group": {}},
}
h := HostInfo{
ConnectionState: &ConnectionState{
peerCert: &c,
},
vpnAddrs: []netip.Addr{network.Addr()},
}
h.buildNetworks(myVpnNetworksTable, c.Certificate)
cp := cert.NewCAPool()
templ := firewall.Packet{
LocalAddr: netip.MustParseAddr("1.2.3.4"),
RemoteAddr: netip.MustParseAddr("1.2.3.4"),
Protocol: firewall.ProtoICMP,
Fragment: false,
}
t.Run("ICMP allowed", func(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 0, 0, []string{"any"}, "", "", "", "", ""))
t.Run("zero ports", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 0
p.RemotePort = 0
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
})
t.Run("nonzero ports", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 0xabcd
p.RemotePort = 0x1234
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
})
})
t.Run("Any proto, some ports allowed", func(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 444, []string{"any"}, "", "", "", "", ""))
t.Run("zero ports, still blocked", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 0
p.RemotePort = 0
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
//now also allow outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
})
t.Run("nonzero ports, still blocked", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 0xabcd
p.RemotePort = 0x1234
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
//now also allow outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
})
t.Run("nonzero, matching ports, still blocked", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 80
p.RemotePort = 80
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
//now also allow outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
})
})
t.Run("Any proto, any port", func(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
t.Run("zero ports, allowed", func(t *testing.T) {
resetConntrack(fw)
p := templ.Copy()
p.LocalPort = 0
p.RemotePort = 0
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
})
t.Run("nonzero ports, allowed", func(t *testing.T) {
resetConntrack(fw)
p := templ.Copy()
p.LocalPort = 0xabcd
p.RemotePort = 0x1234
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
//different ID is blocked
p.RemotePort++
require.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
})
})
}
func TestFirewall_DropIPSpoofing(t *testing.T) { func TestFirewall_DropIPSpoofing(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
ob := &bytes.Buffer{} ob := &bytes.Buffer{}

View File

@@ -327,13 +327,24 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
proto := layers.IPProtocol(data[protoAt]) proto := layers.IPProtocol(data[protoAt])
switch proto { switch proto {
case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader: case layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
fp.Protocol = uint8(proto) fp.Protocol = uint8(proto)
fp.RemotePort = 0 fp.RemotePort = 0
fp.LocalPort = 0 fp.LocalPort = 0
fp.Fragment = false fp.Fragment = false
return nil return nil
case layers.IPProtocolICMPv6:
if dataLen < offset+6 {
return ErrIPv6PacketTooShort
}
fp.Protocol = uint8(proto)
//incoming vs outgoing doesn't matter for icmpv6
fp.RemotePort = binary.BigEndian.Uint16(data[offset+4 : offset+6]) //identifier
fp.LocalPort = 0 //code would be uint16(data[offset+1])
fp.Fragment = false
return nil
case layers.IPProtocolTCP, layers.IPProtocolUDP: case layers.IPProtocolTCP, layers.IPProtocolUDP:
if dataLen < offset+4 { if dataLen < offset+4 {
return ErrIPv6PacketTooShort return ErrIPv6PacketTooShort
@@ -423,34 +434,38 @@ func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
// Accounting for a variable header length, do we have enough data for our src/dst tuples? // Accounting for a variable header length, do we have enough data for our src/dst tuples?
minLen := ihl minLen := ihl
if !fp.Fragment && fp.Protocol != firewall.ProtoICMP { if !fp.Fragment {
minLen += minFwPacketLen if fp.Protocol == firewall.ProtoICMP {
minLen += minFwPacketLen + 2
} else {
minLen += minFwPacketLen
}
} }
if len(data) < minLen { if len(data) < minLen {
return ErrIPv4InvalidHeaderLength return ErrIPv4InvalidHeaderLength
} }
// Firewall packets are locally oriented if incoming { // Firewall packets are locally oriented
if incoming {
fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16]) fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16])
fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20]) fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
fp.RemotePort = 0
fp.LocalPort = 0
} else {
fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2])
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
}
} else { } else {
fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16]) fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16])
fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20]) fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP { }
fp.RemotePort = 0
fp.LocalPort = 0 if fp.Fragment {
} else { fp.RemotePort = 0
fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) fp.LocalPort = 0
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) } else if fp.Protocol == firewall.ProtoICMP { //note that orientation doesn't matter on ICMP
} fp.RemotePort = binary.BigEndian.Uint16(data[ihl+4 : ihl+6]) //identifier
fp.LocalPort = 0 //code would be uint16(data[ihl+1])
} else if incoming {
fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port
} else {
fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2]) //src port
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) //dst port
} }
return nil return nil

View File

@@ -155,6 +155,7 @@ func Test_newPacket_v6(t *testing.T) {
// next layer, missing length byte // next layer, missing length byte
err = newPacket(buffer.Bytes()[:49], true, p) err = newPacket(buffer.Bytes()[:49], true, p)
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
err = nil
// A good ICMP packet // A good ICMP packet
ip = layers.IPv6{ ip = layers.IPv6{
@@ -165,20 +166,26 @@ func Test_newPacket_v6(t *testing.T) {
DstIP: net.IPv6linklocalallnodes, DstIP: net.IPv6linklocalallnodes,
} }
icmp := layers.ICMPv6{} icmp := layers.ICMPv6{
TypeCode: 0x55,
buffer.Clear() Checksum: 0x1234,
err = gopacket.SerializeLayers(buffer, opt, &ip, &icmp)
if err != nil {
panic(err)
} }
err = newPacket(buffer.Bytes(), true, p) buffer.Clear()
require.NoError(t, err) require.NoError(t, gopacket.SerializeLayers(buffer, opt, &ip, &icmp))
require.Error(t, newPacket(buffer.Bytes(), true, p))
buffer.Clear()
echo := layers.ICMPv6Echo{
Identifier: 0xabcd,
SeqNumber: 1234,
}
require.NoError(t, gopacket.SerializeLayers(buffer, opt, &ip, &icmp, &echo))
require.NoError(t, newPacket(buffer.Bytes(), true, p))
assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol) assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
assert.Equal(t, uint16(0), p.RemotePort) assert.Equal(t, uint16(0xabcd), p.RemotePort)
assert.Equal(t, uint16(0), p.LocalPort) assert.Equal(t, uint16(0), p.LocalPort)
assert.False(t, p.Fragment) assert.False(t, p.Fragment)