This commit is contained in:
Wade Simmons
2026-05-06 16:13:53 -04:00
parent bb3c70da2e
commit 610fcdb9bf
5 changed files with 198 additions and 3 deletions

10
config/multiport.go Normal file
View File

@@ -0,0 +1,10 @@
package config
type MultiPortConfig struct {
Tx bool
Rx bool
TxBasePort uint16
TxPorts int
TxHandshake bool
TxHandshakeDelay int64
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/flynn/noise" "github.com/flynn/noise"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
) )
@@ -38,6 +39,9 @@ type Result struct {
HandshakeTime uint64 HandshakeTime uint64
MessageIndex uint64 // number of messages exchanged during the handshake MessageIndex uint64 // number of messages exchanged during the handshake
Initiator bool Initiator bool
MultiportRx bool
MultiportTx bool
} }
// Machine drives a Noise handshake through N messages. It handles Noise // Machine drives a Noise handshake through N messages. It handles Noise
@@ -66,6 +70,8 @@ type Machine struct {
remoteCertSet bool remoteCertSet bool
payloadSet bool payloadSet bool
failed bool failed bool
multiport config.MultiPortConfig
} }
// NewMachine creates a handshake state machine. The subtype determines both // NewMachine creates a handshake state machine. The subtype determines both
@@ -79,6 +85,7 @@ func NewMachine(
allocIndex IndexAllocator, allocIndex IndexAllocator,
initiator bool, initiator bool,
subtype header.MessageSubType, subtype header.MessageSubType,
multiport config.MultiPortConfig,
) (*Machine, error) { ) (*Machine, error) {
info, err := subtypeInfoFor(subtype) info, err := subtypeInfoFor(subtype)
if err != nil { if err != nil {
@@ -106,6 +113,8 @@ func NewMachine(
result: &Result{ result: &Result{
Initiator: initiator, Initiator: initiator,
}, },
multiport: multiport,
}, nil }, nil
} }
@@ -296,7 +305,7 @@ func (m *Machine) processPayload(msg []byte, flags msgFlags) error {
} }
// Assert the payload contains exactly what we expect // Assert the payload contains exactly what we expect
hasPayloadData := payload.InitiatorIndex != 0 || payload.ResponderIndex != 0 || payload.Time != 0 hasPayloadData := payload.InitiatorIndex != 0 || payload.ResponderIndex != 0 || payload.Time != 0 || payload.InitiatorMultiPort != nil || payload.ResponderMultiPort != nil
if hasPayloadData != flags.expectsPayload { if hasPayloadData != flags.expectsPayload {
m.failed = true m.failed = true
return ErrUnexpectedContent return ErrUnexpectedContent
@@ -312,8 +321,16 @@ func (m *Machine) processPayload(msg []byte, flags msgFlags) error {
if flags.expectsPayload { if flags.expectsPayload {
if m.result.Initiator { if m.result.Initiator {
m.result.RemoteIndex = payload.ResponderIndex m.result.RemoteIndex = payload.ResponderIndex
if payload.ResponderMultiPort != nil {
m.result.MultiportRx = payload.ResponderMultiPort.RxSupported
m.result.MultiportTx = payload.ResponderMultiPort.TxSupported
}
} else { } else {
m.result.RemoteIndex = payload.InitiatorIndex m.result.RemoteIndex = payload.InitiatorIndex
if payload.InitiatorMultiPort != nil {
m.result.MultiportRx = payload.InitiatorMultiPort.RxSupported
m.result.MultiportTx = payload.InitiatorMultiPort.TxSupported
}
} }
m.result.HandshakeTime = payload.Time m.result.HandshakeTime = payload.Time
m.payloadSet = true m.payloadSet = true
@@ -387,11 +404,28 @@ func (m *Machine) marshalOutgoing(flags msgFlags) ([]byte, error) {
if m.result.Initiator { if m.result.Initiator {
p.InitiatorIndex = m.result.LocalIndex p.InitiatorIndex = m.result.LocalIndex
if m.multiport.Rx || m.multiport.Tx {
p.InitiatorMultiPort = &PayloadMultiPortDetails{
RxSupported: m.multiport.Rx,
TxSupported: m.multiport.Tx,
BasePort: uint32(m.multiport.TxBasePort),
TotalPorts: uint32(m.multiport.TxPorts),
}
}
} else { } else {
p.ResponderIndex = m.result.LocalIndex p.ResponderIndex = m.result.LocalIndex
p.InitiatorIndex = m.result.RemoteIndex p.InitiatorIndex = m.result.RemoteIndex
if m.multiport.Rx || m.multiport.Tx {
p.ResponderMultiPort = &PayloadMultiPortDetails{
RxSupported: m.multiport.Rx,
TxSupported: m.multiport.Tx,
BasePort: uint32(m.multiport.TxBasePort),
TotalPorts: uint32(m.multiport.TxPorts),
}
}
} }
p.Time = uint64(time.Now().UnixNano()) p.Time = uint64(time.Now().UnixNano())
} }
if flags.expectsCert { if flags.expectsCert {
cred := m.getCred(m.myVersion) cred := m.getCred(m.myVersion)

View File

@@ -20,6 +20,16 @@ type Payload struct {
ResponderIndex uint32 ResponderIndex uint32
Time uint64 Time uint64
CertVersion uint32 CertVersion uint32
InitiatorMultiPort *PayloadMultiPortDetails
ResponderMultiPort *PayloadMultiPortDetails
}
type PayloadMultiPortDetails struct {
RxSupported bool
TxSupported bool
BasePort uint32
TotalPorts uint32
} }
// Proto field numbers for NebulaHandshakeDetails // Proto field numbers for NebulaHandshakeDetails
@@ -29,6 +39,17 @@ const (
fieldResponderIndex = 3 // uint32 fieldResponderIndex = 3 // uint32
fieldTime = 5 // uint64 fieldTime = 5 // uint64
fieldCertVersion = 8 // uint32 fieldCertVersion = 8 // uint32
fieldInitiatorMultiPort = 6 // MultiPortDetails
fieldResponderMultiPort = 7 // MultiPortDetails
)
// Proto field numbers for MultiPortDetails
const (
fieldMultiportRxSupported = 1 // bool
fieldMultiportTxSupported = 2 // bool
fieldMultiportBasePort = 3 // uint32
fieldMultiportTotalPorts = 4 // uint32
) )
// MarshalPayload encodes a handshake payload in protobuf wire format compatible // MarshalPayload encodes a handshake payload in protobuf wire format compatible
@@ -57,6 +78,16 @@ func MarshalPayload(out []byte, p Payload) []byte {
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType) details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
details = protowire.AppendVarint(details, uint64(p.CertVersion)) details = protowire.AppendVarint(details, uint64(p.CertVersion))
} }
if p.InitiatorMultiPort != nil {
details = protowire.AppendTag(details, fieldInitiatorMultiPort, protowire.BytesType)
details = protowire.AppendVarint(details, uint64(p.InitiatorMultiPort.size()))
details = p.InitiatorMultiPort.marshal(details)
}
if p.ResponderMultiPort != nil {
details = protowire.AppendTag(details, fieldResponderMultiPort, protowire.BytesType)
details = protowire.AppendVarint(details, uint64(p.ResponderMultiPort.size()))
details = p.ResponderMultiPort.marshal(details)
}
out = protowire.AppendTag(out, 1, protowire.BytesType) out = protowire.AppendTag(out, 1, protowire.BytesType)
out = protowire.AppendBytes(out, details) out = protowire.AppendBytes(out, details)
@@ -64,6 +95,23 @@ func MarshalPayload(out []byte, p Payload) []byte {
return out return out
} }
func (p PayloadMultiPortDetails) marshal(details []byte) []byte {
details = protowire.AppendTag(details, fieldMultiportRxSupported, protowire.VarintType)
details = protowire.AppendVarint(details, protowire.EncodeBool(p.RxSupported))
details = protowire.AppendTag(details, fieldMultiportTxSupported, protowire.VarintType)
details = protowire.AppendVarint(details, protowire.EncodeBool(p.TxSupported))
details = protowire.AppendTag(details, fieldMultiportBasePort, protowire.VarintType)
details = protowire.AppendVarint(details, uint64(p.BasePort))
details = protowire.AppendTag(details, fieldMultiportTotalPorts, protowire.VarintType)
details = protowire.AppendVarint(details, uint64(p.TotalPorts))
return details
}
func (p PayloadMultiPortDetails) size() int {
return 4 + 2 + protowire.SizeVarint(uint64(p.BasePort)) + protowire.SizeVarint(uint64(p.TotalPorts))
}
// UnmarshalPayload decodes a protobuf-encoded NebulaHandshake message. // UnmarshalPayload decodes a protobuf-encoded NebulaHandshake message.
func UnmarshalPayload(b []byte) (Payload, error) { func UnmarshalPayload(b []byte) (Payload, error) {
var p Payload var p Payload
@@ -161,6 +209,97 @@ func unmarshalPayloadDetails(p *Payload, b []byte) error {
} }
p.CertVersion = uint32(v) p.CertVersion = uint32(v)
b = b[n:] b = b[n:]
case fieldInitiatorMultiPort:
if typ != protowire.BytesType {
return errInvalidHandshakeDetails
}
d, n := protowire.ConsumeBytes(b)
if n < 0 {
return errInvalidHandshakeMessage
}
b = b[n:]
p.InitiatorMultiPort = new(PayloadMultiPortDetails)
if err := unmarshalPayloadMultiPortDetails(p.InitiatorMultiPort, d); err != nil {
return err
}
case fieldResponderMultiPort:
if typ != protowire.BytesType {
return errInvalidHandshakeDetails
}
d, n := protowire.ConsumeBytes(b)
if n < 0 {
return errInvalidHandshakeMessage
}
b = b[n:]
p.ResponderMultiPort = new(PayloadMultiPortDetails)
if err := unmarshalPayloadMultiPortDetails(p.ResponderMultiPort, d); err != nil {
return err
}
default:
n := protowire.ConsumeFieldValue(num, typ, b)
if n < 0 {
return errInvalidHandshakeDetails
}
b = b[n:]
}
}
return nil
}
func unmarshalPayloadMultiPortDetails(p *PayloadMultiPortDetails, b []byte) error {
for len(b) > 0 {
num, typ, n := protowire.ConsumeTag(b)
if n < 0 {
return errInvalidHandshakeDetails
}
b = b[n:]
// For known field numbers, reject any non-matching wire type as a
// hard error rather than silently skipping. The caller will catch
// missing-field cases downstream, but a wire-type mismatch on a tag
// we know is a peer protocol violation worth flagging here.
// Repeated occurrences of a singular field follow proto3 last-wins.
switch num {
case fieldMultiportRxSupported:
if typ != protowire.VarintType {
return errInvalidHandshakeDetails
}
v, n := protowire.ConsumeVarint(b)
if n < 0 || v > math.MaxUint32 {
return errInvalidHandshakeDetails
}
p.RxSupported = protowire.DecodeBool(v)
b = b[n:]
case fieldMultiportTxSupported:
if typ != protowire.VarintType {
return errInvalidHandshakeDetails
}
v, n := protowire.ConsumeVarint(b)
if n < 0 || v > math.MaxUint32 {
return errInvalidHandshakeDetails
}
p.TxSupported = protowire.DecodeBool(v)
b = b[n:]
case fieldMultiportBasePort:
if typ != protowire.VarintType {
return errInvalidHandshakeDetails
}
v, n := protowire.ConsumeVarint(b)
if n < 0 || v > math.MaxUint32 {
return errInvalidHandshakeDetails
}
p.BasePort = uint32(v)
b = b[n:]
case fieldMultiportTotalPorts:
if typ != protowire.VarintType {
return errInvalidHandshakeDetails
}
v, n := protowire.ConsumeVarint(b)
if n < 0 || v > math.MaxUint32 {
return errInvalidHandshakeDetails
}
p.TotalPorts = uint32(v)
b = b[n:]
default: default:
n := protowire.ConsumeFieldValue(num, typ, b) n := protowire.ConsumeFieldValue(num, typ, b)
if n < 0 { if n < 0 {

View File

@@ -14,6 +14,7 @@ import (
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/handshake" "github.com/slackhq/nebula/handshake"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
@@ -71,7 +72,7 @@ type HandshakeManager struct {
f *Interface f *Interface
l *slog.Logger l *slog.Logger
multiPort MultiPortConfig multiPort config.MultiPortConfig
udpRaw *udp.RawConn udpRaw *udp.RawConn
// can be used to trigger outbound handshake for the given vpnIp // can be used to trigger outbound handshake for the given vpnIp
@@ -697,6 +698,7 @@ func (hm *HandshakeManager) buildStage0Packet(hh *HandshakeHostInfo) bool {
v, cs.GetCredential, v, cs.GetCredential,
hm.certVerifier(), func() (uint32, error) { return hm.allocateIndex(hh) }, hm.certVerifier(), func() (uint32, error) { return hm.allocateIndex(hh) },
true, header.HandshakeIXPSK0, true, header.HandshakeIXPSK0,
hm.multiPort,
) )
if err != nil { if err != nil {
hm.f.l.Error("Failed to create handshake machine", hm.f.l.Error("Failed to create handshake machine",
@@ -738,6 +740,7 @@ func (hm *HandshakeManager) beginHandshake(via ViaSender, packet []byte, h *head
v, cs.GetCredential, v, cs.GetCredential,
hm.certVerifier(), func() (uint32, error) { return generateIndex(f.l) }, hm.certVerifier(), func() (uint32, error) { return generateIndex(f.l) },
false, header.HandshakeIXPSK0, false, header.HandshakeIXPSK0,
hm.multiPort,
) )
if err != nil { if err != nil {
f.l.Error("Failed to create handshake machine", "from", via, "error", err) f.l.Error("Failed to create handshake machine", "from", via, "error", err)
@@ -786,6 +789,8 @@ func (hm *HandshakeManager) beginHandshake(via ViaSender, packet []byte, h *head
relayForByAddr: map[netip.Addr]*Relay{}, relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{}, relayForByIdx: map[uint32]*Relay{},
}, },
multiportTx: hm.multiPort.Tx && result.MultiportRx,
multiportRx: hm.multiPort.Rx && result.MultiportTx,
} }
msg := "Handshake message received" msg := "Handshake message received"
@@ -802,6 +807,8 @@ func (hm *HandshakeManager) beginHandshake(via ViaSender, packet []byte, h *head
"initiatorIndex", result.RemoteIndex, "initiatorIndex", result.RemoteIndex,
"responderIndex", result.LocalIndex, "responderIndex", result.LocalIndex,
"handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())}, "handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())},
"multiportTx", hostinfo.multiportTx,
"multiportRx", hostinfo.multiportRx,
) )
// packet aliases the listener's incoming buffer, so this copy must stay. // packet aliases the listener's incoming buffer, so this copy must stay.
@@ -914,6 +921,9 @@ func (hm *HandshakeManager) continueHandshake(via ViaSender, hh *HandshakeHostIn
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
} }
hostinfo.multiportTx = hm.multiPort.Tx && result.MultiportRx
hostinfo.multiportRx = hm.multiPort.Rx && result.MultiportTx
// Verify correct host responded (initiator check) // Verify correct host responded (initiator check)
vpnAddrs := make([]netip.Addr, len(vpnNetworks)) vpnAddrs := make([]netip.Addr, len(vpnNetworks))
correctHostResponded := false correctHostResponded := false
@@ -987,6 +997,8 @@ func (hm *HandshakeManager) continueHandshake(via ViaSender, hh *HandshakeHostIn
"handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())}, "handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())},
"durationNs", duration, "durationNs", duration,
"sentCachedPackets", len(hh.packetStore), "sentCachedPackets", len(hh.packetStore),
"multiportTx", hostinfo.multiportTx,
"multiportRx", hostinfo.multiportRx,
) )
hostinfo.vpnAddrs = vpnAddrs hostinfo.vpnAddrs = vpnAddrs

View File

@@ -98,7 +98,7 @@ type Interface struct {
triggerShutdown func() triggerShutdown func()
udpRaw *udp.RawConn udpRaw *udp.RawConn
multiPort MultiPortConfig multiPort config.MultiPortConfig
metricHandshakes metrics.Histogram metricHandshakes metrics.Histogram
messageMetrics *MessageMetrics messageMetrics *MessageMetrics