diff --git a/handshake_ix.go b/handshake_ix.go index fd0b456..17f2c15 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -99,11 +99,11 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { return true } -func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { +func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) { cs := f.pki.getCertState() crt := cs.GetDefaultCertificate() if crt == nil { - f.l.WithField("udpAddr", addr). + f.l.WithField("from", via). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). WithField("certVersion", cs.initiatingVersion). Error("Unable to handshake with host because no certificate is available") @@ -112,7 +112,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX) if err != nil { - f.l.WithError(err).WithField("udpAddr", addr). + f.l.WithError(err).WithField("from", via). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Error("Failed to create connection state") return @@ -123,7 +123,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { - f.l.WithError(err).WithField("udpAddr", addr). + f.l.WithError(err).WithField("from", via). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Error("Failed to call noise.ReadMessage") return @@ -132,7 +132,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet hs := &NebulaHandshake{} err = hs.Unmarshal(msg) if err != nil || hs.Details == nil { - f.l.WithError(err).WithField("udpAddr", addr). + f.l.WithError(err).WithField("from", via). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Error("Failed unmarshal handshake message") return @@ -140,7 +140,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) if err != nil { - f.l.WithError(err).WithField("udpAddr", addr). + f.l.WithError(err).WithField("from", via). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Info("Handshake did not contain a certificate") return @@ -153,7 +153,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet fp = "" } - e := f.l.WithError(err).WithField("udpAddr", addr). + e := f.l.WithError(err).WithField("from", via). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). WithField("certVpnNetworks", rc.Networks()). WithField("certFingerprint", fp) @@ -172,7 +172,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if myCertOtherVersion == nil { if f.l.Level >= logrus.DebugLevel { f.l.WithError(err).WithFields(m{ - "udpAddr": addr, + "from": via, "handshake": m{"stage": 1, "style": "ix_psk0"}, "cert": remoteCert, }).Debug("Might be unable to handshake with host due to missing certificate version") @@ -184,7 +184,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet } if len(remoteCert.Certificate.Networks()) == 0 { - f.l.WithError(err).WithField("udpAddr", addr). + f.l.WithError(err).WithField("from", via). WithField("cert", remoteCert). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Info("No networks in certificate") @@ -201,7 +201,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet vpnAddrs := make([]netip.Addr, len(vpnNetworks)) for i, network := range vpnNetworks { if f.myVpnAddrsTable.Contains(network.Addr()) { - f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr). + f.l.WithField("vpnNetworks", vpnNetworks).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). @@ -215,18 +215,18 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet } } - if addr.IsValid() { - // addr can be invalid when the tunnel is being relayed. + if !via.IsRelayed { // We only want to apply the remote allow list for direct tunnels here - if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, addr.Addr()) { - f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) { + f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). + Debug("lighthouse.remote_allow_list denied incoming handshake") return } } myIndex, err := generateIndex(f.l) if err != nil { - f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). @@ -251,7 +251,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet msgRxL := f.l.WithFields(m{ "vpnAddrs": vpnAddrs, - "udpAddr": addr, + "from": via, "certName": certName, "certVersion": certVersion, "fingerprint": fingerprint, @@ -283,7 +283,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet hsBytes, err := hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). @@ -295,7 +295,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2) msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes) if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). @@ -303,7 +303,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") return } else if dKey == nil || eKey == nil { - f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). @@ -329,7 +329,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet ci.eKey = NewNebulaCipherState(eKey) hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) - hostinfo.SetRemote(addr) + if !via.IsRelayed { + hostinfo.SetRemote(via.UdpAddr) + } hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) @@ -337,7 +339,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet switch err { case ErrAlreadySeen: // Update remote if preferred - if existing.SetRemoteIfPreferred(f.hostMap, addr) { + if existing.SetRemoteIfPreferred(f.hostMap, via) { // Send a test packet to ensure the other side has also switched to // the preferred remote f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) @@ -345,21 +347,21 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet msg = existing.HandshakePacket[2] f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) - if addr.IsValid() { - err := f.outside.WriteTo(msg, addr) + if !via.IsRelayed { + err := f.outside.WriteTo(msg, via.UdpAddr) if err != nil { - f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithError(err).Error("Failed to send handshake message") } else { - f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). Info("Handshake message sent") } return } else { - if via == nil { - f.l.Error("Handshake send failed: both addr and via are nil.") + if via.relay == nil { + f.l.Error("Handshake send failed: both addr and via.relay are nil.") return } hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) @@ -371,7 +373,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet } case ErrExistingHostInfo: // This means there was an existing tunnel and this handshake was older than the one we are currently based on - f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("oldHandshakeTime", existing.lastHandshakeTime). @@ -387,7 +389,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet return case ErrLocalIndexCollision: // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry - f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). @@ -400,7 +402,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet default: // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // And we forget to update it here - f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). @@ -414,30 +416,23 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet // Do the send f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1) - if addr.IsValid() { - err = f.outside.WriteTo(msg, addr) + if !via.IsRelayed { + err = f.outside.WriteTo(msg, via.UdpAddr) + log := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). + WithField("certName", certName). + WithField("certVersion", certVersion). + WithField("fingerprint", fingerprint). + WithField("issuer", issuer). + WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). + WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) if err != nil { - f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - WithError(err).Error("Failed to send handshake") + log.WithError(err).Error("Failed to send handshake") } else { - f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("certVersion", certVersion). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). - WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). - Info("Handshake message sent") + log.Info("Handshake message sent") } } else { - if via == nil { - f.l.Error("Handshake send failed: both addr and via are nil.") + if via.relay == nil { + f.l.Error("Handshake send failed: both addr and via.relay are nil.") return } hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) @@ -462,7 +457,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet return } -func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { +func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool { if hh == nil { // Nothing here to tear down, got a bogus stage 2 packet return true @@ -472,10 +467,10 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha defer hh.Unlock() hostinfo := hh.hostinfo - if addr.IsValid() { + if !via.IsRelayed { // The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list. - if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, addr.Addr()) { - f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) { + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake") return false } } @@ -483,7 +478,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha ci := hostinfo.ConnectionState msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { - f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). Error("Failed to call noise.ReadMessage") @@ -492,7 +487,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha // near future return false } else if dKey == nil || eKey == nil { - f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Error("Noise did not arrive at a key") @@ -504,7 +499,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha hs := &NebulaHandshake{} err = hs.Unmarshal(msg) if err != nil || hs.Details == nil { - f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again @@ -513,7 +508,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) if err != nil { - f.l.WithError(err).WithField("udpAddr", addr). + f.l.WithError(err).WithField("from", via). WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Info("Handshake did not contain a certificate") @@ -527,7 +522,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha fp = "" } - e := f.l.WithError(err).WithField("udpAddr", addr). + e := f.l.WithError(err).WithField("from", via). WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithField("certFingerprint", fp). @@ -542,7 +537,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha } if len(remoteCert.Certificate.Networks()) == 0 { - f.l.WithError(err).WithField("udpAddr", addr). + f.l.WithError(err).WithField("from", via). WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("cert", remoteCert). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). @@ -565,8 +560,8 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha ci.eKey = NewNebulaCipherState(eKey) // Make sure the current udpAddr being used is set for responding - if addr.IsValid() { - hostinfo.SetRemote(addr) + if !via.IsRelayed { + hostinfo.SetRemote(via.UdpAddr) } else { hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) } @@ -588,7 +583,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha // Ensure the right host responded if !correctHostResponded { f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). - WithField("udpAddr", addr). + WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). @@ -602,7 +597,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) { // Block the current used address newHH.hostinfo.remotes = hostinfo.remotes - newHH.hostinfo.remotes.BlockRemote(addr) + newHH.hostinfo.remotes.BlockRemote(via) f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()). WithField("vpnNetworks", vpnNetworks). @@ -625,7 +620,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha ci.window.Update(f.l, 2) duration := time.Since(hh.startTime).Nanoseconds() - msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). + msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via). WithField("certName", certName). WithField("certVersion", certVersion). WithField("fingerprint", fingerprint). diff --git a/handshake_manager.go b/handshake_manager.go index cae27a2..8b1ce83 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -136,11 +136,11 @@ func (hm *HandshakeManager) Run(ctx context.Context) { } } -func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { +func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *header.H) { // First remote allow list check before we know the vpnIp - if addr.IsValid() { - if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(addr.Addr()) { - hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + if !via.IsRelayed { + if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) { + hm.l.WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake") return } } @@ -149,11 +149,11 @@ func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, case header.HandshakeIXPSK0: switch h.MessageCounter { case 1: - ixHandshakeStage1(hm.f, addr, via, packet, h) + ixHandshakeStage1(hm.f, via, packet, h) case 2: newHostinfo := hm.queryIndex(h.RemoteIndex) - tearDown := ixHandshakeStage2(hm.f, addr, via, newHostinfo, packet, h) + tearDown := ixHandshakeStage2(hm.f, via, newHostinfo, packet, h) if tearDown && newHostinfo != nil { hm.DeleteHostInfo(newHostinfo.hostinfo) } diff --git a/hostmap.go b/hostmap.go index 9f8cd5e..c39e69d 100644 --- a/hostmap.go +++ b/hostmap.go @@ -1,7 +1,9 @@ package nebula import ( + "encoding/json" "errors" + "fmt" "net" "net/netip" "slices" @@ -276,9 +278,25 @@ type HostInfo struct { } type ViaSender struct { + UdpAddr netip.AddrPort relayHI *HostInfo // relayHI is the host info object of the relay remoteIdx uint32 // remoteIdx is the index included in the header of the received packet relay *Relay // relay contains the rest of the relay information, including the PeerIP of the host trying to communicate with us. + IsRelayed bool // IsRelayed is true if the packet was sent through a relay +} + +func (v ViaSender) String() string { + if v.IsRelayed { + return fmt.Sprintf("%s (relayed)", v.UdpAddr) + } + return v.UdpAddr.String() +} + +func (v ViaSender) MarshalJSON() ([]byte, error) { + if v.IsRelayed { + return json.Marshal(m{"direct": v.UdpAddr}) + } + return json.Marshal(m{"relay": v.UdpAddr}) } type cachedPacket struct { @@ -694,6 +712,7 @@ func (i *HostInfo) GetCert() *cert.CachedCertificate { return nil } +// TODO: Maybe use ViaSender here? func (i *HostInfo) SetRemote(remote netip.AddrPort) { // We copy here because we likely got this remote from a source that reuses the object if i.remote != remote { @@ -704,14 +723,14 @@ func (i *HostInfo) SetRemote(remote netip.AddrPort) { // SetRemoteIfPreferred returns true if the remote was changed. The lastRoam // time on the HostInfo will also be updated. -func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) bool { - if !newRemote.IsValid() { - // relays have nil udp Addrs +func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, via ViaSender) bool { + if via.IsRelayed { return false } + currentRemote := i.remote if !currentRemote.IsValid() { - i.SetRemote(newRemote) + i.SetRemote(via.UdpAddr) return true } @@ -724,7 +743,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b return false } - if l.Contains(newRemote.Addr()) { + if l.Contains(via.UdpAddr.Addr()) { newIsPreferred = true } } @@ -734,7 +753,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b i.lastRoam = time.Now() i.lastRoamRemote = currentRemote - i.SetRemote(newRemote) + i.SetRemote(via.UdpAddr) return true } diff --git a/interface.go b/interface.go index 082906d..844fefd 100644 --- a/interface.go +++ b/interface.go @@ -272,7 +272,7 @@ func (f *Interface) listenOut(i int) { nb := make([]byte, 12, 12) li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { - f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) + f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) }) } diff --git a/outside.go b/outside.go index 5ff87bd..ac7fa6d 100644 --- a/outside.go +++ b/outside.go @@ -19,21 +19,21 @@ const ( minFwPacketLen = 4 ) -func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { err := h.Parse(packet) if err != nil { // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors if len(packet) > 1 { - f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err) + f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err) } return } //l.Error("in packet ", header, packet[HeaderLen:]) - if ip.IsValid() { - if f.myVpnNetworksTable.Contains(ip.Addr()) { + if !via.IsRelayed { + if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) { if f.l.Level >= logrus.DebugLevel { - f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet") + f.l.WithField("from", via).Debug("Refusing to process double encrypted packet") } return } @@ -55,7 +55,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] switch h.Type { case header.Message: // TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case. - if !f.handleEncrypted(ci, ip, h) { + //TODO: RELAY-WORK ip could be relayed here + if !f.handleEncrypted(ci, via, h) { return } @@ -79,7 +80,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] // Successfully validated the thing. Get rid of the Relay header. signedPayload = signedPayload[header.Len:] // Pull the Roaming parts up here, and return in all call paths. - f.handleHostRoaming(hostinfo, ip) + f.handleHostRoaming(hostinfo, via) // Track usage of both the HostInfo and the Relay for the received & authenticated packet f.connectionManager.In(hostinfo) f.connectionManager.RelayUsed(h.RemoteIndex) @@ -96,7 +97,14 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] case TerminalType: // If I am the target of this relay, process the unwrapped packet // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. - f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) + via = ViaSender{ + UdpAddr: via.UdpAddr, + relayHI: hostinfo, + remoteIdx: relay.RemoteIndex, + relay: relay, + IsRelayed: true, + } + f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) return case ForwardingType: // Find the target HostInfo relay object @@ -126,31 +134,34 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] case header.LightHouse: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, ip, h) { + if !f.handleEncrypted(ci, via, h) { return } d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). + hostinfo.logger(f.l).WithError(err).WithField("from", via). WithField("packet", packet). Error("Failed to decrypt lighthouse packet") return } - lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f) + //NOTE: via should never be a relayed from here + lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f) // Fallthrough to the bottom to record incoming traffic case header.Test: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, ip, h) { + //TODO: RELAY-WORK ip could be relayed here + if !f.handleEncrypted(ci, via, h) { return } d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). + //TODO: RELAY-WORK ip could be relayed here + hostinfo.logger(f.l).WithError(err).WithField("from", via). WithField("packet", packet). Error("Failed to decrypt test packet") return @@ -159,7 +170,9 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] if h.Subtype == header.TestRequest { // This testRequest might be from TryPromoteBest, so we should roam // to the new IP address before responding - f.handleHostRoaming(hostinfo, ip) + //TODO: RELAY-WORK ip could be relayed here + f.handleHostRoaming(hostinfo, via) + //TODO: RELAY-WORK ip could be relayed here f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out) } @@ -170,34 +183,41 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] case header.Handshake: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handshakeManager.HandleIncoming(ip, via, packet, h) + //TODO: RELAY-WORK ip could be relayed here + f.handshakeManager.HandleIncoming(via, packet, h) return case header.RecvError: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handleRecvError(ip, h) + //TODO: RELAY-WORK we should probably support recv_error better in the relays, pass via directly to handleRecvError + f.handleRecvError(via.UdpAddr, h) return case header.CloseTunnel: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, ip, h) { + //TODO: RELAY-WORK ip could be relayed here + if !f.handleEncrypted(ci, via, h) { return } - hostinfo.logger(f.l).WithField("udpAddr", ip). + //TODO: RELAY-WORK ip could be relayed here + hostinfo.logger(f.l).WithField("from", via). Info("Close tunnel received, tearing down.") f.closeTunnel(hostinfo) return case header.Control: - if !f.handleEncrypted(ci, ip, h) { + //TODO: RELAY-WORK ip could be relayed here + if !f.handleEncrypted(ci, via, h) { return } d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) if err != nil { - hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip). + + //TODO: RELAY-WORK ip could be relayed here + hostinfo.logger(f.l).WithError(err).WithField("from", via). WithField("packet", packet). Error("Failed to decrypt Control packet") return @@ -207,11 +227,12 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] default: f.messageMetrics.Rx(h.Type, h.Subtype, 1) - hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip) + //TODO: RELAY-WORK ip could be relayed here + hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via) return } - f.handleHostRoaming(hostinfo, ip) + f.handleHostRoaming(hostinfo, via) f.connectionManager.In(hostinfo) } @@ -230,36 +251,36 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) { f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) } -func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) { - if udpAddr.IsValid() && hostinfo.remote != udpAddr { - if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, udpAddr.Addr()) { - hostinfo.logger(f.l).WithField("newAddr", udpAddr).Debug("lighthouse.remote_allow_list denied roaming") +func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) { + if !via.IsRelayed && hostinfo.remote != via.UdpAddr { + if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) { + hostinfo.logger(f.l).WithField("newAddr", via.UdpAddr).Debug("lighthouse.remote_allow_list denied roaming") return } - if !hostinfo.lastRoam.IsZero() && udpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { + if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr). Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) } return } - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr). Info("Host roamed to new udp ip/port.") hostinfo.lastRoam = time.Now() hostinfo.lastRoamRemote = hostinfo.remote - hostinfo.SetRemote(udpAddr) + hostinfo.SetRemote(via.UdpAddr) } } // handleEncrypted returns true if a packet should be processed, false otherwise -func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool { +func (f *Interface) handleEncrypted(ci *ConnectionState, via ViaSender, h *header.H) bool { // If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect if ci == nil { - if addr.IsValid() { - f.maybeSendRecvError(addr, h.RemoteIndex) + if !via.IsRelayed { + f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex) } return false } diff --git a/remote_list.go b/remote_list.go index a17003b..1304fd5 100644 --- a/remote_list.go +++ b/remote_list.go @@ -338,21 +338,21 @@ func (r *RemoteList) CopyCache() *CacheMap { } // BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list -func (r *RemoteList) BlockRemote(bad netip.AddrPort) { - if !bad.IsValid() { - // relays can have nil udp Addrs +func (r *RemoteList) BlockRemote(bad ViaSender) { + if bad.IsRelayed { return } + r.Lock() defer r.Unlock() // Check if we already blocked this addr - if r.unlockedIsBad(bad) { + if r.unlockedIsBad(bad.UdpAddr) { return } // We copy here because we are taking something else's memory and we can't trust everything - r.badRemotes = append(r.badRemotes, bad) + r.badRemotes = append(r.badRemotes, bad.UdpAddr) // Mark the next interaction must recollect/dedupe r.shouldRebuild = true