package handshake import ( "errors" "math" "google.golang.org/protobuf/encoding/protowire" ) var ( errInvalidHandshakeMessage = errors.New("invalid handshake message") errInvalidHandshakeDetails = errors.New("invalid handshake details") ) // Payload represents the decoded fields of a handshake message. // Wire format is protobuf-compatible with NebulaHandshake{Details: NebulaHandshakeDetails{...}}. type Payload struct { Cert []byte InitiatorIndex uint32 ResponderIndex uint32 Time uint64 CertVersion uint32 InitiatorMultiPort *PayloadMultiPortDetails ResponderMultiPort *PayloadMultiPortDetails } type PayloadMultiPortDetails struct { RxSupported bool TxSupported bool BasePort uint32 TotalPorts uint32 } // Proto field numbers for NebulaHandshakeDetails const ( fieldCert = 1 // bytes fieldInitiatorIndex = 2 // uint32 fieldResponderIndex = 3 // uint32 fieldTime = 5 // uint64 fieldCertVersion = 8 // uint32 fieldInitiatorMultiPort = 6 // MultiPortDetails fieldResponderMultiPort = 7 // MultiPortDetails ) // Proto field numbers for MultiPortDetails const ( fieldMultiportRxSupported = 1 // bool fieldMultiportTxSupported = 2 // bool fieldMultiportBasePort = 3 // uint32 fieldMultiportTotalPorts = 4 // uint32 ) // MarshalPayload encodes a handshake payload in protobuf wire format compatible // with NebulaHandshake{Details: NebulaHandshakeDetails{...}}. // Returns out (which may be nil), with the marshalled Payload appended to it. func MarshalPayload(out []byte, p Payload) []byte { var details []byte if len(p.Cert) > 0 { details = protowire.AppendTag(details, fieldCert, protowire.BytesType) details = protowire.AppendBytes(details, p.Cert) } if p.InitiatorIndex != 0 { details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType) details = protowire.AppendVarint(details, uint64(p.InitiatorIndex)) } if p.ResponderIndex != 0 { details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType) details = protowire.AppendVarint(details, uint64(p.ResponderIndex)) } if p.Time != 0 { details = protowire.AppendTag(details, fieldTime, protowire.VarintType) details = protowire.AppendVarint(details, p.Time) } if p.CertVersion != 0 { details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType) details = protowire.AppendVarint(details, uint64(p.CertVersion)) } if p.InitiatorMultiPort != nil { details = protowire.AppendTag(details, fieldInitiatorMultiPort, protowire.BytesType) details = protowire.AppendVarint(details, uint64(p.InitiatorMultiPort.size())) details = p.InitiatorMultiPort.marshal(details) } if p.ResponderMultiPort != nil { details = protowire.AppendTag(details, fieldResponderMultiPort, protowire.BytesType) details = protowire.AppendVarint(details, uint64(p.ResponderMultiPort.size())) details = p.ResponderMultiPort.marshal(details) } out = protowire.AppendTag(out, 1, protowire.BytesType) out = protowire.AppendBytes(out, details) return out } func (p PayloadMultiPortDetails) marshal(details []byte) []byte { details = protowire.AppendTag(details, fieldMultiportRxSupported, protowire.VarintType) details = protowire.AppendVarint(details, protowire.EncodeBool(p.RxSupported)) details = protowire.AppendTag(details, fieldMultiportTxSupported, protowire.VarintType) details = protowire.AppendVarint(details, protowire.EncodeBool(p.TxSupported)) details = protowire.AppendTag(details, fieldMultiportBasePort, protowire.VarintType) details = protowire.AppendVarint(details, uint64(p.BasePort)) details = protowire.AppendTag(details, fieldMultiportTotalPorts, protowire.VarintType) details = protowire.AppendVarint(details, uint64(p.TotalPorts)) return details } func (p PayloadMultiPortDetails) size() int { return 4 + 2 + protowire.SizeVarint(uint64(p.BasePort)) + protowire.SizeVarint(uint64(p.TotalPorts)) } // UnmarshalPayload decodes a protobuf-encoded NebulaHandshake message. func UnmarshalPayload(b []byte) (Payload, error) { var p Payload for len(b) > 0 { num, typ, n := protowire.ConsumeTag(b) if n < 0 { return p, errInvalidHandshakeMessage } b = b[n:] switch { case num == 1 && typ == protowire.BytesType: details, n := protowire.ConsumeBytes(b) if n < 0 { return p, errInvalidHandshakeMessage } b = b[n:] if err := unmarshalPayloadDetails(&p, details); err != nil { return p, err } default: n := protowire.ConsumeFieldValue(num, typ, b) if n < 0 { return p, errInvalidHandshakeMessage } b = b[n:] } } return p, nil } func unmarshalPayloadDetails(p *Payload, b []byte) error { for len(b) > 0 { num, typ, n := protowire.ConsumeTag(b) if n < 0 { return errInvalidHandshakeDetails } b = b[n:] // For known field numbers, reject any non-matching wire type as a // hard error rather than silently skipping. The caller will catch // missing-field cases downstream, but a wire-type mismatch on a tag // we know is a peer protocol violation worth flagging here. // Repeated occurrences of a singular field follow proto3 last-wins. switch num { case fieldCert: if typ != protowire.BytesType { return errInvalidHandshakeDetails } v, n := protowire.ConsumeBytes(b) if n < 0 { return errInvalidHandshakeDetails } p.Cert = append([]byte(nil), v...) b = b[n:] case fieldInitiatorIndex: if typ != protowire.VarintType { return errInvalidHandshakeDetails } v, n := protowire.ConsumeVarint(b) if n < 0 || v > math.MaxUint32 { return errInvalidHandshakeDetails } p.InitiatorIndex = uint32(v) b = b[n:] case fieldResponderIndex: if typ != protowire.VarintType { return errInvalidHandshakeDetails } v, n := protowire.ConsumeVarint(b) if n < 0 || v > math.MaxUint32 { return errInvalidHandshakeDetails } p.ResponderIndex = uint32(v) b = b[n:] case fieldTime: if typ != protowire.VarintType { return errInvalidHandshakeDetails } v, n := protowire.ConsumeVarint(b) if n < 0 { return errInvalidHandshakeDetails } p.Time = v b = b[n:] case fieldCertVersion: if typ != protowire.VarintType { return errInvalidHandshakeDetails } v, n := protowire.ConsumeVarint(b) if n < 0 || v > math.MaxUint32 { return errInvalidHandshakeDetails } p.CertVersion = uint32(v) b = b[n:] case fieldInitiatorMultiPort: if typ != protowire.BytesType { return errInvalidHandshakeDetails } d, n := protowire.ConsumeBytes(b) if n < 0 { return errInvalidHandshakeMessage } b = b[n:] p.InitiatorMultiPort = new(PayloadMultiPortDetails) if err := unmarshalPayloadMultiPortDetails(p.InitiatorMultiPort, d); err != nil { return err } case fieldResponderMultiPort: if typ != protowire.BytesType { return errInvalidHandshakeDetails } d, n := protowire.ConsumeBytes(b) if n < 0 { return errInvalidHandshakeMessage } b = b[n:] p.ResponderMultiPort = new(PayloadMultiPortDetails) if err := unmarshalPayloadMultiPortDetails(p.ResponderMultiPort, d); err != nil { return err } default: n := protowire.ConsumeFieldValue(num, typ, b) if n < 0 { return errInvalidHandshakeDetails } b = b[n:] } } return nil } func unmarshalPayloadMultiPortDetails(p *PayloadMultiPortDetails, b []byte) error { for len(b) > 0 { num, typ, n := protowire.ConsumeTag(b) if n < 0 { return errInvalidHandshakeDetails } b = b[n:] // For known field numbers, reject any non-matching wire type as a // hard error rather than silently skipping. The caller will catch // missing-field cases downstream, but a wire-type mismatch on a tag // we know is a peer protocol violation worth flagging here. // Repeated occurrences of a singular field follow proto3 last-wins. switch num { case fieldMultiportRxSupported: if typ != protowire.VarintType { return errInvalidHandshakeDetails } v, n := protowire.ConsumeVarint(b) if n < 0 || v > math.MaxUint32 { return errInvalidHandshakeDetails } p.RxSupported = protowire.DecodeBool(v) b = b[n:] case fieldMultiportTxSupported: if typ != protowire.VarintType { return errInvalidHandshakeDetails } v, n := protowire.ConsumeVarint(b) if n < 0 || v > math.MaxUint32 { return errInvalidHandshakeDetails } p.TxSupported = protowire.DecodeBool(v) b = b[n:] case fieldMultiportBasePort: if typ != protowire.VarintType { return errInvalidHandshakeDetails } v, n := protowire.ConsumeVarint(b) if n < 0 || v > math.MaxUint32 { return errInvalidHandshakeDetails } p.BasePort = uint32(v) b = b[n:] case fieldMultiportTotalPorts: if typ != protowire.VarintType { return errInvalidHandshakeDetails } v, n := protowire.ConsumeVarint(b) if n < 0 || v > math.MaxUint32 { return errInvalidHandshakeDetails } p.TotalPorts = uint32(v) b = b[n:] default: n := protowire.ConsumeFieldValue(num, typ, b) if n < 0 { return errInvalidHandshakeDetails } b = b[n:] } } return nil }