This commit is contained in:
Wade Simmons
2026-05-06 16:13:53 -04:00
parent bb3c70da2e
commit 610fcdb9bf
5 changed files with 198 additions and 3 deletions

View File

@@ -8,6 +8,7 @@ import (
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
)
@@ -38,6 +39,9 @@ type Result struct {
HandshakeTime uint64
MessageIndex uint64 // number of messages exchanged during the handshake
Initiator bool
MultiportRx bool
MultiportTx bool
}
// Machine drives a Noise handshake through N messages. It handles Noise
@@ -66,6 +70,8 @@ type Machine struct {
remoteCertSet bool
payloadSet bool
failed bool
multiport config.MultiPortConfig
}
// NewMachine creates a handshake state machine. The subtype determines both
@@ -79,6 +85,7 @@ func NewMachine(
allocIndex IndexAllocator,
initiator bool,
subtype header.MessageSubType,
multiport config.MultiPortConfig,
) (*Machine, error) {
info, err := subtypeInfoFor(subtype)
if err != nil {
@@ -106,6 +113,8 @@ func NewMachine(
result: &Result{
Initiator: initiator,
},
multiport: multiport,
}, nil
}
@@ -296,7 +305,7 @@ func (m *Machine) processPayload(msg []byte, flags msgFlags) error {
}
// Assert the payload contains exactly what we expect
hasPayloadData := payload.InitiatorIndex != 0 || payload.ResponderIndex != 0 || payload.Time != 0
hasPayloadData := payload.InitiatorIndex != 0 || payload.ResponderIndex != 0 || payload.Time != 0 || payload.InitiatorMultiPort != nil || payload.ResponderMultiPort != nil
if hasPayloadData != flags.expectsPayload {
m.failed = true
return ErrUnexpectedContent
@@ -312,8 +321,16 @@ func (m *Machine) processPayload(msg []byte, flags msgFlags) error {
if flags.expectsPayload {
if m.result.Initiator {
m.result.RemoteIndex = payload.ResponderIndex
if payload.ResponderMultiPort != nil {
m.result.MultiportRx = payload.ResponderMultiPort.RxSupported
m.result.MultiportTx = payload.ResponderMultiPort.TxSupported
}
} else {
m.result.RemoteIndex = payload.InitiatorIndex
if payload.InitiatorMultiPort != nil {
m.result.MultiportRx = payload.InitiatorMultiPort.RxSupported
m.result.MultiportTx = payload.InitiatorMultiPort.TxSupported
}
}
m.result.HandshakeTime = payload.Time
m.payloadSet = true
@@ -387,11 +404,28 @@ func (m *Machine) marshalOutgoing(flags msgFlags) ([]byte, error) {
if m.result.Initiator {
p.InitiatorIndex = m.result.LocalIndex
if m.multiport.Rx || m.multiport.Tx {
p.InitiatorMultiPort = &PayloadMultiPortDetails{
RxSupported: m.multiport.Rx,
TxSupported: m.multiport.Tx,
BasePort: uint32(m.multiport.TxBasePort),
TotalPorts: uint32(m.multiport.TxPorts),
}
}
} else {
p.ResponderIndex = m.result.LocalIndex
p.InitiatorIndex = m.result.RemoteIndex
if m.multiport.Rx || m.multiport.Tx {
p.ResponderMultiPort = &PayloadMultiPortDetails{
RxSupported: m.multiport.Rx,
TxSupported: m.multiport.Tx,
BasePort: uint32(m.multiport.TxBasePort),
TotalPorts: uint32(m.multiport.TxPorts),
}
}
}
p.Time = uint64(time.Now().UnixNano())
}
if flags.expectsCert {
cred := m.getCred(m.myVersion)