mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-16 12:57:38 +02:00
WIP
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/header"
|
||||
)
|
||||
|
||||
@@ -38,6 +39,9 @@ type Result struct {
|
||||
HandshakeTime uint64
|
||||
MessageIndex uint64 // number of messages exchanged during the handshake
|
||||
Initiator bool
|
||||
|
||||
MultiportRx bool
|
||||
MultiportTx bool
|
||||
}
|
||||
|
||||
// Machine drives a Noise handshake through N messages. It handles Noise
|
||||
@@ -66,6 +70,8 @@ type Machine struct {
|
||||
remoteCertSet bool
|
||||
payloadSet bool
|
||||
failed bool
|
||||
|
||||
multiport config.MultiPortConfig
|
||||
}
|
||||
|
||||
// NewMachine creates a handshake state machine. The subtype determines both
|
||||
@@ -79,6 +85,7 @@ func NewMachine(
|
||||
allocIndex IndexAllocator,
|
||||
initiator bool,
|
||||
subtype header.MessageSubType,
|
||||
multiport config.MultiPortConfig,
|
||||
) (*Machine, error) {
|
||||
info, err := subtypeInfoFor(subtype)
|
||||
if err != nil {
|
||||
@@ -106,6 +113,8 @@ func NewMachine(
|
||||
result: &Result{
|
||||
Initiator: initiator,
|
||||
},
|
||||
|
||||
multiport: multiport,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -296,7 +305,7 @@ func (m *Machine) processPayload(msg []byte, flags msgFlags) error {
|
||||
}
|
||||
|
||||
// Assert the payload contains exactly what we expect
|
||||
hasPayloadData := payload.InitiatorIndex != 0 || payload.ResponderIndex != 0 || payload.Time != 0
|
||||
hasPayloadData := payload.InitiatorIndex != 0 || payload.ResponderIndex != 0 || payload.Time != 0 || payload.InitiatorMultiPort != nil || payload.ResponderMultiPort != nil
|
||||
if hasPayloadData != flags.expectsPayload {
|
||||
m.failed = true
|
||||
return ErrUnexpectedContent
|
||||
@@ -312,8 +321,16 @@ func (m *Machine) processPayload(msg []byte, flags msgFlags) error {
|
||||
if flags.expectsPayload {
|
||||
if m.result.Initiator {
|
||||
m.result.RemoteIndex = payload.ResponderIndex
|
||||
if payload.ResponderMultiPort != nil {
|
||||
m.result.MultiportRx = payload.ResponderMultiPort.RxSupported
|
||||
m.result.MultiportTx = payload.ResponderMultiPort.TxSupported
|
||||
}
|
||||
} else {
|
||||
m.result.RemoteIndex = payload.InitiatorIndex
|
||||
if payload.InitiatorMultiPort != nil {
|
||||
m.result.MultiportRx = payload.InitiatorMultiPort.RxSupported
|
||||
m.result.MultiportTx = payload.InitiatorMultiPort.TxSupported
|
||||
}
|
||||
}
|
||||
m.result.HandshakeTime = payload.Time
|
||||
m.payloadSet = true
|
||||
@@ -387,11 +404,28 @@ func (m *Machine) marshalOutgoing(flags msgFlags) ([]byte, error) {
|
||||
|
||||
if m.result.Initiator {
|
||||
p.InitiatorIndex = m.result.LocalIndex
|
||||
if m.multiport.Rx || m.multiport.Tx {
|
||||
p.InitiatorMultiPort = &PayloadMultiPortDetails{
|
||||
RxSupported: m.multiport.Rx,
|
||||
TxSupported: m.multiport.Tx,
|
||||
BasePort: uint32(m.multiport.TxBasePort),
|
||||
TotalPorts: uint32(m.multiport.TxPorts),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
p.ResponderIndex = m.result.LocalIndex
|
||||
p.InitiatorIndex = m.result.RemoteIndex
|
||||
if m.multiport.Rx || m.multiport.Tx {
|
||||
p.ResponderMultiPort = &PayloadMultiPortDetails{
|
||||
RxSupported: m.multiport.Rx,
|
||||
TxSupported: m.multiport.Tx,
|
||||
BasePort: uint32(m.multiport.TxBasePort),
|
||||
TotalPorts: uint32(m.multiport.TxPorts),
|
||||
}
|
||||
}
|
||||
}
|
||||
p.Time = uint64(time.Now().UnixNano())
|
||||
|
||||
}
|
||||
if flags.expectsCert {
|
||||
cred := m.getCred(m.myVersion)
|
||||
|
||||
@@ -20,6 +20,16 @@ type Payload struct {
|
||||
ResponderIndex uint32
|
||||
Time uint64
|
||||
CertVersion uint32
|
||||
|
||||
InitiatorMultiPort *PayloadMultiPortDetails
|
||||
ResponderMultiPort *PayloadMultiPortDetails
|
||||
}
|
||||
|
||||
type PayloadMultiPortDetails struct {
|
||||
RxSupported bool
|
||||
TxSupported bool
|
||||
BasePort uint32
|
||||
TotalPorts uint32
|
||||
}
|
||||
|
||||
// Proto field numbers for NebulaHandshakeDetails
|
||||
@@ -29,6 +39,17 @@ const (
|
||||
fieldResponderIndex = 3 // uint32
|
||||
fieldTime = 5 // uint64
|
||||
fieldCertVersion = 8 // uint32
|
||||
|
||||
fieldInitiatorMultiPort = 6 // MultiPortDetails
|
||||
fieldResponderMultiPort = 7 // MultiPortDetails
|
||||
)
|
||||
|
||||
// Proto field numbers for MultiPortDetails
|
||||
const (
|
||||
fieldMultiportRxSupported = 1 // bool
|
||||
fieldMultiportTxSupported = 2 // bool
|
||||
fieldMultiportBasePort = 3 // uint32
|
||||
fieldMultiportTotalPorts = 4 // uint32
|
||||
)
|
||||
|
||||
// MarshalPayload encodes a handshake payload in protobuf wire format compatible
|
||||
@@ -57,6 +78,16 @@ func MarshalPayload(out []byte, p Payload) []byte {
|
||||
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, uint64(p.CertVersion))
|
||||
}
|
||||
if p.InitiatorMultiPort != nil {
|
||||
details = protowire.AppendTag(details, fieldInitiatorMultiPort, protowire.BytesType)
|
||||
details = protowire.AppendVarint(details, uint64(p.InitiatorMultiPort.size()))
|
||||
details = p.InitiatorMultiPort.marshal(details)
|
||||
}
|
||||
if p.ResponderMultiPort != nil {
|
||||
details = protowire.AppendTag(details, fieldResponderMultiPort, protowire.BytesType)
|
||||
details = protowire.AppendVarint(details, uint64(p.ResponderMultiPort.size()))
|
||||
details = p.ResponderMultiPort.marshal(details)
|
||||
}
|
||||
|
||||
out = protowire.AppendTag(out, 1, protowire.BytesType)
|
||||
out = protowire.AppendBytes(out, details)
|
||||
@@ -64,6 +95,23 @@ func MarshalPayload(out []byte, p Payload) []byte {
|
||||
return out
|
||||
}
|
||||
|
||||
func (p PayloadMultiPortDetails) marshal(details []byte) []byte {
|
||||
details = protowire.AppendTag(details, fieldMultiportRxSupported, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, protowire.EncodeBool(p.RxSupported))
|
||||
details = protowire.AppendTag(details, fieldMultiportTxSupported, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, protowire.EncodeBool(p.TxSupported))
|
||||
details = protowire.AppendTag(details, fieldMultiportBasePort, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, uint64(p.BasePort))
|
||||
details = protowire.AppendTag(details, fieldMultiportTotalPorts, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, uint64(p.TotalPorts))
|
||||
|
||||
return details
|
||||
}
|
||||
|
||||
func (p PayloadMultiPortDetails) size() int {
|
||||
return 4 + 2 + protowire.SizeVarint(uint64(p.BasePort)) + protowire.SizeVarint(uint64(p.TotalPorts))
|
||||
}
|
||||
|
||||
// UnmarshalPayload decodes a protobuf-encoded NebulaHandshake message.
|
||||
func UnmarshalPayload(b []byte) (Payload, error) {
|
||||
var p Payload
|
||||
@@ -161,6 +209,97 @@ func unmarshalPayloadDetails(p *Payload, b []byte) error {
|
||||
}
|
||||
p.CertVersion = uint32(v)
|
||||
b = b[n:]
|
||||
case fieldInitiatorMultiPort:
|
||||
if typ != protowire.BytesType {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
d, n := protowire.ConsumeBytes(b)
|
||||
if n < 0 {
|
||||
return errInvalidHandshakeMessage
|
||||
}
|
||||
b = b[n:]
|
||||
p.InitiatorMultiPort = new(PayloadMultiPortDetails)
|
||||
if err := unmarshalPayloadMultiPortDetails(p.InitiatorMultiPort, d); err != nil {
|
||||
return err
|
||||
}
|
||||
case fieldResponderMultiPort:
|
||||
if typ != protowire.BytesType {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
d, n := protowire.ConsumeBytes(b)
|
||||
if n < 0 {
|
||||
return errInvalidHandshakeMessage
|
||||
}
|
||||
b = b[n:]
|
||||
p.ResponderMultiPort = new(PayloadMultiPortDetails)
|
||||
if err := unmarshalPayloadMultiPortDetails(p.ResponderMultiPort, d); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
n := protowire.ConsumeFieldValue(num, typ, b)
|
||||
if n < 0 {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
b = b[n:]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func unmarshalPayloadMultiPortDetails(p *PayloadMultiPortDetails, b []byte) error {
|
||||
for len(b) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(b)
|
||||
if n < 0 {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
b = b[n:]
|
||||
|
||||
// For known field numbers, reject any non-matching wire type as a
|
||||
// hard error rather than silently skipping. The caller will catch
|
||||
// missing-field cases downstream, but a wire-type mismatch on a tag
|
||||
// we know is a peer protocol violation worth flagging here.
|
||||
// Repeated occurrences of a singular field follow proto3 last-wins.
|
||||
switch num {
|
||||
case fieldMultiportRxSupported:
|
||||
if typ != protowire.VarintType {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
v, n := protowire.ConsumeVarint(b)
|
||||
if n < 0 || v > math.MaxUint32 {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
p.RxSupported = protowire.DecodeBool(v)
|
||||
b = b[n:]
|
||||
case fieldMultiportTxSupported:
|
||||
if typ != protowire.VarintType {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
v, n := protowire.ConsumeVarint(b)
|
||||
if n < 0 || v > math.MaxUint32 {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
p.TxSupported = protowire.DecodeBool(v)
|
||||
b = b[n:]
|
||||
case fieldMultiportBasePort:
|
||||
if typ != protowire.VarintType {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
v, n := protowire.ConsumeVarint(b)
|
||||
if n < 0 || v > math.MaxUint32 {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
p.BasePort = uint32(v)
|
||||
b = b[n:]
|
||||
case fieldMultiportTotalPorts:
|
||||
if typ != protowire.VarintType {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
v, n := protowire.ConsumeVarint(b)
|
||||
if n < 0 || v > math.MaxUint32 {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
p.TotalPorts = uint32(v)
|
||||
b = b[n:]
|
||||
default:
|
||||
n := protowire.ConsumeFieldValue(num, typ, b)
|
||||
if n < 0 {
|
||||
|
||||
Reference in New Issue
Block a user