From 1ad0f57c1ed387ea9c8d5698ce7abaa3dc96caf3 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Mon, 27 Jan 2025 16:40:11 -0600 Subject: [PATCH] Remove unusable networks from remote tunnels at handshake time (#1318) --- firewall.go | 6 ++---- firewall_test.go | 15 +++++++------- handshake_ix.go | 54 ++++++++++++++++++++++++++++++++++++++++-------- hostmap.go | 16 ++++++++------ interface.go | 2 +- 5 files changed, 66 insertions(+), 27 deletions(-) diff --git a/firewall.go b/firewall.go index 0aae7d6..d3b9eb6 100644 --- a/firewall.go +++ b/firewall.go @@ -8,7 +8,6 @@ import ( "hash/fnv" "net/netip" "reflect" - "slices" "strconv" "strings" "sync" @@ -438,9 +437,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * return ErrInvalidRemoteIP } } else { - // Simple case: Certificate has one IP and no subnets - //TODO: we can make this more performant - if !slices.Contains(h.vpnAddrs, fp.RemoteAddr) { + // Simple case: Certificate has one address and no unsafe networks + if h.vpnAddrs[0] != fp.RemoteAddr { f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrInvalidRemoteIP } diff --git a/firewall_test.go b/firewall_test.go index a0d08ac..4dd2c9a 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -152,7 +152,7 @@ func TestFirewall_Drop(t *testing.T) { }, vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")}, } - h.buildNetworks(&c) + h.buildNetworks(c.networks, c.unsafeNetworks) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -332,7 +332,7 @@ func TestFirewall_Drop2(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h.buildNetworks(c.Certificate) + h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) c1 := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -342,11 +342,12 @@ func TestFirewall_Drop2(t *testing.T) { InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}}, } h1 := HostInfo{ + vpnAddrs: []netip.Addr{network.Addr()}, ConnectionState: &ConnectionState{ peerCert: &c1, }, } - h1.buildNetworks(c1.Certificate) + h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -394,7 +395,7 @@ func TestFirewall_Drop3(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h1.buildNetworks(c1.Certificate) + h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) c2 := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -409,7 +410,7 @@ func TestFirewall_Drop3(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h2.buildNetworks(c2.Certificate) + h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks()) c3 := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -424,7 +425,7 @@ func TestFirewall_Drop3(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h3.buildNetworks(c3.Certificate) + h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -471,7 +472,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h.buildNetworks(c.Certificate) + h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) diff --git a/handshake_ix.go b/handshake_ix.go index c77145e..356c034 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -172,6 +172,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet } var vpnAddrs []netip.Addr + var filteredNetworks []netip.Prefix certName := remoteCert.Certificate.Name() fingerprint := remoteCert.Fingerprint issuer := remoteCert.Certificate.Issuer() @@ -189,15 +190,32 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet } if addr.IsValid() { + // addr can be invalid when the tunnel is being relayed. + // We only want to apply the remote allow list for direct tunnels here if !f.lightHouse.GetRemoteAllowList().Allow(vpnAddr, addr.Addr()) { f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return } } + // vpnAddrs outside our vpn networks are of no use to us, filter them out + if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok { + continue + } + + filteredNetworks = append(filteredNetworks, network) vpnAddrs = append(vpnAddrs, vpnAddr) } + if len(vpnAddrs) == 0 { + f.l.WithError(err).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("issuer", issuer). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake") + return + } + myIndex, err := generateIndex(f.l) if err != nil { f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). @@ -294,7 +312,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) hostinfo.SetRemote(addr) - hostinfo.buildNetworks(remoteCert.Certificate) + hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks()) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) if err != nil { @@ -431,7 +449,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha hostinfo := hh.hostinfo if addr.IsValid() { - //TODO: this is kind of nonsense now + // The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list. if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnAddrs[0], addr.Addr()) { f.l.WithField("vpnIp", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return false @@ -492,7 +510,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha e = e.WithField("cert", remoteCert) } - e.Info("Invalid vpn ip from host") + e.Info("Empty networks from host") return true } @@ -516,9 +534,26 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) } - vpnAddrs := make([]netip.Addr, len(vpnNetworks)) - for i, n := range vpnNetworks { - vpnAddrs[i] = n.Addr() + var vpnAddrs []netip.Addr + var filteredNetworks []netip.Prefix + for _, network := range vpnNetworks { + // vpnAddrs outside our vpn networks are of no use to us, filter them out + vpnAddr := network.Addr() + if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok { + continue + } + + filteredNetworks = append(filteredNetworks, network) + vpnAddrs = append(vpnAddrs, vpnAddr) + } + + if len(vpnAddrs) == 0 { + f.l.WithError(err).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("issuer", issuer). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake") + return true } // Ensure the right host responded @@ -558,7 +593,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha ci.window.Update(f.l, 2) duration := time.Since(hh.startTime).Nanoseconds() - f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -569,9 +604,10 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha Info("Handshake message received") // Build up the radix for the firewall if we have subnets in the cert - hostinfo.buildNetworks(remoteCert.Certificate) + hostinfo.vpnAddrs = vpnAddrs + hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks()) - // Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp + // Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here f.handshakeManager.Complete(hostinfo, f) f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) diff --git a/hostmap.go b/hostmap.go index 2869c6b..eca9dd7 100644 --- a/hostmap.go +++ b/hostmap.go @@ -215,8 +215,12 @@ type HostInfo struct { ConnectionState *ConnectionState remoteIndexId uint32 localIndexId uint32 - vpnAddrs []netip.Addr - recvError atomic.Uint32 + + // vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks + // The host may have other vpn addresses that are outside our + // vpn networks but were removed because they are not usable + vpnAddrs []netip.Addr + recvError atomic.Uint32 // networks are both all vpn and unsafe networks assigned to this host networks *bart.Table[struct{}] @@ -712,18 +716,18 @@ func (i *HostInfo) RecvErrorExceeded() bool { return true } -func (i *HostInfo) buildNetworks(c cert.Certificate) { - if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 { +func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) { + if len(networks) == 1 && len(unsafeNetworks) == 0 { // Simple case, no CIDRTree needed return } i.networks = new(bart.Table[struct{}]) - for _, network := range c.Networks() { + for _, network := range networks { i.networks.Insert(network, struct{}{}) } - for _, network := range c.UnsafeNetworks() { + for _, network := range unsafeNetworks { i.networks.Insert(network, struct{}{}) } } diff --git a/interface.go b/interface.go index 19a6864..21e198c 100644 --- a/interface.go +++ b/interface.go @@ -64,7 +64,7 @@ type Interface struct { myBroadcastAddrsTable *bart.Table[struct{}] myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate - myVpnNetworks []netip.Prefix // A table of networks assigned to us via our certificate + myVpnNetworks []netip.Prefix // A list of networks assigned to us via our certificate myVpnNetworksTable *bart.Table[struct{}] // A table of networks assigned to us via our certificate dropLocalBroadcast bool dropMulticast bool