mirror of
https://github.com/slackhq/nebula.git
synced 2026-02-14 08:44:24 +01:00
refactor
This commit is contained in:
@@ -20,13 +20,13 @@ func Test_newPacket(t *testing.T) {
|
||||
p := &firewall.Packet{}
|
||||
|
||||
// length fails
|
||||
err := newPacket([]byte{}, true, p)
|
||||
err := firewall.NewPacket([]byte{}, true, p)
|
||||
require.ErrorIs(t, err, ErrPacketTooShort)
|
||||
|
||||
err = newPacket([]byte{0x40}, true, p)
|
||||
err = firewall.NewPacket([]byte{0x40}, true, p)
|
||||
require.ErrorIs(t, err, ErrIPv4PacketTooShort)
|
||||
|
||||
err = newPacket([]byte{0x60}, true, p)
|
||||
err = firewall.NewPacket([]byte{0x60}, true, p)
|
||||
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||
|
||||
// length fail with ip options
|
||||
@@ -39,15 +39,15 @@ func Test_newPacket(t *testing.T) {
|
||||
}
|
||||
|
||||
b, _ := h.Marshal()
|
||||
err = newPacket(b, true, p)
|
||||
err = firewall.NewPacket(b, true, p)
|
||||
require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
|
||||
|
||||
// not an ipv4 packet
|
||||
err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
||||
err = firewall.NewPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
||||
require.ErrorIs(t, err, ErrUnknownIPVersion)
|
||||
|
||||
// invalid ihl
|
||||
err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
||||
err = firewall.NewPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
||||
require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
|
||||
|
||||
// account for variable ip header length - incoming
|
||||
@@ -62,7 +62,7 @@ func Test_newPacket(t *testing.T) {
|
||||
|
||||
b, _ = h.Marshal()
|
||||
b = append(b, []byte{0, 3, 0, 4}...)
|
||||
err = newPacket(b, true, p)
|
||||
err = firewall.NewPacket(b, true, p)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
||||
@@ -84,7 +84,7 @@ func Test_newPacket(t *testing.T) {
|
||||
|
||||
b, _ = h.Marshal()
|
||||
b = append(b, []byte{0, 5, 0, 6}...)
|
||||
err = newPacket(b, false, p)
|
||||
err = firewall.NewPacket(b, false, p)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint8(2), p.Protocol)
|
||||
@@ -114,7 +114,7 @@ func Test_newPacket_v6(t *testing.T) {
|
||||
err := gopacket.SerializeLayers(buffer, opt, &ip)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = newPacket(buffer.Bytes(), true, p)
|
||||
err = firewall.NewPacket(buffer.Bytes(), true, p)
|
||||
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||
|
||||
// A v6 packet with a hop-by-hop extension
|
||||
@@ -148,12 +148,12 @@ func Test_newPacket_v6(t *testing.T) {
|
||||
|
||||
// A full IPv6 header and 1 byte in the first extension, but missing
|
||||
// the length byte.
|
||||
err = newPacket(buffer.Bytes()[:41], true, p)
|
||||
err = firewall.NewPacket(buffer.Bytes()[:41], true, p)
|
||||
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||
|
||||
// A full IPv6 header plus 1 full extension, but only 1 byte of the
|
||||
// next layer, missing length byte
|
||||
err = newPacket(buffer.Bytes()[:49], true, p)
|
||||
err = firewall.NewPacket(buffer.Bytes()[:49], true, p)
|
||||
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||
|
||||
// A good ICMP packet
|
||||
@@ -173,7 +173,7 @@ func Test_newPacket_v6(t *testing.T) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = newPacket(buffer.Bytes(), true, p)
|
||||
err = firewall.NewPacket(buffer.Bytes(), true, p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||
@@ -185,7 +185,7 @@ func Test_newPacket_v6(t *testing.T) {
|
||||
// A good ESP packet
|
||||
b := buffer.Bytes()
|
||||
b[6] = byte(layers.IPProtocolESP)
|
||||
err = newPacket(b, true, p)
|
||||
err = firewall.NewPacket(b, true, p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||
@@ -197,7 +197,7 @@ func Test_newPacket_v6(t *testing.T) {
|
||||
// A good None packet
|
||||
b = buffer.Bytes()
|
||||
b[6] = byte(layers.IPProtocolNoNextHeader)
|
||||
err = newPacket(b, true, p)
|
||||
err = firewall.NewPacket(b, true, p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||
@@ -209,7 +209,7 @@ func Test_newPacket_v6(t *testing.T) {
|
||||
// An unknown protocol packet
|
||||
b = buffer.Bytes()
|
||||
b[6] = 255 // 255 is a reserved protocol number
|
||||
err = newPacket(b, true, p)
|
||||
err = firewall.NewPacket(b, true, p)
|
||||
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||
|
||||
// A good UDP packet
|
||||
@@ -236,7 +236,7 @@ func Test_newPacket_v6(t *testing.T) {
|
||||
b = buffer.Bytes()
|
||||
|
||||
// incoming
|
||||
err = newPacket(b, true, p)
|
||||
err = firewall.NewPacket(b, true, p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||
@@ -246,7 +246,7 @@ func Test_newPacket_v6(t *testing.T) {
|
||||
assert.False(t, p.Fragment)
|
||||
|
||||
// outgoing
|
||||
err = newPacket(b, false, p)
|
||||
err = firewall.NewPacket(b, false, p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
||||
@@ -256,14 +256,14 @@ func Test_newPacket_v6(t *testing.T) {
|
||||
assert.False(t, p.Fragment)
|
||||
|
||||
// Too short UDP packet
|
||||
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
|
||||
err = firewall.NewPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
|
||||
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||
|
||||
// A good TCP packet
|
||||
b[6] = byte(layers.IPProtocolTCP)
|
||||
|
||||
// incoming
|
||||
err = newPacket(b, true, p)
|
||||
err = firewall.NewPacket(b, true, p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||
@@ -273,7 +273,7 @@ func Test_newPacket_v6(t *testing.T) {
|
||||
assert.False(t, p.Fragment)
|
||||
|
||||
// outgoing
|
||||
err = newPacket(b, false, p)
|
||||
err = firewall.NewPacket(b, false, p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
||||
@@ -283,7 +283,7 @@ func Test_newPacket_v6(t *testing.T) {
|
||||
assert.False(t, p.Fragment)
|
||||
|
||||
// Too short TCP packet
|
||||
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
|
||||
err = firewall.NewPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
|
||||
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||
|
||||
// A good UDP packet with an AH header
|
||||
@@ -318,7 +318,7 @@ func Test_newPacket_v6(t *testing.T) {
|
||||
b = append(b, ahb...)
|
||||
b = append(b, udpHeader...)
|
||||
|
||||
err = newPacket(b, true, p)
|
||||
err = firewall.NewPacket(b, true, p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||
@@ -328,12 +328,12 @@ func Test_newPacket_v6(t *testing.T) {
|
||||
assert.False(t, p.Fragment)
|
||||
|
||||
// Ensure buffer bounds checking during processing
|
||||
err = newPacket(b[:41], true, p)
|
||||
err = firewall.NewPacket(b[:41], true, p)
|
||||
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||
|
||||
// Invalid AH header
|
||||
b = buffer.Bytes()
|
||||
err = newPacket(b, true, p)
|
||||
err = firewall.NewPacket(b, true, p)
|
||||
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||
}
|
||||
|
||||
@@ -381,7 +381,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
|
||||
firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
|
||||
|
||||
// Test first fragment incoming
|
||||
err = newPacket(firstFrag, true, p)
|
||||
err = firewall.NewPacket(firstFrag, true, p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||
@@ -391,7 +391,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
|
||||
assert.False(t, p.Fragment)
|
||||
|
||||
// Test first fragment outgoing
|
||||
err = newPacket(firstFrag, false, p)
|
||||
err = firewall.NewPacket(firstFrag, false, p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
||||
@@ -420,7 +420,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
|
||||
secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
|
||||
|
||||
// Test second fragment incoming
|
||||
err = newPacket(secondFrag, true, p)
|
||||
err = firewall.NewPacket(secondFrag, true, p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||
@@ -430,7 +430,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
|
||||
assert.True(t, p.Fragment)
|
||||
|
||||
// Test second fragment outgoing
|
||||
err = newPacket(secondFrag, false, p)
|
||||
err = firewall.NewPacket(secondFrag, false, p)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
||||
@@ -440,7 +440,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
|
||||
assert.True(t, p.Fragment)
|
||||
|
||||
// Too short of a fragment packet
|
||||
err = newPacket(secondFrag[:len(secondFrag)-10], false, p)
|
||||
err = firewall.NewPacket(secondFrag[:len(secondFrag)-10], false, p)
|
||||
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user