From f597aa71e325eb52f57d2ac404862ecc79822f79 Mon Sep 17 00:00:00 2001 From: JackDoan Date: Tue, 21 Oct 2025 11:03:13 -0500 Subject: [PATCH] firewall can distinguish if the host connecting has an overlapping network, is a VPN peer without an overlapping network, or is a unsafe network --- firewall.go | 30 ++++++++++++++++++++++-------- firewall_test.go | 37 ++++++++++++++++++++++++++----------- handshake_ix.go | 4 ++-- hostmap.go | 33 ++++++++++++++++++++++++++------- 4 files changed, 76 insertions(+), 28 deletions(-) diff --git a/firewall.go b/firewall.go index 3359082..a349d2f 100644 --- a/firewall.go +++ b/firewall.go @@ -417,6 +417,8 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw return nil } +var ErrUnknownNetworkType = errors.New("unknown network type") +var ErrPeerRejected = errors.New("remote IP is not within a subnet that we handle") var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets") var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs") var ErrNoMatchingRule = errors.New("no matching rule in firewall table") @@ -429,19 +431,31 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * return nil } - // TODO if we don't have a network in common with this packet's source IP, (and it's not for an unsafe_network), do we reject it? - // Make sure remote address matches nebula certificate - if h.networks != nil { - if !h.networks.Contains(fp.RemoteAddr) { - f.metrics(incoming).droppedRemoteAddr.Inc(1) - return ErrInvalidRemoteIP - } - } else { + // Make sure remote address matches nebula certificate, and determine how to treat it + if h.networks == nil { // Simple case: Certificate has one address and no unsafe networks if h.vpnAddrs[0] != fp.RemoteAddr { f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrInvalidRemoteIP } + } else { + nwType, ok := h.networks.Lookup(fp.RemoteAddr) + if !ok { + f.metrics(incoming).droppedRemoteAddr.Inc(1) + return ErrInvalidRemoteIP + } + switch nwType { + case NetworkTypeVPN: + break // nothing special + case NetworkTypeVPNPeer: + f.metrics(incoming).droppedRemoteAddr.Inc(1) + return ErrPeerRejected // reject for now, one day this may have different FW rules + case NetworkTypeUnsafe: + break // nothing special, one day this may have different FW rules + default: + f.metrics(incoming).droppedRemoteAddr.Inc(1) + return ErrUnknownNetworkType //should never happen + } } // Make sure we are supposed to be handling this local ip address diff --git a/firewall_test.go b/firewall_test.go index a0cb3c8..f8eec68 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/gaissmai/bart" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" @@ -149,7 +150,8 @@ func TestFirewall_Drop(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) - + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) p := firewall.Packet{ LocalAddr: netip.MustParseAddr("1.2.3.4"), RemoteAddr: netip.MustParseAddr("1.2.3.4"), @@ -174,7 +176,7 @@ func TestFirewall_Drop(t *testing.T) { }, vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")}, } - h.buildNetworks(c.networks, c.unsafeNetworks) + h.buildNetworks(myVpnNetworksTable, c.networks, c.unsafeNetworks) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -226,6 +228,9 @@ func TestFirewall_DropV6(t *testing.T) { ob := &bytes.Buffer{} l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7")) + p := firewall.Packet{ LocalAddr: netip.MustParseAddr("fd12::34"), RemoteAddr: netip.MustParseAddr("fd12::34"), @@ -250,7 +255,7 @@ func TestFirewall_DropV6(t *testing.T) { }, vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")}, } - h.buildNetworks(c.networks, c.unsafeNetworks) + h.buildNetworks(myVpnNetworksTable, c.networks, c.unsafeNetworks) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -453,6 +458,8 @@ func TestFirewall_Drop2(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) p := firewall.Packet{ LocalAddr: netip.MustParseAddr("1.2.3.4"), @@ -478,7 +485,7 @@ func TestFirewall_Drop2(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) + h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) c1 := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -493,7 +500,7 @@ func TestFirewall_Drop2(t *testing.T) { peerCert: &c1, }, } - h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) + h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -510,6 +517,8 @@ func TestFirewall_Drop3(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) p := firewall.Packet{ LocalAddr: netip.MustParseAddr("1.2.3.4"), @@ -541,7 +550,7 @@ func TestFirewall_Drop3(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) + h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) c2 := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -556,7 +565,7 @@ func TestFirewall_Drop3(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks()) + h2.buildNetworks(myVpnNetworksTable, c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks()) c3 := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -571,7 +580,7 @@ func TestFirewall_Drop3(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks()) + h3.buildNetworks(myVpnNetworksTable, c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -597,6 +606,8 @@ func TestFirewall_Drop3V6(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7")) p := firewall.Packet{ LocalAddr: netip.MustParseAddr("fd12::34"), @@ -620,7 +631,7 @@ func TestFirewall_Drop3V6(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) + h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) // Test a remote address match fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) @@ -633,6 +644,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) p := firewall.Packet{ LocalAddr: netip.MustParseAddr("1.2.3.4"), @@ -659,7 +672,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) + h.buildNetworks(myVpnNetworksTable, c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -696,6 +709,8 @@ func TestFirewall_DropIPSpoofing(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24")) c := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -717,7 +732,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) { }, vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()}, } - h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) + h1.buildNetworks(myVpnNetworksTable, c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) diff --git a/handshake_ix.go b/handshake_ix.go index 4b93ba0..c5c44f4 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -323,7 +323,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) hostinfo.SetRemote(addr) - hostinfo.buildNetworks(remoteCert.Certificate.Networks(), remoteCert.Certificate.UnsafeNetworks()) + hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate.Networks(), remoteCert.Certificate.UnsafeNetworks()) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) if err != nil { @@ -632,7 +632,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha // Build up the radix for the firewall if we have subnets in the cert hostinfo.vpnAddrs = vpnAddrs - hostinfo.buildNetworks(remoteCert.Certificate.Networks(), remoteCert.Certificate.UnsafeNetworks()) + hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate.Networks(), remoteCert.Certificate.UnsafeNetworks()) // Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here f.handshakeManager.Complete(hostinfo, f) diff --git a/hostmap.go b/hostmap.go index ea221a8..fc69c40 100644 --- a/hostmap.go +++ b/hostmap.go @@ -212,6 +212,18 @@ func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) { rs.relayForByIdx[idx] = r } +type NetworkType uint8 + +const ( + NetworkTypeUnknown NetworkType = iota + // NetworkTypeVPN is a network that overlaps one or more of the vpnNetworks in our certificate + NetworkTypeVPN + // NetworkTypeVPNPeer is a network that does not overlap one of our networks + NetworkTypeVPNPeer + // NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks() + NetworkTypeUnsafe +) + type HostInfo struct { remote netip.AddrPort remotes *RemoteList @@ -224,7 +236,7 @@ type HostInfo struct { vpnAddrs []netip.Addr // networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host. - networks *bart.Lite + networks *bart.Table[NetworkType] relayState RelayState // HandshakePacket records the packets used to create this hostinfo @@ -728,20 +740,27 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b return false } -func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) { +func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, networks, unsafeNetworks []netip.Prefix) { if len(networks) == 1 && len(unsafeNetworks) == 0 { - // Simple case, no CIDRTree needed - return + if myVpnNetworksTable.Contains(networks[0].Addr()) { + return // Simple case, no CIDRTree needed + } } - i.networks = new(bart.Lite) + i.networks = new(bart.Table[NetworkType]) for _, network := range networks { + var nwType NetworkType + if myVpnNetworksTable.Contains(network.Addr()) { + nwType = NetworkTypeVPN + } else { + nwType = NetworkTypeVPNPeer + } nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) - i.networks.Insert(nprefix) + i.networks.Insert(nprefix, nwType) } for _, network := range unsafeNetworks { - i.networks.Insert(network) + i.networks.Insert(network, NetworkTypeUnsafe) } }