diff --git a/handshake/machine.go b/handshake/machine.go index 31f6a08b..564e5ad3 100644 --- a/handshake/machine.go +++ b/handshake/machine.go @@ -3,6 +3,7 @@ package handshake import ( "bytes" "fmt" + "math" "slices" "time" @@ -40,8 +41,9 @@ type Result struct { MessageIndex uint64 // number of messages exchanged during the handshake Initiator bool - MultiportRx bool - MultiportTx bool + MultiportRx bool + MultiportTx bool + MultiportBasePort uint16 } // Machine drives a Noise handshake through N messages. It handles Noise @@ -324,12 +326,18 @@ func (m *Machine) processPayload(msg []byte, flags msgFlags) error { if payload.ResponderMultiPort != nil { m.result.MultiportRx = payload.ResponderMultiPort.RxSupported m.result.MultiportTx = payload.ResponderMultiPort.TxSupported + if payload.ResponderMultiPort.BasePort <= math.MaxUint16 { + m.result.MultiportBasePort = uint16(payload.ResponderMultiPort.BasePort) + } } } else { m.result.RemoteIndex = payload.InitiatorIndex if payload.InitiatorMultiPort != nil { m.result.MultiportRx = payload.InitiatorMultiPort.RxSupported m.result.MultiportTx = payload.InitiatorMultiPort.TxSupported + if payload.InitiatorMultiPort.BasePort <= math.MaxUint16 { + m.result.MultiportBasePort = uint16(payload.InitiatorMultiPort.BasePort) + } } } m.result.HandshakeTime = payload.Time diff --git a/handshake_manager.go b/handshake_manager.go index 76920a72..f68f815a 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -765,6 +765,12 @@ func (hm *HandshakeManager) beginHandshake(via ViaSender, packet []byte, h *head return } + if !via.IsRelayed && result.MultiportTx && result.MultiportBasePort != via.UdpAddr.Port() { + // The other side sent us a handshake from a different port, make sure + // we send responses back to the BasePort + via.UdpAddr = netip.AddrPortFrom(via.UdpAddr.Addr(), result.MultiportBasePort) + } + remoteCert := result.RemoteCert if remoteCert == nil { f.l.Error("Handshake did not produce a peer certificate", "from", via) @@ -900,6 +906,14 @@ func (hm *HandshakeManager) continueHandshake(via ViaSender, hh *HandshakeHostIn return } + if !via.IsRelayed && result.MultiportTx && result.MultiportBasePort != via.UdpAddr.Port() { + // The other side sent us a handshake from a different port, make sure + // we send responses back to the BasePort + via.UdpAddr = netip.AddrPortFrom(via.UdpAddr.Addr(), result.MultiportBasePort) + } + hostinfo.multiportTx = hm.multiPort.Tx && result.MultiportRx + hostinfo.multiportRx = hm.multiPort.Rx && result.MultiportTx + // Handshake complete; build the ConnectionState now that we have keys and a verified peer cert. hostinfo.ConnectionState = newConnectionStateFromResult(result) @@ -921,9 +935,6 @@ func (hm *HandshakeManager) continueHandshake(via ViaSender, hh *HandshakeHostIn hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) } - hostinfo.multiportTx = hm.multiPort.Tx && result.MultiportRx - hostinfo.multiportRx = hm.multiPort.Rx && result.MultiportTx - // Verify correct host responded (initiator check) vpnAddrs := make([]netip.Addr, len(vpnNetworks)) correctHostResponded := false @@ -1141,6 +1152,10 @@ func (hm *HandshakeManager) handleCheckAndCompleteError(err error, existing, hos switch err { case ErrAlreadySeen: + if hostinfo.multiportRx { + // The other host is sending to us with multiport, so only grab the IP + via.UdpAddr = netip.AddrPortFrom(via.UdpAddr.Addr(), hostinfo.remote.Port()) + } if existing.SetRemoteIfPreferred(f.hostMap, via) { f.SendMessageToVpnAddr(header.Test, header.TestRequest, hostinfo.vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) }