mirror of
https://github.com/slackhq/nebula.git
synced 2026-02-16 01:34:22 +01:00
refactor
This commit is contained in:
@@ -1,9 +1,15 @@
|
|||||||
package firewall
|
package firewall
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
)
|
)
|
||||||
|
|
||||||
type m = map[string]any
|
type m = map[string]any
|
||||||
@@ -17,6 +23,17 @@ const (
|
|||||||
|
|
||||||
PortAny = 0 // Special value for matching `port: any`
|
PortAny = 0 // Special value for matching `port: any`
|
||||||
PortFragment = -1 // Special value for matching `port: fragment`
|
PortFragment = -1 // Special value for matching `port: fragment`
|
||||||
|
|
||||||
|
minFwPacketLen = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrPacketTooShort = errors.New("packet is too short")
|
||||||
|
ErrUnknownIPVersion = errors.New("packet is an unknown ip version")
|
||||||
|
ErrIPv4InvalidHeaderLength = errors.New("invalid ipv4 header length")
|
||||||
|
ErrIPv4PacketTooShort = errors.New("ipv4 packet is too short")
|
||||||
|
ErrIPv6PacketTooShort = errors.New("ipv6 packet is too short")
|
||||||
|
ErrIPv6CouldNotFindPayload = errors.New("could not find payload in ipv6 packet")
|
||||||
)
|
)
|
||||||
|
|
||||||
type Packet struct {
|
type Packet struct {
|
||||||
@@ -60,3 +77,172 @@ func (fp Packet) MarshalJSON() ([]byte, error) {
|
|||||||
"Fragment": fp.Fragment,
|
"Fragment": fp.Fragment,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseV6(data []byte, incoming bool, fp *Packet) error {
|
||||||
|
dataLen := len(data)
|
||||||
|
if dataLen < ipv6.HeaderLen {
|
||||||
|
return ErrIPv6PacketTooShort
|
||||||
|
}
|
||||||
|
|
||||||
|
if incoming {
|
||||||
|
fp.RemoteAddr, _ = netip.AddrFromSlice(data[8:24])
|
||||||
|
fp.LocalAddr, _ = netip.AddrFromSlice(data[24:40])
|
||||||
|
} else {
|
||||||
|
fp.LocalAddr, _ = netip.AddrFromSlice(data[8:24])
|
||||||
|
fp.RemoteAddr, _ = netip.AddrFromSlice(data[24:40])
|
||||||
|
}
|
||||||
|
|
||||||
|
protoAt := 6 // NextHeader is at 6 bytes into the ipv6 header
|
||||||
|
offset := ipv6.HeaderLen // Start at the end of the ipv6 header
|
||||||
|
next := 0
|
||||||
|
for {
|
||||||
|
if protoAt >= dataLen {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
proto := layers.IPProtocol(data[protoAt])
|
||||||
|
|
||||||
|
switch proto {
|
||||||
|
case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
|
||||||
|
fp.Protocol = uint8(proto)
|
||||||
|
fp.RemotePort = 0
|
||||||
|
fp.LocalPort = 0
|
||||||
|
fp.Fragment = false
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case layers.IPProtocolTCP, layers.IPProtocolUDP:
|
||||||
|
if dataLen < offset+4 {
|
||||||
|
return ErrIPv6PacketTooShort
|
||||||
|
}
|
||||||
|
|
||||||
|
fp.Protocol = uint8(proto)
|
||||||
|
if incoming {
|
||||||
|
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
|
||||||
|
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
|
||||||
|
} else {
|
||||||
|
fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
|
||||||
|
fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
|
||||||
|
}
|
||||||
|
|
||||||
|
fp.Fragment = false
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case layers.IPProtocolIPv6Fragment:
|
||||||
|
// Fragment header is 8 bytes, need at least offset+4 to read the offset field
|
||||||
|
if dataLen < offset+8 {
|
||||||
|
return ErrIPv6PacketTooShort
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this is the first fragment
|
||||||
|
fragmentOffset := binary.BigEndian.Uint16(data[offset+2:offset+4]) &^ uint16(0x7) // Remove the reserved and M flag bits
|
||||||
|
if fragmentOffset != 0 {
|
||||||
|
// Non-first fragment, use what we have now and stop processing
|
||||||
|
fp.Protocol = data[offset]
|
||||||
|
fp.Fragment = true
|
||||||
|
fp.RemotePort = 0
|
||||||
|
fp.LocalPort = 0
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// The next loop should be the transport layer since we are the first fragment
|
||||||
|
next = 8 // Fragment headers are always 8 bytes
|
||||||
|
|
||||||
|
case layers.IPProtocolAH:
|
||||||
|
// Auth headers, used by IPSec, have a different meaning for header length
|
||||||
|
if dataLen <= offset+1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
next = int(data[offset+1]+2) << 2
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Normal ipv6 header length processing
|
||||||
|
if dataLen <= offset+1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
next = int(data[offset+1]+1) << 3
|
||||||
|
}
|
||||||
|
|
||||||
|
if next <= 0 {
|
||||||
|
// Safety check, each ipv6 header has to be at least 8 bytes
|
||||||
|
next = 8
|
||||||
|
}
|
||||||
|
|
||||||
|
protoAt = offset
|
||||||
|
offset = offset + next
|
||||||
|
}
|
||||||
|
|
||||||
|
return ErrIPv6CouldNotFindPayload
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseV4(data []byte, incoming bool, fp *Packet) error {
|
||||||
|
// Do we at least have an ipv4 header worth of data?
|
||||||
|
if len(data) < ipv4.HeaderLen {
|
||||||
|
return ErrIPv4PacketTooShort
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adjust our start position based on the advertised ip header length
|
||||||
|
ihl := int(data[0]&0x0f) << 2
|
||||||
|
|
||||||
|
// Well-formed ip header length?
|
||||||
|
if ihl < ipv4.HeaderLen {
|
||||||
|
return ErrIPv4InvalidHeaderLength
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this is the second or further fragment of a fragmented packet.
|
||||||
|
flagsfrags := binary.BigEndian.Uint16(data[6:8])
|
||||||
|
fp.Fragment = (flagsfrags & 0x1FFF) != 0
|
||||||
|
|
||||||
|
// Firewall handles protocol checks
|
||||||
|
fp.Protocol = data[9]
|
||||||
|
|
||||||
|
// Accounting for a variable header length, do we have enough data for our src/dst tuples?
|
||||||
|
minLen := ihl
|
||||||
|
if !fp.Fragment && fp.Protocol != ProtoICMP {
|
||||||
|
minLen += minFwPacketLen
|
||||||
|
}
|
||||||
|
if len(data) < minLen {
|
||||||
|
return ErrIPv4InvalidHeaderLength
|
||||||
|
}
|
||||||
|
|
||||||
|
// Firewall packets are locally oriented
|
||||||
|
if incoming {
|
||||||
|
fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16])
|
||||||
|
fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20])
|
||||||
|
if fp.Fragment || fp.Protocol == ProtoICMP {
|
||||||
|
fp.RemotePort = 0
|
||||||
|
fp.LocalPort = 0
|
||||||
|
} else {
|
||||||
|
fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2])
|
||||||
|
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16])
|
||||||
|
fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20])
|
||||||
|
if fp.Fragment || fp.Protocol == ProtoICMP {
|
||||||
|
fp.RemotePort = 0
|
||||||
|
fp.LocalPort = 0
|
||||||
|
} else {
|
||||||
|
fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2])
|
||||||
|
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
|
||||||
|
func NewPacket(data []byte, incoming bool, fp *Packet) error {
|
||||||
|
if len(data) < 1 {
|
||||||
|
return ErrPacketTooShort
|
||||||
|
}
|
||||||
|
|
||||||
|
version := int((data[0] >> 4) & 0x0f)
|
||||||
|
switch version {
|
||||||
|
case ipv4.Version:
|
||||||
|
return parseV4(data, incoming, fp)
|
||||||
|
case ipv6.Version:
|
||||||
|
return parseV6(data, incoming, fp)
|
||||||
|
}
|
||||||
|
return ErrUnknownIPVersion
|
||||||
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
err := newPacket(packet, false, fwPacket)
|
err := firewall.NewPacket(packet, false, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if f.l.Level >= logrus.DebugLevel {
|
if f.l.Level >= logrus.DebugLevel {
|
||||||
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
|
f.l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
|
||||||
@@ -211,7 +211,7 @@ func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cac
|
|||||||
|
|
||||||
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
|
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
|
||||||
fp := &firewall.Packet{}
|
fp := &firewall.Packet{}
|
||||||
err := newPacket(p, false, fp)
|
err := firewall.NewPacket(p, false, fp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
|
f.l.Warnf("error while parsing outgoing packet for firewall check; %v", err)
|
||||||
return
|
return
|
||||||
|
|||||||
189
outside.go
189
outside.go
@@ -1,22 +1,13 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
"errors"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"golang.org/x/net/ipv6"
|
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/firewall"
|
"github.com/slackhq/nebula/firewall"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
"golang.org/x/net/ipv4"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
minFwPacketLen = 4
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
|
func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
|
||||||
@@ -278,184 +269,6 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, via ViaSender, h *heade
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
ErrPacketTooShort = errors.New("packet is too short")
|
|
||||||
ErrUnknownIPVersion = errors.New("packet is an unknown ip version")
|
|
||||||
ErrIPv4InvalidHeaderLength = errors.New("invalid ipv4 header length")
|
|
||||||
ErrIPv4PacketTooShort = errors.New("ipv4 packet is too short")
|
|
||||||
ErrIPv6PacketTooShort = errors.New("ipv6 packet is too short")
|
|
||||||
ErrIPv6CouldNotFindPayload = errors.New("could not find payload in ipv6 packet")
|
|
||||||
)
|
|
||||||
|
|
||||||
// newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
|
|
||||||
func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
|
|
||||||
if len(data) < 1 {
|
|
||||||
return ErrPacketTooShort
|
|
||||||
}
|
|
||||||
|
|
||||||
version := int((data[0] >> 4) & 0x0f)
|
|
||||||
switch version {
|
|
||||||
case ipv4.Version:
|
|
||||||
return parseV4(data, incoming, fp)
|
|
||||||
case ipv6.Version:
|
|
||||||
return parseV6(data, incoming, fp)
|
|
||||||
}
|
|
||||||
return ErrUnknownIPVersion
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
|
|
||||||
dataLen := len(data)
|
|
||||||
if dataLen < ipv6.HeaderLen {
|
|
||||||
return ErrIPv6PacketTooShort
|
|
||||||
}
|
|
||||||
|
|
||||||
if incoming {
|
|
||||||
fp.RemoteAddr, _ = netip.AddrFromSlice(data[8:24])
|
|
||||||
fp.LocalAddr, _ = netip.AddrFromSlice(data[24:40])
|
|
||||||
} else {
|
|
||||||
fp.LocalAddr, _ = netip.AddrFromSlice(data[8:24])
|
|
||||||
fp.RemoteAddr, _ = netip.AddrFromSlice(data[24:40])
|
|
||||||
}
|
|
||||||
|
|
||||||
protoAt := 6 // NextHeader is at 6 bytes into the ipv6 header
|
|
||||||
offset := ipv6.HeaderLen // Start at the end of the ipv6 header
|
|
||||||
next := 0
|
|
||||||
for {
|
|
||||||
if protoAt >= dataLen {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
proto := layers.IPProtocol(data[protoAt])
|
|
||||||
|
|
||||||
switch proto {
|
|
||||||
case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader:
|
|
||||||
fp.Protocol = uint8(proto)
|
|
||||||
fp.RemotePort = 0
|
|
||||||
fp.LocalPort = 0
|
|
||||||
fp.Fragment = false
|
|
||||||
return nil
|
|
||||||
|
|
||||||
case layers.IPProtocolTCP, layers.IPProtocolUDP:
|
|
||||||
if dataLen < offset+4 {
|
|
||||||
return ErrIPv6PacketTooShort
|
|
||||||
}
|
|
||||||
|
|
||||||
fp.Protocol = uint8(proto)
|
|
||||||
if incoming {
|
|
||||||
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
|
|
||||||
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
|
|
||||||
} else {
|
|
||||||
fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2])
|
|
||||||
fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
|
|
||||||
}
|
|
||||||
|
|
||||||
fp.Fragment = false
|
|
||||||
return nil
|
|
||||||
|
|
||||||
case layers.IPProtocolIPv6Fragment:
|
|
||||||
// Fragment header is 8 bytes, need at least offset+4 to read the offset field
|
|
||||||
if dataLen < offset+8 {
|
|
||||||
return ErrIPv6PacketTooShort
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if this is the first fragment
|
|
||||||
fragmentOffset := binary.BigEndian.Uint16(data[offset+2:offset+4]) &^ uint16(0x7) // Remove the reserved and M flag bits
|
|
||||||
if fragmentOffset != 0 {
|
|
||||||
// Non-first fragment, use what we have now and stop processing
|
|
||||||
fp.Protocol = data[offset]
|
|
||||||
fp.Fragment = true
|
|
||||||
fp.RemotePort = 0
|
|
||||||
fp.LocalPort = 0
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// The next loop should be the transport layer since we are the first fragment
|
|
||||||
next = 8 // Fragment headers are always 8 bytes
|
|
||||||
|
|
||||||
case layers.IPProtocolAH:
|
|
||||||
// Auth headers, used by IPSec, have a different meaning for header length
|
|
||||||
if dataLen <= offset+1 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
next = int(data[offset+1]+2) << 2
|
|
||||||
|
|
||||||
default:
|
|
||||||
// Normal ipv6 header length processing
|
|
||||||
if dataLen <= offset+1 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
next = int(data[offset+1]+1) << 3
|
|
||||||
}
|
|
||||||
|
|
||||||
if next <= 0 {
|
|
||||||
// Safety check, each ipv6 header has to be at least 8 bytes
|
|
||||||
next = 8
|
|
||||||
}
|
|
||||||
|
|
||||||
protoAt = offset
|
|
||||||
offset = offset + next
|
|
||||||
}
|
|
||||||
|
|
||||||
return ErrIPv6CouldNotFindPayload
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
|
|
||||||
// Do we at least have an ipv4 header worth of data?
|
|
||||||
if len(data) < ipv4.HeaderLen {
|
|
||||||
return ErrIPv4PacketTooShort
|
|
||||||
}
|
|
||||||
|
|
||||||
// Adjust our start position based on the advertised ip header length
|
|
||||||
ihl := int(data[0]&0x0f) << 2
|
|
||||||
|
|
||||||
// Well-formed ip header length?
|
|
||||||
if ihl < ipv4.HeaderLen {
|
|
||||||
return ErrIPv4InvalidHeaderLength
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if this is the second or further fragment of a fragmented packet.
|
|
||||||
flagsfrags := binary.BigEndian.Uint16(data[6:8])
|
|
||||||
fp.Fragment = (flagsfrags & 0x1FFF) != 0
|
|
||||||
|
|
||||||
// Firewall handles protocol checks
|
|
||||||
fp.Protocol = data[9]
|
|
||||||
|
|
||||||
// Accounting for a variable header length, do we have enough data for our src/dst tuples?
|
|
||||||
minLen := ihl
|
|
||||||
if !fp.Fragment && fp.Protocol != firewall.ProtoICMP {
|
|
||||||
minLen += minFwPacketLen
|
|
||||||
}
|
|
||||||
if len(data) < minLen {
|
|
||||||
return ErrIPv4InvalidHeaderLength
|
|
||||||
}
|
|
||||||
|
|
||||||
// Firewall packets are locally oriented
|
|
||||||
if incoming {
|
|
||||||
fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16])
|
|
||||||
fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20])
|
|
||||||
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
|
|
||||||
fp.RemotePort = 0
|
|
||||||
fp.LocalPort = 0
|
|
||||||
} else {
|
|
||||||
fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2])
|
|
||||||
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16])
|
|
||||||
fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20])
|
|
||||||
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
|
|
||||||
fp.RemotePort = 0
|
|
||||||
fp.LocalPort = 0
|
|
||||||
} else {
|
|
||||||
fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2])
|
|
||||||
fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []byte, h *header.H, nb []byte) ([]byte, error) {
|
func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []byte, h *header.H, nb []byte) ([]byte, error) {
|
||||||
var err error
|
var err error
|
||||||
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], mc, nb)
|
out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], mc, nb)
|
||||||
@@ -481,7 +294,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
err = newPacket(out, true, fwPacket)
|
err = firewall.NewPacket(out, true, fwPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
|
hostinfo.logger(f.l).WithError(err).WithField("packet", out).
|
||||||
Warnf("Error while validating inbound packet")
|
Warnf("Error while validating inbound packet")
|
||||||
|
|||||||
@@ -20,13 +20,13 @@ func Test_newPacket(t *testing.T) {
|
|||||||
p := &firewall.Packet{}
|
p := &firewall.Packet{}
|
||||||
|
|
||||||
// length fails
|
// length fails
|
||||||
err := newPacket([]byte{}, true, p)
|
err := firewall.NewPacket([]byte{}, true, p)
|
||||||
require.ErrorIs(t, err, ErrPacketTooShort)
|
require.ErrorIs(t, err, ErrPacketTooShort)
|
||||||
|
|
||||||
err = newPacket([]byte{0x40}, true, p)
|
err = firewall.NewPacket([]byte{0x40}, true, p)
|
||||||
require.ErrorIs(t, err, ErrIPv4PacketTooShort)
|
require.ErrorIs(t, err, ErrIPv4PacketTooShort)
|
||||||
|
|
||||||
err = newPacket([]byte{0x60}, true, p)
|
err = firewall.NewPacket([]byte{0x60}, true, p)
|
||||||
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||||
|
|
||||||
// length fail with ip options
|
// length fail with ip options
|
||||||
@@ -39,15 +39,15 @@ func Test_newPacket(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
b, _ := h.Marshal()
|
b, _ := h.Marshal()
|
||||||
err = newPacket(b, true, p)
|
err = firewall.NewPacket(b, true, p)
|
||||||
require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
|
require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
|
||||||
|
|
||||||
// not an ipv4 packet
|
// not an ipv4 packet
|
||||||
err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
err = firewall.NewPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
||||||
require.ErrorIs(t, err, ErrUnknownIPVersion)
|
require.ErrorIs(t, err, ErrUnknownIPVersion)
|
||||||
|
|
||||||
// invalid ihl
|
// invalid ihl
|
||||||
err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
err = firewall.NewPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
|
||||||
require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
|
require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
|
||||||
|
|
||||||
// account for variable ip header length - incoming
|
// account for variable ip header length - incoming
|
||||||
@@ -62,7 +62,7 @@ func Test_newPacket(t *testing.T) {
|
|||||||
|
|
||||||
b, _ = h.Marshal()
|
b, _ = h.Marshal()
|
||||||
b = append(b, []byte{0, 3, 0, 4}...)
|
b = append(b, []byte{0, 3, 0, 4}...)
|
||||||
err = newPacket(b, true, p)
|
err = firewall.NewPacket(b, true, p)
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
||||||
@@ -84,7 +84,7 @@ func Test_newPacket(t *testing.T) {
|
|||||||
|
|
||||||
b, _ = h.Marshal()
|
b, _ = h.Marshal()
|
||||||
b = append(b, []byte{0, 5, 0, 6}...)
|
b = append(b, []byte{0, 5, 0, 6}...)
|
||||||
err = newPacket(b, false, p)
|
err = firewall.NewPacket(b, false, p)
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, uint8(2), p.Protocol)
|
assert.Equal(t, uint8(2), p.Protocol)
|
||||||
@@ -114,7 +114,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
err := gopacket.SerializeLayers(buffer, opt, &ip)
|
err := gopacket.SerializeLayers(buffer, opt, &ip)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = newPacket(buffer.Bytes(), true, p)
|
err = firewall.NewPacket(buffer.Bytes(), true, p)
|
||||||
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||||
|
|
||||||
// A v6 packet with a hop-by-hop extension
|
// A v6 packet with a hop-by-hop extension
|
||||||
@@ -148,12 +148,12 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
|
|
||||||
// A full IPv6 header and 1 byte in the first extension, but missing
|
// A full IPv6 header and 1 byte in the first extension, but missing
|
||||||
// the length byte.
|
// the length byte.
|
||||||
err = newPacket(buffer.Bytes()[:41], true, p)
|
err = firewall.NewPacket(buffer.Bytes()[:41], true, p)
|
||||||
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||||
|
|
||||||
// A full IPv6 header plus 1 full extension, but only 1 byte of the
|
// A full IPv6 header plus 1 full extension, but only 1 byte of the
|
||||||
// next layer, missing length byte
|
// next layer, missing length byte
|
||||||
err = newPacket(buffer.Bytes()[:49], true, p)
|
err = firewall.NewPacket(buffer.Bytes()[:49], true, p)
|
||||||
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||||
|
|
||||||
// A good ICMP packet
|
// A good ICMP packet
|
||||||
@@ -173,7 +173,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = newPacket(buffer.Bytes(), true, p)
|
err = firewall.NewPacket(buffer.Bytes(), true, p)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
|
assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
@@ -185,7 +185,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
// A good ESP packet
|
// A good ESP packet
|
||||||
b := buffer.Bytes()
|
b := buffer.Bytes()
|
||||||
b[6] = byte(layers.IPProtocolESP)
|
b[6] = byte(layers.IPProtocolESP)
|
||||||
err = newPacket(b, true, p)
|
err = firewall.NewPacket(b, true, p)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol)
|
assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
@@ -197,7 +197,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
// A good None packet
|
// A good None packet
|
||||||
b = buffer.Bytes()
|
b = buffer.Bytes()
|
||||||
b[6] = byte(layers.IPProtocolNoNextHeader)
|
b[6] = byte(layers.IPProtocolNoNextHeader)
|
||||||
err = newPacket(b, true, p)
|
err = firewall.NewPacket(b, true, p)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol)
|
assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
@@ -209,7 +209,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
// An unknown protocol packet
|
// An unknown protocol packet
|
||||||
b = buffer.Bytes()
|
b = buffer.Bytes()
|
||||||
b[6] = 255 // 255 is a reserved protocol number
|
b[6] = 255 // 255 is a reserved protocol number
|
||||||
err = newPacket(b, true, p)
|
err = firewall.NewPacket(b, true, p)
|
||||||
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||||
|
|
||||||
// A good UDP packet
|
// A good UDP packet
|
||||||
@@ -236,7 +236,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
b = buffer.Bytes()
|
b = buffer.Bytes()
|
||||||
|
|
||||||
// incoming
|
// incoming
|
||||||
err = newPacket(b, true, p)
|
err = firewall.NewPacket(b, true, p)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
@@ -246,7 +246,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
assert.False(t, p.Fragment)
|
assert.False(t, p.Fragment)
|
||||||
|
|
||||||
// outgoing
|
// outgoing
|
||||||
err = newPacket(b, false, p)
|
err = firewall.NewPacket(b, false, p)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
||||||
@@ -256,14 +256,14 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
assert.False(t, p.Fragment)
|
assert.False(t, p.Fragment)
|
||||||
|
|
||||||
// Too short UDP packet
|
// Too short UDP packet
|
||||||
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
|
err = firewall.NewPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
|
||||||
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||||
|
|
||||||
// A good TCP packet
|
// A good TCP packet
|
||||||
b[6] = byte(layers.IPProtocolTCP)
|
b[6] = byte(layers.IPProtocolTCP)
|
||||||
|
|
||||||
// incoming
|
// incoming
|
||||||
err = newPacket(b, true, p)
|
err = firewall.NewPacket(b, true, p)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
@@ -273,7 +273,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
assert.False(t, p.Fragment)
|
assert.False(t, p.Fragment)
|
||||||
|
|
||||||
// outgoing
|
// outgoing
|
||||||
err = newPacket(b, false, p)
|
err = firewall.NewPacket(b, false, p)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
||||||
@@ -283,7 +283,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
assert.False(t, p.Fragment)
|
assert.False(t, p.Fragment)
|
||||||
|
|
||||||
// Too short TCP packet
|
// Too short TCP packet
|
||||||
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
|
err = firewall.NewPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
|
||||||
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||||
|
|
||||||
// A good UDP packet with an AH header
|
// A good UDP packet with an AH header
|
||||||
@@ -318,7 +318,7 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
b = append(b, ahb...)
|
b = append(b, ahb...)
|
||||||
b = append(b, udpHeader...)
|
b = append(b, udpHeader...)
|
||||||
|
|
||||||
err = newPacket(b, true, p)
|
err = firewall.NewPacket(b, true, p)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
@@ -328,12 +328,12 @@ func Test_newPacket_v6(t *testing.T) {
|
|||||||
assert.False(t, p.Fragment)
|
assert.False(t, p.Fragment)
|
||||||
|
|
||||||
// Ensure buffer bounds checking during processing
|
// Ensure buffer bounds checking during processing
|
||||||
err = newPacket(b[:41], true, p)
|
err = firewall.NewPacket(b[:41], true, p)
|
||||||
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||||
|
|
||||||
// Invalid AH header
|
// Invalid AH header
|
||||||
b = buffer.Bytes()
|
b = buffer.Bytes()
|
||||||
err = newPacket(b, true, p)
|
err = firewall.NewPacket(b, true, p)
|
||||||
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -381,7 +381,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
|
|||||||
firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
|
firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
|
||||||
|
|
||||||
// Test first fragment incoming
|
// Test first fragment incoming
|
||||||
err = newPacket(firstFrag, true, p)
|
err = firewall.NewPacket(firstFrag, true, p)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
@@ -391,7 +391,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
|
|||||||
assert.False(t, p.Fragment)
|
assert.False(t, p.Fragment)
|
||||||
|
|
||||||
// Test first fragment outgoing
|
// Test first fragment outgoing
|
||||||
err = newPacket(firstFrag, false, p)
|
err = firewall.NewPacket(firstFrag, false, p)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
||||||
@@ -420,7 +420,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
|
|||||||
secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
|
secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...)
|
||||||
|
|
||||||
// Test second fragment incoming
|
// Test second fragment incoming
|
||||||
err = newPacket(secondFrag, true, p)
|
err = firewall.NewPacket(secondFrag, true, p)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
|
||||||
@@ -430,7 +430,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
|
|||||||
assert.True(t, p.Fragment)
|
assert.True(t, p.Fragment)
|
||||||
|
|
||||||
// Test second fragment outgoing
|
// Test second fragment outgoing
|
||||||
err = newPacket(secondFrag, false, p)
|
err = firewall.NewPacket(secondFrag, false, p)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
|
||||||
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
|
||||||
@@ -440,7 +440,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
|
|||||||
assert.True(t, p.Fragment)
|
assert.True(t, p.Fragment)
|
||||||
|
|
||||||
// Too short of a fragment packet
|
// Too short of a fragment packet
|
||||||
err = newPacket(secondFrag[:len(secondFrag)-10], false, p)
|
err = firewall.NewPacket(secondFrag[:len(secondFrag)-10], false, p)
|
||||||
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
require.ErrorIs(t, err, ErrIPv6PacketTooShort)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user