From a89f95182c9db812b880c8df7fa520120841c682 Mon Sep 17 00:00:00 2001 From: Jack Doan Date: Wed, 12 Nov 2025 13:40:20 -0600 Subject: [PATCH] Firewall types and cross-stack subnet stuff (#1509) * firewall can distinguish if the host connecting has an overlapping network, is a VPN peer without an overlapping network, or is a unsafe network * Cross stack subnet stuff (#1512) * experiment with not filtering out non-common addresses in hostinfo.networks * allow handshakes without overlaps * unsafe network test * change HostInfo.buildNetworks argument to reference the cert --- control_tester.go | 4 + e2e/handshakes_test.go | 170 ++++++++++++++++++++++++++++++++++ e2e/helpers_test.go | 62 +++++++++++-- e2e/tunnels_test.go | 47 ++++++++++ firewall.go | 33 +++++-- firewall_test.go | 203 ++++++++++++++++++++++++++++++++++++++--- handshake_ix.go | 118 +++++++++++------------- handshake_manager.go | 4 +- hostmap.go | 40 +++++--- inside.go | 11 ++- lighthouse.go | 6 +- 11 files changed, 582 insertions(+), 116 deletions(-) diff --git a/control_tester.go b/control_tester.go index 451dac5..7403a74 100644 --- a/control_tester.go +++ b/control_tester.go @@ -174,6 +174,10 @@ func (c *Control) GetHostmap() *HostMap { return c.f.hostMap } +func (c *Control) GetF() *Interface { + return c.f +} + func (c *Control) GetCertState() *CertState { return c.f.pki.getCertState() } diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 7e751c5..757cfd1 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -97,6 +97,41 @@ func TestGoodHandshake(t *testing.T) { theirControl.Stop() } +func TestGoodHandshakeNoOverlap(t *testing.T) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "2001::69/24", nil) //look ma, cross-stack! + + // Put their info in our lighthouse + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + + // Start the servers + myControl.Start() + theirControl.Start() + + empty := []byte{} + t.Log("do something to cause a handshake") + myControl.GetF().SendMessageToVpnAddr(header.Test, header.MessageNone, theirVpnIpNet[0].Addr(), empty, empty, empty) + + t.Log("Have them consume my stage 0 packet. They have a tunnel now") + theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) + + t.Log("Get their stage 1 packet") + stage1Packet := theirControl.GetFromUDP(true) + + t.Log("Have me consume their stage 1 packet. I have a tunnel now") + myControl.InjectUDPPacket(stage1Packet) + + t.Log("Wait until we see a test packet come through to make sure we give the tunnel time to complete") + myControl.WaitForType(header.Test, 0, theirControl) + + t.Log("Make sure our host infos are correct") + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) + + myControl.Stop() + theirControl.Stop() +} + func TestWrongResponderHandshake(t *testing.T) { ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) @@ -464,6 +499,35 @@ func TestRelays(t *testing.T) { r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) } +func TestRelaysDontCareAboutIps(t *testing.T) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "2001::9999/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + + // Teach my how to get to the relay and that their can be reached via the relay + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, relayControl, theirControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + relayControl.Start() + theirControl.Start() + + t.Log("Trigger a handshake from me to them via the relay") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + + p := r.RouteForAllUntilTxTun(theirControl) + r.Log("Assert the tunnel works") + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) + r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) +} + func TestReestablishRelays(t *testing.T) { ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) @@ -1227,3 +1291,109 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) { myControl.Stop() theirControl.Stop() } + +func TestV2NonPrimaryWithOffNetLighthouse(t *testing.T) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "2001::1/64", m{"lighthouse": m{"am_lighthouse": true}}) + + o := m{ + "static_host_map": m{ + lhVpnIpNet[0].Addr().String(): []string{lhUdpAddr.String()}, + }, + "lighthouse": m{ + "hosts": []string{lhVpnIpNet[0].Addr().String()}, + "local_allow_list": m{ + // Try and block our lighthouse updates from using the actual addresses assigned to this computer + // If we start discovering addresses the test router doesn't know about then test traffic cant flow + "10.0.0.0/24": true, + "::/0": false, + }, + }, + } + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.2/24, ff::2/64", o) + theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, lhControl, myControl, theirControl) + defer r.RenderFlow() + + // Start the servers + lhControl.Start() + myControl.Start() + theirControl.Start() + + t.Log("Stand up an ipv6 tunnel between me and them") + assert.True(t, myVpnIpNet[1].Addr().Is6()) + assert.True(t, theirVpnIpNet[1].Addr().Is6()) + assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r) + + lhControl.Stop() + myControl.Stop() + theirControl.Stop() +} + +func TestGoodHandshakeUnsafeDest(t *testing.T) { + unsafePrefix := "192.168.6.0/24" + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdpAndUnsafeNetworks(cert.Version2, ca, caKey, "spooky", "10.128.0.2/24", netip.MustParseAddrPort("10.64.0.2:4242"), unsafePrefix, nil) + route := m{"route": unsafePrefix, "via": theirVpnIpNet[0].Addr().String()} + myCfg := m{ + "tun": m{ + "unsafe_routes": []m{route}, + }, + } + myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version2, ca, caKey, "me", "10.128.0.1/24", myCfg) + t.Logf("my config %v", myConfig) + // Put their info in our lighthouse + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + + spookyDest := netip.MustParseAddr("192.168.6.4") + + // Start the servers + myControl.Start() + theirControl.Start() + + t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") + myControl.InjectTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + + t.Log("Have them consume my stage 0 packet. They have a tunnel now") + theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) + + t.Log("Get their stage 1 packet so that we can play with it") + stage1Packet := theirControl.GetFromUDP(true) + + t.Log("I consume a garbage packet with a proper nebula header for our tunnel") + // this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel + badPacket := stage1Packet.Copy() + badPacket.Data = badPacket.Data[:len(badPacket.Data)-header.Len] + myControl.InjectUDPPacket(badPacket) + + t.Log("Have me consume their real stage 1 packet. I have a tunnel now") + myControl.InjectUDPPacket(stage1Packet) + + t.Log("Wait until we see my cached packet come through") + myControl.WaitForType(1, 0, theirControl) + + t.Log("Make sure our host infos are correct") + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) + + t.Log("Get that cached packet and make sure it looks right") + myCachedPacket := theirControl.GetFromTun(true) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80) + + //reply + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman")) + //wait for reply + theirControl.WaitForType(1, 0, myControl) + theirCachedPacket := myControl.GetFromTun(true) + assertUdpPacket(t, []byte("Hi from the spookyman"), theirCachedPacket, spookyDest, myVpnIpNet[0].Addr(), 80, 80) + + t.Log("Do a bidirectional tunnel test") + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + r.RenderHostmaps("Final hostmaps", myControl, theirControl) + myControl.Stop() + theirControl.Stop() +} diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index 8c19d3b..cb9fc37 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -22,6 +22,7 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/e2e/router" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.yaml.in/yaml/v3" ) @@ -29,8 +30,6 @@ type m = map[string]any // newSimpleServer creates a nebula instance with many assumptions func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { - l := NewTestLogger() - var vpnNetworks []netip.Prefix for _, sn := range strings.Split(sVpnNetworks, ",") { vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn)) @@ -56,7 +55,54 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name budpIp[3] = 239 udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) } - _, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{}) + return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides) +} + +func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { + return newSimpleServerWithUdpAndUnsafeNetworks(v, caCrt, caKey, name, sVpnNetworks, udpAddr, "", overrides) +} + +func newSimpleServerWithUdpAndUnsafeNetworks(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, sUnsafeNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { + l := NewTestLogger() + + var vpnNetworks []netip.Prefix + for _, sn := range strings.Split(sVpnNetworks, ",") { + vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn)) + if err != nil { + panic(err) + } + vpnNetworks = append(vpnNetworks, vpnIpNet) + } + + if len(vpnNetworks) == 0 { + panic("no vpn networks") + } + + firewallInbound := []m{{ + "proto": "any", + "port": "any", + "host": "any", + }} + + var unsafeNetworks []netip.Prefix + if sUnsafeNetworks != "" { + firewallInbound = []m{{ + "proto": "any", + "port": "any", + "host": "any", + "local_cidr": "0.0.0.0/0", + }} + + for _, sn := range strings.Split(sUnsafeNetworks, ",") { + x, err := netip.ParsePrefix(strings.TrimSpace(sn)) + if err != nil { + panic(err) + } + unsafeNetworks = append(unsafeNetworks, x) + } + } + + _, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, unsafeNetworks, []string{}) caB, err := caCrt.MarshalPEM() if err != nil { @@ -76,11 +122,7 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name "port": "any", "host": "any", }}, - "inbound": []m{{ - "proto": "any", - "port": "any", - "host": "any", - }}, + "inbound": firewallInbound, }, //"handshakes": m{ // "try_interval": "1s", @@ -266,10 +308,10 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpn // Get both host infos //TODO: CERT-V2 we may want to loop over each vpnAddr and assert all the things hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false) - assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA") + require.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA") hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false) - assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB") + require.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB") // Check that both vpn and real addr are correct assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A") diff --git a/e2e/tunnels_test.go b/e2e/tunnels_test.go index f1e9ca7..e89cf86 100644 --- a/e2e/tunnels_test.go +++ b/e2e/tunnels_test.go @@ -318,3 +318,50 @@ func TestCertMismatchCorrection(t *testing.T) { myControl.Stop() theirControl.Stop() } + +func TestCrossStackRelaysWork(t *testing.T) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}}) + theirUdp := netip.MustParseAddrPort("10.0.0.2:4242") + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}}) + + //myVpnV4 := myVpnIpNet[0] + myVpnV6 := myVpnIpNet[1] + relayVpnV4 := relayVpnIpNet[0] + relayVpnV6 := relayVpnIpNet[1] + theirVpnV6 := theirVpnIpNet[0] + + // Teach my how to get to the relay and that their can be reached via the relay + myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr) + myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()}) + relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, relayControl, theirControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + relayControl.Start() + theirControl.Start() + + t.Log("Trigger a handshake from me to them via the relay") + myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me")) + + p := r.RouteForAllUntilTxTun(theirControl) + r.Log("Assert the tunnel works") + assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80) + + t.Log("reply?") + theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them")) + p = r.RouteForAllUntilTxTun(myControl) + assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80) + + r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) + //t.Log("finish up") + //myControl.Stop() + //theirControl.Stop() + //relayControl.Stop() +} diff --git a/firewall.go b/firewall.go index 971c156..6bf470d 100644 --- a/firewall.go +++ b/firewall.go @@ -417,8 +417,10 @@ func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, c *config.C, fw return nil } -var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets") -var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs") +var ErrUnknownNetworkType = errors.New("unknown network type") +var ErrPeerRejected = errors.New("remote address is not within a network that we handle") +var ErrInvalidRemoteIP = errors.New("remote address is not in remote certificate networks") +var ErrInvalidLocalIP = errors.New("local address is not in list of handled local addresses") var ErrNoMatchingRule = errors.New("no matching rule in firewall table") // Drop returns an error if the packet should be dropped, explaining why. It @@ -429,18 +431,31 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * return nil } - // Make sure remote address matches nebula certificate - if h.networks != nil { - if !h.networks.Contains(fp.RemoteAddr) { - f.metrics(incoming).droppedRemoteAddr.Inc(1) - return ErrInvalidRemoteIP - } - } else { + // Make sure remote address matches nebula certificate, and determine how to treat it + if h.networks == nil { // Simple case: Certificate has one address and no unsafe networks if h.vpnAddrs[0] != fp.RemoteAddr { f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrInvalidRemoteIP } + } else { + nwType, ok := h.networks.Lookup(fp.RemoteAddr) + if !ok { + f.metrics(incoming).droppedRemoteAddr.Inc(1) + return ErrInvalidRemoteIP + } + switch nwType { + case NetworkTypeVPN: + break // nothing special + case NetworkTypeVPNPeer: + f.metrics(incoming).droppedRemoteAddr.Inc(1) + return ErrPeerRejected // reject for now, one day this may have different FW rules + case NetworkTypeUnsafe: + break // nothing special, one day this may have different FW rules + default: + f.metrics(incoming).droppedRemoteAddr.Inc(1) + return ErrUnknownNetworkType //should never happen + } } // Make sure we are supposed to be handling this local ip address diff --git a/firewall_test.go b/firewall_test.go index a0cb3c8..6a4e00c 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -8,6 +8,8 @@ import ( "testing" "time" + "github.com/gaissmai/bart" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" @@ -149,7 +151,8 @@ func TestFirewall_Drop(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) - + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) p := firewall.Packet{ LocalAddr: netip.MustParseAddr("1.2.3.4"), RemoteAddr: netip.MustParseAddr("1.2.3.4"), @@ -174,7 +177,7 @@ func TestFirewall_Drop(t *testing.T) { }, vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")}, } - h.buildNetworks(c.networks, c.unsafeNetworks) + h.buildNetworks(myVpnNetworksTable, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -226,6 +229,9 @@ func TestFirewall_DropV6(t *testing.T) { ob := &bytes.Buffer{} l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7")) + p := firewall.Packet{ LocalAddr: netip.MustParseAddr("fd12::34"), RemoteAddr: netip.MustParseAddr("fd12::34"), @@ -250,7 +256,7 @@ func TestFirewall_DropV6(t *testing.T) { }, vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")}, } - h.buildNetworks(c.networks, c.unsafeNetworks) + h.buildNetworks(myVpnNetworksTable, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -453,6 +459,8 @@ func TestFirewall_Drop2(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) p := firewall.Packet{ LocalAddr: netip.MustParseAddr("1.2.3.4"), @@ -478,7 +486,7 @@ func TestFirewall_Drop2(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) + h.buildNetworks(myVpnNetworksTable, c.Certificate) c1 := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -493,7 +501,7 @@ func TestFirewall_Drop2(t *testing.T) { peerCert: &c1, }, } - h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) + h1.buildNetworks(myVpnNetworksTable, c1.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -510,6 +518,8 @@ func TestFirewall_Drop3(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) p := firewall.Packet{ LocalAddr: netip.MustParseAddr("1.2.3.4"), @@ -541,7 +551,7 @@ func TestFirewall_Drop3(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) + h1.buildNetworks(myVpnNetworksTable, c1.Certificate) c2 := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -556,7 +566,7 @@ func TestFirewall_Drop3(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks()) + h2.buildNetworks(myVpnNetworksTable, c2.Certificate) c3 := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -571,7 +581,7 @@ func TestFirewall_Drop3(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks()) + h3.buildNetworks(myVpnNetworksTable, c3.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -597,6 +607,8 @@ func TestFirewall_Drop3V6(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7")) p := firewall.Packet{ LocalAddr: netip.MustParseAddr("fd12::34"), @@ -620,7 +632,7 @@ func TestFirewall_Drop3V6(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) + h.buildNetworks(myVpnNetworksTable, c.Certificate) // Test a remote address match fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) @@ -633,6 +645,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8")) p := firewall.Packet{ LocalAddr: netip.MustParseAddr("1.2.3.4"), @@ -659,7 +673,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) + h.buildNetworks(myVpnNetworksTable, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -696,6 +710,8 @@ func TestFirewall_DropIPSpoofing(t *testing.T) { l := test.NewLogger() ob := &bytes.Buffer{} l.SetOutput(ob) + myVpnNetworksTable := new(bart.Lite) + myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24")) c := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -717,7 +733,7 @@ func TestFirewall_DropIPSpoofing(t *testing.T) { }, vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()}, } - h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) + h1.buildNetworks(myVpnNetworksTable, c1.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) @@ -1047,6 +1063,171 @@ func TestFirewall_convertRule(t *testing.T) { assert.Equal(t, "group1", r.Group) } +type testcase struct { + h *HostInfo + p firewall.Packet + c cert.Certificate + err error +} + +func (c *testcase) Test(t *testing.T, fw *Firewall) { + t.Helper() + cp := cert.NewCAPool() + resetConntrack(fw) + err := fw.Drop(c.p, true, c.h, cp, nil) + if c.err == nil { + require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr) + } else { + require.ErrorIs(t, c.err, err, "failed to drop remote address %s", c.p.RemoteAddr) + } +} + +func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase { + c1 := dummyCert{ + name: "host1", + networks: theirPrefixes, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + h := HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &cert.CachedCertificate{ + Certificate: &c1, + InvertedGroups: map[string]struct{}{"default-group": {}}, + }, + }, + vpnAddrs: make([]netip.Addr, len(theirPrefixes)), + } + for i := range theirPrefixes { + h.vpnAddrs[i] = theirPrefixes[i].Addr() + } + h.buildNetworks(setup.myVpnNetworksTable, &c1) + p := firewall.Packet{ + LocalAddr: setup.c.Networks()[0].Addr(), //todo? + RemoteAddr: theirPrefixes[0].Addr(), + LocalPort: 10, + RemotePort: 90, + Protocol: firewall.ProtoUDP, + Fragment: false, + } + return testcase{ + h: &h, + p: p, + c: &c1, + err: err, + } +} + +type testsetup struct { + c dummyCert + myVpnNetworksTable *bart.Lite + fw *Firewall +} + +func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup { + c := dummyCert{ + name: "me", + networks: myPrefixes, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + + return newSetupFromCert(t, l, c) +} + +func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup { + myVpnNetworksTable := new(bart.Lite) + for _, prefix := range c.Networks() { + myVpnNetworksTable.Insert(prefix) + } + fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + + return testsetup{ + c: c, + fw: fw, + myVpnNetworksTable: myVpnNetworksTable, + } +} + +func TestFirewall_Drop_EnforceIPMatch(t *testing.T) { + t.Parallel() + l := test.NewLogger() + ob := &bytes.Buffer{} + l.SetOutput(ob) + + myPrefix := netip.MustParsePrefix("1.1.1.1/8") + // for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out + t.Run("allow inbound all matching", func(t *testing.T) { + t.Parallel() + setup := newSetup(t, l, myPrefix) + tc := buildTestCase(setup, nil, netip.MustParsePrefix("1.2.3.4/24")) + tc.Test(t, setup.fw) + }) + t.Run("allow inbound local matching", func(t *testing.T) { + t.Parallel() + setup := newSetup(t, l, myPrefix) + tc := buildTestCase(setup, ErrInvalidLocalIP, netip.MustParsePrefix("1.2.3.4/24")) + tc.p.LocalAddr = netip.MustParseAddr("1.2.3.8") + tc.Test(t, setup.fw) + }) + t.Run("block inbound remote mismatched", func(t *testing.T) { + t.Parallel() + setup := newSetup(t, l, myPrefix) + tc := buildTestCase(setup, ErrInvalidRemoteIP, netip.MustParsePrefix("1.2.3.4/24")) + tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9") + tc.Test(t, setup.fw) + }) + t.Run("Block a vpn peer packet", func(t *testing.T) { + t.Parallel() + setup := newSetup(t, l, myPrefix) + tc := buildTestCase(setup, ErrPeerRejected, netip.MustParsePrefix("2.2.2.2/24")) + tc.Test(t, setup.fw) + }) + twoPrefixes := []netip.Prefix{ + netip.MustParsePrefix("1.2.3.4/24"), netip.MustParsePrefix("2.2.2.2/24"), + } + t.Run("allow inbound one matching", func(t *testing.T) { + t.Parallel() + setup := newSetup(t, l, myPrefix) + tc := buildTestCase(setup, nil, twoPrefixes...) + tc.Test(t, setup.fw) + }) + t.Run("block inbound multimismatch", func(t *testing.T) { + t.Parallel() + setup := newSetup(t, l, myPrefix) + tc := buildTestCase(setup, ErrInvalidRemoteIP, twoPrefixes...) + tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9") + tc.Test(t, setup.fw) + }) + t.Run("allow inbound 2nd one matching", func(t *testing.T) { + t.Parallel() + setup2 := newSetup(t, l, netip.MustParsePrefix("2.2.2.1/24")) + tc := buildTestCase(setup2, nil, twoPrefixes...) + tc.p.RemoteAddr = twoPrefixes[1].Addr() + tc.Test(t, setup2.fw) + }) + t.Run("allow inbound unsafe route", func(t *testing.T) { + t.Parallel() + unsafePrefix := netip.MustParsePrefix("192.168.0.0/24") + c := dummyCert{ + name: "me", + networks: []netip.Prefix{myPrefix}, + unsafeNetworks: []netip.Prefix{unsafePrefix}, + groups: []string{"default-group"}, + issuer: "signer-shasum", + } + unsafeSetup := newSetupFromCert(t, l, c) + tc := buildTestCase(unsafeSetup, nil, twoPrefixes...) + tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3") + tc.err = ErrNoMatchingRule + tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off + require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, unsafePrefix, "", "")) + tc.err = nil + tc.Test(t, unsafeSetup.fw) //should pass + }) +} + type addRuleCall struct { incoming bool proto uint8 diff --git a/handshake_ix.go b/handshake_ix.go index 00b1d40..fd0b456 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -2,7 +2,6 @@ package nebula import ( "net/netip" - "slices" "time" "github.com/flynn/noise" @@ -192,17 +191,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet return } - var vpnAddrs []netip.Addr - var filteredNetworks []netip.Prefix certName := remoteCert.Certificate.Name() certVersion := remoteCert.Certificate.Version() fingerprint := remoteCert.Fingerprint issuer := remoteCert.Certificate.Issuer() + vpnNetworks := remoteCert.Certificate.Networks() - for _, network := range remoteCert.Certificate.Networks() { - vpnAddr := network.Addr() - if f.myVpnAddrsTable.Contains(vpnAddr) { - f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr). + anyVpnAddrsInCommon := false + vpnAddrs := make([]netip.Addr, len(vpnNetworks)) + for i, network := range vpnNetworks { + if f.myVpnAddrsTable.Contains(network.Addr()) { + f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). @@ -210,24 +209,10 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") return } - - // vpnAddrs outside our vpn networks are of no use to us, filter them out - if !f.myVpnNetworksTable.Contains(vpnAddr) { - continue + vpnAddrs[i] = network.Addr() + if f.myVpnNetworksTable.Contains(network.Addr()) { + anyVpnAddrsInCommon = true } - - filteredNetworks = append(filteredNetworks, network) - vpnAddrs = append(vpnAddrs, vpnAddr) - } - - if len(vpnAddrs) == 0 { - f.l.WithError(err).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake") - return } if addr.IsValid() { @@ -264,26 +249,30 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet }, } - f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Info("Handshake message received") + msgRxL := f.l.WithFields(m{ + "vpnAddrs": vpnAddrs, + "udpAddr": addr, + "certName": certName, + "certVersion": certVersion, + "fingerprint": fingerprint, + "issuer": issuer, + "initiatorIndex": hs.Details.InitiatorIndex, + "responderIndex": hs.Details.ResponderIndex, + "remoteIndex": h.RemoteIndex, + "handshake": m{"stage": 1, "style": "ix_psk0"}, + }) + + if anyVpnAddrsInCommon { + msgRxL.Info("Handshake message received") + } else { + //todo warn if not lighthouse or relay? + msgRxL.Info("Handshake message received, but no vpnNetworks in common.") + } hs.Details.ResponderIndex = myIndex hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version()) if hs.Details.Cert == nil { - f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("certVersion", ci.myCert.Version()). + msgRxL.WithField("myCertVersion", ci.myCert.Version()). Error("Unable to handshake with host because no certificate handshake bytes is available") return } @@ -341,7 +330,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) hostinfo.SetRemote(addr) - hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks()) + hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) if err != nil { @@ -582,31 +571,22 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) } - var vpnAddrs []netip.Addr - var filteredNetworks []netip.Prefix - for _, network := range vpnNetworks { - // vpnAddrs outside our vpn networks are of no use to us, filter them out - vpnAddr := network.Addr() - if !f.myVpnNetworksTable.Contains(vpnAddr) { - continue + correctHostResponded := false + anyVpnAddrsInCommon := false + vpnAddrs := make([]netip.Addr, len(vpnNetworks)) + for i, network := range vpnNetworks { + vpnAddrs[i] = network.Addr() + if f.myVpnNetworksTable.Contains(network.Addr()) { + anyVpnAddrsInCommon = true + } + if hostinfo.vpnAddrs[0] == network.Addr() { + // todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not? + correctHostResponded = true } - - filteredNetworks = append(filteredNetworks, network) - vpnAddrs = append(vpnAddrs, vpnAddr) - } - - if len(vpnAddrs) == 0 { - f.l.WithError(err).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake") - return true } // Ensure the right host responded - if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) { + if !correctHostResponded { f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). WithField("udpAddr", addr). WithField("certName", certName). @@ -618,6 +598,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha f.handshakeManager.DeleteHostInfo(hostinfo) // Create a new hostinfo/handshake for the intended vpn ip + //TODO is hostinfo.vpnAddrs[0] always the address to use? f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) { // Block the current used address newHH.hostinfo.remotes = hostinfo.remotes @@ -644,7 +625,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha ci.window.Update(f.l, 2) duration := time.Since(hh.startTime).Nanoseconds() - f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). + msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). @@ -652,12 +633,17 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("durationNs", duration). - WithField("sentCachedPackets", len(hh.packetStore)). - Info("Handshake message received") + WithField("sentCachedPackets", len(hh.packetStore)) + if anyVpnAddrsInCommon { + msgRxL.Info("Handshake message received") + } else { + //todo warn if not lighthouse or relay? + msgRxL.Info("Handshake message received, but no vpnNetworks in common.") + } // Build up the radix for the firewall if we have subnets in the cert hostinfo.vpnAddrs = vpnAddrs - hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks()) + hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) // Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here f.handshakeManager.Complete(hostinfo, f) diff --git a/handshake_manager.go b/handshake_manager.go index ee72d71..cae27a2 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -269,12 +269,12 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts") // Send a RelayRequest to all known Relay IP's for _, relay := range hostinfo.remotes.relays { - // Don't relay to myself + // Don't relay through the host I'm trying to connect to if relay == vpnIp { continue } - // Don't relay through the host I'm trying to connect to + // Don't relay to myself if hm.f.myVpnAddrsTable.Contains(relay) { continue } diff --git a/hostmap.go b/hostmap.go index 66b4851..9f8cd5e 100644 --- a/hostmap.go +++ b/hostmap.go @@ -212,6 +212,18 @@ func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) { rs.relayForByIdx[idx] = r } +type NetworkType uint8 + +const ( + NetworkTypeUnknown NetworkType = iota + // NetworkTypeVPN is a network that overlaps one or more of the vpnNetworks in our certificate + NetworkTypeVPN + // NetworkTypeVPNPeer is a network that does not overlap one of our networks + NetworkTypeVPNPeer + // NetworkTypeUnsafe is a network from Certificate.UnsafeNetworks() + NetworkTypeUnsafe +) + type HostInfo struct { remote netip.AddrPort remotes *RemoteList @@ -225,8 +237,8 @@ type HostInfo struct { // vpn networks but were removed because they are not usable vpnAddrs []netip.Addr - // networks are both all vpn and unsafe networks assigned to this host - networks *bart.Lite + // networks is a combination of specific vpn addresses (not prefixes!) and full unsafe networks assigned to this host. + networks *bart.Table[NetworkType] relayState RelayState // HandshakePacket records the packets used to create this hostinfo @@ -730,20 +742,26 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b return false } -func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) { - if len(networks) == 1 && len(unsafeNetworks) == 0 { - // Simple case, no CIDRTree needed - return +// buildNetworks fills in the networks field of HostInfo. It accepts a cert.Certificate so you never ever mix the network types up. +func (i *HostInfo) buildNetworks(myVpnNetworksTable *bart.Lite, c cert.Certificate) { + if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 { + if myVpnNetworksTable.Contains(c.Networks()[0].Addr()) { + return // Simple case, no BART needed + } } - i.networks = new(bart.Lite) - for _, network := range networks { + i.networks = new(bart.Table[NetworkType]) + for _, network := range c.Networks() { nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) - i.networks.Insert(nprefix) + if myVpnNetworksTable.Contains(network.Addr()) { + i.networks.Insert(nprefix, NetworkTypeVPN) + } else { + i.networks.Insert(nprefix, NetworkTypeVPNPeer) + } } - for _, network := range unsafeNetworks { - i.networks.Insert(network) + for _, network := range c.UnsafeNetworks() { + i.networks.Insert(network, NetworkTypeUnsafe) } } diff --git a/inside.go b/inside.go index d24ed31..0d53f95 100644 --- a/inside.go +++ b/inside.go @@ -120,9 +120,10 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo * f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) } -// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established +// Handshake will attempt to initiate a tunnel with the provided vpn address. This is a no-op if the tunnel is already established or being established +// it does not check if it is within our vpn networks! func (f *Interface) Handshake(vpnAddr netip.Addr) { - f.getOrHandshakeNoRouting(vpnAddr, nil) + f.handshakeManager.GetOrHandshake(vpnAddr, nil) } // getOrHandshakeNoRouting returns nil if the vpnAddr is not routable. @@ -138,7 +139,6 @@ func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback fu // getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel. func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { - destinationAddr := fwPacket.RemoteAddr hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback) @@ -231,9 +231,10 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0) } -// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr +// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr. +// This function ignores myVpnNetworksTable, and will always attempt to treat the address as a vpnAddr func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) { - hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) { + hostInfo, ready := f.handshakeManager.GetOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) }) diff --git a/lighthouse.go b/lighthouse.go index 4a191e6..1510b94 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -360,7 +360,8 @@ func (lh *LightHouse) parseLighthouses(c *config.C) ([]netip.Addr, error) { } if !lh.myVpnNetworksTable.Contains(addr) { - return nil, util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil) + lh.l.WithFields(m{"vpnAddr": addr, "networks": lh.myVpnNetworks}). + Warn("lighthouse host is not within our networks, lighthouse functionality will work but layer 3 network traffic to the lighthouse will not") } out[i] = addr } @@ -431,7 +432,8 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc } if !lh.myVpnNetworksTable.Contains(vpnAddr) { - return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil) + lh.l.WithFields(m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}). + Warn("static_host_map key is not within our networks, layer 3 network traffic to this host will not work") } vals, ok := v.([]any)