diff --git a/handshake_ix.go b/handshake_ix.go index d53e5a7..026bfbd 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -459,7 +459,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet f.connectionManager.AddTrafficWatch(hostinfo) - hostinfo.remotes.ResetBlockedRemotes() + hostinfo.remotes.RefreshFromHandshake(vpnAddrs) return } @@ -667,7 +667,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore))) } - hostinfo.remotes.ResetBlockedRemotes() + hostinfo.remotes.RefreshFromHandshake(vpnAddrs) f.metricHandshakes.Update(duration) return false diff --git a/lighthouse.go b/lighthouse.go index 809b04e..57f9f1e 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -487,7 +487,7 @@ func (lh *LightHouse) QueryCache(vpnAddrs []netip.Addr) *RemoteList { lh.Lock() defer lh.Unlock() // Add an entry if we don't already have one - return lh.unlockedGetRemoteList(vpnAddrs) + return lh.unlockedGetRemoteList(vpnAddrs) //todo CERT-V2 this contains addrmap lookups we could potentially skip } // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing @@ -570,7 +570,7 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t am.unlockedSetHostnamesResults(hr) for _, addrPort := range hr.GetAddrs() { - if !lh.shouldAdd(vpnAddr, addrPort.Addr()) { + if !lh.shouldAdd([]netip.Addr{vpnAddr}, addrPort.Addr()) { continue } switch { @@ -645,18 +645,17 @@ func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList { } } - //TODO lighthouse.remote_allow_ranges is almost certainly broken in a multiple-address-per-cert scenario - am := NewRemoteList(allAddrs, func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) }) + am := NewRemoteList(allAddrs, lh.shouldAdd) for _, addr := range allAddrs { lh.addrMap[addr] = am } return am } -func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool { - allow := lh.GetRemoteAllowList().Allow(vpnAddr, to) +func (lh *LightHouse) shouldAdd(vpnAddrs []netip.Addr, to netip.Addr) bool { + allow := lh.GetRemoteAllowList().AllowAll(vpnAddrs, to) if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", to).WithField("allow", allow). + lh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", to).WithField("allow", allow). Trace("remoteAllowList.Allow") } if !allow { diff --git a/remote_list.go b/remote_list.go index 6baed29..a17003b 100644 --- a/remote_list.go +++ b/remote_list.go @@ -190,7 +190,7 @@ type RemoteList struct { // The full list of vpn addresses assigned to this host vpnAddrs []netip.Addr - // A deduplicated set of addresses. Any accessor should lock beforehand. + // A deduplicated set of underlay addresses. Any accessor should lock beforehand. addrs []netip.AddrPort // A set of relay addresses. VpnIp addresses that the remote identified as relays. @@ -201,8 +201,10 @@ type RemoteList struct { // For learned addresses, this is the vpnIp that sent the packet cache map[netip.Addr]*cache - hr *hostnamesResults - shouldAdd func(netip.Addr) bool + hr *hostnamesResults + + // shouldAdd is a nillable function that decides if x should be added to addrs. + shouldAdd func(vpnAddrs []netip.Addr, x netip.Addr) bool // This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip. // They should not be tried again during a handshake @@ -213,7 +215,7 @@ type RemoteList struct { } // NewRemoteList creates a new empty RemoteList -func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList { +func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func([]netip.Addr, netip.Addr) bool) *RemoteList { r := &RemoteList{ vpnAddrs: make([]netip.Addr, len(vpnAddrs)), addrs: make([]netip.AddrPort, 0), @@ -368,6 +370,15 @@ func (r *RemoteList) CopyBlockedRemotes() []netip.AddrPort { return c } +// RefreshFromHandshake locks and updates the RemoteList to account for data learned upon a completed handshake +func (r *RemoteList) RefreshFromHandshake(vpnAddrs []netip.Addr) { + r.Lock() + r.badRemotes = nil + r.vpnAddrs = make([]netip.Addr, len(vpnAddrs)) + copy(r.vpnAddrs, vpnAddrs) + r.Unlock() +} + // ResetBlockedRemotes locks and clears the blocked remotes list func (r *RemoteList) ResetBlockedRemotes() { r.Lock() @@ -577,7 +588,7 @@ func (r *RemoteList) unlockedCollect() { dnsAddrs := r.hr.GetAddrs() for _, addr := range dnsAddrs { - if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) { + if r.shouldAdd == nil || r.shouldAdd(r.vpnAddrs, addr.Addr()) { if !r.unlockedIsBad(addr) { addrs = append(addrs, addr) }