diff --git a/config/multiport.go b/config/multiport.go new file mode 100644 index 00000000..0f435d38 --- /dev/null +++ b/config/multiport.go @@ -0,0 +1,10 @@ +package config + +type MultiPortConfig struct { + Tx bool + Rx bool + TxBasePort uint16 + TxPorts int + TxHandshake bool + TxHandshakeDelay int64 +} diff --git a/handshake/machine.go b/handshake/machine.go index 25ed3a5a..31f6a08b 100644 --- a/handshake/machine.go +++ b/handshake/machine.go @@ -8,6 +8,7 @@ import ( "github.com/flynn/noise" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" ) @@ -38,6 +39,9 @@ type Result struct { HandshakeTime uint64 MessageIndex uint64 // number of messages exchanged during the handshake Initiator bool + + MultiportRx bool + MultiportTx bool } // Machine drives a Noise handshake through N messages. It handles Noise @@ -66,6 +70,8 @@ type Machine struct { remoteCertSet bool payloadSet bool failed bool + + multiport config.MultiPortConfig } // NewMachine creates a handshake state machine. The subtype determines both @@ -79,6 +85,7 @@ func NewMachine( allocIndex IndexAllocator, initiator bool, subtype header.MessageSubType, + multiport config.MultiPortConfig, ) (*Machine, error) { info, err := subtypeInfoFor(subtype) if err != nil { @@ -106,6 +113,8 @@ func NewMachine( result: &Result{ Initiator: initiator, }, + + multiport: multiport, }, nil } @@ -296,7 +305,7 @@ func (m *Machine) processPayload(msg []byte, flags msgFlags) error { } // Assert the payload contains exactly what we expect - hasPayloadData := payload.InitiatorIndex != 0 || payload.ResponderIndex != 0 || payload.Time != 0 + hasPayloadData := payload.InitiatorIndex != 0 || payload.ResponderIndex != 0 || payload.Time != 0 || payload.InitiatorMultiPort != nil || payload.ResponderMultiPort != nil if hasPayloadData != flags.expectsPayload { m.failed = true return ErrUnexpectedContent @@ -312,8 +321,16 @@ func (m *Machine) processPayload(msg []byte, flags msgFlags) error { if flags.expectsPayload { if m.result.Initiator { m.result.RemoteIndex = payload.ResponderIndex + if payload.ResponderMultiPort != nil { + m.result.MultiportRx = payload.ResponderMultiPort.RxSupported + m.result.MultiportTx = payload.ResponderMultiPort.TxSupported + } } else { m.result.RemoteIndex = payload.InitiatorIndex + if payload.InitiatorMultiPort != nil { + m.result.MultiportRx = payload.InitiatorMultiPort.RxSupported + m.result.MultiportTx = payload.InitiatorMultiPort.TxSupported + } } m.result.HandshakeTime = payload.Time m.payloadSet = true @@ -387,11 +404,28 @@ func (m *Machine) marshalOutgoing(flags msgFlags) ([]byte, error) { if m.result.Initiator { p.InitiatorIndex = m.result.LocalIndex + if m.multiport.Rx || m.multiport.Tx { + p.InitiatorMultiPort = &PayloadMultiPortDetails{ + RxSupported: m.multiport.Rx, + TxSupported: m.multiport.Tx, + BasePort: uint32(m.multiport.TxBasePort), + TotalPorts: uint32(m.multiport.TxPorts), + } + } } else { p.ResponderIndex = m.result.LocalIndex p.InitiatorIndex = m.result.RemoteIndex + if m.multiport.Rx || m.multiport.Tx { + p.ResponderMultiPort = &PayloadMultiPortDetails{ + RxSupported: m.multiport.Rx, + TxSupported: m.multiport.Tx, + BasePort: uint32(m.multiport.TxBasePort), + TotalPorts: uint32(m.multiport.TxPorts), + } + } } p.Time = uint64(time.Now().UnixNano()) + } if flags.expectsCert { cred := m.getCred(m.myVersion) diff --git a/handshake/payload.go b/handshake/payload.go index 4567fc0d..313b066f 100644 --- a/handshake/payload.go +++ b/handshake/payload.go @@ -20,6 +20,16 @@ type Payload struct { ResponderIndex uint32 Time uint64 CertVersion uint32 + + InitiatorMultiPort *PayloadMultiPortDetails + ResponderMultiPort *PayloadMultiPortDetails +} + +type PayloadMultiPortDetails struct { + RxSupported bool + TxSupported bool + BasePort uint32 + TotalPorts uint32 } // Proto field numbers for NebulaHandshakeDetails @@ -29,6 +39,17 @@ const ( fieldResponderIndex = 3 // uint32 fieldTime = 5 // uint64 fieldCertVersion = 8 // uint32 + + fieldInitiatorMultiPort = 6 // MultiPortDetails + fieldResponderMultiPort = 7 // MultiPortDetails +) + +// Proto field numbers for MultiPortDetails +const ( + fieldMultiportRxSupported = 1 // bool + fieldMultiportTxSupported = 2 // bool + fieldMultiportBasePort = 3 // uint32 + fieldMultiportTotalPorts = 4 // uint32 ) // MarshalPayload encodes a handshake payload in protobuf wire format compatible @@ -57,6 +78,16 @@ func MarshalPayload(out []byte, p Payload) []byte { details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType) details = protowire.AppendVarint(details, uint64(p.CertVersion)) } + if p.InitiatorMultiPort != nil { + details = protowire.AppendTag(details, fieldInitiatorMultiPort, protowire.BytesType) + details = protowire.AppendVarint(details, uint64(p.InitiatorMultiPort.size())) + details = p.InitiatorMultiPort.marshal(details) + } + if p.ResponderMultiPort != nil { + details = protowire.AppendTag(details, fieldResponderMultiPort, protowire.BytesType) + details = protowire.AppendVarint(details, uint64(p.ResponderMultiPort.size())) + details = p.ResponderMultiPort.marshal(details) + } out = protowire.AppendTag(out, 1, protowire.BytesType) out = protowire.AppendBytes(out, details) @@ -64,6 +95,23 @@ func MarshalPayload(out []byte, p Payload) []byte { return out } +func (p PayloadMultiPortDetails) marshal(details []byte) []byte { + details = protowire.AppendTag(details, fieldMultiportRxSupported, protowire.VarintType) + details = protowire.AppendVarint(details, protowire.EncodeBool(p.RxSupported)) + details = protowire.AppendTag(details, fieldMultiportTxSupported, protowire.VarintType) + details = protowire.AppendVarint(details, protowire.EncodeBool(p.TxSupported)) + details = protowire.AppendTag(details, fieldMultiportBasePort, protowire.VarintType) + details = protowire.AppendVarint(details, uint64(p.BasePort)) + details = protowire.AppendTag(details, fieldMultiportTotalPorts, protowire.VarintType) + details = protowire.AppendVarint(details, uint64(p.TotalPorts)) + + return details +} + +func (p PayloadMultiPortDetails) size() int { + return 4 + 2 + protowire.SizeVarint(uint64(p.BasePort)) + protowire.SizeVarint(uint64(p.TotalPorts)) +} + // UnmarshalPayload decodes a protobuf-encoded NebulaHandshake message. func UnmarshalPayload(b []byte) (Payload, error) { var p Payload @@ -161,6 +209,97 @@ func unmarshalPayloadDetails(p *Payload, b []byte) error { } p.CertVersion = uint32(v) b = b[n:] + case fieldInitiatorMultiPort: + if typ != protowire.BytesType { + return errInvalidHandshakeDetails + } + d, n := protowire.ConsumeBytes(b) + if n < 0 { + return errInvalidHandshakeMessage + } + b = b[n:] + p.InitiatorMultiPort = new(PayloadMultiPortDetails) + if err := unmarshalPayloadMultiPortDetails(p.InitiatorMultiPort, d); err != nil { + return err + } + case fieldResponderMultiPort: + if typ != protowire.BytesType { + return errInvalidHandshakeDetails + } + d, n := protowire.ConsumeBytes(b) + if n < 0 { + return errInvalidHandshakeMessage + } + b = b[n:] + p.ResponderMultiPort = new(PayloadMultiPortDetails) + if err := unmarshalPayloadMultiPortDetails(p.ResponderMultiPort, d); err != nil { + return err + } + default: + n := protowire.ConsumeFieldValue(num, typ, b) + if n < 0 { + return errInvalidHandshakeDetails + } + b = b[n:] + } + } + return nil +} + +func unmarshalPayloadMultiPortDetails(p *PayloadMultiPortDetails, b []byte) error { + for len(b) > 0 { + num, typ, n := protowire.ConsumeTag(b) + if n < 0 { + return errInvalidHandshakeDetails + } + b = b[n:] + + // For known field numbers, reject any non-matching wire type as a + // hard error rather than silently skipping. The caller will catch + // missing-field cases downstream, but a wire-type mismatch on a tag + // we know is a peer protocol violation worth flagging here. + // Repeated occurrences of a singular field follow proto3 last-wins. + switch num { + case fieldMultiportRxSupported: + if typ != protowire.VarintType { + return errInvalidHandshakeDetails + } + v, n := protowire.ConsumeVarint(b) + if n < 0 || v > math.MaxUint32 { + return errInvalidHandshakeDetails + } + p.RxSupported = protowire.DecodeBool(v) + b = b[n:] + case fieldMultiportTxSupported: + if typ != protowire.VarintType { + return errInvalidHandshakeDetails + } + v, n := protowire.ConsumeVarint(b) + if n < 0 || v > math.MaxUint32 { + return errInvalidHandshakeDetails + } + p.TxSupported = protowire.DecodeBool(v) + b = b[n:] + case fieldMultiportBasePort: + if typ != protowire.VarintType { + return errInvalidHandshakeDetails + } + v, n := protowire.ConsumeVarint(b) + if n < 0 || v > math.MaxUint32 { + return errInvalidHandshakeDetails + } + p.BasePort = uint32(v) + b = b[n:] + case fieldMultiportTotalPorts: + if typ != protowire.VarintType { + return errInvalidHandshakeDetails + } + v, n := protowire.ConsumeVarint(b) + if n < 0 || v > math.MaxUint32 { + return errInvalidHandshakeDetails + } + p.TotalPorts = uint32(v) + b = b[n:] default: n := protowire.ConsumeFieldValue(num, typ, b) if n < 0 { diff --git a/handshake_manager.go b/handshake_manager.go index 7a33cd9a..76920a72 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -14,6 +14,7 @@ import ( "github.com/rcrowley/go-metrics" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/handshake" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" @@ -71,7 +72,7 @@ type HandshakeManager struct { f *Interface l *slog.Logger - multiPort MultiPortConfig + multiPort config.MultiPortConfig udpRaw *udp.RawConn // can be used to trigger outbound handshake for the given vpnIp @@ -697,6 +698,7 @@ func (hm *HandshakeManager) buildStage0Packet(hh *HandshakeHostInfo) bool { v, cs.GetCredential, hm.certVerifier(), func() (uint32, error) { return hm.allocateIndex(hh) }, true, header.HandshakeIXPSK0, + hm.multiPort, ) if err != nil { hm.f.l.Error("Failed to create handshake machine", @@ -738,6 +740,7 @@ func (hm *HandshakeManager) beginHandshake(via ViaSender, packet []byte, h *head v, cs.GetCredential, hm.certVerifier(), func() (uint32, error) { return generateIndex(f.l) }, false, header.HandshakeIXPSK0, + hm.multiPort, ) if err != nil { f.l.Error("Failed to create handshake machine", "from", via, "error", err) @@ -786,6 +789,8 @@ func (hm *HandshakeManager) beginHandshake(via ViaSender, packet []byte, h *head relayForByAddr: map[netip.Addr]*Relay{}, relayForByIdx: map[uint32]*Relay{}, }, + multiportTx: hm.multiPort.Tx && result.MultiportRx, + multiportRx: hm.multiPort.Rx && result.MultiportTx, } msg := "Handshake message received" @@ -802,6 +807,8 @@ func (hm *HandshakeManager) beginHandshake(via ViaSender, packet []byte, h *head "initiatorIndex", result.RemoteIndex, "responderIndex", result.LocalIndex, "handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())}, + "multiportTx", hostinfo.multiportTx, + "multiportRx", hostinfo.multiportRx, ) // packet aliases the listener's incoming buffer, so this copy must stay. @@ -914,6 +921,9 @@ func (hm *HandshakeManager) continueHandshake(via ViaSender, hh *HandshakeHostIn hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) } + hostinfo.multiportTx = hm.multiPort.Tx && result.MultiportRx + hostinfo.multiportRx = hm.multiPort.Rx && result.MultiportTx + // Verify correct host responded (initiator check) vpnAddrs := make([]netip.Addr, len(vpnNetworks)) correctHostResponded := false @@ -987,6 +997,8 @@ func (hm *HandshakeManager) continueHandshake(via ViaSender, hh *HandshakeHostIn "handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())}, "durationNs", duration, "sentCachedPackets", len(hh.packetStore), + "multiportTx", hostinfo.multiportTx, + "multiportRx", hostinfo.multiportRx, ) hostinfo.vpnAddrs = vpnAddrs diff --git a/interface.go b/interface.go index 799ea034..cbfbea00 100644 --- a/interface.go +++ b/interface.go @@ -98,7 +98,7 @@ type Interface struct { triggerShutdown func() udpRaw *udp.RawConn - multiPort MultiPortConfig + multiPort config.MultiPortConfig metricHandshakes metrics.Histogram messageMetrics *MessageMetrics