use in-Nebula SNAT to send IPv4 UnsafeNetworks traffic over an IPv6 overlay

This commit is contained in:
JackDoan
2026-01-14 12:36:55 -06:00
parent 39452b5eec
commit c2a63499ac
22 changed files with 770 additions and 210 deletions

View File

@@ -2,6 +2,7 @@ package nebula
import (
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
@@ -22,10 +23,29 @@ import (
"github.com/slackhq/nebula/firewall"
)
var ErrCannotSNAT = errors.New("cannot snat this packet")
var ErrSNATIdentityMismatch = errors.New("refusing to SNAT for mismatched host")
type FirewallInterface interface {
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error
}
type snatInfo struct {
//Src is the source IP+port to write into unsafe-route-bound packet
Src netip.AddrPort
//SrcVpnIp is the overlay IP associated with this flow. It's needed to associate reply traffic so we can get it back to the right host.
SrcVpnIp netip.Addr
//SnatPort is the port to rewrite into an overlay-bound packet
SnatPort uint16
}
func (s *snatInfo) Valid() bool {
if s == nil {
return false
}
return s.Src.IsValid()
}
type conn struct {
Expires time.Time // Time when this conntrack entry will expire
@@ -34,6 +54,9 @@ type conn struct {
// fields pack for free after the uint32 above
incoming bool
rulesVersion uint16
//for SNAT support
snat *snatInfo
}
// TODO: need conntrack max tracked connections handling
@@ -66,6 +89,7 @@ type Firewall struct {
defaultLocalCIDRAny bool
incomingMetrics firewallMetrics
outgoingMetrics firewallMetrics
snatAddr netip.Addr
l *logrus.Logger
}
@@ -83,6 +107,15 @@ type FirewallConntrack struct {
TimerWheel *TimerWheel[firewall.Packet]
}
func (ct *FirewallConntrack) dupeConnUnlocked(fp firewall.Packet, c *conn, timeout time.Duration) {
if _, ok := ct.Conns[fp]; !ok {
ct.TimerWheel.Advance(time.Now())
ct.TimerWheel.Add(fp, timeout)
}
ct.Conns[fp] = c
}
// FirewallTable is the entry point for a rule, the evaluation order is:
// Proto AND port AND (CA SHA or CA name) AND local CIDR AND (group OR groups OR name OR remote CIDR)
type FirewallTable struct {
@@ -131,7 +164,7 @@ type firewallLocalCIDR struct {
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
// The certificate provided should be the highest version loaded in memory.
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate, snatAddr netip.Addr) *Firewall {
//TODO: error on 0 duration
var tmin, tmax time.Duration
@@ -149,12 +182,14 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
tmax = defaultTimeout
}
hasV4Networks := false
routableNetworks := new(bart.Lite)
var assignedNetworks []netip.Prefix
for _, network := range c.Networks() {
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
routableNetworks.Insert(nprefix)
assignedNetworks = append(assignedNetworks, network)
hasV4Networks = hasV4Networks || network.Addr().Is4()
}
hasUnsafeNetworks := false
@@ -163,6 +198,10 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
hasUnsafeNetworks = true
}
if !hasUnsafeNetworks || hasV4Networks {
snatAddr = netip.Addr{} //disable using the special snat address if it doesn't make sense to use it
}
return &Firewall{
Conntrack: &FirewallConntrack{
Conns: make(map[firewall.Packet]*conn),
@@ -176,6 +215,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
routableNetworks: routableNetworks,
assignedNetworks: assignedNetworks,
hasUnsafeNetworks: hasUnsafeNetworks,
snatAddr: snatAddr,
l: l,
incomingMetrics: firewallMetrics{
@@ -191,7 +231,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
}
}
func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C, snatAddr netip.Addr) (*Firewall, error) {
certificate := cs.getCertificate(cert.Version2)
if certificate == nil {
certificate = cs.getCertificate(cert.Version1)
@@ -201,14 +241,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
panic("No certificate available to reconfigure the firewall")
}
fw := NewFirewall(
l,
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
certificate,
//TODO: max_connections
)
fw := NewFirewall(l, c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), certificate, snatAddr)
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false)
@@ -314,6 +347,11 @@ func (f *Firewall) GetRuleHashes() string {
return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10)
}
func (f *Firewall) ShouldUnSNAT(fp *firewall.Packet) bool {
// f.snatAddr is only valid if we're a snat-capable router
return f.snatAddr.IsValid() && fp.RemoteAddr == f.snatAddr
}
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
var table string
if inbound {
@@ -414,50 +452,204 @@ var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate
var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses")
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
func (f *Firewall) unSnat(data []byte, fp *firewall.Packet) netip.Addr {
c := f.peek(*fp) //unfortunately this needs to lock. Surely there's a better way.
if c == nil {
return netip.Addr{}
}
if !c.snat.Valid() {
return netip.Addr{}
}
oldIP := netip.AddrPortFrom(f.snatAddr, fp.RemotePort)
rewritePacket(data, fp, oldIP, c.snat.Src, 16, 2)
return c.snat.SrcVpnIp
}
func rewritePacket(data []byte, fp *firewall.Packet, oldIP netip.AddrPort, newIP netip.AddrPort, ipOffset int, portOffset int) {
//change address
copy(data[ipOffset:], newIP.Addr().AsSlice())
recalcIPv4Checksum(data, oldIP.Addr(), newIP.Addr())
ipHeaderLen := int(data[0]&0x0F) * 4
switch fp.Protocol {
case firewall.ProtoICMP:
binary.BigEndian.PutUint16(data[ipHeaderLen+4:ipHeaderLen+6], newIP.Port()) //we use the ID field as a "port" for ICMP
icmpCode := uint16(data[ipHeaderLen+1]) //todo not snatting on code yet (but Linux would)
recalcICMPv4Checksum(data, icmpCode, icmpCode, oldIP.Port(), newIP.Port())
case firewall.ProtoUDP:
dstport := ipHeaderLen + portOffset
binary.BigEndian.PutUint16(data[dstport:dstport+2], newIP.Port())
recalcUDPv4Checksum(data, oldIP, newIP)
case firewall.ProtoTCP:
dstport := ipHeaderLen + portOffset
binary.BigEndian.PutUint16(data[dstport:dstport+2], newIP.Port())
recalcTCPv4Checksum(data, oldIP, newIP)
}
}
func (f *Firewall) findUsableSNATPort(fp *firewall.Packet, c *conn) error {
oldPort := fp.RemotePort
conntrack := f.Conntrack
conntrack.Lock()
defer conntrack.Unlock()
for numPortsChecked := 0; numPortsChecked < 0x7ff; numPortsChecked++ {
_, ok := conntrack.Conns[*fp]
if !ok {
//yay, we can use this port
//track the snatted flow with the same expiration as the unsnatted version
conntrack.dupeConnUnlocked(*fp, c, f.packetTimeout(*fp))
return nil
}
//increment and retry. There's probably better strategies out there
fp.RemotePort++
if fp.RemotePort < 0x7ff {
fp.RemotePort += 0x7ff // keep it ephemeral for now
}
}
//if we made it here, we failed
fp.RemotePort = oldPort
return ErrCannotSNAT
}
func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo *HostInfo) error {
if !f.snatAddr.IsValid() {
return ErrCannotSNAT
}
if c.snat.Valid() {
//old flow: make sure it came from the right place
if !slices.Contains(hostinfo.vpnAddrs, c.snat.SrcVpnIp) {
return ErrSNATIdentityMismatch
}
fp.RemoteAddr = f.snatAddr
fp.RemotePort = c.snat.SnatPort
} else if hostinfo.vpnAddrs[0].Is6() {
//we got a new flow
c.snat = &snatInfo{
Src: netip.AddrPortFrom(fp.RemoteAddr, fp.RemotePort),
SrcVpnIp: hostinfo.vpnAddrs[0],
}
fp.RemoteAddr = f.snatAddr
//find a new port to use, if needed
err := f.findUsableSNATPort(fp, c)
if err != nil {
c.snat = nil
return err
}
c.snat.SnatPort = fp.RemotePort //may have been updated inside f.findUsableSNATPort
} else {
return ErrCannotSNAT
}
newIP := netip.AddrPortFrom(f.snatAddr, c.snat.SnatPort)
rewritePacket(data, fp, c.snat.Src, newIP, 12, 0)
return nil
}
func (f *Firewall) identifyNetworkType(h *HostInfo, fp firewall.Packet) NetworkType {
if h.networks == nil {
// Simple case: Certificate has one address and no unsafe networks
if h.vpnAddrs[0] == fp.RemoteAddr {
return NetworkTypeVPN
} else if fp.IsIPv4() && h.HasOnlyV6Addresses() {
return NetworkTypeUncheckedSNATPeer
} else {
return NetworkTypeInvalidPeer
}
} else if nwType, ok := h.networks.Lookup(fp.RemoteAddr); ok {
//todo check for if fp.RemoteAddr is our f.snatAddr here too? Does that need a special case?
return nwType //will return NetworkTypeVPN or NetworkTypeUnsafe
} else if fp.IsIPv4() && h.HasOnlyV6Addresses() { //todo surely I'm smart enough to avoid writing these branches twice
return NetworkTypeUncheckedSNATPeer
} else {
return NetworkTypeInvalidPeer
}
}
func (f *Firewall) allowNetworkType(nwType NetworkType) error {
switch nwType {
case NetworkTypeVPN:
return nil
case NetworkTypeInvalidPeer:
return ErrInvalidRemoteIP
case NetworkTypeVPNPeer:
//todo we might need a specialSnatMode case in here to handle routers with v4 addresses when we don't also have a v4 address?
return ErrPeerRejected // reject for now, one day this may have different FW rules
case NetworkTypeUnsafe:
return nil // nothing special, one day this may have different FW rules
case NetworkTypeUncheckedSNATPeer:
if f.snatAddr.IsValid() {
return nil //todo is this enough?
} else {
return ErrInvalidRemoteIP
}
default:
return ErrUnknownNetworkType //should never happen
}
}
func (f *Firewall) willingToHandleLocalAddr(incoming bool, fp firewall.Packet, remoteNwType NetworkType) error {
if f.routableNetworks.Contains(fp.LocalAddr) {
return nil //easy, this should handle NetworkTypeVPN in all cases, and NetworkTypeUnsafe on the router side
}
//watch out, when incoming, this function decides if we will deliver a packet locally
//when outgoing, much less important, it just decides if we're willing to tx
switch remoteNwType {
// we never want to accept unconntracked inbound traffic from these network types, but outbound is okay.
// It's the recipient's job to validate and accept or deny the packet.
case NetworkTypeUncheckedSNATPeer, NetworkTypeUnsafe:
//NetworkTypeUnsafe needed here to allow inbound from an unsafe-router
if incoming {
return ErrInvalidLocalIP
}
return nil
default:
return ErrInvalidLocalIP
}
}
// Drop returns an error if the packet should be dropped, explaining why. It
// returns nil if the packet should not be dropped.
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
func (f *Firewall) Drop(fp firewall.Packet, pkt []byte, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
table := f.OutRules
if incoming {
table = f.InRules
}
snatmode := fp.IsIPv4() && h.HasOnlyV6Addresses() && f.snatAddr.IsValid()
if snatmode {
//if this is an IPv4 packet from a V6 only host, and we're configured to snat that kind of traffic, it must be snatted,
//so it can never be in the localcache, which lacks SNAT data
//nil out the pointer to avoid ever using it
localCache = nil
}
// Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(fp, h, caPool, localCache) {
if localCache != nil {
if _, ok := localCache[fp]; ok {
return nil //packet matched the cache, we're not snatting, we can return early
}
}
c := f.inConns(fp, h, caPool, localCache)
if c != nil {
if incoming && snatmode {
return f.applySnat(pkt, &fp, c, h)
}
return nil
}
// Make sure remote address matches nebula certificate, and determine how to treat it
if h.networks == nil {
// Simple case: Certificate has one address and no unsafe networks
if h.vpnAddrs[0] != fp.RemoteAddr {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP
}
} else {
nwType, ok := h.networks.Lookup(fp.RemoteAddr)
if !ok {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP
}
switch nwType {
case NetworkTypeVPN:
break // nothing special
case NetworkTypeVPNPeer:
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrPeerRejected // reject for now, one day this may have different FW rules
case NetworkTypeUnsafe:
break // nothing special, one day this may have different FW rules
default:
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrUnknownNetworkType //should never happen
}
remoteNetworkType := f.identifyNetworkType(h, fp)
if err := f.allowNetworkType(remoteNetworkType); err != nil {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return err
}
// Make sure we are supposed to be handling this local ip address
if !f.routableNetworks.Contains(fp.LocalAddr) {
if err := f.willingToHandleLocalAddr(incoming, fp, remoteNetworkType); err != nil {
f.metrics(incoming).droppedLocalAddr.Inc(1)
return ErrInvalidLocalIP
}
table := f.OutRules
if incoming {
table = f.InRules
return err
}
// We now know which firewall table to check against
@@ -467,9 +659,14 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
}
// We always want to conntrack since it is a faster operation
f.addConn(fp, incoming)
c = f.addConn(fp, incoming)
return nil
if incoming && remoteNetworkType == NetworkTypeUncheckedSNATPeer {
return f.applySnat(pkt, &fp, c, h)
} else {
//outgoing snat is handled before this function is called
return nil
}
}
func (f *Firewall) metrics(incoming bool) firewallMetrics {
@@ -496,12 +693,23 @@ func (f *Firewall) EmitStats() {
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
}
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
if localCache != nil {
if _, ok := localCache[fp]; ok {
return true
}
func (f *Firewall) peek(fp firewall.Packet) *conn {
conntrack := f.Conntrack
conntrack.Lock()
// Purge every time we test
ep, has := conntrack.TimerWheel.Purge()
if has {
f.evict(ep)
}
c := conntrack.Conns[fp]
conntrack.Unlock()
return c
}
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) *conn {
conntrack := f.Conntrack
conntrack.Lock()
@@ -515,7 +723,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
if !ok {
conntrack.Unlock()
return false
return nil
}
if c.rulesVersion != f.rulesVersion {
@@ -538,7 +746,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
}
delete(conntrack.Conns, fp)
conntrack.Unlock()
return false
return nil
}
if f.l.Level >= logrus.DebugLevel {
@@ -568,12 +776,11 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
localCache[fp] = struct{}{}
}
return true
return c
}
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
func (f *Firewall) packetTimeout(fp firewall.Packet) time.Duration {
var timeout time.Duration
c := &conn{}
switch fp.Protocol {
case firewall.ProtoTCP:
@@ -583,7 +790,13 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
default:
timeout = f.DefaultTimeout
}
return timeout
}
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) *conn {
c := &conn{}
timeout := f.packetTimeout(fp)
conntrack := f.Conntrack
conntrack.Lock()
if _, ok := conntrack.Conns[fp]; !ok {
@@ -597,7 +810,9 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
c.rulesVersion = f.rulesVersion
c.Expires = time.Now().Add(timeout)
conntrack.Conns[fp] = c
conntrack.Unlock()
return c
}
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
@@ -682,6 +897,13 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer
var port int32
if p.Protocol == firewall.ProtoICMP {
// port numbers are re-used for connection tracking and SNAT,
// but we don't want to actually filter on them for ICMP
// ICMP6 is omitted because we don't attempt to parse code/identifier/etc out of ICMP6
return fp[firewall.PortAny].match(p, c, caPool)
}
if p.Fragment {
port = firewall.PortFragment
} else if incoming {