fix multiport handshakes from non-baseport

This commit is contained in:
Wade Simmons
2026-05-07 10:39:36 -04:00
parent c72a37c16f
commit 8e607a91f4
2 changed files with 28 additions and 5 deletions

View File

@@ -3,6 +3,7 @@ package handshake
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"math"
"slices" "slices"
"time" "time"
@@ -40,8 +41,9 @@ type Result struct {
MessageIndex uint64 // number of messages exchanged during the handshake MessageIndex uint64 // number of messages exchanged during the handshake
Initiator bool Initiator bool
MultiportRx bool MultiportRx bool
MultiportTx bool MultiportTx bool
MultiportBasePort uint16
} }
// Machine drives a Noise handshake through N messages. It handles Noise // Machine drives a Noise handshake through N messages. It handles Noise
@@ -324,12 +326,18 @@ func (m *Machine) processPayload(msg []byte, flags msgFlags) error {
if payload.ResponderMultiPort != nil { if payload.ResponderMultiPort != nil {
m.result.MultiportRx = payload.ResponderMultiPort.RxSupported m.result.MultiportRx = payload.ResponderMultiPort.RxSupported
m.result.MultiportTx = payload.ResponderMultiPort.TxSupported m.result.MultiportTx = payload.ResponderMultiPort.TxSupported
if payload.ResponderMultiPort.BasePort <= math.MaxUint16 {
m.result.MultiportBasePort = uint16(payload.ResponderMultiPort.BasePort)
}
} }
} else { } else {
m.result.RemoteIndex = payload.InitiatorIndex m.result.RemoteIndex = payload.InitiatorIndex
if payload.InitiatorMultiPort != nil { if payload.InitiatorMultiPort != nil {
m.result.MultiportRx = payload.InitiatorMultiPort.RxSupported m.result.MultiportRx = payload.InitiatorMultiPort.RxSupported
m.result.MultiportTx = payload.InitiatorMultiPort.TxSupported m.result.MultiportTx = payload.InitiatorMultiPort.TxSupported
if payload.InitiatorMultiPort.BasePort <= math.MaxUint16 {
m.result.MultiportBasePort = uint16(payload.InitiatorMultiPort.BasePort)
}
} }
} }
m.result.HandshakeTime = payload.Time m.result.HandshakeTime = payload.Time

View File

@@ -765,6 +765,12 @@ func (hm *HandshakeManager) beginHandshake(via ViaSender, packet []byte, h *head
return return
} }
if !via.IsRelayed && result.MultiportTx && result.MultiportBasePort != via.UdpAddr.Port() {
// The other side sent us a handshake from a different port, make sure
// we send responses back to the BasePort
via.UdpAddr = netip.AddrPortFrom(via.UdpAddr.Addr(), result.MultiportBasePort)
}
remoteCert := result.RemoteCert remoteCert := result.RemoteCert
if remoteCert == nil { if remoteCert == nil {
f.l.Error("Handshake did not produce a peer certificate", "from", via) f.l.Error("Handshake did not produce a peer certificate", "from", via)
@@ -900,6 +906,14 @@ func (hm *HandshakeManager) continueHandshake(via ViaSender, hh *HandshakeHostIn
return return
} }
if !via.IsRelayed && result.MultiportTx && result.MultiportBasePort != via.UdpAddr.Port() {
// The other side sent us a handshake from a different port, make sure
// we send responses back to the BasePort
via.UdpAddr = netip.AddrPortFrom(via.UdpAddr.Addr(), result.MultiportBasePort)
}
hostinfo.multiportTx = hm.multiPort.Tx && result.MultiportRx
hostinfo.multiportRx = hm.multiPort.Rx && result.MultiportTx
// Handshake complete; build the ConnectionState now that we have keys and a verified peer cert. // Handshake complete; build the ConnectionState now that we have keys and a verified peer cert.
hostinfo.ConnectionState = newConnectionStateFromResult(result) hostinfo.ConnectionState = newConnectionStateFromResult(result)
@@ -921,9 +935,6 @@ func (hm *HandshakeManager) continueHandshake(via ViaSender, hh *HandshakeHostIn
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
} }
hostinfo.multiportTx = hm.multiPort.Tx && result.MultiportRx
hostinfo.multiportRx = hm.multiPort.Rx && result.MultiportTx
// Verify correct host responded (initiator check) // Verify correct host responded (initiator check)
vpnAddrs := make([]netip.Addr, len(vpnNetworks)) vpnAddrs := make([]netip.Addr, len(vpnNetworks))
correctHostResponded := false correctHostResponded := false
@@ -1141,6 +1152,10 @@ func (hm *HandshakeManager) handleCheckAndCompleteError(err error, existing, hos
switch err { switch err {
case ErrAlreadySeen: case ErrAlreadySeen:
if hostinfo.multiportRx {
// The other host is sending to us with multiport, so only grab the IP
via.UdpAddr = netip.AddrPortFrom(via.UdpAddr.Addr(), hostinfo.remote.Port())
}
if existing.SetRemoteIfPreferred(f.hostMap, via) { if existing.SetRemoteIfPreferred(f.hostMap, via) {
f.SendMessageToVpnAddr(header.Test, header.TestRequest, hostinfo.vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) f.SendMessageToVpnAddr(header.Test, header.TestRequest, hostinfo.vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
} }