use in-Nebula SNAT to send IPv4 UnsafeNetworks traffic over an IPv6 overlay

This commit is contained in:
JackDoan
2026-01-14 12:36:55 -06:00
parent 39452b5eec
commit c2a63499ac
22 changed files with 770 additions and 210 deletions

102
SNAT.md Normal file
View File

@@ -0,0 +1,102 @@
# Don't merge me
# Accessing IPv4-only UnsafeNetworks via an IPv6-only overlay
## Background
Nebula is an VPN-like connectivity solution. It provides what we call an "overlay network",
an additional IP-addressed network on top of one-or-more traditional "underlay networks".
As long as two devices with Nebula certificates (credentials, essentially signed keypairs with metadata) can find each other
and exchange traffic via a common underlay network (often this is the Internet),
they will also be able to exchange traffic securely via a point-to-point, encrypted, authenticated tunnel.
Typically, all Nebula traffic is strongly associated with the Nebula certificate of the sender
(that is, the source IP of all packets matches the IP listed in the sender's certificate).
However, it is useful to be able to bend this rule. That is why there is another field in the Nebula certificate, named UnsafeNetworks,
which lists the network prefixes that the host bearing this certificate is allowed to "speak for".
## Problem Statement
We want IPv6-only overlay networks to be able to carry IPv4 traffic to reach off-overlay hosts via UnsafeNetworks
### Scenario
To illustrate this scenario, we will define 3 hosts:
* a Phone, running Nebula, assigned the overlay IP fd01::AA/64. It has an undefined underlay, but we assert that it always has working IPv4 OR IPv6 connectivity to Router.
* a Router, running Nebula, assigned the overlay IP fd01::55/64. It has a stable underlay link that Phone can always reach.
* a Printer, which cannot run Nebula, and is only capable of IPv4 communication. It has a direct link to Router, but the Phone cannot reach it directly.
You, the User, wish to use your Phone to print out something on the Printer while you're away from home. How can we make this possible with your IPv6-only Nebula overlay?
In particular, your Phone may connect to any cellular or public WiFi network, and we cannot control the IP address it will be assigned. If you MUST print, an IP conflict is not acceptable.
Therefore, we cannot simply dismiss this problem by suggesting that you assign a small IPv4 network within your overlay. Sure, it probably works, and in this toy scenario, the odds of a conflict are pretty small. But it scales very poorly. What if a whole company needs to use this printer (or perhaps a less contrived need?)
We can do better.
## Solution
* Even though Phone and Router lack IPv4 assignments, we can still put V4 addresses on their tun devices.
* Each overlay host who wishes to use this feature shall select (or configure?) an assignment within 169.254.0.0/16, the IPv4 link-local range
* this is a pretty small space, but it confines the region of IP conflict to a much smaller domain. And, because overlay hosts will never dial one another with this address, cryptographic verification of it via the certificate is less important.
* On Phone, Nebula will configure an unsafe_route to the Printer using this address. Because it is a device route, we do not need to tell the operating system the address of the next hop (no `via`)
* On Router, Nebula will use this address to masquerade packets from Phone. You'll see!
* Let's walk through setting up a TCP session between Phone and Printer in this scheme:
* Phone sends SYN to the printer's underlay IPv4 address
* This packet lands on Phone's Nebula tun device
* Nebula locates Router as the destination for this packet, as defined in `tun.unsafe_routes`
* Nebula checks the packet against the outbound chain:
* the destination IP of Printer is listed in Router's UnsafeNetworks, so that check will pass
* Phone's source IP is not listed in any certificate, but because the destination address is of `NetworkTypeUnsafe` and this is an outgoing flow, we keep going
* Actual outbound firewall rules get checked, assume they pass
* conntrack entry created to permit replies
* Phone encrypts the packet and sends it to Router
* Router gets the packet from Phone, and decrypts it. It is passed to the Nebula firewall for evaluation:
* `firewall.Drop()` on the Router's Nebula inbound rules:
* Because Router is configured to allow SNAT, and this packet is an IPv4 packet from a IPv6-only host, the firewall module enters "snat mode" (`TODO explain?`)
* This is a new flow, so the conntrack lookup for it fails
* `firewall.identifyNetworkType()`
* identify what "kind" of remote address (this is the inbound firewall, so the remote address is the packet's src IP) we've been given
* `NetworkTypeVPN`, for example is a remote address that matches the peer's certificate
* In this case, because the traffic is IPv4 traffic flowing from an IPv6-only host, and we've opted into supporting SNAT, this traffic is marked as `NetworkTypeUncheckedSNATPeer`
* `firewall.allowNetworkType()` will allow `NetworkTypeUncheckedSNATPeer` traffic to proceed because we have opted into SNAT
* `firewall.willingToHandleLocalAddr()` now needs to check if we're willing to accept the destination address
* Because this traffic is addressed to a destination listed in our UnsafeNetworks, it's considered "routable" and passes this check
* Nebula's firewall rules are evaluated as normal. In particular, the `cidr` parameter will be checked against the IPv4 address, NOT the IPv6 address in the Phone's certificate
* @Nate I think this is "correct", but could be a source of footgun
* Let's assume the Nebula rules accept the traffic
* We create a conntrack entry for this flow
* We do not want to transmit with the IPv4 source IP we got from Phone. We don't want the Phone's IP assignments (in this scheme) to enter the network-space on Router at all.
* To this end, we rewrite the source address (and port, if needed) to our own pre-selected IPv4 link-local address. This address will never actually leave the machine, but we need it so return traffic can be routed back to the nebula tun on Router
* Replace source IP with "Router's SNAT address"
* Look in our conntrack table, and ensure we do not already have a flow that matches this srcip/srcport/dstip/dstport/proto tuple
* if we do, increment srcport until we find an empty slot. Only use ephemeral ports. This gives 0x7ff flows per dstip/dstport/proto tuple, which ought to be plenty.
* Record the original srcip/srcport as part of the conntrack data for later
* Fix checksums
* Nebula writes the rewritten packet to Router's tun
* netfilter picks up the packet. In this example, Router is using `iptables`. A rule in the `nat` table similar to `-A POSTROUTING -d PRINTER_UNDERLAY_IP_HERE/32 -j MASQUERADE` is hit
* This ensures that "Router's SNAT address" never actually leaves Router.
* The packet leaves Router, and hits Printer
* Printer gleefully accepts the SYN from Router, and replies with an ACK
* iptables on Router de-masquerades the packet, and delivers it to the Nebula tun
* Nebula reads the packet off the tun. Because it came from the tun, and not UDP, remember that this is considered "inside" traffic and will be evaluated as "outbound" traffic by Nebula.
* Because this is inside traffic, it needs to be associated with a HostInfo before we can pass it to the firewall.
* Check that the packet is addressed to the "Router's SNAT address". If so, attempt to un-SNAT by "peeking" into conntrack
* If a Router needs to speak to _another_ Router with v4-in-v6 unsafe_routes like this, it _must_ use a distinct address from the "Router's SNAT address"
* the easy way on Linux to assure this is to set a route for the "Router SNAT address" to the Nebula tun, but not actually assign the address
* The "peek" into conntrack succeeds, and we find everything we need to rewrite the packet for transmission to Phone, as well as Phone's overlay IP, which lets us locate Phone's HostInfo
* The packet is rewritten, replacing the destination address/port to match the ones Phone expects
* checksums corrected
* Check the Nebula firewall, and see that we have a valid conntrack entry (wow!)
* we could _technically_ skip this check, but I dislike not passing all traffic we intend to accept into `firewall.Drop()`. The second conntrack locking-lookup does suck. There's room for improvement here.
* The traffic is accepted, encrypted, and sent back to Phone
* Phone gets the packet from Router, decrypts it, checks the firewall
* we have a conntrack entry for this flow, so the firewall accepts it, and delivers it to the tun
* Both sides now have a nice conntrack entry, and traffic should continue to flow uninterrupted until it expires
This conntrack entry technically creates a risk though. Let's examine that.
The Phone will accept inbound traffic matching the conntrack spec from any Router-like host authorized to speak for that UnsafeRoute, not just Router. In theory, this is desireable, and the risk is mitigated by accepting/trusting Nebula's certificate model.
There's a good chance that if you "switch" from one Router to another, you'll lose your session on your Printer-like host. Such is life under NAT!
Can the Router be exploited somehow?
* an attacker that shares a network with Printer would be able to spoof traffic as if they are Printer. This is the same risk as UnsafeNetworks today.
* an attacker on the overlay would have their traffic evaluated as inbound
* if they try to tx on the same source IP as Phone, SNAT will assign a different port
* if they try to send inbound traffic that matches the un-masqueraded traffic iptables would have delivered
* conntrack will accept the packet, but before we finish firewalling and return, is the applySnat step
* this will fail because the hostinfo that sent the packet does not contain the vpnip that is associated with the snat entry

View File

@@ -441,7 +441,7 @@ func (c *certificateV2) validate() error {
}
} else if network.Addr().Is4() {
if !hasV4Networks {
return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network)
//return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network)
}
}
}

View File

@@ -2,6 +2,7 @@ package nebula
import (
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
@@ -22,10 +23,29 @@ import (
"github.com/slackhq/nebula/firewall"
)
var ErrCannotSNAT = errors.New("cannot snat this packet")
var ErrSNATIdentityMismatch = errors.New("refusing to SNAT for mismatched host")
type FirewallInterface interface {
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error
}
type snatInfo struct {
//Src is the source IP+port to write into unsafe-route-bound packet
Src netip.AddrPort
//SrcVpnIp is the overlay IP associated with this flow. It's needed to associate reply traffic so we can get it back to the right host.
SrcVpnIp netip.Addr
//SnatPort is the port to rewrite into an overlay-bound packet
SnatPort uint16
}
func (s *snatInfo) Valid() bool {
if s == nil {
return false
}
return s.Src.IsValid()
}
type conn struct {
Expires time.Time // Time when this conntrack entry will expire
@@ -34,6 +54,9 @@ type conn struct {
// fields pack for free after the uint32 above
incoming bool
rulesVersion uint16
//for SNAT support
snat *snatInfo
}
// TODO: need conntrack max tracked connections handling
@@ -66,6 +89,7 @@ type Firewall struct {
defaultLocalCIDRAny bool
incomingMetrics firewallMetrics
outgoingMetrics firewallMetrics
snatAddr netip.Addr
l *logrus.Logger
}
@@ -83,6 +107,15 @@ type FirewallConntrack struct {
TimerWheel *TimerWheel[firewall.Packet]
}
func (ct *FirewallConntrack) dupeConnUnlocked(fp firewall.Packet, c *conn, timeout time.Duration) {
if _, ok := ct.Conns[fp]; !ok {
ct.TimerWheel.Advance(time.Now())
ct.TimerWheel.Add(fp, timeout)
}
ct.Conns[fp] = c
}
// FirewallTable is the entry point for a rule, the evaluation order is:
// Proto AND port AND (CA SHA or CA name) AND local CIDR AND (group OR groups OR name OR remote CIDR)
type FirewallTable struct {
@@ -131,7 +164,7 @@ type firewallLocalCIDR struct {
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
// The certificate provided should be the highest version loaded in memory.
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate, snatAddr netip.Addr) *Firewall {
//TODO: error on 0 duration
var tmin, tmax time.Duration
@@ -149,12 +182,14 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
tmax = defaultTimeout
}
hasV4Networks := false
routableNetworks := new(bart.Lite)
var assignedNetworks []netip.Prefix
for _, network := range c.Networks() {
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
routableNetworks.Insert(nprefix)
assignedNetworks = append(assignedNetworks, network)
hasV4Networks = hasV4Networks || network.Addr().Is4()
}
hasUnsafeNetworks := false
@@ -163,6 +198,10 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
hasUnsafeNetworks = true
}
if !hasUnsafeNetworks || hasV4Networks {
snatAddr = netip.Addr{} //disable using the special snat address if it doesn't make sense to use it
}
return &Firewall{
Conntrack: &FirewallConntrack{
Conns: make(map[firewall.Packet]*conn),
@@ -176,6 +215,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
routableNetworks: routableNetworks,
assignedNetworks: assignedNetworks,
hasUnsafeNetworks: hasUnsafeNetworks,
snatAddr: snatAddr,
l: l,
incomingMetrics: firewallMetrics{
@@ -191,7 +231,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
}
}
func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C, snatAddr netip.Addr) (*Firewall, error) {
certificate := cs.getCertificate(cert.Version2)
if certificate == nil {
certificate = cs.getCertificate(cert.Version1)
@@ -201,14 +241,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
panic("No certificate available to reconfigure the firewall")
}
fw := NewFirewall(
l,
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
certificate,
//TODO: max_connections
)
fw := NewFirewall(l, c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), certificate, snatAddr)
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false)
@@ -314,6 +347,11 @@ func (f *Firewall) GetRuleHashes() string {
return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10)
}
func (f *Firewall) ShouldUnSNAT(fp *firewall.Packet) bool {
// f.snatAddr is only valid if we're a snat-capable router
return f.snatAddr.IsValid() && fp.RemoteAddr == f.snatAddr
}
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
var table string
if inbound {
@@ -414,50 +452,204 @@ var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate
var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses")
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
func (f *Firewall) unSnat(data []byte, fp *firewall.Packet) netip.Addr {
c := f.peek(*fp) //unfortunately this needs to lock. Surely there's a better way.
if c == nil {
return netip.Addr{}
}
if !c.snat.Valid() {
return netip.Addr{}
}
oldIP := netip.AddrPortFrom(f.snatAddr, fp.RemotePort)
rewritePacket(data, fp, oldIP, c.snat.Src, 16, 2)
return c.snat.SrcVpnIp
}
func rewritePacket(data []byte, fp *firewall.Packet, oldIP netip.AddrPort, newIP netip.AddrPort, ipOffset int, portOffset int) {
//change address
copy(data[ipOffset:], newIP.Addr().AsSlice())
recalcIPv4Checksum(data, oldIP.Addr(), newIP.Addr())
ipHeaderLen := int(data[0]&0x0F) * 4
switch fp.Protocol {
case firewall.ProtoICMP:
binary.BigEndian.PutUint16(data[ipHeaderLen+4:ipHeaderLen+6], newIP.Port()) //we use the ID field as a "port" for ICMP
icmpCode := uint16(data[ipHeaderLen+1]) //todo not snatting on code yet (but Linux would)
recalcICMPv4Checksum(data, icmpCode, icmpCode, oldIP.Port(), newIP.Port())
case firewall.ProtoUDP:
dstport := ipHeaderLen + portOffset
binary.BigEndian.PutUint16(data[dstport:dstport+2], newIP.Port())
recalcUDPv4Checksum(data, oldIP, newIP)
case firewall.ProtoTCP:
dstport := ipHeaderLen + portOffset
binary.BigEndian.PutUint16(data[dstport:dstport+2], newIP.Port())
recalcTCPv4Checksum(data, oldIP, newIP)
}
}
func (f *Firewall) findUsableSNATPort(fp *firewall.Packet, c *conn) error {
oldPort := fp.RemotePort
conntrack := f.Conntrack
conntrack.Lock()
defer conntrack.Unlock()
for numPortsChecked := 0; numPortsChecked < 0x7ff; numPortsChecked++ {
_, ok := conntrack.Conns[*fp]
if !ok {
//yay, we can use this port
//track the snatted flow with the same expiration as the unsnatted version
conntrack.dupeConnUnlocked(*fp, c, f.packetTimeout(*fp))
return nil
}
//increment and retry. There's probably better strategies out there
fp.RemotePort++
if fp.RemotePort < 0x7ff {
fp.RemotePort += 0x7ff // keep it ephemeral for now
}
}
//if we made it here, we failed
fp.RemotePort = oldPort
return ErrCannotSNAT
}
func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo *HostInfo) error {
if !f.snatAddr.IsValid() {
return ErrCannotSNAT
}
if c.snat.Valid() {
//old flow: make sure it came from the right place
if !slices.Contains(hostinfo.vpnAddrs, c.snat.SrcVpnIp) {
return ErrSNATIdentityMismatch
}
fp.RemoteAddr = f.snatAddr
fp.RemotePort = c.snat.SnatPort
} else if hostinfo.vpnAddrs[0].Is6() {
//we got a new flow
c.snat = &snatInfo{
Src: netip.AddrPortFrom(fp.RemoteAddr, fp.RemotePort),
SrcVpnIp: hostinfo.vpnAddrs[0],
}
fp.RemoteAddr = f.snatAddr
//find a new port to use, if needed
err := f.findUsableSNATPort(fp, c)
if err != nil {
c.snat = nil
return err
}
c.snat.SnatPort = fp.RemotePort //may have been updated inside f.findUsableSNATPort
} else {
return ErrCannotSNAT
}
newIP := netip.AddrPortFrom(f.snatAddr, c.snat.SnatPort)
rewritePacket(data, fp, c.snat.Src, newIP, 12, 0)
return nil
}
func (f *Firewall) identifyNetworkType(h *HostInfo, fp firewall.Packet) NetworkType {
if h.networks == nil {
// Simple case: Certificate has one address and no unsafe networks
if h.vpnAddrs[0] == fp.RemoteAddr {
return NetworkTypeVPN
} else if fp.IsIPv4() && h.HasOnlyV6Addresses() {
return NetworkTypeUncheckedSNATPeer
} else {
return NetworkTypeInvalidPeer
}
} else if nwType, ok := h.networks.Lookup(fp.RemoteAddr); ok {
//todo check for if fp.RemoteAddr is our f.snatAddr here too? Does that need a special case?
return nwType //will return NetworkTypeVPN or NetworkTypeUnsafe
} else if fp.IsIPv4() && h.HasOnlyV6Addresses() { //todo surely I'm smart enough to avoid writing these branches twice
return NetworkTypeUncheckedSNATPeer
} else {
return NetworkTypeInvalidPeer
}
}
func (f *Firewall) allowNetworkType(nwType NetworkType) error {
switch nwType {
case NetworkTypeVPN:
return nil
case NetworkTypeInvalidPeer:
return ErrInvalidRemoteIP
case NetworkTypeVPNPeer:
//todo we might need a specialSnatMode case in here to handle routers with v4 addresses when we don't also have a v4 address?
return ErrPeerRejected // reject for now, one day this may have different FW rules
case NetworkTypeUnsafe:
return nil // nothing special, one day this may have different FW rules
case NetworkTypeUncheckedSNATPeer:
if f.snatAddr.IsValid() {
return nil //todo is this enough?
} else {
return ErrInvalidRemoteIP
}
default:
return ErrUnknownNetworkType //should never happen
}
}
func (f *Firewall) willingToHandleLocalAddr(incoming bool, fp firewall.Packet, remoteNwType NetworkType) error {
if f.routableNetworks.Contains(fp.LocalAddr) {
return nil //easy, this should handle NetworkTypeVPN in all cases, and NetworkTypeUnsafe on the router side
}
//watch out, when incoming, this function decides if we will deliver a packet locally
//when outgoing, much less important, it just decides if we're willing to tx
switch remoteNwType {
// we never want to accept unconntracked inbound traffic from these network types, but outbound is okay.
// It's the recipient's job to validate and accept or deny the packet.
case NetworkTypeUncheckedSNATPeer, NetworkTypeUnsafe:
//NetworkTypeUnsafe needed here to allow inbound from an unsafe-router
if incoming {
return ErrInvalidLocalIP
}
return nil
default:
return ErrInvalidLocalIP
}
}
// Drop returns an error if the packet should be dropped, explaining why. It
// returns nil if the packet should not be dropped.
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
func (f *Firewall) Drop(fp firewall.Packet, pkt []byte, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
table := f.OutRules
if incoming {
table = f.InRules
}
snatmode := fp.IsIPv4() && h.HasOnlyV6Addresses() && f.snatAddr.IsValid()
if snatmode {
//if this is an IPv4 packet from a V6 only host, and we're configured to snat that kind of traffic, it must be snatted,
//so it can never be in the localcache, which lacks SNAT data
//nil out the pointer to avoid ever using it
localCache = nil
}
// Check if we spoke to this tuple, if we did then allow this packet
if f.inConns(fp, h, caPool, localCache) {
if localCache != nil {
if _, ok := localCache[fp]; ok {
return nil //packet matched the cache, we're not snatting, we can return early
}
}
c := f.inConns(fp, h, caPool, localCache)
if c != nil {
if incoming && snatmode {
return f.applySnat(pkt, &fp, c, h)
}
return nil
}
// Make sure remote address matches nebula certificate, and determine how to treat it
if h.networks == nil {
// Simple case: Certificate has one address and no unsafe networks
if h.vpnAddrs[0] != fp.RemoteAddr {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP
}
} else {
nwType, ok := h.networks.Lookup(fp.RemoteAddr)
if !ok {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrInvalidRemoteIP
}
switch nwType {
case NetworkTypeVPN:
break // nothing special
case NetworkTypeVPNPeer:
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrPeerRejected // reject for now, one day this may have different FW rules
case NetworkTypeUnsafe:
break // nothing special, one day this may have different FW rules
default:
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return ErrUnknownNetworkType //should never happen
}
remoteNetworkType := f.identifyNetworkType(h, fp)
if err := f.allowNetworkType(remoteNetworkType); err != nil {
f.metrics(incoming).droppedRemoteAddr.Inc(1)
return err
}
// Make sure we are supposed to be handling this local ip address
if !f.routableNetworks.Contains(fp.LocalAddr) {
if err := f.willingToHandleLocalAddr(incoming, fp, remoteNetworkType); err != nil {
f.metrics(incoming).droppedLocalAddr.Inc(1)
return ErrInvalidLocalIP
}
table := f.OutRules
if incoming {
table = f.InRules
return err
}
// We now know which firewall table to check against
@@ -467,9 +659,14 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
}
// We always want to conntrack since it is a faster operation
f.addConn(fp, incoming)
c = f.addConn(fp, incoming)
return nil
if incoming && remoteNetworkType == NetworkTypeUncheckedSNATPeer {
return f.applySnat(pkt, &fp, c, h)
} else {
//outgoing snat is handled before this function is called
return nil
}
}
func (f *Firewall) metrics(incoming bool) firewallMetrics {
@@ -496,12 +693,23 @@ func (f *Firewall) EmitStats() {
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
}
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
if localCache != nil {
if _, ok := localCache[fp]; ok {
return true
}
func (f *Firewall) peek(fp firewall.Packet) *conn {
conntrack := f.Conntrack
conntrack.Lock()
// Purge every time we test
ep, has := conntrack.TimerWheel.Purge()
if has {
f.evict(ep)
}
c := conntrack.Conns[fp]
conntrack.Unlock()
return c
}
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) *conn {
conntrack := f.Conntrack
conntrack.Lock()
@@ -515,7 +723,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
if !ok {
conntrack.Unlock()
return false
return nil
}
if c.rulesVersion != f.rulesVersion {
@@ -538,7 +746,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
}
delete(conntrack.Conns, fp)
conntrack.Unlock()
return false
return nil
}
if f.l.Level >= logrus.DebugLevel {
@@ -568,12 +776,11 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
localCache[fp] = struct{}{}
}
return true
return c
}
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
func (f *Firewall) packetTimeout(fp firewall.Packet) time.Duration {
var timeout time.Duration
c := &conn{}
switch fp.Protocol {
case firewall.ProtoTCP:
@@ -583,7 +790,13 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
default:
timeout = f.DefaultTimeout
}
return timeout
}
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) *conn {
c := &conn{}
timeout := f.packetTimeout(fp)
conntrack := f.Conntrack
conntrack.Lock()
if _, ok := conntrack.Conns[fp]; !ok {
@@ -597,7 +810,9 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
c.rulesVersion = f.rulesVersion
c.Expires = time.Now().Add(timeout)
conntrack.Conns[fp] = c
conntrack.Unlock()
return c
}
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
@@ -682,6 +897,13 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer
var port int32
if p.Protocol == firewall.ProtoICMP {
// port numbers are re-used for connection tracking and SNAT,
// but we don't want to actually filter on them for ICMP
// ICMP6 is omitted because we don't attempt to parse code/identifier/etc out of ICMP6
return fp[firewall.PortAny].match(p, c, caPool)
}
if p.Fragment {
port = firewall.PortFragment
} else if incoming {

View File

@@ -31,6 +31,10 @@ type Packet struct {
Fragment bool
}
func (fp *Packet) IsIPv4() bool {
return fp.LocalAddr.Is4() && fp.RemoteAddr.Is4()
}
func (fp *Packet) Copy() *Packet {
return &Packet{
LocalAddr: fp.LocalAddr,

View File

@@ -21,7 +21,7 @@ import (
func TestNewFirewall(t *testing.T) {
l := test.NewLogger()
c := &dummyCert{}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
conntrack := fw.Conntrack
assert.NotNil(t, conntrack)
assert.NotNil(t, conntrack.Conns)
@@ -36,23 +36,23 @@ func TestNewFirewall(t *testing.T) {
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c, netip.Addr{})
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c, netip.Addr{})
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c, netip.Addr{})
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c, netip.Addr{})
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c, netip.Addr{})
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
}
@@ -63,7 +63,7 @@ func TestFirewall_AddRule(t *testing.T) {
l.SetOutput(ob)
c := &dummyCert{}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
assert.NotNil(t, fw.InRules)
assert.NotNil(t, fw.OutRules)
@@ -79,56 +79,56 @@ func TestFirewall_AddRule(t *testing.T) {
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", ""))
assert.Nil(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", ""))
//no matter what port is given for icmp, it should end up as "any"
assert.Nil(t, fw.InRules.ICMP[firewall.PortAny].Any.Any)
assert.Empty(t, fw.InRules.ICMP[firewall.PortAny].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[firewall.PortAny].Any.Hosts, "h1")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
anyIp, err := netip.ParsePrefix("0.0.0.0/0")
require.NoError(t, err)
@@ -139,7 +139,7 @@ func TestFirewall_AddRule(t *testing.T) {
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
assert.False(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
anyIp6, err := netip.ParsePrefix("::/0")
require.NoError(t, err)
@@ -150,28 +150,28 @@ func TestFirewall_AddRule(t *testing.T) {
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
assert.False(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", ""))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", ""))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", ""))
require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", ""))
}
@@ -208,49 +208,49 @@ func TestFirewall_Drop(t *testing.T) {
}
h.buildNetworks(myVpnNetworksTable, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, nil, false, &h, cp, nil))
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
// Allow outbound because conntrack
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
// test remote mismatch
oldRemote := p.RemoteAddr
p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteAddr = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
}
func TestFirewall_DropV6(t *testing.T) {
@@ -287,49 +287,49 @@ func TestFirewall_DropV6(t *testing.T) {
}
h.buildNetworks(myVpnNetworksTable, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, nil, false, &h, cp, nil))
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
// Allow outbound because conntrack
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
// test remote mismatch
oldRemote := p.RemoteAddr
p.RemoteAddr = netip.MustParseAddr("fd12::56")
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrInvalidRemoteIP)
p.RemoteAddr = oldRemote
// ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
}
func BenchmarkFirewallTable_match(b *testing.B) {
@@ -532,15 +532,15 @@ func TestFirewall_Drop2(t *testing.T) {
}
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", ""))
cp := cert.NewCAPool()
// h1/c1 lacks the proper groups
require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
require.ErrorIs(t, fw.Drop(p, nil, true, &h1, cp, nil), ErrNoMatchingRule)
// c has the proper groups
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
}
func TestFirewall_Drop3(t *testing.T) {
@@ -612,24 +612,24 @@ func TestFirewall_Drop3(t *testing.T) {
}
h3.buildNetworks(myVpnNetworksTable, c3.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", ""))
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha"))
cp := cert.NewCAPool()
// c1 should pass because host match
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil))
// c2 should pass because ca sha match
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h2, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h2, cp, nil))
// c3 should fail because no match
resetConntrack(fw)
assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, true, &h3, cp, nil), ErrNoMatchingRule)
// Test a remote address match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", ""))
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil))
}
func TestFirewall_Drop3V6(t *testing.T) {
@@ -664,10 +664,10 @@ func TestFirewall_Drop3V6(t *testing.T) {
h.buildNetworks(myVpnNetworksTable, c.Certificate)
// Test a remote address match
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
cp := cert.NewCAPool()
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", ""))
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
}
func TestFirewall_DropConntrackReload(t *testing.T) {
@@ -704,35 +704,35 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
}
h.buildNetworks(myVpnNetworksTable, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
cp := cert.NewCAPool()
// Drop outbound
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
// Allow outbound because conntrack
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
// Allow outbound because conntrack and new rules allow port 10
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", ""))
fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1
// Drop outbound because conntrack doesn't match new ruleset
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule)
}
func TestFirewall_ICMPPortBehavior(t *testing.T) {
@@ -771,19 +771,19 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
}
t.Run("ICMP allowed", func(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 0, 0, []string{"any"}, "", "", "", "", ""))
t.Run("zero ports", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 0
p.RemotePort = 0
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
})
t.Run("nonzero ports", func(t *testing.T) {
@@ -791,29 +791,29 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
p.LocalPort = 0xabcd
p.RemotePort = 0x1234
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
})
})
t.Run("Any proto, some ports allowed", func(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 444, []string{"any"}, "", "", "", "", ""))
t.Run("zero ports, still blocked", func(t *testing.T) {
p := templ.Copy()
p.LocalPort = 0
p.RemotePort = 0
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule)
//now also allow outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
})
t.Run("nonzero ports, still blocked", func(t *testing.T) {
@@ -821,12 +821,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
p.LocalPort = 0xabcd
p.RemotePort = 0x1234
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule)
//now also allow outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
})
t.Run("nonzero, matching ports, still blocked", func(t *testing.T) {
@@ -834,16 +834,16 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
p.LocalPort = 80
p.RemotePort = 80
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule)
//now also allow outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
})
})
t.Run("Any proto, any port", func(t *testing.T) {
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
t.Run("zero ports, allowed", func(t *testing.T) {
resetConntrack(fw)
@@ -851,12 +851,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
p.LocalPort = 0
p.RemotePort = 0
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
})
t.Run("nonzero ports, allowed", func(t *testing.T) {
@@ -865,15 +865,15 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
p.LocalPort = 0xabcd
p.RemotePort = 0x1234
// Drop outbound
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound
resetConntrack(fw)
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
//now also allow outbound
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
//different ID is blocked
p.RemotePort++
require.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
require.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
})
})
@@ -908,7 +908,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
}
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", ""))
cp := cert.NewCAPool()
@@ -922,7 +922,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
Protocol: firewall.ProtoUDP,
Fragment: false,
}
assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
assert.Equal(t, fw.Drop(p, nil, true, &h1, cp, nil), ErrInvalidRemoteIP)
}
func BenchmarkLookup(b *testing.B) {
@@ -1047,53 +1047,53 @@ func TestNewFirewallFromConfig(t *testing.T) {
conf := config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
// Test both port and code
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
// Test missing host, group, cidr, ca_name and ca_sha
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
// Test code/port error
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
// Test proto error
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
// Test cidr parse error
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test local_cidr parse error
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test both group and groups
conf = config.NewC(l)
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
_, err = NewFirewallFromConfig(l, cs, conf)
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
}
@@ -1336,7 +1336,7 @@ func (c *testcase) Test(t *testing.T, fw *Firewall) {
t.Helper()
cp := cert.NewCAPool()
resetConntrack(fw)
err := fw.Drop(c.p, true, c.h, cp, nil)
err := fw.Drop(c.p, nil, true, c.h, cp, nil)
if c.err == nil {
require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
} else {
@@ -1344,7 +1344,7 @@ func (c *testcase) Test(t *testing.T, fw *Firewall) {
}
}
func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
func buildHostinfo(setup testsetup, theirPrefixes ...netip.Prefix) *HostInfo {
c1 := dummyCert{
name: "host1",
networks: theirPrefixes,
@@ -1364,6 +1364,11 @@ func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) te
h.vpnAddrs[i] = theirPrefixes[i].Addr()
}
h.buildNetworks(setup.myVpnNetworksTable, &c1)
return &h
}
func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
h := buildHostinfo(setup, theirPrefixes...)
p := firewall.Packet{
LocalAddr: setup.c.Networks()[0].Addr(), //todo?
RemoteAddr: theirPrefixes[0].Addr(),
@@ -1373,9 +1378,9 @@ func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) te
Fragment: false,
}
return testcase{
h: &h,
h: h,
p: p,
c: &c1,
c: h.ConnectionState.peerCert.Certificate,
err: err,
}
}
@@ -1397,12 +1402,25 @@ func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testse
return newSetupFromCert(t, l, c)
}
func newSnatSetup(t *testing.T, l *logrus.Logger, myPrefix netip.Prefix, snatAddr netip.Addr) testsetup {
c := dummyCert{
name: "me",
networks: []netip.Prefix{myPrefix},
groups: []string{"default-group"},
issuer: "signer-shasum",
}
out := newSetupFromCert(t, l, c)
out.fw.snatAddr = snatAddr
return out
}
func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
myVpnNetworksTable := new(bart.Lite)
for _, prefix := range c.Networks() {
myVpnNetworksTable.Insert(prefix)
}
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
return testsetup{
@@ -1532,3 +1550,59 @@ func resetConntrack(fw *Firewall) {
fw.Conntrack.Conns = map[firewall.Packet]*conn{}
fw.Conntrack.Unlock()
}
func TestFirewall_SNAT(t *testing.T) {
t.Parallel()
l := test.NewLogger()
ob := &bytes.Buffer{}
l.SetOutput(ob)
cp := cert.NewCAPool()
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
MyCert := dummyCert{
name: "me",
networks: []netip.Prefix{myPrefix},
groups: []string{"default-group"},
issuer: "signer-shasum",
}
theirPrefix := netip.MustParsePrefix("1.2.2.2/8")
snatAddr := netip.MustParseAddr("169.254.55.96")
t.Run("allow inbound all matching", func(t *testing.T) {
t.Parallel()
myCert := MyCert.Copy()
setup := newSnatSetup(t, l, myPrefix, snatAddr)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, snatAddr)
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
resetConntrack(setup.fw)
h := buildHostinfo(setup, theirPrefix)
p := firewall.Packet{
LocalAddr: setup.c.Networks()[0].Addr(), //todo?
RemoteAddr: h.vpnAddrs[0],
LocalPort: 10,
RemotePort: 90,
Protocol: firewall.ProtoUDP,
Fragment: false,
}
require.NoError(t, setup.fw.Drop(p, nil, true, h, cp, nil))
})
//t.Run("allow inbound unsafe route", func(t *testing.T) {
// t.Parallel()
// unsafePrefix := netip.MustParsePrefix("192.168.0.0/24")
// c := dummyCert{
// name: "me",
// networks: []netip.Prefix{myPrefix},
// unsafeNetworks: []netip.Prefix{unsafePrefix},
// groups: []string{"default-group"},
// issuer: "signer-shasum",
// }
// unsafeSetup := newSetupFromCert(t, l, c)
// tc := buildTestCase(unsafeSetup, nil, twoPrefixes...)
// tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
// tc.err = ErrNoMatchingRule
// tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
// require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", ""))
// tc.err = nil
// tc.Test(t, unsafeSetup.fw) //should pass
//})
}

View File

@@ -224,6 +224,9 @@ const (
NetworkTypeVPNPeer
// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
NetworkTypeUnsafe
// NetworkTypeUncheckedSNATPeer is used to indicate traffic we're willing to route, but never deliver to a NetworkTypeVPN
NetworkTypeUncheckedSNATPeer
NetworkTypeInvalidPeer
)
type HostInfo struct {
@@ -277,6 +280,15 @@ type HostInfo struct {
lastUsed time.Time
}
func (i *HostInfo) HasOnlyV6Addresses() bool {
for _, vpnIp := range i.vpnAddrs {
if !vpnIp.Is6() {
return false
}
}
return true
}
type ViaSender struct {
UdpAddr netip.AddrPort
relayHI *HostInfo // relayHI is the host info object of the relay

View File

@@ -48,9 +48,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
return
}
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
})
hostinfo, ready := f.getHostinfo(packet, fwPacket)
if hostinfo == nil {
f.rejectInside(packet, out, q)
@@ -66,10 +64,9 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
return
}
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
dropReason := f.firewall.Drop(*fwPacket, packet, false, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason == nil {
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
} else {
f.rejectInside(packet, out, q)
if f.l.Level >= logrus.DebugLevel {
@@ -81,6 +78,26 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
}
}
func (f *Interface) getHostinfo(packet []byte, fwPacket *firewall.Packet) (*HostInfo, bool) {
if f.firewall.ShouldUnSNAT(fwPacket) {
//unsnat packet re-writing also happens here, would be nice to not,
//but we need to do the unsnat lookup to find the hostinfo so we can run the firewall checks
destVpnAddr := f.firewall.unSnat(packet, fwPacket)
if destVpnAddr.IsValid() {
//because this was a snatted packet, we know it has an on-overlay destination, so no routing should be required.
return f.getOrHandshakeNoRouting(destVpnAddr, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
})
} else {
return nil, false
}
} else { //if we didn't need to unsnat
return f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
})
}
}
func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
if !f.firewall.InSendReject {
return
@@ -218,7 +235,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
}
// check if packet is in outbound fw rules
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
dropReason := f.firewall.Drop(*fp, nil, false, hostinfo, f.pki.GetCAPool(), nil)
if dropReason != nil {
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("fwPacket", fp).

View File

@@ -56,6 +56,7 @@ type Interface struct {
inside overlay.Device
pki *PKI
firewall *Firewall
snatAddr netip.Addr
connectionManager *connectionManager
handshakeManager *HandshakeManager
serveDns bool
@@ -339,7 +340,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
return
}
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c, f.firewall.snatAddr)
if err != nil {
f.l.WithError(err).Error("Error while creating firewall during reload")
return

View File

@@ -66,7 +66,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
}
fw, err := NewFirewallFromConfig(l, pki.getCertState(), c)
snatAddr := netip.MustParseAddr("169.254.55.96") //todo get this from tun!
fw, err := NewFirewallFromConfig(l, pki.getCertState(), c, snatAddr)
if err != nil {
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
}
@@ -135,7 +136,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
deviceFactory = overlay.NewDeviceFromConfig
}
tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines)
cs := pki.getCertState()
tun, err = deviceFactory(c, l, cs.myVpnNetworks, cs.GetDefaultCertificate().UnsafeNetworks(), routines)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
}

View File

@@ -509,7 +509,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
return false
}
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
dropReason := f.firewall.Drop(*fwPacket, out, true, hostinfo, f.pki.GetCAPool(), localCache)
if dropReason != nil {
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
// This gives us a buffer to build the reject packet in

View File

@@ -22,22 +22,22 @@ func (e *NameError) Error() string {
}
// TODO: We may be able to remove routines
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error)
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) {
switch {
case c.GetBool("tun.disabled", false):
tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
return tun, nil
default:
return newTun(c, l, vpnNetworks, routines > 1)
return newTun(c, l, vpnNetworks, unsafeNetworks, routines > 1)
}
}
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
return newTunFromFd(c, l, *fd, vpnNetworks)
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) {
return newTunFromFd(c, l, *fd, vpnNetworks, unsafeNetworks)
}
}

View File

@@ -26,7 +26,7 @@ type tun struct {
l *logrus.Logger
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, _ []netip.Prefix) (*tun, error) {
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
@@ -53,7 +53,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
return t, nil
}
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ []netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in Android")
}

View File

@@ -79,7 +79,7 @@ type ifreqAlias6 struct {
Lifetime addrLifetime
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) {
name := c.GetString("tun.dev", "")
ifIndex := -1
if name != "" && name != "utun" {
@@ -153,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
return
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
}

View File

@@ -199,11 +199,11 @@ func (t *tun) Close() error {
return nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open existing tun device
var fd int
var err error

View File

@@ -28,11 +28,11 @@ type tun struct {
l *logrus.Logger
}
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ []netip.Prefix, _ bool) (*tun, error) {
return nil, fmt.Errorf("newTun not supported in iOS")
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, _ []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
t := &tun{
vpnNetworks: vpnNetworks,

View File

@@ -26,14 +26,15 @@ import (
type tun struct {
io.ReadWriteCloser
fd int
Device string
vpnNetworks []netip.Prefix
MaxMTU int
DefaultMTU int
TXQueueLen int
deviceIndex int
ioctlFd uintptr
fd int
Device string
vpnNetworks []netip.Prefix
unsafeNetworks []netip.Prefix
MaxMTU int
DefaultMTU int
TXQueueLen int
deviceIndex int
ioctlFd uintptr
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
@@ -71,10 +72,10 @@ type ifreqQLEN struct {
pad [8]byte
}
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) {
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, vpnNetworks)
t, err := newTunGeneric(c, l, file, vpnNetworks, unsafeNetworks)
if err != nil {
return nil, err
}
@@ -84,7 +85,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
return t, nil
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
if err != nil {
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
@@ -123,7 +124,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
name := strings.Trim(string(req.Name[:]), "\x00")
file := os.NewFile(uintptr(fd), "/dev/net/tun")
t, err := newTunGeneric(c, l, file, vpnNetworks)
t, err := newTunGeneric(c, l, file, vpnNetworks, unsafeNetworks)
if err != nil {
return nil, err
}
@@ -133,11 +134,12 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
return t, nil
}
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) {
t := &tun{
ReadWriteCloser: file,
fd: int(file.Fd()),
vpnNetworks: vpnNetworks,
unsafeNetworks: unsafeNetworks,
TXQueueLen: c.GetInt("tun.tx_queue", 500),
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
@@ -427,6 +429,27 @@ func (t *tun) setMTU() {
}
}
func (t *tun) setSnatRoute() error {
snataddr := netip.MustParsePrefix("169.254.55.96/32") //todo get this from elsewhere? Or maybe we should pick it, and feed it back out to the firewall?
dr := &net.IPNet{
IP: snataddr.Masked().Addr().AsSlice(),
Mask: net.CIDRMask(snataddr.Bits(), snataddr.Addr().BitLen()),
}
nr := netlink.Route{
LinkIndex: t.deviceIndex,
Dst: dr,
//todo do we need these other options?
//MTU: t.DefaultMTU,
//AdvMSS: t.advMSS(Route{}),
Scope: unix.RT_SCOPE_LINK,
//Protocol: unix.RTPROT_KERNEL,
Table: unix.RT_TABLE_MAIN,
Type: unix.RTN_UNICAST,
}
return netlink.RouteReplace(&nr)
}
func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
dr := &net.IPNet{
IP: cidr.Masked().Addr().AsSlice(),
@@ -503,6 +526,18 @@ func (t *tun) addRoutes(logErrors bool) error {
}
}
onlyV6Addresses := false
for _, n := range t.vpnNetworks {
if n.Addr().Is6() {
onlyV6Addresses = true
break
}
}
if len(t.unsafeNetworks) != 0 && onlyV6Addresses {
return t.setSnatRoute()
}
return nil
}

View File

@@ -70,11 +70,11 @@ type tun struct {
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device
var err error
deviceName := c.GetString("tun.dev", "")

View File

@@ -63,11 +63,11 @@ type tun struct {
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) {
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) {
// Try to open tun device
var err error
deviceName := c.GetString("tun.dev", "")

View File

@@ -28,7 +28,7 @@ type TestTun struct {
TxPackets chan []byte // Packets transmitted outside by nebula
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*TestTun, error) {
_, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true)
if err != nil {
return nil, err
@@ -49,7 +49,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
}, nil
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*TestTun, error) {
return nil, fmt.Errorf("newTunFromFd not supported")
}

View File

@@ -38,11 +38,11 @@ type winTun struct {
tun *wintun.NativeTun
}
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) {
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (Device, error) {
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
}
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*winTun, error) {
err := checkWinTunExists()
if err != nil {
return nil, fmt.Errorf("can not load the wintun driver: %w", err)

View File

@@ -9,7 +9,7 @@ import (
"github.com/slackhq/nebula/routing"
)
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) {
return NewUserDevice(vpnNetworks)
}

91
snat.go Normal file
View File

@@ -0,0 +1,91 @@
package nebula
import (
"encoding/binary"
"net/netip"
)
func recalcIPv4Checksum(data []byte, oldSrcIP netip.Addr, newSrcIP netip.Addr) {
oldChecksum := binary.BigEndian.Uint16(data[10:12])
//because of how checksums work, we can re-use this function
checksum := calcNewTransportChecksum(oldChecksum, oldSrcIP, 0, newSrcIP, 0)
binary.BigEndian.PutUint16(data[10:12], checksum)
}
func calcNewTransportChecksum(oldChecksum uint16, oldSrcIP netip.Addr, oldSrcPort uint16, newSrcIP netip.Addr, newSrcPort uint16) uint16 {
oldIP := binary.BigEndian.Uint32(oldSrcIP.AsSlice())
newIP := binary.BigEndian.Uint32(newSrcIP.AsSlice())
// Start with inverted checksum
checksum := uint32(^oldChecksum)
// Subtract old IP (as two 16-bit words)
checksum += uint32(^uint16(oldIP >> 16))
checksum += uint32(^uint16(oldIP & 0xFFFF))
// Subtract old port
checksum += uint32(^oldSrcPort)
// Add new IP (as two 16-bit words)
checksum += uint32(newIP >> 16)
checksum += uint32(newIP & 0xFFFF)
// Add new port
checksum += uint32(newSrcPort)
// Fold carries
for checksum > 0xFFFF {
checksum = (checksum & 0xFFFF) + (checksum >> 16)
}
// Return ones' complement
return ^uint16(checksum)
}
func recalcV4TransportChecksum(offsetInsideHeader int, data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) {
ipHeaderOffset := int(data[0]&0x0F) * 4
offset := ipHeaderOffset + offsetInsideHeader
oldcsum := binary.BigEndian.Uint16(data[offset : offset+2])
checksum := calcNewTransportChecksum(oldcsum, oldSrcIP.Addr(), oldSrcIP.Port(), newSrcIP.Addr(), newSrcIP.Port())
binary.BigEndian.PutUint16(data[offset:offset+2], checksum)
}
func recalcUDPv4Checksum(data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) {
const offsetInsideHeader = 6
recalcV4TransportChecksum(offsetInsideHeader, data, oldSrcIP, newSrcIP)
}
func recalcTCPv4Checksum(data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) {
const offsetInsideHeader = 16
recalcV4TransportChecksum(offsetInsideHeader, data, oldSrcIP, newSrcIP)
}
func calcNewICMPChecksum(oldChecksum uint16, oldCode uint16, newCode uint16, oldID uint16, newID uint16) uint16 {
// Start with inverted checksum
checksum := uint32(^oldChecksum)
// Subtract old stuff
checksum += uint32(^oldCode)
checksum += uint32(^oldID)
// Add new stuff
checksum += uint32(newCode)
checksum += uint32(newID)
// Fold carries
for checksum > 0xFFFF {
checksum = (checksum & 0xFFFF) + (checksum >> 16)
}
// Return ones' complement
return ^uint16(checksum)
}
func recalcICMPv4Checksum(data []byte, oldCode uint16, newCode uint16, oldID uint16, newID uint16) {
const offsetInsideHeader = 2
ipHeaderOffset := int(data[0]&0x0F) * 4
offset := ipHeaderOffset + offsetInsideHeader
oldChecksum := binary.BigEndian.Uint16(data[offset : offset+2])
checksum := calcNewICMPChecksum(oldChecksum, oldCode, newCode, oldID, newID)
binary.BigEndian.PutUint16(data[offset:offset+2], checksum)
}