mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-11 16:33:58 +01:00
More correct ipv6 header parsing (#1323)
This commit is contained in:
parent
e4daed3563
commit
fbff6a1487
105
outside.go
105
outside.go
@ -3,7 +3,6 @@ package nebula
|
|||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -271,10 +270,19 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrPacketTooShort = errors.New("packet is too short")
|
||||||
|
ErrUnknownIPVersion = errors.New("packet is an unknown ip version")
|
||||||
|
ErrIPv4InvalidHeaderLength = errors.New("invalid ipv4 header length")
|
||||||
|
ErrIPv4PacketTooShort = errors.New("ipv4 packet is too short")
|
||||||
|
ErrIPv6PacketTooShort = errors.New("ipv6 packet is too short")
|
||||||
|
ErrIPv6CouldNotFindPayload = errors.New("could not find payload in ipv6 packet")
|
||||||
|
)
|
||||||
|
|
||||||
// newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
|
// newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
|
||||||
func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
|
func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
|
||||||
if len(data) < 1 {
|
if len(data) < 1 {
|
||||||
return errors.New("packet too short")
|
return ErrPacketTooShort
|
||||||
}
|
}
|
||||||
|
|
||||||
version := int((data[0] >> 4) & 0x0f)
|
version := int((data[0] >> 4) & 0x0f)
|
||||||
@ -284,13 +292,13 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
|
|||||||
case ipv6.Version:
|
case ipv6.Version:
|
||||||
return parseV6(data, incoming, fp)
|
return parseV6(data, incoming, fp)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("packet is an unknown ip version: %v", version)
|
return ErrUnknownIPVersion
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
|
func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
|
||||||
dataLen := len(data)
|
dataLen := len(data)
|
||||||
if dataLen < ipv6.HeaderLen {
|
if dataLen < ipv6.HeaderLen {
|
||||||
return fmt.Errorf("ipv6 packet is less than %v bytes", ipv4.HeaderLen)
|
return ErrIPv6PacketTooShort
|
||||||
}
|
}
|
||||||
|
|
||||||
if incoming {
|
if incoming {
|
||||||
@ -301,11 +309,10 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
|
|||||||
fp.RemoteAddr, _ = netip.AddrFromSlice(data[24:40])
|
fp.RemoteAddr, _ = netip.AddrFromSlice(data[24:40])
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: CERT-V2 whats a reasonable number of extension headers to attempt to parse?
|
protoAt := 6 // NextHeader is at 6 bytes into the ipv6 header
|
||||||
//https://www.ietf.org/archive/id/draft-ietf-6man-eh-limits-00.html
|
offset := ipv6.HeaderLen // Start at the end of the ipv6 header
|
||||||
protoAt := 6
|
next := 0
|
||||||
offset := 40
|
for {
|
||||||
for i := 0; i < 24; i++ {
|
|
||||||
if dataLen < offset {
|
if dataLen < offset {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -313,32 +320,18 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
|
|||||||
proto := layers.IPProtocol(data[protoAt])
|
proto := layers.IPProtocol(data[protoAt])
|
||||||
//fmt.Println(proto, protoAt)
|
//fmt.Println(proto, protoAt)
|
||||||
switch proto {
|
switch proto {
|
||||||
case layers.IPProtocolICMPv6:
|
case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
|
||||||
fp.Protocol = uint8(proto)
|
fp.Protocol = uint8(proto)
|
||||||
fp.RemotePort = 0
|
fp.RemotePort = 0
|
||||||
fp.LocalPort = 0
|
fp.LocalPort = 0
|
||||||
fp.Fragment = false
|
fp.Fragment = false
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
case layers.IPProtocolTCP:
|
case layers.IPProtocolTCP, layers.IPProtocolUDP:
|
||||||
if dataLen < offset+4 {
|
if dataLen < offset+4 {
|
||||||
return fmt.Errorf("ipv6 packet was too small")
|
return ErrIPv6PacketTooShort
|
||||||
}
|
}
|
||||||
fp.Protocol = uint8(proto)
|
|
||||||
if incoming {
|
|
||||||
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
|
|
||||||
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
|
|
||||||
} else {
|
|
||||||
fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
|
|
||||||
fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
|
|
||||||
}
|
|
||||||
fp.Fragment = false
|
|
||||||
return nil
|
|
||||||
|
|
||||||
case layers.IPProtocolUDP:
|
|
||||||
if dataLen < offset+4 {
|
|
||||||
return fmt.Errorf("ipv6 packet was too small")
|
|
||||||
}
|
|
||||||
fp.Protocol = uint8(proto)
|
fp.Protocol = uint8(proto)
|
||||||
if incoming {
|
if incoming {
|
||||||
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
|
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
|
||||||
@ -347,47 +340,71 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
|
|||||||
fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
|
fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
|
||||||
fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
|
fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
|
||||||
}
|
}
|
||||||
|
|
||||||
fp.Fragment = false
|
fp.Fragment = false
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
case layers.IPProtocolIPv6Fragment:
|
case layers.IPProtocolIPv6Fragment:
|
||||||
//TODO: CERT-V2 can we determine the protocol?
|
// Fragment header is 8 bytes, need at least offset+4 to read the offset field
|
||||||
fp.RemotePort = 0
|
if dataLen < offset+8 {
|
||||||
fp.LocalPort = 0
|
return ErrIPv6PacketTooShort
|
||||||
fp.Fragment = true
|
}
|
||||||
return nil
|
|
||||||
|
|
||||||
default:
|
// Check if this is the first fragment
|
||||||
|
fragmentOffset := binary.BigEndian.Uint16(data[offset+2:offset+4]) &^ uint16(0x7) // Remove the reserved and M flag bits
|
||||||
|
if fragmentOffset != 0 {
|
||||||
|
// Non-first fragment, use what we have now and stop processing
|
||||||
|
fp.Protocol = data[offset]
|
||||||
|
fp.Fragment = true
|
||||||
|
fp.RemotePort = 0
|
||||||
|
fp.LocalPort = 0
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// The next loop should be the transport layer since we are the first fragment
|
||||||
|
next = 8 // Fragment headers are always 8 bytes
|
||||||
|
|
||||||
|
case layers.IPProtocolAH:
|
||||||
|
// Auth headers, used by IPSec, have a different meaning for header length
|
||||||
if dataLen < offset+1 {
|
if dataLen < offset+1 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
next := int(data[offset+1]) * 8
|
next = int(data[offset+1]+2) << 2
|
||||||
if next == 0 {
|
|
||||||
// each extension is at least 8 bytes
|
default:
|
||||||
next = 8
|
// Normal ipv6 header length processing
|
||||||
|
if dataLen < offset+1 {
|
||||||
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
protoAt = offset
|
next = int(data[offset+1]+1) << 3
|
||||||
offset = offset + next
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if next <= 0 {
|
||||||
|
// Safety check, each ipv6 header has to be at least 8 bytes
|
||||||
|
next = 8
|
||||||
|
}
|
||||||
|
|
||||||
|
protoAt = offset
|
||||||
|
offset = offset + next
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("could not find payload in ipv6 packet")
|
return ErrIPv6CouldNotFindPayload
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
|
func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
|
||||||
// Do we at least have an ipv4 header worth of data?
|
// Do we at least have an ipv4 header worth of data?
|
||||||
if len(data) < ipv4.HeaderLen {
|
if len(data) < ipv4.HeaderLen {
|
||||||
return fmt.Errorf("ipv4 packet is less than %v bytes", ipv4.HeaderLen)
|
return ErrIPv4PacketTooShort
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adjust our start position based on the advertised ip header length
|
// Adjust our start position based on the advertised ip header length
|
||||||
ihl := int(data[0]&0x0f) << 2
|
ihl := int(data[0]&0x0f) << 2
|
||||||
|
|
||||||
// Well formed ip header length?
|
// Well-formed ip header length?
|
||||||
if ihl < ipv4.HeaderLen {
|
if ihl < ipv4.HeaderLen {
|
||||||
return fmt.Errorf("ipv4 packet had an invalid header length: %v", ihl)
|
return ErrIPv4InvalidHeaderLength
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if this is the second or further fragment of a fragmented packet.
|
// Check if this is the second or further fragment of a fragmented packet.
|
||||||
@ -403,7 +420,7 @@ func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
|
|||||||
minLen += minFwPacketLen
|
minLen += minFwPacketLen
|
||||||
}
|
}
|
||||||
if len(data) < minLen {
|
if len(data) < minLen {
|
||||||
return fmt.Errorf("ipv4 packet is less than %v bytes, ip header len: %v", minLen, ihl)
|
return ErrIPv4InvalidHeaderLength
|
||||||
}
|
}
|
||||||
|
|
||||||
// Firewall packets are locally oriented
|
// Firewall packets are locally oriented
|
||||||
@ -501,7 +518,7 @@ func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) {
|
|||||||
f.messageMetrics.Tx(header.RecvError, 0, 1)
|
f.messageMetrics.Tx(header.RecvError, 0, 1)
|
||||||
|
|
||||||
b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
|
b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
|
||||||
f.outside.WriteTo(b, endpoint)
|
_ = f.outside.WriteTo(b, endpoint)
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
f.l.WithField("index", index).
|
f.l.WithField("index", index).
|
||||||
WithField("udpAddr", endpoint).
|
WithField("udpAddr", endpoint).
|
||||||
|
|||||||
529
outside_test.go
529
outside_test.go
@ -1,6 +1,8 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
@ -18,13 +20,13 @@ func Test_newPacket(t *testing.T) {
|
|||||||
|
|
||||||
// length fails
|
// length fails
|
||||||
err := newPacket([]byte{}, true, p)
|
err := newPacket([]byte{}, true, p)
|
||||||
assert.EqualError(t, err, "packet too short")
|
assert.ErrorIs(t, err, ErrPacketTooShort)
|
||||||
|
|
||||||
err = newPacket([]byte{0x40}, true, p)
|
err = newPacket([]byte{0x40}, true, p)
|
||||||
assert.EqualError(t, err, "ipv4 packet is less than 20 bytes")
|
assert.ErrorIs(t, err, ErrIPv4PacketTooShort)
|
||||||
|
|
||||||
err = newPacket([]byte{0x60}, true, p)
|
err = newPacket([]byte{0x60}, true, p)
|
||||||
assert.EqualError(t, err, "ipv6 packet is less than 20 bytes")
|
assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||||
|
|
||||||
// length fail with ip options
|
// length fail with ip options
|
||||||
h := ipv4.Header{
|
h := ipv4.Header{
|
||||||
@ -37,16 +39,15 @@ func Test_newPacket(t *testing.T) {
|
|||||||
|
|
||||||
b, _ := h.Marshal()
|
b, _ := h.Marshal()
|
||||||
err = newPacket(b, true, p)
|
err = newPacket(b, true, p)
|
||||||
|
assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
|
||||||
assert.EqualError(t, err, "ipv4 packet is less than 28 bytes, ip header len: 24")
|
|
||||||
|
|
||||||
// not an ipv4 packet
|
// not an ipv4 packet
|
||||||
err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
||||||
assert.EqualError(t, err, "packet is an unknown ip version: 0")
|
assert.ErrorIs(t, err, ErrUnknownIPVersion)
|
||||||
|
|
||||||
// invalid ihl
|
// invalid ihl
|
||||||
err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
||||||
assert.EqualError(t, err, "ipv4 packet had an invalid header length: 8")
|
assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
|
||||||
|
|
||||||
// account for variable ip header length - incoming
|
// account for variable ip header length - incoming
|
||||||
h = ipv4.Header{
|
h = ipv4.Header{
|
||||||
@ -63,11 +64,12 @@ func Test_newPacket(t *testing.T) {
|
|||||||
err = newPacket(b, true, p)
|
err = newPacket(b, true, p)
|
||||||
|
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP))
|
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
||||||
assert.Equal(t, p.LocalAddr, netip.MustParseAddr("10.0.0.2"))
|
assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr)
|
||||||
assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("10.0.0.1"))
|
assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr)
|
||||||
assert.Equal(t, p.RemotePort, uint16(3))
|
assert.Equal(t, uint16(3), p.RemotePort)
|
||||||
assert.Equal(t, p.LocalPort, uint16(4))
|
assert.Equal(t, uint16(4), p.LocalPort)
|
||||||
|
assert.False(t, p.Fragment)
|
||||||
|
|
||||||
// account for variable ip header length - outgoing
|
// account for variable ip header length - outgoing
|
||||||
h = ipv4.Header{
|
h = ipv4.Header{
|
||||||
@ -84,17 +86,94 @@ func Test_newPacket(t *testing.T) {
|
|||||||
err = newPacket(b, false, p)
|
err = newPacket(b, false, p)
|
||||||
|
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, p.Protocol, uint8(2))
|
assert.Equal(t, uint8(2), p.Protocol)
|
||||||
assert.Equal(t, p.LocalAddr, netip.MustParseAddr("10.0.0.1"))
|
assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr)
|
||||||
assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("10.0.0.2"))
|
assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr)
|
||||||
assert.Equal(t, p.RemotePort, uint16(6))
|
assert.Equal(t, uint16(6), p.RemotePort)
|
||||||
assert.Equal(t, p.LocalPort, uint16(5))
|
assert.Equal(t, uint16(5), p.LocalPort)
|
||||||
|
assert.False(t, p.Fragment)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_newPacket_v6(t *testing.T) {
|
func Test_newPacket_v6(t *testing.T) {
|
||||||
p := &firewall.Packet{}
|
p := &firewall.Packet{}
|
||||||
|
|
||||||
|
// invalid ipv6
|
||||||
ip := layers.IPv6{
|
ip := layers.IPv6{
|
||||||
|
Version: 6,
|
||||||
|
HopLimit: 128,
|
||||||
|
SrcIP: net.IPv6linklocalallrouters,
|
||||||
|
DstIP: net.IPv6linklocalallnodes,
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer := gopacket.NewSerializeBuffer()
|
||||||
|
opt := gopacket.SerializeOptions{
|
||||||
|
ComputeChecksums: false,
|
||||||
|
FixLengths: false,
|
||||||
|
}
|
||||||
|
err := gopacket.SerializeLayers(buffer, opt, &ip)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = newPacket(buffer.Bytes(), true, p)
|
||||||
|
assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||||
|
|
||||||
|
// A good ICMP packet
|
||||||
|
ip = layers.IPv6{
|
||||||
|
Version: 6,
|
||||||
|
NextHeader: layers.IPProtocolICMPv6,
|
||||||
|
HopLimit: 128,
|
||||||
|
SrcIP: net.IPv6linklocalallrouters,
|
||||||
|
DstIP: net.IPv6linklocalallnodes,
|
||||||
|
}
|
||||||
|
|
||||||
|
icmp := layers.ICMPv6{}
|
||||||
|
|
||||||
|
buffer.Clear()
|
||||||
|
err = gopacket.SerializeLayers(buffer, opt, &ip, &icmp)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = newPacket(buffer.Bytes(), true, p)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
|
assert.Equal(t, uint16(0), p.RemotePort)
|
||||||
|
assert.Equal(t, uint16(0), p.LocalPort)
|
||||||
|
assert.False(t, p.Fragment)
|
||||||
|
|
||||||
|
// A good ESP packet
|
||||||
|
b := buffer.Bytes()
|
||||||
|
b[6] = byte(layers.IPProtocolESP)
|
||||||
|
err = newPacket(b, true, p)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
|
assert.Equal(t, uint16(0), p.RemotePort)
|
||||||
|
assert.Equal(t, uint16(0), p.LocalPort)
|
||||||
|
assert.False(t, p.Fragment)
|
||||||
|
|
||||||
|
// A good None packet
|
||||||
|
b = buffer.Bytes()
|
||||||
|
b[6] = byte(layers.IPProtocolNoNextHeader)
|
||||||
|
err = newPacket(b, true, p)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
|
assert.Equal(t, uint16(0), p.RemotePort)
|
||||||
|
assert.Equal(t, uint16(0), p.LocalPort)
|
||||||
|
assert.False(t, p.Fragment)
|
||||||
|
|
||||||
|
// An unknown protocol packet
|
||||||
|
b = buffer.Bytes()
|
||||||
|
b[6] = 255 // 255 is a reserved protocol number
|
||||||
|
err = newPacket(b, true, p)
|
||||||
|
assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||||
|
|
||||||
|
// A good UDP packet
|
||||||
|
ip = layers.IPv6{
|
||||||
Version: 6,
|
Version: 6,
|
||||||
NextHeader: firewall.ProtoUDP,
|
NextHeader: firewall.ProtoUDP,
|
||||||
HopLimit: 128,
|
HopLimit: 128,
|
||||||
@ -106,39 +185,407 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
SrcPort: layers.UDPPort(36123),
|
SrcPort: layers.UDPPort(36123),
|
||||||
DstPort: layers.UDPPort(22),
|
DstPort: layers.UDPPort(22),
|
||||||
}
|
}
|
||||||
err := udp.SetNetworkLayerForChecksum(&ip)
|
err = udp.SetNetworkLayerForChecksum(&ip)
|
||||||
if err != nil {
|
assert.NoError(t, err)
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
buffer := gopacket.NewSerializeBuffer()
|
buffer.Clear()
|
||||||
opt := gopacket.SerializeOptions{
|
|
||||||
ComputeChecksums: true,
|
|
||||||
FixLengths: true,
|
|
||||||
}
|
|
||||||
err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef}))
|
err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef}))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
b := buffer.Bytes()
|
b = buffer.Bytes()
|
||||||
|
|
||||||
//test incoming
|
// incoming
|
||||||
err = newPacket(b, true, p)
|
err = newPacket(b, true, p)
|
||||||
|
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, p.Protocol, uint8(firewall.ProtoUDP))
|
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
||||||
assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("ff02::2"))
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
assert.Equal(t, p.LocalAddr, netip.MustParseAddr("ff02::1"))
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
assert.Equal(t, p.RemotePort, uint16(36123))
|
assert.Equal(t, uint16(36123), p.RemotePort)
|
||||||
assert.Equal(t, p.LocalPort, uint16(22))
|
assert.Equal(t, uint16(22), p.LocalPort)
|
||||||
|
assert.False(t, p.Fragment)
|
||||||
|
|
||||||
//test outgoing
|
// outgoing
|
||||||
err = newPacket(b, false, p)
|
err = newPacket(b, false, p)
|
||||||
|
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, p.Protocol, uint8(firewall.ProtoUDP))
|
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
||||||
assert.Equal(t, p.LocalAddr, netip.MustParseAddr("ff02::2"))
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
||||||
assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("ff02::1"))
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
||||||
assert.Equal(t, p.LocalPort, uint16(36123))
|
assert.Equal(t, uint16(36123), p.LocalPort)
|
||||||
assert.Equal(t, p.RemotePort, uint16(22))
|
assert.Equal(t, uint16(22), p.RemotePort)
|
||||||
|
assert.False(t, p.Fragment)
|
||||||
|
|
||||||
|
// Too short UDP packet
|
||||||
|
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
|
||||||
|
assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||||
|
|
||||||
|
// A good TCP packet
|
||||||
|
b[6] = byte(layers.IPProtocolTCP)
|
||||||
|
|
||||||
|
// incoming
|
||||||
|
err = newPacket(b, true, p)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
|
assert.Equal(t, uint16(36123), p.RemotePort)
|
||||||
|
assert.Equal(t, uint16(22), p.LocalPort)
|
||||||
|
assert.False(t, p.Fragment)
|
||||||
|
|
||||||
|
// outgoing
|
||||||
|
err = newPacket(b, false, p)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
||||||
|
assert.Equal(t, uint16(36123), p.LocalPort)
|
||||||
|
assert.Equal(t, uint16(22), p.RemotePort)
|
||||||
|
assert.False(t, p.Fragment)
|
||||||
|
|
||||||
|
// Too short TCP packet
|
||||||
|
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
|
||||||
|
assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||||
|
|
||||||
|
// A good UDP packet with an AH header
|
||||||
|
ip = layers.IPv6{
|
||||||
|
Version: 6,
|
||||||
|
NextHeader: layers.IPProtocolAH,
|
||||||
|
HopLimit: 128,
|
||||||
|
SrcIP: net.IPv6linklocalallrouters,
|
||||||
|
DstIP: net.IPv6linklocalallnodes,
|
||||||
|
}
|
||||||
|
|
||||||
|
ah := layers.IPSecAH{
|
||||||
|
AuthenticationData: []byte{0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef},
|
||||||
|
}
|
||||||
|
ah.NextHeader = layers.IPProtocolUDP
|
||||||
|
|
||||||
|
udpHeader := []byte{
|
||||||
|
0x8d, 0x1b, // Source port 36123
|
||||||
|
0x00, 0x16, // Destination port 22
|
||||||
|
0x00, 0x00, // Length
|
||||||
|
0x00, 0x00, // Checksum
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer.Clear()
|
||||||
|
err = ip.SerializeTo(buffer, opt)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b = buffer.Bytes()
|
||||||
|
ahb := serializeAH(&ah)
|
||||||
|
b = append(b, ahb...)
|
||||||
|
b = append(b, udpHeader...)
|
||||||
|
|
||||||
|
err = newPacket(b, true, p)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
|
assert.Equal(t, uint16(36123), p.RemotePort)
|
||||||
|
assert.Equal(t, uint16(22), p.LocalPort)
|
||||||
|
assert.False(t, p.Fragment)
|
||||||
|
|
||||||
|
// Invalid AH header
|
||||||
|
b = buffer.Bytes()
|
||||||
|
err = newPacket(b, true, p)
|
||||||
|
assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_newPacket_ipv6Fragment(t *testing.T) {
|
||||||
|
p := &firewall.Packet{}
|
||||||
|
|
||||||
|
ip := &layers.IPv6{
|
||||||
|
Version: 6,
|
||||||
|
NextHeader: layers.IPProtocolIPv6Fragment,
|
||||||
|
HopLimit: 64,
|
||||||
|
SrcIP: net.IPv6linklocalallrouters,
|
||||||
|
DstIP: net.IPv6linklocalallnodes,
|
||||||
|
}
|
||||||
|
|
||||||
|
// First fragment
|
||||||
|
fragHeader1 := []byte{
|
||||||
|
uint8(layers.IPProtocolUDP), // Next Header (UDP)
|
||||||
|
0x00, // Reserved
|
||||||
|
0x00, // Fragment Offset high byte (0)
|
||||||
|
0x01, // Fragment Offset low byte & flags (M=1)
|
||||||
|
0x00, 0x00, 0x00, 0x01, // Identification
|
||||||
|
}
|
||||||
|
|
||||||
|
udpHeader := []byte{
|
||||||
|
0x8d, 0x1b, // Source port 36123
|
||||||
|
0x00, 0x16, // Destination port 22
|
||||||
|
0x00, 0x00, // Length
|
||||||
|
0x00, 0x00, // Checksum
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{
|
||||||
|
ComputeChecksums: true,
|
||||||
|
FixLengths: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := ip.SerializeTo(buffer, opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
firstFrag := buffer.Bytes()
|
||||||
|
firstFrag = append(firstFrag, fragHeader1...)
|
||||||
|
firstFrag = append(firstFrag, udpHeader...)
|
||||||
|
firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
|
||||||
|
|
||||||
|
// Test first fragment incoming
|
||||||
|
err = newPacket(firstFrag, true, p)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
|
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
|
||||||
|
assert.Equal(t, uint16(36123), p.RemotePort)
|
||||||
|
assert.Equal(t, uint16(22), p.LocalPort)
|
||||||
|
assert.False(t, p.Fragment)
|
||||||
|
|
||||||
|
// Test first fragment outgoing
|
||||||
|
err = newPacket(firstFrag, false, p)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
||||||
|
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
|
||||||
|
assert.Equal(t, uint16(36123), p.LocalPort)
|
||||||
|
assert.Equal(t, uint16(22), p.RemotePort)
|
||||||
|
assert.False(t, p.Fragment)
|
||||||
|
|
||||||
|
// Second fragment
|
||||||
|
fragHeader2 := []byte{
|
||||||
|
uint8(layers.IPProtocolUDP), // Next Header (UDP)
|
||||||
|
0x00, // Reserved
|
||||||
|
0xb9, // Fragment Offset high byte (185)
|
||||||
|
0x01, // Fragment Offset low byte & flags (M=1)
|
||||||
|
0x00, 0x00, 0x00, 0x01, // Identification
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer.Clear()
|
||||||
|
err = ip.SerializeTo(buffer, opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
secondFrag := buffer.Bytes()
|
||||||
|
secondFrag = append(secondFrag, fragHeader2...)
|
||||||
|
secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
|
||||||
|
|
||||||
|
// Test second fragment incoming
|
||||||
|
err = newPacket(secondFrag, true, p)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
|
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
|
||||||
|
assert.Equal(t, uint16(0), p.RemotePort)
|
||||||
|
assert.Equal(t, uint16(0), p.LocalPort)
|
||||||
|
assert.True(t, p.Fragment)
|
||||||
|
|
||||||
|
// Test second fragment outgoing
|
||||||
|
err = newPacket(secondFrag, false, p)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
||||||
|
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
|
||||||
|
assert.Equal(t, uint16(0), p.LocalPort)
|
||||||
|
assert.Equal(t, uint16(0), p.RemotePort)
|
||||||
|
assert.True(t, p.Fragment)
|
||||||
|
|
||||||
|
// Too short of a fragment packet
|
||||||
|
err = newPacket(secondFrag[:len(secondFrag)-10], false, p)
|
||||||
|
assert.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkParseV6(b *testing.B) {
|
||||||
|
// Regular UDP packet
|
||||||
|
ip := &layers.IPv6{
|
||||||
|
Version: 6,
|
||||||
|
NextHeader: layers.IPProtocolUDP,
|
||||||
|
HopLimit: 64,
|
||||||
|
SrcIP: net.IPv6linklocalallrouters,
|
||||||
|
DstIP: net.IPv6linklocalallnodes,
|
||||||
|
}
|
||||||
|
|
||||||
|
udp := &layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(36123),
|
||||||
|
DstPort: layers.UDPPort(22),
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{
|
||||||
|
ComputeChecksums: false,
|
||||||
|
FixLengths: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := gopacket.SerializeLayers(buffer, opts, ip, udp)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
normalPacket := buffer.Bytes()
|
||||||
|
|
||||||
|
// First Fragment packet
|
||||||
|
ipFrag := &layers.IPv6{
|
||||||
|
Version: 6,
|
||||||
|
NextHeader: layers.IPProtocolIPv6Fragment,
|
||||||
|
HopLimit: 64,
|
||||||
|
SrcIP: net.IPv6linklocalallrouters,
|
||||||
|
DstIP: net.IPv6linklocalallnodes,
|
||||||
|
}
|
||||||
|
|
||||||
|
fragHeader := []byte{
|
||||||
|
uint8(layers.IPProtocolUDP), // Next Header (UDP)
|
||||||
|
0x00, // Reserved
|
||||||
|
0x00, // Fragment Offset high byte (0)
|
||||||
|
0x01, // Fragment Offset low byte & flags (M=1)
|
||||||
|
0x00, 0x00, 0x00, 0x01, // Identification
|
||||||
|
}
|
||||||
|
|
||||||
|
udpHeader := []byte{
|
||||||
|
0x8d, 0x7b, // Source port 36123
|
||||||
|
0x00, 0x16, // Destination port 22
|
||||||
|
0x00, 0x00, // Length
|
||||||
|
0x00, 0x00, // Checksum
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer.Clear()
|
||||||
|
err = ipFrag.SerializeTo(buffer, opts)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
firstFrag := buffer.Bytes()
|
||||||
|
firstFrag = append(firstFrag, fragHeader...)
|
||||||
|
firstFrag = append(firstFrag, udpHeader...)
|
||||||
|
firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
|
||||||
|
|
||||||
|
// Second Fragment packet
|
||||||
|
fragHeader[2] = 0xb9 // offset 185
|
||||||
|
buffer.Clear()
|
||||||
|
err = ipFrag.SerializeTo(buffer, opts)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
secondFrag := buffer.Bytes()
|
||||||
|
secondFrag = append(secondFrag, fragHeader...)
|
||||||
|
secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
|
||||||
|
|
||||||
|
fp := &firewall.Packet{}
|
||||||
|
|
||||||
|
b.Run("Normal", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if err = parseV6(normalPacket, true, fp); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("FirstFragment", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if err = parseV6(firstFrag, true, fp); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("SecondFragment", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if err = parseV6(secondFrag, true, fp); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Evil packet
|
||||||
|
evilPacket := &layers.IPv6{
|
||||||
|
Version: 6,
|
||||||
|
NextHeader: layers.IPProtocolIPv6HopByHop,
|
||||||
|
HopLimit: 64,
|
||||||
|
SrcIP: net.IPv6linklocalallrouters,
|
||||||
|
DstIP: net.IPv6linklocalallnodes,
|
||||||
|
}
|
||||||
|
|
||||||
|
hopHeader := []byte{
|
||||||
|
uint8(layers.IPProtocolIPv6HopByHop), // Next Header (HopByHop)
|
||||||
|
0x00, // Length
|
||||||
|
0x00, 0x00, // Options and padding
|
||||||
|
0x00, 0x00, 0x00, 0x00, // More options and padding
|
||||||
|
}
|
||||||
|
|
||||||
|
lastHopHeader := []byte{
|
||||||
|
uint8(layers.IPProtocolUDP), // Next Header (UDP)
|
||||||
|
0x00, // Length
|
||||||
|
0x00, 0x00, // Options and padding
|
||||||
|
0x00, 0x00, 0x00, 0x00, // More options and padding
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer.Clear()
|
||||||
|
err = evilPacket.SerializeTo(buffer, opts)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
evilBytes := buffer.Bytes()
|
||||||
|
for i := 0; i < 200; i++ {
|
||||||
|
evilBytes = append(evilBytes, hopHeader...)
|
||||||
|
}
|
||||||
|
evilBytes = append(evilBytes, lastHopHeader...)
|
||||||
|
evilBytes = append(evilBytes, udpHeader...)
|
||||||
|
evilBytes = append(evilBytes, []byte{0xde, 0xad, 0xbe, 0xef}...)
|
||||||
|
|
||||||
|
b.Run("200 HopByHop headers", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if err = parseV6(evilBytes, false, fp); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure authentication data is a multiple of 8 bytes by padding if necessary
|
||||||
|
func padAuthData(authData []byte) []byte {
|
||||||
|
// Length of Authentication Data must be a multiple of 8 bytes
|
||||||
|
paddingLength := (8 - (len(authData) % 8)) % 8 // Only pad if necessary
|
||||||
|
if paddingLength > 0 {
|
||||||
|
authData = append(authData, make([]byte, paddingLength)...)
|
||||||
|
}
|
||||||
|
return authData
|
||||||
|
}
|
||||||
|
|
||||||
|
// Custom function to manually serialize IPSecAH for both IPv4 and IPv6
|
||||||
|
func serializeAH(ah *layers.IPSecAH) []byte {
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
|
||||||
|
// Ensure Authentication Data is a multiple of 8 bytes
|
||||||
|
ah.AuthenticationData = padAuthData(ah.AuthenticationData)
|
||||||
|
// Calculate Payload Length (in 32-bit words, minus 2)
|
||||||
|
payloadLen := uint8((12+len(ah.AuthenticationData))/4) - 2
|
||||||
|
|
||||||
|
// Serialize fields
|
||||||
|
if err := binary.Write(buf, binary.BigEndian, ah.NextHeader); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := binary.Write(buf, binary.BigEndian, payloadLen); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := binary.Write(buf, binary.BigEndian, ah.Reserved); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := binary.Write(buf, binary.BigEndian, ah.SPI); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := binary.Write(buf, binary.BigEndian, ah.Seq); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if len(ah.AuthenticationData) > 0 {
|
||||||
|
if err := binary.Write(buf, binary.BigEndian, ah.AuthenticationData); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.Bytes()
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user