mirror of
https://github.com/slackhq/nebula.git
synced 2026-02-14 08:44:24 +01:00
use in-Nebula SNAT to send IPv4 UnsafeNetworks traffic over an IPv6 overlay
This commit is contained in:
102
SNAT.md
Normal file
102
SNAT.md
Normal file
@@ -0,0 +1,102 @@
|
||||
# Don't merge me
|
||||
|
||||
# Accessing IPv4-only UnsafeNetworks via an IPv6-only overlay
|
||||
|
||||
## Background
|
||||
Nebula is an VPN-like connectivity solution. It provides what we call an "overlay network",
|
||||
an additional IP-addressed network on top of one-or-more traditional "underlay networks".
|
||||
As long as two devices with Nebula certificates (credentials, essentially signed keypairs with metadata) can find each other
|
||||
and exchange traffic via a common underlay network (often this is the Internet),
|
||||
they will also be able to exchange traffic securely via a point-to-point, encrypted, authenticated tunnel.
|
||||
|
||||
Typically, all Nebula traffic is strongly associated with the Nebula certificate of the sender
|
||||
(that is, the source IP of all packets matches the IP listed in the sender's certificate).
|
||||
However, it is useful to be able to bend this rule. That is why there is another field in the Nebula certificate, named UnsafeNetworks,
|
||||
which lists the network prefixes that the host bearing this certificate is allowed to "speak for".
|
||||
|
||||
## Problem Statement
|
||||
We want IPv6-only overlay networks to be able to carry IPv4 traffic to reach off-overlay hosts via UnsafeNetworks
|
||||
|
||||
### Scenario
|
||||
|
||||
To illustrate this scenario, we will define 3 hosts:
|
||||
* a Phone, running Nebula, assigned the overlay IP fd01::AA/64. It has an undefined underlay, but we assert that it always has working IPv4 OR IPv6 connectivity to Router.
|
||||
* a Router, running Nebula, assigned the overlay IP fd01::55/64. It has a stable underlay link that Phone can always reach.
|
||||
* a Printer, which cannot run Nebula, and is only capable of IPv4 communication. It has a direct link to Router, but the Phone cannot reach it directly.
|
||||
|
||||
You, the User, wish to use your Phone to print out something on the Printer while you're away from home. How can we make this possible with your IPv6-only Nebula overlay?
|
||||
In particular, your Phone may connect to any cellular or public WiFi network, and we cannot control the IP address it will be assigned. If you MUST print, an IP conflict is not acceptable.
|
||||
Therefore, we cannot simply dismiss this problem by suggesting that you assign a small IPv4 network within your overlay. Sure, it probably works, and in this toy scenario, the odds of a conflict are pretty small. But it scales very poorly. What if a whole company needs to use this printer (or perhaps a less contrived need?)
|
||||
We can do better.
|
||||
|
||||
## Solution
|
||||
|
||||
* Even though Phone and Router lack IPv4 assignments, we can still put V4 addresses on their tun devices.
|
||||
* Each overlay host who wishes to use this feature shall select (or configure?) an assignment within 169.254.0.0/16, the IPv4 link-local range
|
||||
* this is a pretty small space, but it confines the region of IP conflict to a much smaller domain. And, because overlay hosts will never dial one another with this address, cryptographic verification of it via the certificate is less important.
|
||||
* On Phone, Nebula will configure an unsafe_route to the Printer using this address. Because it is a device route, we do not need to tell the operating system the address of the next hop (no `via`)
|
||||
* On Router, Nebula will use this address to masquerade packets from Phone. You'll see!
|
||||
* Let's walk through setting up a TCP session between Phone and Printer in this scheme:
|
||||
* Phone sends SYN to the printer's underlay IPv4 address
|
||||
* This packet lands on Phone's Nebula tun device
|
||||
* Nebula locates Router as the destination for this packet, as defined in `tun.unsafe_routes`
|
||||
* Nebula checks the packet against the outbound chain:
|
||||
* the destination IP of Printer is listed in Router's UnsafeNetworks, so that check will pass
|
||||
* Phone's source IP is not listed in any certificate, but because the destination address is of `NetworkTypeUnsafe` and this is an outgoing flow, we keep going
|
||||
* Actual outbound firewall rules get checked, assume they pass
|
||||
* conntrack entry created to permit replies
|
||||
* Phone encrypts the packet and sends it to Router
|
||||
* Router gets the packet from Phone, and decrypts it. It is passed to the Nebula firewall for evaluation:
|
||||
* `firewall.Drop()` on the Router's Nebula inbound rules:
|
||||
* Because Router is configured to allow SNAT, and this packet is an IPv4 packet from a IPv6-only host, the firewall module enters "snat mode" (`TODO explain?`)
|
||||
* This is a new flow, so the conntrack lookup for it fails
|
||||
* `firewall.identifyNetworkType()`
|
||||
* identify what "kind" of remote address (this is the inbound firewall, so the remote address is the packet's src IP) we've been given
|
||||
* `NetworkTypeVPN`, for example is a remote address that matches the peer's certificate
|
||||
* In this case, because the traffic is IPv4 traffic flowing from an IPv6-only host, and we've opted into supporting SNAT, this traffic is marked as `NetworkTypeUncheckedSNATPeer`
|
||||
* `firewall.allowNetworkType()` will allow `NetworkTypeUncheckedSNATPeer` traffic to proceed because we have opted into SNAT
|
||||
* `firewall.willingToHandleLocalAddr()` now needs to check if we're willing to accept the destination address
|
||||
* Because this traffic is addressed to a destination listed in our UnsafeNetworks, it's considered "routable" and passes this check
|
||||
* Nebula's firewall rules are evaluated as normal. In particular, the `cidr` parameter will be checked against the IPv4 address, NOT the IPv6 address in the Phone's certificate
|
||||
* @Nate I think this is "correct", but could be a source of footgun
|
||||
* Let's assume the Nebula rules accept the traffic
|
||||
* We create a conntrack entry for this flow
|
||||
* We do not want to transmit with the IPv4 source IP we got from Phone. We don't want the Phone's IP assignments (in this scheme) to enter the network-space on Router at all.
|
||||
* To this end, we rewrite the source address (and port, if needed) to our own pre-selected IPv4 link-local address. This address will never actually leave the machine, but we need it so return traffic can be routed back to the nebula tun on Router
|
||||
* Replace source IP with "Router's SNAT address"
|
||||
* Look in our conntrack table, and ensure we do not already have a flow that matches this srcip/srcport/dstip/dstport/proto tuple
|
||||
* if we do, increment srcport until we find an empty slot. Only use ephemeral ports. This gives 0x7ff flows per dstip/dstport/proto tuple, which ought to be plenty.
|
||||
* Record the original srcip/srcport as part of the conntrack data for later
|
||||
* Fix checksums
|
||||
* Nebula writes the rewritten packet to Router's tun
|
||||
* netfilter picks up the packet. In this example, Router is using `iptables`. A rule in the `nat` table similar to `-A POSTROUTING -d PRINTER_UNDERLAY_IP_HERE/32 -j MASQUERADE` is hit
|
||||
* This ensures that "Router's SNAT address" never actually leaves Router.
|
||||
* The packet leaves Router, and hits Printer
|
||||
* Printer gleefully accepts the SYN from Router, and replies with an ACK
|
||||
* iptables on Router de-masquerades the packet, and delivers it to the Nebula tun
|
||||
* Nebula reads the packet off the tun. Because it came from the tun, and not UDP, remember that this is considered "inside" traffic and will be evaluated as "outbound" traffic by Nebula.
|
||||
* Because this is inside traffic, it needs to be associated with a HostInfo before we can pass it to the firewall.
|
||||
* Check that the packet is addressed to the "Router's SNAT address". If so, attempt to un-SNAT by "peeking" into conntrack
|
||||
* If a Router needs to speak to _another_ Router with v4-in-v6 unsafe_routes like this, it _must_ use a distinct address from the "Router's SNAT address"
|
||||
* the easy way on Linux to assure this is to set a route for the "Router SNAT address" to the Nebula tun, but not actually assign the address
|
||||
* The "peek" into conntrack succeeds, and we find everything we need to rewrite the packet for transmission to Phone, as well as Phone's overlay IP, which lets us locate Phone's HostInfo
|
||||
* The packet is rewritten, replacing the destination address/port to match the ones Phone expects
|
||||
* checksums corrected
|
||||
* Check the Nebula firewall, and see that we have a valid conntrack entry (wow!)
|
||||
* we could _technically_ skip this check, but I dislike not passing all traffic we intend to accept into `firewall.Drop()`. The second conntrack locking-lookup does suck. There's room for improvement here.
|
||||
* The traffic is accepted, encrypted, and sent back to Phone
|
||||
* Phone gets the packet from Router, decrypts it, checks the firewall
|
||||
* we have a conntrack entry for this flow, so the firewall accepts it, and delivers it to the tun
|
||||
* Both sides now have a nice conntrack entry, and traffic should continue to flow uninterrupted until it expires
|
||||
|
||||
This conntrack entry technically creates a risk though. Let's examine that.
|
||||
The Phone will accept inbound traffic matching the conntrack spec from any Router-like host authorized to speak for that UnsafeRoute, not just Router. In theory, this is desireable, and the risk is mitigated by accepting/trusting Nebula's certificate model.
|
||||
There's a good chance that if you "switch" from one Router to another, you'll lose your session on your Printer-like host. Such is life under NAT!
|
||||
|
||||
Can the Router be exploited somehow?
|
||||
* an attacker that shares a network with Printer would be able to spoof traffic as if they are Printer. This is the same risk as UnsafeNetworks today.
|
||||
* an attacker on the overlay would have their traffic evaluated as inbound
|
||||
* if they try to tx on the same source IP as Phone, SNAT will assign a different port
|
||||
* if they try to send inbound traffic that matches the un-masqueraded traffic iptables would have delivered
|
||||
* conntrack will accept the packet, but before we finish firewalling and return, is the applySnat step
|
||||
* this will fail because the hostinfo that sent the packet does not contain the vpnip that is associated with the snat entry
|
||||
@@ -441,7 +441,7 @@ func (c *certificateV2) validate() error {
|
||||
}
|
||||
} else if network.Addr().Is4() {
|
||||
if !hasV4Networks {
|
||||
return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network)
|
||||
//return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
332
firewall.go
332
firewall.go
@@ -2,6 +2,7 @@ package nebula
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -22,10 +23,29 @@ import (
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
)
|
||||
|
||||
var ErrCannotSNAT = errors.New("cannot snat this packet")
|
||||
var ErrSNATIdentityMismatch = errors.New("refusing to SNAT for mismatched host")
|
||||
|
||||
type FirewallInterface interface {
|
||||
AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, cidr, localCidr string, caName string, caSha string) error
|
||||
}
|
||||
|
||||
type snatInfo struct {
|
||||
//Src is the source IP+port to write into unsafe-route-bound packet
|
||||
Src netip.AddrPort
|
||||
//SrcVpnIp is the overlay IP associated with this flow. It's needed to associate reply traffic so we can get it back to the right host.
|
||||
SrcVpnIp netip.Addr
|
||||
//SnatPort is the port to rewrite into an overlay-bound packet
|
||||
SnatPort uint16
|
||||
}
|
||||
|
||||
func (s *snatInfo) Valid() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
return s.Src.IsValid()
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
Expires time.Time // Time when this conntrack entry will expire
|
||||
|
||||
@@ -34,6 +54,9 @@ type conn struct {
|
||||
// fields pack for free after the uint32 above
|
||||
incoming bool
|
||||
rulesVersion uint16
|
||||
|
||||
//for SNAT support
|
||||
snat *snatInfo
|
||||
}
|
||||
|
||||
// TODO: need conntrack max tracked connections handling
|
||||
@@ -66,6 +89,7 @@ type Firewall struct {
|
||||
defaultLocalCIDRAny bool
|
||||
incomingMetrics firewallMetrics
|
||||
outgoingMetrics firewallMetrics
|
||||
snatAddr netip.Addr
|
||||
|
||||
l *logrus.Logger
|
||||
}
|
||||
@@ -83,6 +107,15 @@ type FirewallConntrack struct {
|
||||
TimerWheel *TimerWheel[firewall.Packet]
|
||||
}
|
||||
|
||||
func (ct *FirewallConntrack) dupeConnUnlocked(fp firewall.Packet, c *conn, timeout time.Duration) {
|
||||
if _, ok := ct.Conns[fp]; !ok {
|
||||
ct.TimerWheel.Advance(time.Now())
|
||||
ct.TimerWheel.Add(fp, timeout)
|
||||
}
|
||||
|
||||
ct.Conns[fp] = c
|
||||
}
|
||||
|
||||
// FirewallTable is the entry point for a rule, the evaluation order is:
|
||||
// Proto AND port AND (CA SHA or CA name) AND local CIDR AND (group OR groups OR name OR remote CIDR)
|
||||
type FirewallTable struct {
|
||||
@@ -131,7 +164,7 @@ type firewallLocalCIDR struct {
|
||||
|
||||
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
||||
// The certificate provided should be the highest version loaded in memory.
|
||||
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall {
|
||||
func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate, snatAddr netip.Addr) *Firewall {
|
||||
//TODO: error on 0 duration
|
||||
var tmin, tmax time.Duration
|
||||
|
||||
@@ -149,12 +182,14 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
||||
tmax = defaultTimeout
|
||||
}
|
||||
|
||||
hasV4Networks := false
|
||||
routableNetworks := new(bart.Lite)
|
||||
var assignedNetworks []netip.Prefix
|
||||
for _, network := range c.Networks() {
|
||||
nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen())
|
||||
routableNetworks.Insert(nprefix)
|
||||
assignedNetworks = append(assignedNetworks, network)
|
||||
hasV4Networks = hasV4Networks || network.Addr().Is4()
|
||||
}
|
||||
|
||||
hasUnsafeNetworks := false
|
||||
@@ -163,6 +198,10 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
||||
hasUnsafeNetworks = true
|
||||
}
|
||||
|
||||
if !hasUnsafeNetworks || hasV4Networks {
|
||||
snatAddr = netip.Addr{} //disable using the special snat address if it doesn't make sense to use it
|
||||
}
|
||||
|
||||
return &Firewall{
|
||||
Conntrack: &FirewallConntrack{
|
||||
Conns: make(map[firewall.Packet]*conn),
|
||||
@@ -176,6 +215,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
||||
routableNetworks: routableNetworks,
|
||||
assignedNetworks: assignedNetworks,
|
||||
hasUnsafeNetworks: hasUnsafeNetworks,
|
||||
snatAddr: snatAddr,
|
||||
l: l,
|
||||
|
||||
incomingMetrics: firewallMetrics{
|
||||
@@ -191,7 +231,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
||||
}
|
||||
}
|
||||
|
||||
func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) {
|
||||
func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C, snatAddr netip.Addr) (*Firewall, error) {
|
||||
certificate := cs.getCertificate(cert.Version2)
|
||||
if certificate == nil {
|
||||
certificate = cs.getCertificate(cert.Version1)
|
||||
@@ -201,14 +241,7 @@ func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firew
|
||||
panic("No certificate available to reconfigure the firewall")
|
||||
}
|
||||
|
||||
fw := NewFirewall(
|
||||
l,
|
||||
c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
|
||||
c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
|
||||
c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
|
||||
certificate,
|
||||
//TODO: max_connections
|
||||
)
|
||||
fw := NewFirewall(l, c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), certificate, snatAddr)
|
||||
|
||||
fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false)
|
||||
|
||||
@@ -314,6 +347,11 @@ func (f *Firewall) GetRuleHashes() string {
|
||||
return "SHA:" + f.GetRuleHash() + ",FNV:" + strconv.FormatUint(uint64(f.GetRuleHashFNV()), 10)
|
||||
}
|
||||
|
||||
func (f *Firewall) ShouldUnSNAT(fp *firewall.Packet) bool {
|
||||
// f.snatAddr is only valid if we're a snat-capable router
|
||||
return f.snatAddr.IsValid() && fp.RemoteAddr == f.snatAddr
|
||||
}
|
||||
|
||||
func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw FirewallInterface) error {
|
||||
var table string
|
||||
if inbound {
|
||||
@@ -414,50 +452,204 @@ var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate
|
||||
var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses")
|
||||
var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
|
||||
|
||||
func (f *Firewall) unSnat(data []byte, fp *firewall.Packet) netip.Addr {
|
||||
c := f.peek(*fp) //unfortunately this needs to lock. Surely there's a better way.
|
||||
if c == nil {
|
||||
return netip.Addr{}
|
||||
}
|
||||
if !c.snat.Valid() {
|
||||
return netip.Addr{}
|
||||
}
|
||||
oldIP := netip.AddrPortFrom(f.snatAddr, fp.RemotePort)
|
||||
rewritePacket(data, fp, oldIP, c.snat.Src, 16, 2)
|
||||
return c.snat.SrcVpnIp
|
||||
}
|
||||
|
||||
func rewritePacket(data []byte, fp *firewall.Packet, oldIP netip.AddrPort, newIP netip.AddrPort, ipOffset int, portOffset int) {
|
||||
//change address
|
||||
copy(data[ipOffset:], newIP.Addr().AsSlice())
|
||||
recalcIPv4Checksum(data, oldIP.Addr(), newIP.Addr())
|
||||
ipHeaderLen := int(data[0]&0x0F) * 4
|
||||
|
||||
switch fp.Protocol {
|
||||
case firewall.ProtoICMP:
|
||||
binary.BigEndian.PutUint16(data[ipHeaderLen+4:ipHeaderLen+6], newIP.Port()) //we use the ID field as a "port" for ICMP
|
||||
icmpCode := uint16(data[ipHeaderLen+1]) //todo not snatting on code yet (but Linux would)
|
||||
recalcICMPv4Checksum(data, icmpCode, icmpCode, oldIP.Port(), newIP.Port())
|
||||
case firewall.ProtoUDP:
|
||||
dstport := ipHeaderLen + portOffset
|
||||
binary.BigEndian.PutUint16(data[dstport:dstport+2], newIP.Port())
|
||||
recalcUDPv4Checksum(data, oldIP, newIP)
|
||||
case firewall.ProtoTCP:
|
||||
dstport := ipHeaderLen + portOffset
|
||||
binary.BigEndian.PutUint16(data[dstport:dstport+2], newIP.Port())
|
||||
recalcTCPv4Checksum(data, oldIP, newIP)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Firewall) findUsableSNATPort(fp *firewall.Packet, c *conn) error {
|
||||
oldPort := fp.RemotePort
|
||||
conntrack := f.Conntrack
|
||||
conntrack.Lock()
|
||||
defer conntrack.Unlock()
|
||||
for numPortsChecked := 0; numPortsChecked < 0x7ff; numPortsChecked++ {
|
||||
_, ok := conntrack.Conns[*fp]
|
||||
if !ok {
|
||||
//yay, we can use this port
|
||||
//track the snatted flow with the same expiration as the unsnatted version
|
||||
conntrack.dupeConnUnlocked(*fp, c, f.packetTimeout(*fp))
|
||||
return nil
|
||||
}
|
||||
//increment and retry. There's probably better strategies out there
|
||||
fp.RemotePort++
|
||||
if fp.RemotePort < 0x7ff {
|
||||
fp.RemotePort += 0x7ff // keep it ephemeral for now
|
||||
}
|
||||
}
|
||||
|
||||
//if we made it here, we failed
|
||||
fp.RemotePort = oldPort
|
||||
return ErrCannotSNAT
|
||||
}
|
||||
|
||||
func (f *Firewall) applySnat(data []byte, fp *firewall.Packet, c *conn, hostinfo *HostInfo) error {
|
||||
if !f.snatAddr.IsValid() {
|
||||
return ErrCannotSNAT
|
||||
}
|
||||
if c.snat.Valid() {
|
||||
//old flow: make sure it came from the right place
|
||||
if !slices.Contains(hostinfo.vpnAddrs, c.snat.SrcVpnIp) {
|
||||
return ErrSNATIdentityMismatch
|
||||
}
|
||||
fp.RemoteAddr = f.snatAddr
|
||||
fp.RemotePort = c.snat.SnatPort
|
||||
} else if hostinfo.vpnAddrs[0].Is6() {
|
||||
//we got a new flow
|
||||
c.snat = &snatInfo{
|
||||
Src: netip.AddrPortFrom(fp.RemoteAddr, fp.RemotePort),
|
||||
SrcVpnIp: hostinfo.vpnAddrs[0],
|
||||
}
|
||||
fp.RemoteAddr = f.snatAddr
|
||||
//find a new port to use, if needed
|
||||
err := f.findUsableSNATPort(fp, c)
|
||||
if err != nil {
|
||||
c.snat = nil
|
||||
return err
|
||||
}
|
||||
c.snat.SnatPort = fp.RemotePort //may have been updated inside f.findUsableSNATPort
|
||||
} else {
|
||||
return ErrCannotSNAT
|
||||
}
|
||||
|
||||
newIP := netip.AddrPortFrom(f.snatAddr, c.snat.SnatPort)
|
||||
rewritePacket(data, fp, c.snat.Src, newIP, 12, 0)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Firewall) identifyNetworkType(h *HostInfo, fp firewall.Packet) NetworkType {
|
||||
if h.networks == nil {
|
||||
// Simple case: Certificate has one address and no unsafe networks
|
||||
if h.vpnAddrs[0] == fp.RemoteAddr {
|
||||
return NetworkTypeVPN
|
||||
} else if fp.IsIPv4() && h.HasOnlyV6Addresses() {
|
||||
return NetworkTypeUncheckedSNATPeer
|
||||
} else {
|
||||
return NetworkTypeInvalidPeer
|
||||
}
|
||||
} else if nwType, ok := h.networks.Lookup(fp.RemoteAddr); ok {
|
||||
//todo check for if fp.RemoteAddr is our f.snatAddr here too? Does that need a special case?
|
||||
return nwType //will return NetworkTypeVPN or NetworkTypeUnsafe
|
||||
} else if fp.IsIPv4() && h.HasOnlyV6Addresses() { //todo surely I'm smart enough to avoid writing these branches twice
|
||||
return NetworkTypeUncheckedSNATPeer
|
||||
} else {
|
||||
return NetworkTypeInvalidPeer
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Firewall) allowNetworkType(nwType NetworkType) error {
|
||||
switch nwType {
|
||||
case NetworkTypeVPN:
|
||||
return nil
|
||||
case NetworkTypeInvalidPeer:
|
||||
return ErrInvalidRemoteIP
|
||||
case NetworkTypeVPNPeer:
|
||||
//todo we might need a specialSnatMode case in here to handle routers with v4 addresses when we don't also have a v4 address?
|
||||
return ErrPeerRejected // reject for now, one day this may have different FW rules
|
||||
case NetworkTypeUnsafe:
|
||||
return nil // nothing special, one day this may have different FW rules
|
||||
case NetworkTypeUncheckedSNATPeer:
|
||||
if f.snatAddr.IsValid() {
|
||||
return nil //todo is this enough?
|
||||
} else {
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
default:
|
||||
return ErrUnknownNetworkType //should never happen
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Firewall) willingToHandleLocalAddr(incoming bool, fp firewall.Packet, remoteNwType NetworkType) error {
|
||||
if f.routableNetworks.Contains(fp.LocalAddr) {
|
||||
return nil //easy, this should handle NetworkTypeVPN in all cases, and NetworkTypeUnsafe on the router side
|
||||
}
|
||||
|
||||
//watch out, when incoming, this function decides if we will deliver a packet locally
|
||||
//when outgoing, much less important, it just decides if we're willing to tx
|
||||
switch remoteNwType {
|
||||
// we never want to accept unconntracked inbound traffic from these network types, but outbound is okay.
|
||||
// It's the recipient's job to validate and accept or deny the packet.
|
||||
case NetworkTypeUncheckedSNATPeer, NetworkTypeUnsafe:
|
||||
//NetworkTypeUnsafe needed here to allow inbound from an unsafe-router
|
||||
if incoming {
|
||||
return ErrInvalidLocalIP
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return ErrInvalidLocalIP
|
||||
}
|
||||
}
|
||||
|
||||
// Drop returns an error if the packet should be dropped, explaining why. It
|
||||
// returns nil if the packet should not be dropped.
|
||||
func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
|
||||
func (f *Firewall) Drop(fp firewall.Packet, pkt []byte, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
|
||||
table := f.OutRules
|
||||
if incoming {
|
||||
table = f.InRules
|
||||
}
|
||||
|
||||
snatmode := fp.IsIPv4() && h.HasOnlyV6Addresses() && f.snatAddr.IsValid()
|
||||
if snatmode {
|
||||
//if this is an IPv4 packet from a V6 only host, and we're configured to snat that kind of traffic, it must be snatted,
|
||||
//so it can never be in the localcache, which lacks SNAT data
|
||||
//nil out the pointer to avoid ever using it
|
||||
localCache = nil
|
||||
}
|
||||
|
||||
// Check if we spoke to this tuple, if we did then allow this packet
|
||||
if f.inConns(fp, h, caPool, localCache) {
|
||||
if localCache != nil {
|
||||
if _, ok := localCache[fp]; ok {
|
||||
return nil //packet matched the cache, we're not snatting, we can return early
|
||||
}
|
||||
}
|
||||
c := f.inConns(fp, h, caPool, localCache)
|
||||
if c != nil {
|
||||
if incoming && snatmode {
|
||||
return f.applySnat(pkt, &fp, c, h)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Make sure remote address matches nebula certificate, and determine how to treat it
|
||||
if h.networks == nil {
|
||||
// Simple case: Certificate has one address and no unsafe networks
|
||||
if h.vpnAddrs[0] != fp.RemoteAddr {
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
} else {
|
||||
nwType, ok := h.networks.Lookup(fp.RemoteAddr)
|
||||
if !ok {
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrInvalidRemoteIP
|
||||
}
|
||||
switch nwType {
|
||||
case NetworkTypeVPN:
|
||||
break // nothing special
|
||||
case NetworkTypeVPNPeer:
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrPeerRejected // reject for now, one day this may have different FW rules
|
||||
case NetworkTypeUnsafe:
|
||||
break // nothing special, one day this may have different FW rules
|
||||
default:
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return ErrUnknownNetworkType //should never happen
|
||||
}
|
||||
remoteNetworkType := f.identifyNetworkType(h, fp)
|
||||
if err := f.allowNetworkType(remoteNetworkType); err != nil {
|
||||
f.metrics(incoming).droppedRemoteAddr.Inc(1)
|
||||
return err
|
||||
}
|
||||
|
||||
// Make sure we are supposed to be handling this local ip address
|
||||
if !f.routableNetworks.Contains(fp.LocalAddr) {
|
||||
if err := f.willingToHandleLocalAddr(incoming, fp, remoteNetworkType); err != nil {
|
||||
f.metrics(incoming).droppedLocalAddr.Inc(1)
|
||||
return ErrInvalidLocalIP
|
||||
}
|
||||
|
||||
table := f.OutRules
|
||||
if incoming {
|
||||
table = f.InRules
|
||||
return err
|
||||
}
|
||||
|
||||
// We now know which firewall table to check against
|
||||
@@ -467,9 +659,14 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
|
||||
}
|
||||
|
||||
// We always want to conntrack since it is a faster operation
|
||||
f.addConn(fp, incoming)
|
||||
c = f.addConn(fp, incoming)
|
||||
|
||||
return nil
|
||||
if incoming && remoteNetworkType == NetworkTypeUncheckedSNATPeer {
|
||||
return f.applySnat(pkt, &fp, c, h)
|
||||
} else {
|
||||
//outgoing snat is handled before this function is called
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Firewall) metrics(incoming bool) firewallMetrics {
|
||||
@@ -496,12 +693,23 @@ func (f *Firewall) EmitStats() {
|
||||
metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
|
||||
}
|
||||
|
||||
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
|
||||
if localCache != nil {
|
||||
if _, ok := localCache[fp]; ok {
|
||||
return true
|
||||
}
|
||||
func (f *Firewall) peek(fp firewall.Packet) *conn {
|
||||
conntrack := f.Conntrack
|
||||
conntrack.Lock()
|
||||
|
||||
// Purge every time we test
|
||||
ep, has := conntrack.TimerWheel.Purge()
|
||||
if has {
|
||||
f.evict(ep)
|
||||
}
|
||||
|
||||
c := conntrack.Conns[fp]
|
||||
|
||||
conntrack.Unlock()
|
||||
return c
|
||||
}
|
||||
|
||||
func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) *conn {
|
||||
conntrack := f.Conntrack
|
||||
conntrack.Lock()
|
||||
|
||||
@@ -515,7 +723,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
||||
|
||||
if !ok {
|
||||
conntrack.Unlock()
|
||||
return false
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.rulesVersion != f.rulesVersion {
|
||||
@@ -538,7 +746,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
||||
}
|
||||
delete(conntrack.Conns, fp)
|
||||
conntrack.Unlock()
|
||||
return false
|
||||
return nil
|
||||
}
|
||||
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
@@ -568,12 +776,11 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
|
||||
localCache[fp] = struct{}{}
|
||||
}
|
||||
|
||||
return true
|
||||
return c
|
||||
}
|
||||
|
||||
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
|
||||
func (f *Firewall) packetTimeout(fp firewall.Packet) time.Duration {
|
||||
var timeout time.Duration
|
||||
c := &conn{}
|
||||
|
||||
switch fp.Protocol {
|
||||
case firewall.ProtoTCP:
|
||||
@@ -583,7 +790,13 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
|
||||
default:
|
||||
timeout = f.DefaultTimeout
|
||||
}
|
||||
return timeout
|
||||
}
|
||||
|
||||
func (f *Firewall) addConn(fp firewall.Packet, incoming bool) *conn {
|
||||
c := &conn{}
|
||||
|
||||
timeout := f.packetTimeout(fp)
|
||||
conntrack := f.Conntrack
|
||||
conntrack.Lock()
|
||||
if _, ok := conntrack.Conns[fp]; !ok {
|
||||
@@ -597,7 +810,9 @@ func (f *Firewall) addConn(fp firewall.Packet, incoming bool) {
|
||||
c.rulesVersion = f.rulesVersion
|
||||
c.Expires = time.Now().Add(timeout)
|
||||
conntrack.Conns[fp] = c
|
||||
|
||||
conntrack.Unlock()
|
||||
return c
|
||||
}
|
||||
|
||||
// Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
|
||||
@@ -682,6 +897,13 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCer
|
||||
|
||||
var port int32
|
||||
|
||||
if p.Protocol == firewall.ProtoICMP {
|
||||
// port numbers are re-used for connection tracking and SNAT,
|
||||
// but we don't want to actually filter on them for ICMP
|
||||
// ICMP6 is omitted because we don't attempt to parse code/identifier/etc out of ICMP6
|
||||
return fp[firewall.PortAny].match(p, c, caPool)
|
||||
}
|
||||
|
||||
if p.Fragment {
|
||||
port = firewall.PortFragment
|
||||
} else if incoming {
|
||||
|
||||
@@ -31,6 +31,10 @@ type Packet struct {
|
||||
Fragment bool
|
||||
}
|
||||
|
||||
func (fp *Packet) IsIPv4() bool {
|
||||
return fp.LocalAddr.Is4() && fp.RemoteAddr.Is4()
|
||||
}
|
||||
|
||||
func (fp *Packet) Copy() *Packet {
|
||||
return &Packet{
|
||||
LocalAddr: fp.LocalAddr,
|
||||
|
||||
292
firewall_test.go
292
firewall_test.go
@@ -21,7 +21,7 @@ import (
|
||||
func TestNewFirewall(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
c := &dummyCert{}
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
|
||||
conntrack := fw.Conntrack
|
||||
assert.NotNil(t, conntrack)
|
||||
assert.NotNil(t, conntrack.Conns)
|
||||
@@ -36,23 +36,23 @@ func TestNewFirewall(t *testing.T) {
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
|
||||
fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c, netip.Addr{})
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
|
||||
|
||||
fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
|
||||
fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c, netip.Addr{})
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
|
||||
|
||||
fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
|
||||
fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c, netip.Addr{})
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
|
||||
|
||||
fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
|
||||
fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c, netip.Addr{})
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
|
||||
|
||||
fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c, netip.Addr{})
|
||||
assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
||||
assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
|
||||
}
|
||||
@@ -63,7 +63,7 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||
l.SetOutput(ob)
|
||||
|
||||
c := &dummyCert{}
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
|
||||
assert.NotNil(t, fw.InRules)
|
||||
assert.NotNil(t, fw.OutRules)
|
||||
|
||||
@@ -79,56 +79,56 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||
assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
|
||||
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", ""))
|
||||
assert.Nil(t, fw.InRules.UDP[1].Any.Any)
|
||||
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
|
||||
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", ""))
|
||||
//no matter what port is given for icmp, it should end up as "any"
|
||||
assert.Nil(t, fw.InRules.ICMP[firewall.PortAny].Any.Any)
|
||||
assert.Empty(t, fw.InRules.ICMP[firewall.PortAny].Any.Groups)
|
||||
assert.Contains(t, fw.InRules.ICMP[firewall.PortAny].Any.Hosts, "h1")
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", ""))
|
||||
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
|
||||
assert.True(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", ""))
|
||||
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
|
||||
assert.True(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", ""))
|
||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
|
||||
assert.True(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", ""))
|
||||
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
|
||||
ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
|
||||
assert.True(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", ""))
|
||||
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha"))
|
||||
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", ""))
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
|
||||
anyIp, err := netip.ParsePrefix("0.0.0.0/0")
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -139,7 +139,7 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
|
||||
assert.False(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
|
||||
anyIp6, err := netip.ParsePrefix("::/0")
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -150,28 +150,28 @@ func TestFirewall_AddRule(t *testing.T) {
|
||||
table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
|
||||
assert.False(t, ok)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", ""))
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", ""))
|
||||
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
|
||||
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", ""))
|
||||
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
|
||||
assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
|
||||
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", ""))
|
||||
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
|
||||
|
||||
// Test error conditions
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
|
||||
require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", ""))
|
||||
require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", ""))
|
||||
}
|
||||
@@ -208,49 +208,49 @@ func TestFirewall_Drop(t *testing.T) {
|
||||
}
|
||||
h.buildNetworks(myVpnNetworksTable, &c)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// Drop outbound
|
||||
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
|
||||
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, nil, false, &h, cp, nil))
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
// Allow outbound because conntrack
|
||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
|
||||
|
||||
// test remote mismatch
|
||||
oldRemote := p.RemoteAddr
|
||||
p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
|
||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
|
||||
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrInvalidRemoteIP)
|
||||
p.RemoteAddr = oldRemote
|
||||
|
||||
// ensure signer doesn't get in the way of group checks
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
|
||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caSha doesn't drop on match
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
|
||||
// ensure ca name doesn't get in the way of group checks
|
||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
|
||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caName doesn't drop on match
|
||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
func TestFirewall_DropV6(t *testing.T) {
|
||||
@@ -287,49 +287,49 @@ func TestFirewall_DropV6(t *testing.T) {
|
||||
}
|
||||
h.buildNetworks(myVpnNetworksTable, &c)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// Drop outbound
|
||||
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
|
||||
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, nil, false, &h, cp, nil))
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
// Allow outbound because conntrack
|
||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
|
||||
|
||||
// test remote mismatch
|
||||
oldRemote := p.RemoteAddr
|
||||
p.RemoteAddr = netip.MustParseAddr("fd12::56")
|
||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
|
||||
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrInvalidRemoteIP)
|
||||
p.RemoteAddr = oldRemote
|
||||
|
||||
// ensure signer doesn't get in the way of group checks
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
|
||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caSha doesn't drop on match
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
|
||||
// ensure ca name doesn't get in the way of group checks
|
||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
|
||||
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// test caName doesn't drop on match
|
||||
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
func BenchmarkFirewallTable_match(b *testing.B) {
|
||||
@@ -532,15 +532,15 @@ func TestFirewall_Drop2(t *testing.T) {
|
||||
}
|
||||
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// h1/c1 lacks the proper groups
|
||||
require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
|
||||
require.ErrorIs(t, fw.Drop(p, nil, true, &h1, cp, nil), ErrNoMatchingRule)
|
||||
// c has the proper groups
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
func TestFirewall_Drop3(t *testing.T) {
|
||||
@@ -612,24 +612,24 @@ func TestFirewall_Drop3(t *testing.T) {
|
||||
}
|
||||
h3.buildNetworks(myVpnNetworksTable, c3.Certificate)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", ""))
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha"))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// c1 should pass because host match
|
||||
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil))
|
||||
// c2 should pass because ca sha match
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(p, true, &h2, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h2, cp, nil))
|
||||
// c3 should fail because no match
|
||||
resetConntrack(fw)
|
||||
assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(p, nil, true, &h3, cp, nil), ErrNoMatchingRule)
|
||||
|
||||
// Test a remote address match
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", ""))
|
||||
require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil))
|
||||
}
|
||||
|
||||
func TestFirewall_Drop3V6(t *testing.T) {
|
||||
@@ -664,10 +664,10 @@ func TestFirewall_Drop3V6(t *testing.T) {
|
||||
h.buildNetworks(myVpnNetworksTable, c.Certificate)
|
||||
|
||||
// Test a remote address match
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
|
||||
cp := cert.NewCAPool()
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", ""))
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
}
|
||||
|
||||
func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
@@ -704,35 +704,35 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
|
||||
}
|
||||
h.buildNetworks(myVpnNetworksTable, c.Certificate)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
|
||||
// Allow outbound because conntrack
|
||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
|
||||
|
||||
oldFw := fw
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", ""))
|
||||
fw.Conntrack = oldFw.Conntrack
|
||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||
|
||||
// Allow outbound because conntrack and new rules allow port 10
|
||||
require.NoError(t, fw.Drop(p, false, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
|
||||
|
||||
oldFw = fw
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", ""))
|
||||
fw.Conntrack = oldFw.Conntrack
|
||||
fw.rulesVersion = oldFw.rulesVersion + 1
|
||||
|
||||
// Drop outbound because conntrack doesn't match new ruleset
|
||||
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
}
|
||||
|
||||
func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
||||
@@ -771,19 +771,19 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("ICMP allowed", func(t *testing.T) {
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 0, 0, []string{"any"}, "", "", "", "", ""))
|
||||
t.Run("zero ports", func(t *testing.T) {
|
||||
p := templ.Copy()
|
||||
p.LocalPort = 0
|
||||
p.RemotePort = 0
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
|
||||
//now also allow outbound
|
||||
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
|
||||
})
|
||||
|
||||
t.Run("nonzero ports", func(t *testing.T) {
|
||||
@@ -791,29 +791,29 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
||||
p.LocalPort = 0xabcd
|
||||
p.RemotePort = 0x1234
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
|
||||
//now also allow outbound
|
||||
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Any proto, some ports allowed", func(t *testing.T) {
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 444, []string{"any"}, "", "", "", "", ""))
|
||||
t.Run("zero ports, still blocked", func(t *testing.T) {
|
||||
p := templ.Copy()
|
||||
p.LocalPort = 0
|
||||
p.RemotePort = 0
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
//now also allow outbound
|
||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
})
|
||||
|
||||
t.Run("nonzero ports, still blocked", func(t *testing.T) {
|
||||
@@ -821,12 +821,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
||||
p.LocalPort = 0xabcd
|
||||
p.RemotePort = 0x1234
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
//now also allow outbound
|
||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
})
|
||||
|
||||
t.Run("nonzero, matching ports, still blocked", func(t *testing.T) {
|
||||
@@ -834,16 +834,16 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
||||
p.LocalPort = 80
|
||||
p.RemotePort = 80
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
assert.Equal(t, fw.Drop(*p, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule)
|
||||
//now also allow outbound
|
||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
})
|
||||
})
|
||||
t.Run("Any proto, any port", func(t *testing.T) {
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
|
||||
t.Run("zero ports, allowed", func(t *testing.T) {
|
||||
resetConntrack(fw)
|
||||
@@ -851,12 +851,12 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
||||
p.LocalPort = 0
|
||||
p.RemotePort = 0
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
|
||||
//now also allow outbound
|
||||
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
|
||||
})
|
||||
|
||||
t.Run("nonzero ports, allowed", func(t *testing.T) {
|
||||
@@ -865,15 +865,15 @@ func TestFirewall_ICMPPortBehavior(t *testing.T) {
|
||||
p.LocalPort = 0xabcd
|
||||
p.RemotePort = 0x1234
|
||||
// Drop outbound
|
||||
assert.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
// Allow inbound
|
||||
resetConntrack(fw)
|
||||
require.NoError(t, fw.Drop(*p, true, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
|
||||
//now also allow outbound
|
||||
require.NoError(t, fw.Drop(*p, false, &h, cp, nil))
|
||||
require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
|
||||
//different ID is blocked
|
||||
p.RemotePort++
|
||||
require.Equal(t, fw.Drop(*p, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
require.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -908,7 +908,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
|
||||
}
|
||||
h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
|
||||
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
|
||||
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", ""))
|
||||
cp := cert.NewCAPool()
|
||||
@@ -922,7 +922,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) {
|
||||
Protocol: firewall.ProtoUDP,
|
||||
Fragment: false,
|
||||
}
|
||||
assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
|
||||
assert.Equal(t, fw.Drop(p, nil, true, &h1, cp, nil), ErrInvalidRemoteIP)
|
||||
}
|
||||
|
||||
func BenchmarkLookup(b *testing.B) {
|
||||
@@ -1047,53 +1047,53 @@ func TestNewFirewallFromConfig(t *testing.T) {
|
||||
|
||||
conf := config.NewC(l)
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
|
||||
require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
|
||||
|
||||
// Test both port and code
|
||||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
|
||||
|
||||
// Test missing host, group, cidr, ca_name and ca_sha
|
||||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
|
||||
|
||||
// Test code/port error
|
||||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh", "proto": "any"}}}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
|
||||
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh", "proto": "any"}}}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
|
||||
|
||||
// Test proto error
|
||||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
|
||||
|
||||
// Test cidr parse error
|
||||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
||||
|
||||
// Test local_cidr parse error
|
||||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
|
||||
require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
|
||||
|
||||
// Test both group and groups
|
||||
conf = config.NewC(l)
|
||||
conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
|
||||
_, err = NewFirewallFromConfig(l, cs, conf)
|
||||
_, err = NewFirewallFromConfig(l, cs, conf, netip.Addr{})
|
||||
require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
|
||||
}
|
||||
|
||||
@@ -1336,7 +1336,7 @@ func (c *testcase) Test(t *testing.T, fw *Firewall) {
|
||||
t.Helper()
|
||||
cp := cert.NewCAPool()
|
||||
resetConntrack(fw)
|
||||
err := fw.Drop(c.p, true, c.h, cp, nil)
|
||||
err := fw.Drop(c.p, nil, true, c.h, cp, nil)
|
||||
if c.err == nil {
|
||||
require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
|
||||
} else {
|
||||
@@ -1344,7 +1344,7 @@ func (c *testcase) Test(t *testing.T, fw *Firewall) {
|
||||
}
|
||||
}
|
||||
|
||||
func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
|
||||
func buildHostinfo(setup testsetup, theirPrefixes ...netip.Prefix) *HostInfo {
|
||||
c1 := dummyCert{
|
||||
name: "host1",
|
||||
networks: theirPrefixes,
|
||||
@@ -1364,6 +1364,11 @@ func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) te
|
||||
h.vpnAddrs[i] = theirPrefixes[i].Addr()
|
||||
}
|
||||
h.buildNetworks(setup.myVpnNetworksTable, &c1)
|
||||
return &h
|
||||
}
|
||||
|
||||
func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
|
||||
h := buildHostinfo(setup, theirPrefixes...)
|
||||
p := firewall.Packet{
|
||||
LocalAddr: setup.c.Networks()[0].Addr(), //todo?
|
||||
RemoteAddr: theirPrefixes[0].Addr(),
|
||||
@@ -1373,9 +1378,9 @@ func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) te
|
||||
Fragment: false,
|
||||
}
|
||||
return testcase{
|
||||
h: &h,
|
||||
h: h,
|
||||
p: p,
|
||||
c: &c1,
|
||||
c: h.ConnectionState.peerCert.Certificate,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
@@ -1397,12 +1402,25 @@ func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testse
|
||||
return newSetupFromCert(t, l, c)
|
||||
}
|
||||
|
||||
func newSnatSetup(t *testing.T, l *logrus.Logger, myPrefix netip.Prefix, snatAddr netip.Addr) testsetup {
|
||||
c := dummyCert{
|
||||
name: "me",
|
||||
networks: []netip.Prefix{myPrefix},
|
||||
groups: []string{"default-group"},
|
||||
issuer: "signer-shasum",
|
||||
}
|
||||
|
||||
out := newSetupFromCert(t, l, c)
|
||||
out.fw.snatAddr = snatAddr
|
||||
return out
|
||||
}
|
||||
|
||||
func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
|
||||
myVpnNetworksTable := new(bart.Lite)
|
||||
for _, prefix := range c.Networks() {
|
||||
myVpnNetworksTable.Insert(prefix)
|
||||
}
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
|
||||
|
||||
return testsetup{
|
||||
@@ -1532,3 +1550,59 @@ func resetConntrack(fw *Firewall) {
|
||||
fw.Conntrack.Conns = map[firewall.Packet]*conn{}
|
||||
fw.Conntrack.Unlock()
|
||||
}
|
||||
|
||||
func TestFirewall_SNAT(t *testing.T) {
|
||||
t.Parallel()
|
||||
l := test.NewLogger()
|
||||
ob := &bytes.Buffer{}
|
||||
l.SetOutput(ob)
|
||||
cp := cert.NewCAPool()
|
||||
myPrefix := netip.MustParsePrefix("1.1.1.1/8")
|
||||
|
||||
MyCert := dummyCert{
|
||||
name: "me",
|
||||
networks: []netip.Prefix{myPrefix},
|
||||
groups: []string{"default-group"},
|
||||
issuer: "signer-shasum",
|
||||
}
|
||||
|
||||
theirPrefix := netip.MustParsePrefix("1.2.2.2/8")
|
||||
snatAddr := netip.MustParseAddr("169.254.55.96")
|
||||
t.Run("allow inbound all matching", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
myCert := MyCert.Copy()
|
||||
setup := newSnatSetup(t, l, myPrefix, snatAddr)
|
||||
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, snatAddr)
|
||||
require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
|
||||
resetConntrack(setup.fw)
|
||||
h := buildHostinfo(setup, theirPrefix)
|
||||
p := firewall.Packet{
|
||||
LocalAddr: setup.c.Networks()[0].Addr(), //todo?
|
||||
RemoteAddr: h.vpnAddrs[0],
|
||||
LocalPort: 10,
|
||||
RemotePort: 90,
|
||||
Protocol: firewall.ProtoUDP,
|
||||
Fragment: false,
|
||||
}
|
||||
require.NoError(t, setup.fw.Drop(p, nil, true, h, cp, nil))
|
||||
})
|
||||
//t.Run("allow inbound unsafe route", func(t *testing.T) {
|
||||
// t.Parallel()
|
||||
// unsafePrefix := netip.MustParsePrefix("192.168.0.0/24")
|
||||
// c := dummyCert{
|
||||
// name: "me",
|
||||
// networks: []netip.Prefix{myPrefix},
|
||||
// unsafeNetworks: []netip.Prefix{unsafePrefix},
|
||||
// groups: []string{"default-group"},
|
||||
// issuer: "signer-shasum",
|
||||
// }
|
||||
// unsafeSetup := newSetupFromCert(t, l, c)
|
||||
// tc := buildTestCase(unsafeSetup, nil, twoPrefixes...)
|
||||
// tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
|
||||
// tc.err = ErrNoMatchingRule
|
||||
// tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
|
||||
// require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", ""))
|
||||
// tc.err = nil
|
||||
// tc.Test(t, unsafeSetup.fw) //should pass
|
||||
//})
|
||||
}
|
||||
|
||||
12
hostmap.go
12
hostmap.go
@@ -224,6 +224,9 @@ const (
|
||||
NetworkTypeVPNPeer
|
||||
// NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks()
|
||||
NetworkTypeUnsafe
|
||||
// NetworkTypeUncheckedSNATPeer is used to indicate traffic we're willing to route, but never deliver to a NetworkTypeVPN
|
||||
NetworkTypeUncheckedSNATPeer
|
||||
NetworkTypeInvalidPeer
|
||||
)
|
||||
|
||||
type HostInfo struct {
|
||||
@@ -277,6 +280,15 @@ type HostInfo struct {
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
func (i *HostInfo) HasOnlyV6Addresses() bool {
|
||||
for _, vpnIp := range i.vpnAddrs {
|
||||
if !vpnIp.Is6() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type ViaSender struct {
|
||||
UdpAddr netip.AddrPort
|
||||
relayHI *HostInfo // relayHI is the host info object of the relay
|
||||
|
||||
29
inside.go
29
inside.go
@@ -48,9 +48,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
||||
return
|
||||
}
|
||||
|
||||
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
|
||||
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
||||
})
|
||||
hostinfo, ready := f.getHostinfo(packet, fwPacket)
|
||||
|
||||
if hostinfo == nil {
|
||||
f.rejectInside(packet, out, q)
|
||||
@@ -66,10 +64,9 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
||||
return
|
||||
}
|
||||
|
||||
dropReason := f.firewall.Drop(*fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
dropReason := f.firewall.Drop(*fwPacket, packet, false, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
if dropReason == nil {
|
||||
f.sendNoMetrics(header.Message, 0, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, packet, nb, out, q)
|
||||
|
||||
} else {
|
||||
f.rejectInside(packet, out, q)
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
@@ -81,6 +78,26 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) getHostinfo(packet []byte, fwPacket *firewall.Packet) (*HostInfo, bool) {
|
||||
if f.firewall.ShouldUnSNAT(fwPacket) {
|
||||
//unsnat packet re-writing also happens here, would be nice to not,
|
||||
//but we need to do the unsnat lookup to find the hostinfo so we can run the firewall checks
|
||||
destVpnAddr := f.firewall.unSnat(packet, fwPacket)
|
||||
if destVpnAddr.IsValid() {
|
||||
//because this was a snatted packet, we know it has an on-overlay destination, so no routing should be required.
|
||||
return f.getOrHandshakeNoRouting(destVpnAddr, func(hh *HandshakeHostInfo) {
|
||||
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
||||
})
|
||||
} else {
|
||||
return nil, false
|
||||
}
|
||||
} else { //if we didn't need to unsnat
|
||||
return f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
|
||||
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Interface) rejectInside(packet []byte, out []byte, q int) {
|
||||
if !f.firewall.InSendReject {
|
||||
return
|
||||
@@ -218,7 +235,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
|
||||
}
|
||||
|
||||
// check if packet is in outbound fw rules
|
||||
dropReason := f.firewall.Drop(*fp, false, hostinfo, f.pki.GetCAPool(), nil)
|
||||
dropReason := f.firewall.Drop(*fp, nil, false, hostinfo, f.pki.GetCAPool(), nil)
|
||||
if dropReason != nil {
|
||||
if f.l.Level >= logrus.DebugLevel {
|
||||
f.l.WithField("fwPacket", fp).
|
||||
|
||||
@@ -56,6 +56,7 @@ type Interface struct {
|
||||
inside overlay.Device
|
||||
pki *PKI
|
||||
firewall *Firewall
|
||||
snatAddr netip.Addr
|
||||
connectionManager *connectionManager
|
||||
handshakeManager *HandshakeManager
|
||||
serveDns bool
|
||||
@@ -339,7 +340,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
|
||||
return
|
||||
}
|
||||
|
||||
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c)
|
||||
fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c, f.firewall.snatAddr)
|
||||
if err != nil {
|
||||
f.l.WithError(err).Error("Error while creating firewall during reload")
|
||||
return
|
||||
|
||||
6
main.go
6
main.go
@@ -66,7 +66,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
|
||||
}
|
||||
|
||||
fw, err := NewFirewallFromConfig(l, pki.getCertState(), c)
|
||||
snatAddr := netip.MustParseAddr("169.254.55.96") //todo get this from tun!
|
||||
fw, err := NewFirewallFromConfig(l, pki.getCertState(), c, snatAddr)
|
||||
if err != nil {
|
||||
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
|
||||
}
|
||||
@@ -135,7 +136,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
||||
deviceFactory = overlay.NewDeviceFromConfig
|
||||
}
|
||||
|
||||
tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines)
|
||||
cs := pki.getCertState()
|
||||
tun, err = deviceFactory(c, l, cs.myVpnNetworks, cs.GetDefaultCertificate().UnsafeNetworks(), routines)
|
||||
if err != nil {
|
||||
return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
|
||||
}
|
||||
|
||||
@@ -509,7 +509,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
|
||||
return false
|
||||
}
|
||||
|
||||
dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
dropReason := f.firewall.Drop(*fwPacket, out, true, hostinfo, f.pki.GetCAPool(), localCache)
|
||||
if dropReason != nil {
|
||||
// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
|
||||
// This gives us a buffer to build the reject packet in
|
||||
|
||||
@@ -22,22 +22,22 @@ func (e *NameError) Error() string {
|
||||
}
|
||||
|
||||
// TODO: We may be able to remove routines
|
||||
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error)
|
||||
type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error)
|
||||
|
||||
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
switch {
|
||||
case c.GetBool("tun.disabled", false):
|
||||
tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l)
|
||||
return tun, nil
|
||||
|
||||
default:
|
||||
return newTun(c, l, vpnNetworks, routines > 1)
|
||||
return newTun(c, l, vpnNetworks, unsafeNetworks, routines > 1)
|
||||
}
|
||||
}
|
||||
|
||||
func NewFdDeviceFromConfig(fd *int) DeviceFactory {
|
||||
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
return newTunFromFd(c, l, *fd, vpnNetworks)
|
||||
return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
return newTunFromFd(c, l, *fd, vpnNetworks, unsafeNetworks)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ type tun struct {
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, _ []netip.Prefix) (*tun, error) {
|
||||
// XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly.
|
||||
// Be sure not to call file.Fd() as it will set the fd to blocking mode.
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
@@ -53,7 +53,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTun not supported in Android")
|
||||
}
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ type ifreqAlias6 struct {
|
||||
Lifetime addrLifetime
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
name := c.GetString("tun.dev", "")
|
||||
ifIndex := -1
|
||||
if name != "" && name != "utun" {
|
||||
@@ -153,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) {
|
||||
return
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in Darwin")
|
||||
}
|
||||
|
||||
|
||||
@@ -199,11 +199,11 @@ func (t *tun) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD")
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open existing tun device
|
||||
var fd int
|
||||
var err error
|
||||
|
||||
@@ -28,11 +28,11 @@ type tun struct {
|
||||
l *logrus.Logger
|
||||
}
|
||||
|
||||
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ []netip.Prefix, _ bool) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTun not supported in iOS")
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, _ []netip.Prefix) (*tun, error) {
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/tun")
|
||||
t := &tun{
|
||||
vpnNetworks: vpnNetworks,
|
||||
|
||||
@@ -26,14 +26,15 @@ import (
|
||||
|
||||
type tun struct {
|
||||
io.ReadWriteCloser
|
||||
fd int
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
MaxMTU int
|
||||
DefaultMTU int
|
||||
TXQueueLen int
|
||||
deviceIndex int
|
||||
ioctlFd uintptr
|
||||
fd int
|
||||
Device string
|
||||
vpnNetworks []netip.Prefix
|
||||
unsafeNetworks []netip.Prefix
|
||||
MaxMTU int
|
||||
DefaultMTU int
|
||||
TXQueueLen int
|
||||
deviceIndex int
|
||||
ioctlFd uintptr
|
||||
|
||||
Routes atomic.Pointer[[]Route]
|
||||
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
|
||||
@@ -71,10 +72,10 @@ type ifreqQLEN struct {
|
||||
pad [8]byte
|
||||
}
|
||||
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) {
|
||||
file := os.NewFile(uintptr(deviceFd), "/dev/net/tun")
|
||||
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks, unsafeNetworks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -84,7 +85,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []net
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, multiqueue bool) (*tun, error) {
|
||||
fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
// If /dev/net/tun doesn't exist, try to create it (will happen in docker)
|
||||
@@ -123,7 +124,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
||||
name := strings.Trim(string(req.Name[:]), "\x00")
|
||||
|
||||
file := os.NewFile(uintptr(fd), "/dev/net/tun")
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks)
|
||||
t, err := newTunGeneric(c, l, file, vpnNetworks, unsafeNetworks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -133,11 +134,12 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueu
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) {
|
||||
func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix) (*tun, error) {
|
||||
t := &tun{
|
||||
ReadWriteCloser: file,
|
||||
fd: int(file.Fd()),
|
||||
vpnNetworks: vpnNetworks,
|
||||
unsafeNetworks: unsafeNetworks,
|
||||
TXQueueLen: c.GetInt("tun.tx_queue", 500),
|
||||
useSystemRoutes: c.GetBool("tun.use_system_route_table", false),
|
||||
useSystemRoutesBufferSize: c.GetInt("tun.use_system_route_table_buffer_size", 0),
|
||||
@@ -427,6 +429,27 @@ func (t *tun) setMTU() {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tun) setSnatRoute() error {
|
||||
snataddr := netip.MustParsePrefix("169.254.55.96/32") //todo get this from elsewhere? Or maybe we should pick it, and feed it back out to the firewall?
|
||||
dr := &net.IPNet{
|
||||
IP: snataddr.Masked().Addr().AsSlice(),
|
||||
Mask: net.CIDRMask(snataddr.Bits(), snataddr.Addr().BitLen()),
|
||||
}
|
||||
|
||||
nr := netlink.Route{
|
||||
LinkIndex: t.deviceIndex,
|
||||
Dst: dr,
|
||||
//todo do we need these other options?
|
||||
//MTU: t.DefaultMTU,
|
||||
//AdvMSS: t.advMSS(Route{}),
|
||||
Scope: unix.RT_SCOPE_LINK,
|
||||
//Protocol: unix.RTPROT_KERNEL,
|
||||
Table: unix.RT_TABLE_MAIN,
|
||||
Type: unix.RTN_UNICAST,
|
||||
}
|
||||
return netlink.RouteReplace(&nr)
|
||||
}
|
||||
|
||||
func (t *tun) setDefaultRoute(cidr netip.Prefix) error {
|
||||
dr := &net.IPNet{
|
||||
IP: cidr.Masked().Addr().AsSlice(),
|
||||
@@ -503,6 +526,18 @@ func (t *tun) addRoutes(logErrors bool) error {
|
||||
}
|
||||
}
|
||||
|
||||
onlyV6Addresses := false
|
||||
for _, n := range t.vpnNetworks {
|
||||
if n.Addr().Is6() {
|
||||
onlyV6Addresses = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(t.unsafeNetworks) != 0 && onlyV6Addresses {
|
||||
return t.setSnatRoute()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -70,11 +70,11 @@ type tun struct {
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in NetBSD")
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open tun device
|
||||
var err error
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
|
||||
@@ -63,11 +63,11 @@ type tun struct {
|
||||
|
||||
var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`)
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*tun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in openbsd")
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*tun, error) {
|
||||
// Try to open tun device
|
||||
var err error
|
||||
deviceName := c.GetString("tun.dev", "")
|
||||
|
||||
@@ -28,7 +28,7 @@ type TestTun struct {
|
||||
TxPackets chan []byte // Packets transmitted outside by nebula
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*TestTun, error) {
|
||||
_, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -49,7 +49,7 @@ func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (*TestTun, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported")
|
||||
}
|
||||
|
||||
|
||||
@@ -38,11 +38,11 @@ type winTun struct {
|
||||
tun *wintun.NativeTun
|
||||
}
|
||||
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) {
|
||||
func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix, _ []netip.Prefix) (Device, error) {
|
||||
return nil, fmt.Errorf("newTunFromFd not supported in Windows")
|
||||
}
|
||||
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) {
|
||||
func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, _ bool) (*winTun, error) {
|
||||
err := checkWinTunExists()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can not load the wintun driver: %w", err)
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/slackhq/nebula/routing"
|
||||
)
|
||||
|
||||
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, unsafeNetworks []netip.Prefix, routines int) (Device, error) {
|
||||
return NewUserDevice(vpnNetworks)
|
||||
}
|
||||
|
||||
|
||||
91
snat.go
Normal file
91
snat.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func recalcIPv4Checksum(data []byte, oldSrcIP netip.Addr, newSrcIP netip.Addr) {
|
||||
oldChecksum := binary.BigEndian.Uint16(data[10:12])
|
||||
//because of how checksums work, we can re-use this function
|
||||
checksum := calcNewTransportChecksum(oldChecksum, oldSrcIP, 0, newSrcIP, 0)
|
||||
binary.BigEndian.PutUint16(data[10:12], checksum)
|
||||
}
|
||||
|
||||
func calcNewTransportChecksum(oldChecksum uint16, oldSrcIP netip.Addr, oldSrcPort uint16, newSrcIP netip.Addr, newSrcPort uint16) uint16 {
|
||||
oldIP := binary.BigEndian.Uint32(oldSrcIP.AsSlice())
|
||||
newIP := binary.BigEndian.Uint32(newSrcIP.AsSlice())
|
||||
|
||||
// Start with inverted checksum
|
||||
checksum := uint32(^oldChecksum)
|
||||
|
||||
// Subtract old IP (as two 16-bit words)
|
||||
checksum += uint32(^uint16(oldIP >> 16))
|
||||
checksum += uint32(^uint16(oldIP & 0xFFFF))
|
||||
|
||||
// Subtract old port
|
||||
checksum += uint32(^oldSrcPort)
|
||||
|
||||
// Add new IP (as two 16-bit words)
|
||||
checksum += uint32(newIP >> 16)
|
||||
checksum += uint32(newIP & 0xFFFF)
|
||||
|
||||
// Add new port
|
||||
checksum += uint32(newSrcPort)
|
||||
|
||||
// Fold carries
|
||||
for checksum > 0xFFFF {
|
||||
checksum = (checksum & 0xFFFF) + (checksum >> 16)
|
||||
}
|
||||
|
||||
// Return ones' complement
|
||||
return ^uint16(checksum)
|
||||
}
|
||||
|
||||
func recalcV4TransportChecksum(offsetInsideHeader int, data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) {
|
||||
ipHeaderOffset := int(data[0]&0x0F) * 4
|
||||
offset := ipHeaderOffset + offsetInsideHeader
|
||||
oldcsum := binary.BigEndian.Uint16(data[offset : offset+2])
|
||||
checksum := calcNewTransportChecksum(oldcsum, oldSrcIP.Addr(), oldSrcIP.Port(), newSrcIP.Addr(), newSrcIP.Port())
|
||||
binary.BigEndian.PutUint16(data[offset:offset+2], checksum)
|
||||
}
|
||||
|
||||
func recalcUDPv4Checksum(data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) {
|
||||
const offsetInsideHeader = 6
|
||||
recalcV4TransportChecksum(offsetInsideHeader, data, oldSrcIP, newSrcIP)
|
||||
}
|
||||
|
||||
func recalcTCPv4Checksum(data []byte, oldSrcIP netip.AddrPort, newSrcIP netip.AddrPort) {
|
||||
const offsetInsideHeader = 16
|
||||
recalcV4TransportChecksum(offsetInsideHeader, data, oldSrcIP, newSrcIP)
|
||||
}
|
||||
|
||||
func calcNewICMPChecksum(oldChecksum uint16, oldCode uint16, newCode uint16, oldID uint16, newID uint16) uint16 {
|
||||
// Start with inverted checksum
|
||||
checksum := uint32(^oldChecksum)
|
||||
|
||||
// Subtract old stuff
|
||||
checksum += uint32(^oldCode)
|
||||
checksum += uint32(^oldID)
|
||||
|
||||
// Add new stuff
|
||||
checksum += uint32(newCode)
|
||||
checksum += uint32(newID)
|
||||
|
||||
// Fold carries
|
||||
for checksum > 0xFFFF {
|
||||
checksum = (checksum & 0xFFFF) + (checksum >> 16)
|
||||
}
|
||||
|
||||
// Return ones' complement
|
||||
return ^uint16(checksum)
|
||||
}
|
||||
|
||||
func recalcICMPv4Checksum(data []byte, oldCode uint16, newCode uint16, oldID uint16, newID uint16) {
|
||||
const offsetInsideHeader = 2
|
||||
ipHeaderOffset := int(data[0]&0x0F) * 4
|
||||
offset := ipHeaderOffset + offsetInsideHeader
|
||||
oldChecksum := binary.BigEndian.Uint16(data[offset : offset+2])
|
||||
checksum := calcNewICMPChecksum(oldChecksum, oldCode, newCode, oldID, newID)
|
||||
binary.BigEndian.PutUint16(data[offset:offset+2], checksum)
|
||||
}
|
||||
Reference in New Issue
Block a user