diff --git a/allow_list.go b/allow_list.go index c574b6e..cfdd983 100644 --- a/allow_list.go +++ b/allow_list.go @@ -250,20 +250,20 @@ func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error return remoteAllowRanges, nil } -func (al *AllowList) Allow(ip netip.Addr) bool { +func (al *AllowList) Allow(addr netip.Addr) bool { if al == nil { return true } - result, _ := al.cidrTree.Lookup(ip) + result, _ := al.cidrTree.Lookup(addr) return result } -func (al *LocalAllowList) Allow(ip netip.Addr) bool { +func (al *LocalAllowList) Allow(udpAddr netip.Addr) bool { if al == nil { return true } - return al.AllowList.Allow(ip) + return al.AllowList.Allow(udpAddr) } func (al *LocalAllowList) AllowName(name string) bool { @@ -281,23 +281,37 @@ func (al *LocalAllowList) AllowName(name string) bool { return !al.nameRules[0].Allow } -func (al *RemoteAllowList) AllowUnknownVpnIp(ip netip.Addr) bool { +func (al *RemoteAllowList) AllowUnknownVpnAddr(vpnAddr netip.Addr) bool { if al == nil { return true } - return al.AllowList.Allow(ip) + return al.AllowList.Allow(vpnAddr) } -func (al *RemoteAllowList) Allow(vpnIp netip.Addr, ip netip.Addr) bool { - if !al.getInsideAllowList(vpnIp).Allow(ip) { +func (al *RemoteAllowList) Allow(vpnAddr netip.Addr, udpAddr netip.Addr) bool { + if !al.getInsideAllowList(vpnAddr).Allow(udpAddr) { return false } - return al.AllowList.Allow(ip) + return al.AllowList.Allow(udpAddr) } -func (al *RemoteAllowList) getInsideAllowList(vpnIp netip.Addr) *AllowList { +func (al *RemoteAllowList) AllowAll(vpnAddrs []netip.Addr, udpAddr netip.Addr) bool { + if !al.AllowList.Allow(udpAddr) { + return false + } + + for _, vpnAddr := range vpnAddrs { + if !al.getInsideAllowList(vpnAddr).Allow(udpAddr) { + return false + } + } + + return true +} + +func (al *RemoteAllowList) getInsideAllowList(vpnAddr netip.Addr) *AllowList { if al.insideAllowLists != nil { - inside, ok := al.insideAllowLists.Lookup(vpnIp) + inside, ok := al.insideAllowLists.Lookup(vpnAddr) if ok { return inside } diff --git a/handshake_ix.go b/handshake_ix.go index 356c034..9b8b3e9 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -189,15 +189,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet return } - if addr.IsValid() { - // addr can be invalid when the tunnel is being relayed. - // We only want to apply the remote allow list for direct tunnels here - if !f.lightHouse.GetRemoteAllowList().Allow(vpnAddr, addr.Addr()) { - f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") - return - } - } - // vpnAddrs outside our vpn networks are of no use to us, filter them out if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok { continue @@ -216,6 +207,15 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet return } + if addr.IsValid() { + // addr can be invalid when the tunnel is being relayed. + // We only want to apply the remote allow list for direct tunnels here + if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, addr.Addr()) { + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + return + } + } + myIndex, err := generateIndex(f.l) if err != nil { f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). @@ -450,8 +450,8 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha hostinfo := hh.hostinfo if addr.IsValid() { // The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list. - if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnAddrs[0], addr.Addr()) { - f.l.WithField("vpnIp", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, addr.Addr()) { + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return false } } diff --git a/handshake_manager.go b/handshake_manager.go index 85ed173..6d3ed12 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -138,7 +138,7 @@ func (hm *HandshakeManager) Run(ctx context.Context) { func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { // First remote allow list check before we know the vpnIp if addr.IsValid() { - if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.Addr()) { + if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(addr.Addr()) { hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return } diff --git a/outside.go b/outside.go index dbf75f1..1e9cde1 100644 --- a/outside.go +++ b/outside.go @@ -231,26 +231,26 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) { f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) } -func (f *Interface) handleHostRoaming(hostinfo *HostInfo, vpnAddr netip.AddrPort) { - if vpnAddr.IsValid() && hostinfo.remote != vpnAddr { - //TODO: CERT-V2 this is weird now that we can have multiple vpn addrs - if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnAddrs[0], vpnAddr.Addr()) { - hostinfo.logger(f.l).WithField("newAddr", vpnAddr).Debug("lighthouse.remote_allow_list denied roaming") +func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) { + if udpAddr.IsValid() && hostinfo.remote != udpAddr { + if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, udpAddr.Addr()) { + hostinfo.logger(f.l).WithField("newAddr", udpAddr).Debug("lighthouse.remote_allow_list denied roaming") return } - if !hostinfo.lastRoam.IsZero() && vpnAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { + + if !hostinfo.lastRoam.IsZero() && udpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", vpnAddr). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr). Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) } return } - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", vpnAddr). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr). Info("Host roamed to new udp ip/port.") hostinfo.lastRoam = time.Now() hostinfo.lastRoamRemote = hostinfo.remote - hostinfo.SetRemote(vpnAddr) + hostinfo.SetRemote(udpAddr) } }