mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-15 20:37:36 +02:00
Handshake state machine (#1656)
This commit is contained in:
57
handshake/credential.go
Normal file
57
handshake/credential.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
)
|
||||
|
||||
// Credential holds everything needed to participate in a handshake
|
||||
// at a given cert version. Version and Curve are read from Cert; the public
|
||||
// half of the static keypair likewise comes from Cert.PublicKey().
|
||||
type Credential struct {
|
||||
Cert cert.Certificate // the certificate
|
||||
Bytes []byte // pre-marshaled certificate bytes
|
||||
privateKey []byte // static private key (public half lives in Cert)
|
||||
cipherSuite noise.CipherSuite // pre-built cipher suite (DH + cipher + hash)
|
||||
}
|
||||
|
||||
// NewCredential creates a Credential with all material needed for handshake
|
||||
// participation. The cipherSuite should be pre-built by the caller with the
|
||||
// appropriate DH function, cipher, and hash.
|
||||
func NewCredential(
|
||||
c cert.Certificate,
|
||||
hsBytes []byte,
|
||||
privateKey []byte,
|
||||
cipherSuite noise.CipherSuite,
|
||||
) *Credential {
|
||||
return &Credential{
|
||||
Cert: c,
|
||||
Bytes: hsBytes,
|
||||
privateKey: privateKey,
|
||||
cipherSuite: cipherSuite,
|
||||
}
|
||||
}
|
||||
|
||||
// buildHandshakeState creates a noise.HandshakeState from this credential.
|
||||
func (hc *Credential) buildHandshakeState(initiator bool, pattern noise.HandshakePattern) (*noise.HandshakeState, error) {
|
||||
return noise.NewHandshakeState(noise.Config{
|
||||
CipherSuite: hc.cipherSuite,
|
||||
Random: rand.Reader,
|
||||
Pattern: pattern,
|
||||
Initiator: initiator,
|
||||
StaticKeypair: noise.DHKey{Private: hc.privateKey, Public: hc.Cert.PublicKey()},
|
||||
PresharedKey: []byte{},
|
||||
PresharedKeyPlacement: 0,
|
||||
})
|
||||
}
|
||||
|
||||
// GetCredentialFunc returns the handshake credential for the given version,
|
||||
// or nil if that version is not available.
|
||||
//
|
||||
// Implementations must return credentials drawn from a snapshot stable for
|
||||
// the lifetime of any single Machine. The Machine may call this multiple
|
||||
// times during a handshake (e.g. when negotiating to the peer's version)
|
||||
// and assumes the underlying static keypair is consistent across calls.
|
||||
type GetCredentialFunc func(v cert.Version) *Credential
|
||||
21
handshake/errors.go
Normal file
21
handshake/errors.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package handshake
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrInitiateOnResponder = errors.New("initiate called on responder")
|
||||
ErrInitiateAlreadyCalled = errors.New("initiate already called")
|
||||
ErrInitiateNotCalled = errors.New("initiate must be called before ProcessPacket for initiators")
|
||||
ErrPacketTooShort = errors.New("packet too short")
|
||||
ErrPublicKeyMismatch = errors.New("public key mismatch between certificate and handshake")
|
||||
ErrIncompleteHandshake = errors.New("handshake completed without receiving required content")
|
||||
ErrMachineFailed = errors.New("handshake machine has failed")
|
||||
ErrUnknownSubtype = errors.New("unknown handshake subtype")
|
||||
ErrMissingContent = errors.New("expected handshake content but message was empty")
|
||||
ErrUnexpectedContent = errors.New("received unexpected handshake content")
|
||||
ErrIndexAllocation = errors.New("failed to allocate local index")
|
||||
ErrNoCredential = errors.New("no handshake credential available for cert version")
|
||||
ErrAsymmetricCipherKeys = errors.New("noise produced only one cipher key")
|
||||
ErrMultiMessageUnsupported = errors.New("multi-message handshake patterns are not yet supported by the manager")
|
||||
ErrSubtypeMismatch = errors.New("packet subtype does not match handshake machine subtype")
|
||||
)
|
||||
29
handshake/handshake.proto
Normal file
29
handshake/handshake.proto
Normal file
@@ -0,0 +1,29 @@
|
||||
// This file documents the wire format the nebula handshake speaks. It is
|
||||
// not run through protoc; the encoder/decoder in payload.go is hand-written
|
||||
// against this shape directly to keep the parser narrow and panic-free.
|
||||
//
|
||||
// Any change to the wire format must be reflected here, and adding a new
|
||||
// field requires updating MarshalPayload / unmarshalPayloadDetails together
|
||||
// with the field-uniqueness and wire-type checks in those functions.
|
||||
|
||||
syntax = "proto3";
|
||||
package nebula.handshake;
|
||||
|
||||
message NebulaHandshake {
|
||||
NebulaHandshakeDetails Details = 1;
|
||||
bytes Hmac = 2;
|
||||
}
|
||||
|
||||
message NebulaHandshakeDetails {
|
||||
bytes Cert = 1;
|
||||
uint32 InitiatorIndex = 2;
|
||||
uint32 ResponderIndex = 3;
|
||||
// Cookie was reserved for an anti-DoS mechanism that was never
|
||||
// implemented. No released version of nebula has ever populated it; the
|
||||
// hand-written parser silently skips it on read.
|
||||
uint64 Cookie = 4 [deprecated = true];
|
||||
uint64 Time = 5;
|
||||
uint32 CertVersion = 8;
|
||||
// reserved for WIP multiport
|
||||
reserved 6, 7;
|
||||
}
|
||||
116
handshake/helpers_test.go
Normal file
116
handshake/helpers_test.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
ct "github.com/slackhq/nebula/cert_test"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testCertState holds cert material for a test peer.
|
||||
type testCertState struct {
|
||||
version cert.Version
|
||||
creds map[cert.Version]*Credential
|
||||
}
|
||||
|
||||
func (s *testCertState) getCredential(v cert.Version) *Credential {
|
||||
return s.creds[v]
|
||||
}
|
||||
|
||||
func newTestCertState(
|
||||
t *testing.T, ca cert.Certificate, caKey []byte, name string, networks []netip.Prefix,
|
||||
) *testCertState {
|
||||
return newTestCertStateWithCipher(t, ca, caKey, name, networks, noise.CipherChaChaPoly)
|
||||
}
|
||||
|
||||
func newTestCertStateWithCipher(
|
||||
t *testing.T, ca cert.Certificate, caKey []byte, name string, networks []netip.Prefix,
|
||||
cipher noise.CipherFunc,
|
||||
) *testCertState {
|
||||
t.Helper()
|
||||
c, _, rawPrivKey, _ := ct.NewTestCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, ca, caKey,
|
||||
name, ca.NotBefore(), ca.NotAfter(), networks, nil, nil,
|
||||
)
|
||||
|
||||
priv, _, _, err := cert.UnmarshalPrivateKeyFromPEM(rawPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
hsBytes, err := c.MarshalForHandshakes()
|
||||
require.NoError(t, err)
|
||||
|
||||
ncs := noise.NewCipherSuite(noise.DH25519, cipher, noise.HashSHA256)
|
||||
return &testCertState{
|
||||
version: cert.Version2,
|
||||
creds: map[cert.Version]*Credential{
|
||||
cert.Version2: NewCredential(c, hsBytes, priv, ncs),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testVerifier(pool *cert.CAPool) CertVerifier {
|
||||
return func(c cert.Certificate) (*cert.CachedCertificate, error) {
|
||||
return pool.VerifyCertificate(time.Now(), c)
|
||||
}
|
||||
}
|
||||
|
||||
func newTestMachine(
|
||||
t *testing.T,
|
||||
cs *testCertState,
|
||||
verifier CertVerifier,
|
||||
initiator bool,
|
||||
localIndex uint32,
|
||||
) *Machine {
|
||||
t.Helper()
|
||||
m, err := NewMachine(
|
||||
cs.version, cs.getCredential,
|
||||
verifier, func() (uint32, error) { return localIndex, nil },
|
||||
initiator, header.HandshakeIXPSK0,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
return m
|
||||
}
|
||||
|
||||
func initiateHandshake(
|
||||
t *testing.T,
|
||||
initCS *testCertState, initVerifier CertVerifier,
|
||||
respCS *testCertState, respVerifier CertVerifier,
|
||||
) (initM, respM *Machine, respResult *Result, resp []byte, err error) {
|
||||
t.Helper()
|
||||
initM = newTestMachine(t, initCS, initVerifier, true, 100)
|
||||
msg1, merr := initM.Initiate(nil)
|
||||
require.NoError(t, merr)
|
||||
|
||||
respM = newTestMachine(t, respCS, respVerifier, false, 200)
|
||||
resp, respResult, err = respM.ProcessPacket(nil, msg1)
|
||||
return
|
||||
}
|
||||
|
||||
func doFullHandshake(
|
||||
t *testing.T, initCS, respCS *testCertState, caPool *cert.CAPool,
|
||||
) (initResult, respResult *Result) {
|
||||
t.Helper()
|
||||
v := testVerifier(caPool)
|
||||
|
||||
initM := newTestMachine(t, initCS, v, true, 1000)
|
||||
respM := newTestMachine(t, respCS, v, false, 2000)
|
||||
|
||||
msg1, err := initM.Initiate(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, respResult, err := respM.ProcessPacket(nil, msg1)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, respResult)
|
||||
require.NotEmpty(t, resp)
|
||||
|
||||
_, initResult, err = initM.ProcessPacket(nil, resp)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, initResult)
|
||||
|
||||
return initResult, respResult
|
||||
}
|
||||
444
handshake/machine.go
Normal file
444
handshake/machine.go
Normal file
@@ -0,0 +1,444 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/header"
|
||||
)
|
||||
|
||||
// IndexAllocator is called by the Machine to allocate a local index for the
|
||||
// handshake. It is called at most once, when the first outgoing message that
|
||||
// carries a payload is built.
|
||||
//
|
||||
// Implementations MUST NOT return 0. Zero is reserved as a sentinel meaning
|
||||
// "no index assigned" on the wire and in the payload-presence checks. If an
|
||||
// allocator ever returned 0, a legitimate handshake's payload could be
|
||||
// indistinguishable from an empty one and would be rejected.
|
||||
type IndexAllocator func() (uint32, error)
|
||||
|
||||
// CertVerifier is called by the Machine after reconstructing the peer's
|
||||
// certificate from the handshake. The verifier performs all validation
|
||||
// (CA trust, expiry, policy checks, allow lists).
|
||||
type CertVerifier func(cert.Certificate) (*cert.CachedCertificate, error)
|
||||
|
||||
// Result contains the results of a successful handshake.
|
||||
// Returned by ProcessPacket when the handshake is complete.
|
||||
type Result struct {
|
||||
EKey *noise.CipherState
|
||||
DKey *noise.CipherState
|
||||
MyCert cert.Certificate
|
||||
RemoteCert *cert.CachedCertificate
|
||||
RemoteIndex uint32
|
||||
LocalIndex uint32
|
||||
HandshakeTime uint64
|
||||
MessageIndex uint64 // number of messages exchanged during the handshake
|
||||
Initiator bool
|
||||
}
|
||||
|
||||
// Machine drives a Noise handshake through N messages. It handles Noise
|
||||
// protocol operations, certificate reconstruction, and payload encoding.
|
||||
// Certificate validation is delegated to the caller via CertVerifier.
|
||||
//
|
||||
// A Machine is not safe for concurrent use. The caller must ensure that
|
||||
// Initiate and ProcessPacket are not called concurrently.
|
||||
//
|
||||
// Error contract: when ProcessPacket or Initiate returns an error, callers
|
||||
// must check Failed() to decide what to do next. If Failed() is false the
|
||||
// underlying noise state was not advanced (the packet was rejected before
|
||||
// ReadMessage took effect, or the rejection is non-fatal like a stale
|
||||
// retransmit) and the Machine can accept another packet. If Failed() is
|
||||
// true the Machine is unrecoverable and the caller must abandon it.
|
||||
type Machine struct {
|
||||
hs *noise.HandshakeState
|
||||
getCred GetCredentialFunc
|
||||
allocIndex IndexAllocator
|
||||
verifier CertVerifier
|
||||
result *Result
|
||||
msgs []msgFlags
|
||||
myVersion cert.Version
|
||||
subtype header.MessageSubType
|
||||
indexAllocated bool
|
||||
remoteCertSet bool
|
||||
payloadSet bool
|
||||
failed bool
|
||||
}
|
||||
|
||||
// NewMachine creates a handshake state machine. The subtype determines both
|
||||
// the noise pattern and the per-message content layout. The credential for
|
||||
// `version` is fetched via getCred and used to seed the noise.HandshakeState.
|
||||
// IndexAllocator is called lazily when the first outgoing payload is built.
|
||||
func NewMachine(
|
||||
version cert.Version,
|
||||
getCred GetCredentialFunc,
|
||||
verifier CertVerifier,
|
||||
allocIndex IndexAllocator,
|
||||
initiator bool,
|
||||
subtype header.MessageSubType,
|
||||
) (*Machine, error) {
|
||||
info, err := subtypeInfoFor(subtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cred := getCred(version)
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("%w: %v", ErrNoCredential, version)
|
||||
}
|
||||
|
||||
hs, err := cred.buildHandshakeState(initiator, info.pattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build noise state: %w", err)
|
||||
}
|
||||
|
||||
return &Machine{
|
||||
hs: hs,
|
||||
subtype: subtype,
|
||||
msgs: info.msgs,
|
||||
getCred: getCred,
|
||||
allocIndex: allocIndex,
|
||||
verifier: verifier,
|
||||
myVersion: version,
|
||||
result: &Result{
|
||||
Initiator: initiator,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Failed returns true if the Machine is in an unrecoverable state.
|
||||
func (m *Machine) Failed() bool {
|
||||
return m.failed
|
||||
}
|
||||
|
||||
// Subtype returns the handshake subtype this Machine was built for.
|
||||
func (m *Machine) Subtype() header.MessageSubType {
|
||||
return m.subtype
|
||||
}
|
||||
|
||||
// MessageIndex returns the noise handshake message index, which equals the
|
||||
// wire counter of the most recently sent or received message.
|
||||
func (m *Machine) MessageIndex() int {
|
||||
return m.hs.MessageIndex()
|
||||
}
|
||||
|
||||
// requireComplete checks that both a peer cert and payload have been received.
|
||||
// Marks the machine as failed if not.
|
||||
func (m *Machine) requireComplete() error {
|
||||
if !m.payloadSet || !m.remoteCertSet {
|
||||
m.failed = true
|
||||
return ErrIncompleteHandshake
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// myMsgFlags returns the flags for the current outgoing message.
|
||||
func (m *Machine) myMsgFlags() msgFlags {
|
||||
idx := m.hs.MessageIndex()
|
||||
if idx < len(m.msgs) {
|
||||
return m.msgs[idx]
|
||||
}
|
||||
return msgFlags{}
|
||||
}
|
||||
|
||||
// peerMsgFlags returns the flags for the message we just read.
|
||||
func (m *Machine) peerMsgFlags() msgFlags {
|
||||
idx := m.hs.MessageIndex() - 1
|
||||
if idx >= 0 && idx < len(m.msgs) {
|
||||
return m.msgs[idx]
|
||||
}
|
||||
return msgFlags{}
|
||||
}
|
||||
|
||||
// Initiate produces the first handshake message. Only valid for initiators,
|
||||
// and must be called exactly once before ProcessPacket.
|
||||
//
|
||||
// out is a destination buffer the message is appended to and returned. Pass
|
||||
// nil to allocate fresh, or pass a re-used buffer sliced to length 0 (e.g.
|
||||
// buf[:0]) with sufficient capacity to avoid allocation.
|
||||
//
|
||||
// An error return may not indicate a fatal condition, check Failed() to
|
||||
// determine if the Machine can still be used.
|
||||
func (m *Machine) Initiate(out []byte) ([]byte, error) {
|
||||
if m.failed {
|
||||
return nil, ErrMachineFailed
|
||||
}
|
||||
if !m.result.Initiator {
|
||||
m.failed = true
|
||||
return nil, ErrInitiateOnResponder
|
||||
}
|
||||
if m.hs.MessageIndex() != 0 {
|
||||
m.failed = true
|
||||
return nil, ErrInitiateAlreadyCalled
|
||||
}
|
||||
|
||||
// At MessageIndex=0 with RemoteIndex still zero, buildResponse produces
|
||||
// header counter 1 and remote index 0, which is what the initial message needs.
|
||||
out, _, _, err := m.buildResponse(out)
|
||||
if err != nil {
|
||||
m.failed = true
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// ProcessPacket handles an incoming handshake message. It advances the Noise
|
||||
// state, validates the peer certificate via the verifier, and optionally
|
||||
// produces a response.
|
||||
//
|
||||
// out is a destination buffer the response is appended to and returned. Pass
|
||||
// nil to allocate fresh, or pass a re-used buffer sliced to length 0 (e.g.
|
||||
// buf[:0]) with sufficient capacity to avoid allocation. The returned slice
|
||||
// is nil when no outgoing message is produced (handshake complete on this
|
||||
// side, or final message of a multi-message pattern).
|
||||
//
|
||||
// Returns a non-nil Result when the handshake is complete.
|
||||
// An error return may not indicate a fatal condition, check Failed() to
|
||||
// determine if the Machine can still be used.
|
||||
func (m *Machine) ProcessPacket(out, packet []byte) ([]byte, *Result, error) {
|
||||
if m.failed {
|
||||
return nil, nil, ErrMachineFailed
|
||||
}
|
||||
if len(packet) < header.Len {
|
||||
return nil, nil, ErrPacketTooShort
|
||||
}
|
||||
// Reject packets whose subtype doesn't match the one this Machine was
|
||||
// built for. A pending handshake that suddenly receives a different
|
||||
// subtype on its index is either a stray packet that matched by chance
|
||||
// or a peer protocol violation; drop it without failing the Machine so
|
||||
// the legitimate retransmit can still complete.
|
||||
if header.MessageSubType(packet[1]) != m.subtype {
|
||||
return nil, nil, ErrSubtypeMismatch
|
||||
}
|
||||
if m.result.Initiator && m.hs.MessageIndex() == 0 {
|
||||
m.failed = true
|
||||
return nil, nil, ErrInitiateNotCalled
|
||||
}
|
||||
|
||||
// The (eKey, dKey) ordering here is correct for IX, where the initiator
|
||||
// completes the handshake by reading the responder's stage-2 message.
|
||||
// noise returns (cs1, cs2) where cs1 is the initiator->responder cipher.
|
||||
// For 3-message patterns where a responder finishes by reading the final
|
||||
// message, this ordering would be wrong; revisit when XX/pqIX lands.
|
||||
msg, eKey, dKey, err := m.hs.ReadMessage(nil, packet[header.Len:])
|
||||
if err != nil {
|
||||
// Noise ReadMessage failed. The noise library checkpoints and rolls back
|
||||
// on failure, so the Machine is still alive. The caller can retry with
|
||||
// a different packet.
|
||||
return nil, nil, fmt.Errorf("noise ReadMessage: %w", err)
|
||||
}
|
||||
|
||||
// From here on, noise state has advanced. Any error is fatal.
|
||||
flags := m.peerMsgFlags()
|
||||
|
||||
if err := m.processPayload(msg, flags); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// If ReadMessage derived keys, the handshake is complete. Noise should
|
||||
// always produce both keys together; asymmetry is a protocol invariant
|
||||
// violation.
|
||||
if eKey != nil || dKey != nil {
|
||||
if eKey == nil || dKey == nil {
|
||||
m.failed = true
|
||||
return nil, nil, ErrAsymmetricCipherKeys
|
||||
}
|
||||
if err := m.requireComplete(); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return nil, m.completed(eKey, dKey), nil
|
||||
}
|
||||
|
||||
// ReadMessage didn't complete, produce the next outgoing message
|
||||
out, dk, ek, err := m.buildResponse(out)
|
||||
if err != nil {
|
||||
m.failed = true
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if ek != nil || dk != nil {
|
||||
if ek == nil || dk == nil {
|
||||
m.failed = true
|
||||
return nil, nil, ErrAsymmetricCipherKeys
|
||||
}
|
||||
if err := m.requireComplete(); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return out, m.completed(ek, dk), nil
|
||||
}
|
||||
|
||||
return out, nil, nil
|
||||
}
|
||||
|
||||
func (m *Machine) completed(eKey, dKey *noise.CipherState) *Result {
|
||||
m.result.EKey = eKey
|
||||
m.result.DKey = dKey
|
||||
m.result.MessageIndex = uint64(m.hs.MessageIndex())
|
||||
return m.result
|
||||
}
|
||||
|
||||
func (m *Machine) processPayload(msg []byte, flags msgFlags) error {
|
||||
if len(msg) == 0 {
|
||||
if flags.expectsPayload || flags.expectsCert {
|
||||
m.failed = true
|
||||
return ErrMissingContent
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
payload, err := UnmarshalPayload(msg)
|
||||
if err != nil {
|
||||
m.failed = true
|
||||
return fmt.Errorf("unmarshal handshake: %w", err)
|
||||
}
|
||||
|
||||
// Assert the payload contains exactly what we expect
|
||||
hasPayloadData := payload.InitiatorIndex != 0 || payload.ResponderIndex != 0 || payload.Time != 0
|
||||
if hasPayloadData != flags.expectsPayload {
|
||||
m.failed = true
|
||||
return ErrUnexpectedContent
|
||||
}
|
||||
|
||||
hasCertData := len(payload.Cert) > 0
|
||||
if hasCertData != flags.expectsCert {
|
||||
m.failed = true
|
||||
return ErrUnexpectedContent
|
||||
}
|
||||
|
||||
// Process payload
|
||||
if flags.expectsPayload {
|
||||
if m.result.Initiator {
|
||||
m.result.RemoteIndex = payload.ResponderIndex
|
||||
} else {
|
||||
m.result.RemoteIndex = payload.InitiatorIndex
|
||||
}
|
||||
m.result.HandshakeTime = payload.Time
|
||||
m.payloadSet = true
|
||||
}
|
||||
|
||||
// Process certificate
|
||||
if flags.expectsCert {
|
||||
if err := m.validateCert(payload); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Machine) validateCert(payload Payload) error {
|
||||
cred := m.getCred(m.myVersion)
|
||||
if cred == nil {
|
||||
m.failed = true
|
||||
return fmt.Errorf("%w: %v", ErrNoCredential, m.myVersion)
|
||||
}
|
||||
rc, err := cert.Recombine(
|
||||
cert.Version(payload.CertVersion),
|
||||
payload.Cert,
|
||||
m.hs.PeerStatic(),
|
||||
cred.Cert.Curve(),
|
||||
)
|
||||
if err != nil {
|
||||
m.failed = true
|
||||
return fmt.Errorf("recombine cert: %w", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(rc.PublicKey(), m.hs.PeerStatic()) {
|
||||
m.failed = true
|
||||
return ErrPublicKeyMismatch
|
||||
}
|
||||
|
||||
// Version negotiation, if the peer sent a different version and we have it, switch
|
||||
if rc.Version() != m.myVersion {
|
||||
if m.getCred(rc.Version()) != nil {
|
||||
m.myVersion = rc.Version()
|
||||
}
|
||||
}
|
||||
|
||||
verified, err := m.verifier(rc)
|
||||
if err != nil {
|
||||
m.failed = true
|
||||
return fmt.Errorf("verify cert: %w", err)
|
||||
}
|
||||
|
||||
m.result.RemoteCert = verified
|
||||
m.remoteCertSet = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Machine) marshalOutgoing(flags msgFlags) ([]byte, error) {
|
||||
if !flags.expectsPayload && !flags.expectsCert {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var p Payload
|
||||
if flags.expectsPayload {
|
||||
if !m.indexAllocated {
|
||||
index, err := m.allocIndex()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %w", ErrIndexAllocation, err)
|
||||
}
|
||||
m.result.LocalIndex = index
|
||||
m.indexAllocated = true
|
||||
}
|
||||
|
||||
if m.result.Initiator {
|
||||
p.InitiatorIndex = m.result.LocalIndex
|
||||
} else {
|
||||
p.ResponderIndex = m.result.LocalIndex
|
||||
p.InitiatorIndex = m.result.RemoteIndex
|
||||
}
|
||||
p.Time = uint64(time.Now().UnixNano())
|
||||
}
|
||||
if flags.expectsCert {
|
||||
cred := m.getCred(m.myVersion)
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("%w: %v", ErrNoCredential, m.myVersion)
|
||||
}
|
||||
p.Cert = cred.Bytes
|
||||
p.CertVersion = uint32(cred.Cert.Version())
|
||||
m.result.MyCert = cred.Cert
|
||||
}
|
||||
|
||||
return MarshalPayload(nil, p), nil
|
||||
}
|
||||
|
||||
func (m *Machine) buildResponse(out []byte) ([]byte, *noise.CipherState, *noise.CipherState, error) {
|
||||
flags := m.myMsgFlags()
|
||||
hsBytes, err := m.marshalOutgoing(flags)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
// Extend out by header.Len to make room for the header. slices.Grow is a
|
||||
// no-op when the cap is already sufficient (the zero-copy case where the
|
||||
// caller passed a pre-sized buffer). header.Encode overwrites the new
|
||||
// bytes, so they don't need to be zeroed.
|
||||
start := len(out)
|
||||
out = slices.Grow(out, header.Len)[:start+header.Len]
|
||||
header.Encode(
|
||||
out[start:],
|
||||
header.Version, header.Handshake, m.subtype,
|
||||
m.result.RemoteIndex,
|
||||
uint64(m.hs.MessageIndex()+1),
|
||||
)
|
||||
|
||||
// noise.WriteMessage appends the encrypted handshake message to out,
|
||||
// reusing capacity when present.
|
||||
//
|
||||
// The (dKey, eKey) ordering here is correct for IX, where the responder
|
||||
// completes the handshake by writing the stage-2 message. noise returns
|
||||
// (cs1, cs2) where cs1 is the initiator->responder cipher (which is the
|
||||
// responder's decrypt key). For 3-message patterns where an initiator
|
||||
// finishes by writing the final message, this ordering would be wrong;
|
||||
// revisit when XX/pqIX lands.
|
||||
out, dKey, eKey, err := m.hs.WriteMessage(out, hsBytes)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("noise WriteMessage: %w", err)
|
||||
}
|
||||
|
||||
return out, dKey, eKey, nil
|
||||
}
|
||||
662
handshake/machine_test.go
Normal file
662
handshake/machine_test.go
Normal file
@@ -0,0 +1,662 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
ct "github.com/slackhq/nebula/cert_test"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/noiseutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMachineIXHappyPath(t *testing.T) {
|
||||
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca)
|
||||
|
||||
initCS := newTestCertState(t, ca, caKey, "initiator", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||
respCS := newTestCertState(t, ca, caKey, "responder", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
|
||||
|
||||
initR, respR := doFullHandshake(t, initCS, respCS, caPool)
|
||||
|
||||
assert.Equal(t, "responder", initR.RemoteCert.Certificate.Name())
|
||||
assert.Equal(t, "initiator", respR.RemoteCert.Certificate.Name())
|
||||
|
||||
assert.Equal(t, uint32(1000), initR.LocalIndex)
|
||||
assert.Equal(t, uint32(2000), initR.RemoteIndex)
|
||||
assert.Equal(t, uint32(2000), respR.LocalIndex)
|
||||
assert.Equal(t, uint32(1000), respR.RemoteIndex)
|
||||
|
||||
assert.Equal(t, uint64(2), initR.MessageIndex, "IX has 2 messages")
|
||||
assert.Equal(t, uint64(2), respR.MessageIndex, "IX has 2 messages")
|
||||
|
||||
ct1, err := initR.EKey.Encrypt(nil, nil, []byte("hello"))
|
||||
require.NoError(t, err)
|
||||
pt1, err := respR.DKey.Decrypt(nil, nil, ct1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("hello"), pt1)
|
||||
|
||||
ct2, err := respR.EKey.Encrypt(nil, nil, []byte("world"))
|
||||
require.NoError(t, err)
|
||||
pt2, err := initR.DKey.Decrypt(nil, nil, ct2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("world"), pt2)
|
||||
}
|
||||
|
||||
func TestMachineInitiateErrors(t *testing.T) {
|
||||
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca)
|
||||
cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||
v := testVerifier(caPool)
|
||||
|
||||
t.Run("initiate on responder", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
_, err := m.Initiate(nil)
|
||||
require.ErrorIs(t, err, ErrInitiateOnResponder)
|
||||
assert.True(t, m.Failed())
|
||||
})
|
||||
|
||||
t.Run("initiate called twice", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, true, 100)
|
||||
_, err := m.Initiate(nil)
|
||||
require.NoError(t, err)
|
||||
_, err = m.Initiate(nil)
|
||||
require.ErrorIs(t, err, ErrInitiateAlreadyCalled)
|
||||
assert.True(t, m.Failed())
|
||||
})
|
||||
|
||||
t.Run("process packet before initiate on initiator", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, true, 100)
|
||||
_, _, err := m.ProcessPacket(nil, make([]byte, 100))
|
||||
require.ErrorIs(t, err, ErrInitiateNotCalled)
|
||||
assert.True(t, m.Failed())
|
||||
})
|
||||
|
||||
t.Run("calling failed machine", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
_, err := m.Initiate(nil) // fails: responder
|
||||
require.Error(t, err)
|
||||
_, err = m.Initiate(nil) // fails: already failed
|
||||
require.ErrorIs(t, err, ErrMachineFailed)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMachineProcessPacketErrors(t *testing.T) {
|
||||
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca)
|
||||
cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||
v := testVerifier(caPool)
|
||||
|
||||
t.Run("packet too short", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
_, _, err := m.ProcessPacket(nil, []byte{1, 2, 3})
|
||||
require.ErrorIs(t, err, ErrPacketTooShort)
|
||||
assert.False(t, m.Failed(), "short packet should not kill machine")
|
||||
})
|
||||
|
||||
t.Run("noise decryption failure is recoverable", func(t *testing.T) {
|
||||
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||
initM := newTestMachine(t, initCS, v, true, 100)
|
||||
msg1, err := initM.Initiate(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
respM := newTestMachine(t, cs, v, false, 200)
|
||||
resp, _, err := respM.ProcessPacket(nil, msg1)
|
||||
require.NoError(t, err)
|
||||
|
||||
corrupted := make([]byte, len(resp))
|
||||
copy(corrupted, resp)
|
||||
for i := header.Len; i < len(corrupted); i++ {
|
||||
corrupted[i] ^= 0xff
|
||||
}
|
||||
_, _, err = initM.ProcessPacket(nil, corrupted)
|
||||
require.Error(t, err)
|
||||
assert.False(t, initM.Failed(), "noise failure should be recoverable")
|
||||
|
||||
// And the machine should still complete a real handshake afterward.
|
||||
_, result, err := initM.ProcessPacket(nil, resp)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result, "initiator should complete on the legitimate response")
|
||||
})
|
||||
|
||||
t.Run("invalid cert is fatal", func(t *testing.T) {
|
||||
otherCA, _, otherCAKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
otherCS := newTestCertState(t, otherCA, otherCAKey, "other", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
|
||||
|
||||
initM := newTestMachine(t, otherCS, testVerifier(ct.NewTestCAPool(otherCA)), true, 100)
|
||||
msg1, err := initM.Initiate(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
respM := newTestMachine(t, cs, v, false, 200)
|
||||
_, _, err = respM.ProcessPacket(nil, msg1)
|
||||
require.Error(t, err)
|
||||
assert.True(t, respM.Failed(), "cert validation failure should kill machine")
|
||||
})
|
||||
|
||||
t.Run("subtype mismatch is recoverable", func(t *testing.T) {
|
||||
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||
initM := newTestMachine(t, initCS, v, true, 100)
|
||||
msg1, err := initM.Initiate(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Mutate the subtype byte (offset 1 in the header) to a value the
|
||||
// responder Machine wasn't built for.
|
||||
bad := make([]byte, len(msg1))
|
||||
copy(bad, msg1)
|
||||
bad[1] = 0xff
|
||||
|
||||
respM := newTestMachine(t, cs, v, false, 200)
|
||||
_, _, err = respM.ProcessPacket(nil, bad)
|
||||
require.ErrorIs(t, err, ErrSubtypeMismatch)
|
||||
assert.False(t, respM.Failed(), "subtype mismatch should not kill the machine")
|
||||
|
||||
// And the machine should still complete a real handshake afterward.
|
||||
resp, result, err := respM.ProcessPacket(nil, msg1)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result, "responder should complete on the legitimate stage-1 packet")
|
||||
assert.NotEmpty(t, resp, "responder should produce a stage-2 reply")
|
||||
})
|
||||
}
|
||||
|
||||
// TestMachineProcessPayload exercises processPayload's internal validation
|
||||
// directly. Most of these failure modes can't be reached black-box once the
|
||||
// subtype check at the top of ProcessPacket gates external callers, so we
|
||||
// drive them by hand here for coverage.
|
||||
func TestMachineProcessPayload(t *testing.T) {
|
||||
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca)
|
||||
cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||
v := testVerifier(caPool)
|
||||
|
||||
t.Run("empty message with expects fails", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
err := m.processPayload(nil, msgFlags{expectsPayload: true, expectsCert: true})
|
||||
require.ErrorIs(t, err, ErrMissingContent)
|
||||
assert.True(t, m.Failed())
|
||||
})
|
||||
|
||||
t.Run("empty message with no expects passes", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
err := m.processPayload(nil, msgFlags{})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, m.Failed())
|
||||
})
|
||||
|
||||
t.Run("malformed protobuf is fatal", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
err := m.processPayload([]byte{0xff, 0xff, 0xff}, msgFlags{expectsPayload: true, expectsCert: true})
|
||||
require.Error(t, err)
|
||||
assert.True(t, m.Failed())
|
||||
})
|
||||
|
||||
t.Run("unexpected payload data is fatal", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
// A payload with index data when none was expected.
|
||||
bytes := MarshalPayload(nil, Payload{InitiatorIndex: 42, Time: 1})
|
||||
err := m.processPayload(bytes, msgFlags{expectsPayload: false, expectsCert: false})
|
||||
require.ErrorIs(t, err, ErrUnexpectedContent)
|
||||
assert.True(t, m.Failed())
|
||||
})
|
||||
|
||||
t.Run("unexpected cert data is fatal", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
// A payload with cert when none was expected.
|
||||
bytes := MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2})
|
||||
err := m.processPayload(bytes, msgFlags{expectsPayload: false, expectsCert: false})
|
||||
require.ErrorIs(t, err, ErrUnexpectedContent)
|
||||
assert.True(t, m.Failed())
|
||||
})
|
||||
|
||||
t.Run("missing payload data when expected is fatal", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
// Cert present, but no index/time fields.
|
||||
bytes := MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2})
|
||||
err := m.processPayload(bytes, msgFlags{expectsPayload: true, expectsCert: true})
|
||||
require.ErrorIs(t, err, ErrUnexpectedContent)
|
||||
assert.True(t, m.Failed())
|
||||
})
|
||||
}
|
||||
|
||||
// TestMachineRequireComplete checks the fail-on-incomplete-handshake path
|
||||
// directly. Like processPayload above this isn't reachable from a normal IX
|
||||
// flow, so we drive it by hand.
|
||||
func TestMachineRequireComplete(t *testing.T) {
|
||||
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca)
|
||||
cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||
v := testVerifier(caPool)
|
||||
|
||||
t.Run("missing both fails", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
err := m.requireComplete()
|
||||
require.ErrorIs(t, err, ErrIncompleteHandshake)
|
||||
assert.True(t, m.Failed())
|
||||
})
|
||||
|
||||
t.Run("payload only fails", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
m.payloadSet = true
|
||||
err := m.requireComplete()
|
||||
require.ErrorIs(t, err, ErrIncompleteHandshake)
|
||||
assert.True(t, m.Failed())
|
||||
})
|
||||
|
||||
t.Run("cert only fails", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
m.remoteCertSet = true
|
||||
err := m.requireComplete()
|
||||
require.ErrorIs(t, err, ErrIncompleteHandshake)
|
||||
assert.True(t, m.Failed())
|
||||
})
|
||||
|
||||
t.Run("both set passes", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
m.payloadSet = true
|
||||
m.remoteCertSet = true
|
||||
err := m.requireComplete()
|
||||
require.NoError(t, err)
|
||||
assert.False(t, m.Failed())
|
||||
})
|
||||
}
|
||||
|
||||
func TestMachineAESCipher(t *testing.T) {
|
||||
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca)
|
||||
|
||||
initCS := newTestCertStateWithCipher(
|
||||
t, ca, caKey, "init",
|
||||
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
|
||||
noiseutil.CipherAESGCM,
|
||||
)
|
||||
respCS := newTestCertStateWithCipher(
|
||||
t, ca, caKey, "resp",
|
||||
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")},
|
||||
noiseutil.CipherAESGCM,
|
||||
)
|
||||
|
||||
initR, respR := doFullHandshake(t, initCS, respCS, caPool)
|
||||
|
||||
ct1, err := initR.EKey.Encrypt(nil, nil, []byte("works"))
|
||||
require.NoError(t, err)
|
||||
pt1, err := respR.DKey.Decrypt(nil, nil, ct1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("works"), pt1)
|
||||
|
||||
ct2, err := respR.EKey.Encrypt(nil, nil, []byte("back"))
|
||||
require.NoError(t, err)
|
||||
pt2, err := initR.DKey.Decrypt(nil, nil, ct2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("back"), pt2)
|
||||
}
|
||||
|
||||
func TestResultFields(t *testing.T) {
|
||||
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca)
|
||||
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||
respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
|
||||
|
||||
initR, respR := doFullHandshake(t, initCS, respCS, caPool)
|
||||
|
||||
assert.True(t, initR.Initiator)
|
||||
assert.False(t, respR.Initiator)
|
||||
assert.NotZero(t, initR.HandshakeTime)
|
||||
assert.NotZero(t, respR.HandshakeTime)
|
||||
assert.NotNil(t, initR.RemoteCert)
|
||||
assert.NotNil(t, respR.RemoteCert)
|
||||
}
|
||||
|
||||
func TestMachineBufferReuse(t *testing.T) {
|
||||
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca)
|
||||
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||
respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
|
||||
v := testVerifier(caPool)
|
||||
|
||||
initM := newTestMachine(t, initCS, v, true, 1000)
|
||||
respM := newTestMachine(t, respCS, v, false, 2000)
|
||||
|
||||
msg1, err := initM.Initiate(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("response writes into provided buffer", func(t *testing.T) {
|
||||
buf := make([]byte, 0, 4096)
|
||||
resp, result, err := respM.ProcessPacket(buf, msg1)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
assert.NotEmpty(t, resp, "response should have content")
|
||||
assert.Equal(t, &buf[:1][0], &resp[:1][0],
|
||||
"response should reuse the provided buffer's backing array")
|
||||
})
|
||||
|
||||
t.Run("initiate writes into provided buffer", func(t *testing.T) {
|
||||
initM2 := newTestMachine(t, initCS, v, true, 3000)
|
||||
buf := make([]byte, 0, 4096)
|
||||
msg, err := initM2.Initiate(buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEmpty(t, msg, "initiate should have content")
|
||||
assert.Equal(t, &buf[:1][0], &msg[:1][0],
|
||||
"initiate should reuse the provided buffer's backing array")
|
||||
})
|
||||
|
||||
t.Run("nil out still works", func(t *testing.T) {
|
||||
initM2 := newTestMachine(t, initCS, v, true, 4000)
|
||||
respM2 := newTestMachine(t, respCS, v, false, 5000)
|
||||
|
||||
msg1, err := initM2.Initiate(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, _, err := respM2.ProcessPacket(nil, msg1)
|
||||
require.NoError(t, err)
|
||||
|
||||
out, result, err := initM2.ProcessPacket(nil, resp)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Nil(t, out, "initiator should have no response for IX msg2")
|
||||
})
|
||||
}
|
||||
|
||||
func TestMachineMsgIndexTracking(t *testing.T) {
|
||||
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca)
|
||||
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||
respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
|
||||
v := testVerifier(caPool)
|
||||
|
||||
initM := newTestMachine(t, initCS, v, true, 100)
|
||||
respM := newTestMachine(t, respCS, v, false, 200)
|
||||
|
||||
msg1, err := initM.Initiate(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp1, result1, err := respM.ProcessPacket(nil, msg1)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result1)
|
||||
|
||||
_, result2, err := initM.ProcessPacket(nil, resp1)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result2)
|
||||
}
|
||||
|
||||
func TestMachineThreeMessagePattern(t *testing.T) {
|
||||
registerTestXXInfo(t)
|
||||
|
||||
// Use HandshakeXX (3 messages) to verify the Machine handles multi-message
|
||||
// patterns correctly. XX flow:
|
||||
// msg1 (I->R): [E] - payload only, no cert
|
||||
// msg2 (R->I): [E, ee, S, es] - payload + cert
|
||||
// msg3 (I->R): [S, se] - cert only (no payload, not first two)
|
||||
|
||||
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca)
|
||||
v := testVerifier(caPool)
|
||||
|
||||
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||
respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
|
||||
|
||||
initM, err := NewMachine(
|
||||
cert.Version2,
|
||||
initCS.getCredential, v,
|
||||
func() (uint32, error) { return 1000, nil },
|
||||
true, header.HandshakeXXPSK0,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
respM, err := NewMachine(
|
||||
cert.Version2,
|
||||
respCS.getCredential, v,
|
||||
func() (uint32, error) { return 2000, nil },
|
||||
false, header.HandshakeXXPSK0,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// msg1: initiator -> responder (E only, no cert)
|
||||
msg1, err := initM.Initiate(nil)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, msg1)
|
||||
|
||||
// Responder processes msg1, should not complete yet, should produce msg2
|
||||
msg2, result, err := respM.ProcessPacket(nil, msg1)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, result, "XX should not complete on msg1")
|
||||
assert.NotEmpty(t, msg2, "responder should produce msg2")
|
||||
|
||||
// Initiator processes msg2: gets responder's cert, produces msg3, and
|
||||
// completes (WriteMessage for msg3 derives keys)
|
||||
msg3, initResult, err := initM.ProcessPacket(nil, msg2)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, initResult, "XX initiator should complete after reading msg2 and writing msg3")
|
||||
assert.NotEmpty(t, msg3, "initiator should produce msg3")
|
||||
assert.Equal(t, "resp", initResult.RemoteCert.Certificate.Name())
|
||||
|
||||
// Responder processes msg3: gets initiator's cert and completes
|
||||
_, respResult, err := respM.ProcessPacket(nil, msg3)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, respResult, "XX responder should complete on msg3")
|
||||
assert.Equal(t, "init", respResult.RemoteCert.Certificate.Name())
|
||||
|
||||
assert.Equal(t, uint64(3), initResult.MessageIndex, "XX has 3 messages")
|
||||
assert.Equal(t, uint64(3), respResult.MessageIndex, "XX has 3 messages")
|
||||
|
||||
// Verify keys work
|
||||
ct1, err := initResult.EKey.Encrypt(nil, nil, []byte("three messages"))
|
||||
require.NoError(t, err)
|
||||
pt1, err := respResult.DKey.Decrypt(nil, nil, ct1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("three messages"), pt1)
|
||||
}
|
||||
|
||||
// NOTE: ErrIncompleteHandshake is tested implicitly. It can't be triggered with
|
||||
// IX since the cert is always in the payload. A 3-message pattern test (HybridIX)
|
||||
// should exercise the case where cert arrives in msg3 and verify that completing
|
||||
// without it fails.
|
||||
|
||||
func TestMachineExpiredCert(t *testing.T) {
|
||||
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519,
|
||||
time.Now().Add(-24*time.Hour), time.Now().Add(24*time.Hour),
|
||||
nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca)
|
||||
|
||||
expCert, _, expKeyPEM, _ := ct.NewTestCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, ca, caKey,
|
||||
"expired", time.Now().Add(-2*time.Hour), time.Now().Add(-1*time.Hour),
|
||||
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}, nil, nil,
|
||||
)
|
||||
expKey, _, _, err := cert.UnmarshalPrivateKeyFromPEM(expKeyPEM)
|
||||
require.NoError(t, err)
|
||||
expHsBytes, err := expCert.MarshalForHandshakes()
|
||||
require.NoError(t, err)
|
||||
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||
|
||||
expiredCS := &testCertState{
|
||||
version: cert.Version2,
|
||||
creds: map[cert.Version]*Credential{
|
||||
cert.Version2: NewCredential(expCert, expHsBytes, expKey, ncs),
|
||||
},
|
||||
}
|
||||
|
||||
respCS := newTestCertState(
|
||||
t, ca, caKey, "responder",
|
||||
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")},
|
||||
)
|
||||
|
||||
_, respM, _, _, err := initiateHandshake(
|
||||
t, expiredCS, testVerifier(caPool),
|
||||
respCS, testVerifier(caPool),
|
||||
)
|
||||
require.ErrorContains(t, err, "verify cert")
|
||||
assert.True(t, respM.Failed())
|
||||
}
|
||||
|
||||
func TestMachineNoCertNetworks(t *testing.T) {
|
||||
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca)
|
||||
|
||||
caHsBytes, err := ca.MarshalForHandshakes()
|
||||
require.NoError(t, err)
|
||||
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||
|
||||
noNetCS := &testCertState{
|
||||
version: cert.Version2,
|
||||
creds: map[cert.Version]*Credential{
|
||||
cert.Version2: NewCredential(ca, caHsBytes, caKey, ncs),
|
||||
},
|
||||
}
|
||||
|
||||
respCS := newTestCertState(
|
||||
t, ca, caKey, "responder",
|
||||
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")},
|
||||
)
|
||||
|
||||
_, respM, _, _, err := initiateHandshake(
|
||||
t, noNetCS, testVerifier(caPool),
|
||||
respCS, testVerifier(caPool),
|
||||
)
|
||||
require.Error(t, err)
|
||||
assert.True(t, respM.Failed())
|
||||
}
|
||||
|
||||
func TestMachineDifferentCAs(t *testing.T) {
|
||||
ca1, _, caKey1, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
ca2, _, caKey2, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
|
||||
initCS := newTestCertState(
|
||||
t, ca1, caKey1, "init",
|
||||
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
|
||||
)
|
||||
respCS := newTestCertState(
|
||||
t, ca2, caKey2, "resp",
|
||||
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")},
|
||||
)
|
||||
|
||||
_, respM, _, _, err := initiateHandshake(
|
||||
t, initCS, testVerifier(ct.NewTestCAPool(ca1)),
|
||||
respCS, testVerifier(ct.NewTestCAPool(ca2)),
|
||||
)
|
||||
require.ErrorContains(t, err, "verify cert")
|
||||
assert.True(t, respM.Failed())
|
||||
}
|
||||
|
||||
func TestMachineVersionNegotiation(t *testing.T) {
|
||||
ca1, _, caKey1, _ := ct.NewTestCaCert(
|
||||
cert.Version1, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
ca2, _, caKey2, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca1, ca2)
|
||||
|
||||
makeMultiVersionResp := func(t *testing.T) *testCertState {
|
||||
t.Helper()
|
||||
respCertV1, _, respKeyPEM, _ := ct.NewTestCert(
|
||||
cert.Version1, cert.Curve_CURVE25519, ca1, caKey1, "resp",
|
||||
ca1.NotBefore(), ca1.NotAfter(),
|
||||
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, nil, nil,
|
||||
)
|
||||
respKey, _, _, _ := cert.UnmarshalPrivateKeyFromPEM(respKeyPEM)
|
||||
respCertV2, _ := ct.NewTestCertDifferentVersion(respCertV1, cert.Version2, ca2, caKey2)
|
||||
respHsV1, _ := respCertV1.MarshalForHandshakes()
|
||||
respHsV2, _ := respCertV2.MarshalForHandshakes()
|
||||
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||
return &testCertState{
|
||||
version: cert.Version1,
|
||||
creds: map[cert.Version]*Credential{
|
||||
cert.Version1: NewCredential(respCertV1, respHsV1, respKey, ncs),
|
||||
cert.Version2: NewCredential(respCertV2, respHsV2, respKey, ncs),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("responder matches initiator version", func(t *testing.T) {
|
||||
initCS := newTestCertState(
|
||||
t, ca2, caKey2, "init",
|
||||
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
|
||||
)
|
||||
respCS := makeMultiVersionResp(t)
|
||||
v := testVerifier(caPool)
|
||||
|
||||
initM, _, respResult, resp, err := initiateHandshake(
|
||||
t, initCS, v,
|
||||
respCS, v,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, respResult)
|
||||
|
||||
assert.Equal(t, cert.Version2, respResult.MyCert.Version(),
|
||||
"responder should negotiate to initiator's version")
|
||||
|
||||
_, initResult, err := initM.ProcessPacket(nil, resp)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, initResult)
|
||||
assert.Equal(t, cert.Version2, initResult.RemoteCert.Certificate.Version(),
|
||||
"initiator should see V2 cert from responder")
|
||||
})
|
||||
|
||||
t.Run("responder keeps version when no match available", func(t *testing.T) {
|
||||
initCS := newTestCertState(
|
||||
t, ca2, caKey2, "init",
|
||||
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
|
||||
)
|
||||
|
||||
respCert, _, respKeyPEM, _ := ct.NewTestCert(
|
||||
cert.Version1, cert.Curve_CURVE25519, ca1, caKey1, "resp",
|
||||
ca1.NotBefore(), ca1.NotAfter(),
|
||||
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, nil, nil,
|
||||
)
|
||||
respKey, _, _, _ := cert.UnmarshalPrivateKeyFromPEM(respKeyPEM)
|
||||
respHs, _ := respCert.MarshalForHandshakes()
|
||||
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||
respCS := &testCertState{
|
||||
version: cert.Version1,
|
||||
creds: map[cert.Version]*Credential{
|
||||
cert.Version1: NewCredential(respCert, respHs, respKey, ncs),
|
||||
},
|
||||
}
|
||||
|
||||
v := testVerifier(caPool)
|
||||
_, _, respResult, _, err := initiateHandshake(
|
||||
t, initCS, v,
|
||||
respCS, v,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, respResult)
|
||||
|
||||
assert.Equal(t, cert.Version1, respResult.MyCert.Version(),
|
||||
"responder should keep V1 when V2 not available")
|
||||
})
|
||||
}
|
||||
54
handshake/patterns.go
Normal file
54
handshake/patterns.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/slackhq/nebula/header"
|
||||
)
|
||||
|
||||
// msgFlags tracks what application data a handshake message carries.
|
||||
type msgFlags struct {
|
||||
expectsPayload bool // message carries indexes and time
|
||||
expectsCert bool // message carries the certificate
|
||||
}
|
||||
|
||||
// subtypeInfo bundles the noise pattern with the per-message flags for a
|
||||
// given handshake subtype.
|
||||
type subtypeInfo struct {
|
||||
pattern noise.HandshakePattern
|
||||
msgs []msgFlags
|
||||
}
|
||||
|
||||
// subtypeInfos defines the noise pattern and message content layout for each
|
||||
// handshake subtype.
|
||||
var subtypeInfos = map[header.MessageSubType]subtypeInfo{
|
||||
// IX: 2 messages, both carry payload and cert
|
||||
header.HandshakeIXPSK0: {
|
||||
pattern: noise.HandshakeIX,
|
||||
msgs: []msgFlags{
|
||||
{expectsPayload: true, expectsCert: true},
|
||||
{expectsPayload: true, expectsCert: true},
|
||||
},
|
||||
},
|
||||
|
||||
// XX: 3 messages
|
||||
// msg1 (I->R): payload only
|
||||
// msg2 (R->I): payload + cert
|
||||
// msg3 (I->R): cert only
|
||||
//header.HandshakeXXPSK0: {
|
||||
// pattern: noise.HandshakeXX,
|
||||
// msgs: []msgFlags{
|
||||
// {expectsPayload: true, expectsCert: false},
|
||||
// {expectsPayload: true, expectsCert: true},
|
||||
// {expectsPayload: false, expectsCert: true},
|
||||
// },
|
||||
//},
|
||||
}
|
||||
|
||||
func subtypeInfoFor(subtype header.MessageSubType) (subtypeInfo, error) {
|
||||
if info, ok := subtypeInfos[subtype]; ok {
|
||||
return info, nil
|
||||
}
|
||||
return subtypeInfo{}, fmt.Errorf("%w: %d", ErrUnknownSubtype, subtype)
|
||||
}
|
||||
63
handshake/patterns_test.go
Normal file
63
handshake/patterns_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSubtypeInfo(t *testing.T) {
|
||||
t.Run("IX", func(t *testing.T) {
|
||||
info, err := subtypeInfoFor(header.HandshakeIXPSK0)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, noise.HandshakeIX.Name, info.pattern.Name)
|
||||
require.Len(t, info.msgs, 2)
|
||||
// msg1: payload + cert
|
||||
assert.True(t, info.msgs[0].expectsPayload)
|
||||
assert.True(t, info.msgs[0].expectsCert)
|
||||
// msg2: payload + cert
|
||||
assert.True(t, info.msgs[1].expectsPayload)
|
||||
assert.True(t, info.msgs[1].expectsCert)
|
||||
})
|
||||
|
||||
t.Run("XX", func(t *testing.T) {
|
||||
registerTestXXInfo(t)
|
||||
info, err := subtypeInfoFor(header.HandshakeXXPSK0)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, noise.HandshakeXX.Name, info.pattern.Name)
|
||||
require.Len(t, info.msgs, 3)
|
||||
// msg1: payload only
|
||||
assert.True(t, info.msgs[0].expectsPayload)
|
||||
assert.False(t, info.msgs[0].expectsCert)
|
||||
// msg2: payload + cert
|
||||
assert.True(t, info.msgs[1].expectsPayload)
|
||||
assert.True(t, info.msgs[1].expectsCert)
|
||||
// msg3: cert only
|
||||
assert.False(t, info.msgs[2].expectsPayload)
|
||||
assert.True(t, info.msgs[2].expectsCert)
|
||||
})
|
||||
|
||||
t.Run("unknown subtype returns error", func(t *testing.T) {
|
||||
_, err := subtypeInfoFor(99)
|
||||
require.ErrorIs(t, err, ErrUnknownSubtype)
|
||||
})
|
||||
}
|
||||
|
||||
// registerTestXXInfo temporarily registers XX subtype info for testing.
|
||||
func registerTestXXInfo(t *testing.T) {
|
||||
t.Helper()
|
||||
subtypeInfos[header.HandshakeXXPSK0] = subtypeInfo{
|
||||
pattern: noise.HandshakeXX,
|
||||
msgs: []msgFlags{
|
||||
{expectsPayload: true, expectsCert: false},
|
||||
{expectsPayload: true, expectsCert: true},
|
||||
{expectsPayload: false, expectsCert: true},
|
||||
},
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
delete(subtypeInfos, header.HandshakeXXPSK0)
|
||||
})
|
||||
}
|
||||
173
handshake/payload.go
Normal file
173
handshake/payload.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"google.golang.org/protobuf/encoding/protowire"
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidHandshakeMessage = errors.New("invalid handshake message")
|
||||
errInvalidHandshakeDetails = errors.New("invalid handshake details")
|
||||
)
|
||||
|
||||
// Payload represents the decoded fields of a handshake message.
|
||||
// Wire format is protobuf-compatible with NebulaHandshake{Details: NebulaHandshakeDetails{...}}.
|
||||
type Payload struct {
|
||||
Cert []byte
|
||||
InitiatorIndex uint32
|
||||
ResponderIndex uint32
|
||||
Time uint64
|
||||
CertVersion uint32
|
||||
}
|
||||
|
||||
// Proto field numbers for NebulaHandshakeDetails
|
||||
const (
|
||||
fieldCert = 1 // bytes
|
||||
fieldInitiatorIndex = 2 // uint32
|
||||
fieldResponderIndex = 3 // uint32
|
||||
fieldTime = 5 // uint64
|
||||
fieldCertVersion = 8 // uint32
|
||||
)
|
||||
|
||||
// MarshalPayload encodes a handshake payload in protobuf wire format compatible
|
||||
// with NebulaHandshake{Details: NebulaHandshakeDetails{...}}.
|
||||
// Returns out (which may be nil), with the marshalled Payload appended to it.
|
||||
func MarshalPayload(out []byte, p Payload) []byte {
|
||||
var details []byte
|
||||
|
||||
if len(p.Cert) > 0 {
|
||||
details = protowire.AppendTag(details, fieldCert, protowire.BytesType)
|
||||
details = protowire.AppendBytes(details, p.Cert)
|
||||
}
|
||||
if p.InitiatorIndex != 0 {
|
||||
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, uint64(p.InitiatorIndex))
|
||||
}
|
||||
if p.ResponderIndex != 0 {
|
||||
details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, uint64(p.ResponderIndex))
|
||||
}
|
||||
if p.Time != 0 {
|
||||
details = protowire.AppendTag(details, fieldTime, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, p.Time)
|
||||
}
|
||||
if p.CertVersion != 0 {
|
||||
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, uint64(p.CertVersion))
|
||||
}
|
||||
|
||||
out = protowire.AppendTag(out, 1, protowire.BytesType)
|
||||
out = protowire.AppendBytes(out, details)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// UnmarshalPayload decodes a protobuf-encoded NebulaHandshake message.
|
||||
func UnmarshalPayload(b []byte) (Payload, error) {
|
||||
var p Payload
|
||||
|
||||
for len(b) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(b)
|
||||
if n < 0 {
|
||||
return p, errInvalidHandshakeMessage
|
||||
}
|
||||
b = b[n:]
|
||||
|
||||
switch {
|
||||
case num == 1 && typ == protowire.BytesType:
|
||||
details, n := protowire.ConsumeBytes(b)
|
||||
if n < 0 {
|
||||
return p, errInvalidHandshakeMessage
|
||||
}
|
||||
b = b[n:]
|
||||
if err := unmarshalPayloadDetails(&p, details); err != nil {
|
||||
return p, err
|
||||
}
|
||||
default:
|
||||
n := protowire.ConsumeFieldValue(num, typ, b)
|
||||
if n < 0 {
|
||||
return p, errInvalidHandshakeMessage
|
||||
}
|
||||
b = b[n:]
|
||||
}
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func unmarshalPayloadDetails(p *Payload, b []byte) error {
|
||||
for len(b) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(b)
|
||||
if n < 0 {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
b = b[n:]
|
||||
|
||||
// For known field numbers, reject any non-matching wire type as a
|
||||
// hard error rather than silently skipping. The caller will catch
|
||||
// missing-field cases downstream, but a wire-type mismatch on a tag
|
||||
// we know is a peer protocol violation worth flagging here.
|
||||
// Repeated occurrences of a singular field follow proto3 last-wins.
|
||||
switch num {
|
||||
case fieldCert:
|
||||
if typ != protowire.BytesType {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
v, n := protowire.ConsumeBytes(b)
|
||||
if n < 0 {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
p.Cert = append([]byte(nil), v...)
|
||||
b = b[n:]
|
||||
case fieldInitiatorIndex:
|
||||
if typ != protowire.VarintType {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
v, n := protowire.ConsumeVarint(b)
|
||||
if n < 0 || v > math.MaxUint32 {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
p.InitiatorIndex = uint32(v)
|
||||
b = b[n:]
|
||||
case fieldResponderIndex:
|
||||
if typ != protowire.VarintType {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
v, n := protowire.ConsumeVarint(b)
|
||||
if n < 0 || v > math.MaxUint32 {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
p.ResponderIndex = uint32(v)
|
||||
b = b[n:]
|
||||
case fieldTime:
|
||||
if typ != protowire.VarintType {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
v, n := protowire.ConsumeVarint(b)
|
||||
if n < 0 {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
p.Time = v
|
||||
b = b[n:]
|
||||
case fieldCertVersion:
|
||||
if typ != protowire.VarintType {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
v, n := protowire.ConsumeVarint(b)
|
||||
if n < 0 || v > math.MaxUint32 {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
p.CertVersion = uint32(v)
|
||||
b = b[n:]
|
||||
default:
|
||||
n := protowire.ConsumeFieldValue(num, typ, b)
|
||||
if n < 0 {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
b = b[n:]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
361
handshake/payload_test.go
Normal file
361
handshake/payload_test.go
Normal file
@@ -0,0 +1,361 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/encoding/protowire"
|
||||
)
|
||||
|
||||
func TestPayloadRoundTrip(t *testing.T) {
|
||||
t.Run("all fields set", func(t *testing.T) {
|
||||
data := MarshalPayload(nil, Payload{
|
||||
Cert: []byte("test-cert-bytes"),
|
||||
CertVersion: 2,
|
||||
InitiatorIndex: 12345,
|
||||
ResponderIndex: 67890,
|
||||
Time: 1234567890,
|
||||
})
|
||||
|
||||
got, err := UnmarshalPayload(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, []byte("test-cert-bytes"), got.Cert)
|
||||
assert.Equal(t, uint32(12345), got.InitiatorIndex)
|
||||
assert.Equal(t, uint32(67890), got.ResponderIndex)
|
||||
assert.Equal(t, uint64(1234567890), got.Time)
|
||||
assert.Equal(t, uint32(2), got.CertVersion)
|
||||
})
|
||||
|
||||
t.Run("minimal fields", func(t *testing.T) {
|
||||
data := MarshalPayload(nil, Payload{InitiatorIndex: 1})
|
||||
|
||||
got, err := UnmarshalPayload(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, uint32(1), got.InitiatorIndex)
|
||||
assert.Equal(t, uint32(0), got.ResponderIndex)
|
||||
assert.Equal(t, uint64(0), got.Time)
|
||||
assert.Nil(t, got.Cert)
|
||||
})
|
||||
|
||||
t.Run("empty payload", func(t *testing.T) {
|
||||
data := MarshalPayload(nil, Payload{})
|
||||
|
||||
got, err := UnmarshalPayload(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, uint32(0), got.InitiatorIndex)
|
||||
})
|
||||
|
||||
t.Run("large cert bytes", func(t *testing.T) {
|
||||
bigCert := make([]byte, 4096)
|
||||
for i := range bigCert {
|
||||
bigCert[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
data := MarshalPayload(nil, Payload{
|
||||
Cert: bigCert,
|
||||
CertVersion: 2,
|
||||
InitiatorIndex: 999,
|
||||
})
|
||||
|
||||
got, err := UnmarshalPayload(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, bigCert, got.Cert)
|
||||
assert.Equal(t, uint32(999), got.InitiatorIndex)
|
||||
})
|
||||
|
||||
t.Run("append to existing buffer", func(t *testing.T) {
|
||||
prefix := []byte("prefix")
|
||||
data := MarshalPayload(prefix, Payload{InitiatorIndex: 42})
|
||||
|
||||
assert.Equal(t, []byte("prefix"), data[:6])
|
||||
|
||||
got, err := UnmarshalPayload(data[6:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(42), got.InitiatorIndex)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPayloadUnknownFields(t *testing.T) {
|
||||
t.Run("unknown field in outer message is skipped", func(t *testing.T) {
|
||||
// Marshal a normal payload then append an unknown field (field 99, varint)
|
||||
data := MarshalPayload(nil, Payload{InitiatorIndex: 42})
|
||||
data = protowire.AppendTag(data, 99, protowire.VarintType)
|
||||
data = protowire.AppendVarint(data, 12345)
|
||||
|
||||
got, err := UnmarshalPayload(data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(42), got.InitiatorIndex)
|
||||
})
|
||||
|
||||
t.Run("unknown field in details is skipped", func(t *testing.T) {
|
||||
// Build details with a known field + unknown field
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 77)
|
||||
// Unknown field 50, varint
|
||||
details = protowire.AppendTag(details, 50, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 9999)
|
||||
// Another known field after the unknown one
|
||||
details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 88)
|
||||
|
||||
// Wrap in outer message
|
||||
var data []byte
|
||||
data = protowire.AppendTag(data, 1, protowire.BytesType)
|
||||
data = protowire.AppendBytes(data, details)
|
||||
|
||||
got, err := UnmarshalPayload(data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(77), got.InitiatorIndex)
|
||||
assert.Equal(t, uint32(88), got.ResponderIndex)
|
||||
})
|
||||
|
||||
t.Run("reserved fields 6 and 7 are skipped", func(t *testing.T) {
|
||||
// Fields 6 and 7 are reserved in the proto definition
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 100)
|
||||
details = protowire.AppendTag(details, 6, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 1)
|
||||
details = protowire.AppendTag(details, 7, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 2)
|
||||
|
||||
var data []byte
|
||||
data = protowire.AppendTag(data, 1, protowire.BytesType)
|
||||
data = protowire.AppendBytes(data, details)
|
||||
|
||||
got, err := UnmarshalPayload(data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(100), got.InitiatorIndex)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPayloadBytesConsumed(t *testing.T) {
|
||||
t.Run("all bytes consumed on valid input", func(t *testing.T) {
|
||||
original := Payload{
|
||||
Cert: []byte("cert"),
|
||||
CertVersion: 2,
|
||||
InitiatorIndex: 100,
|
||||
ResponderIndex: 200,
|
||||
Time: 999,
|
||||
}
|
||||
data := MarshalPayload(nil, original)
|
||||
|
||||
got, err := UnmarshalPayload(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Re-marshal and compare — proves we consumed and reproduced all fields
|
||||
remarshaled := MarshalPayload(nil, got)
|
||||
assert.Equal(t, data, remarshaled)
|
||||
})
|
||||
}
|
||||
|
||||
// wrapDetails wraps raw detail bytes in the outer NebulaHandshake envelope
|
||||
// so UnmarshalPayload can reach unmarshalPayloadDetails.
|
||||
func wrapDetails(details []byte) []byte {
|
||||
var out []byte
|
||||
out = protowire.AppendTag(out, 1, protowire.BytesType)
|
||||
out = protowire.AppendBytes(out, details)
|
||||
return out
|
||||
}
|
||||
|
||||
func TestPayloadUnmarshalErrors(t *testing.T) {
|
||||
t.Run("nil input", func(t *testing.T) {
|
||||
got, err := UnmarshalPayload(nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(0), got.InitiatorIndex)
|
||||
})
|
||||
|
||||
t.Run("truncated outer tag", func(t *testing.T) {
|
||||
_, err := UnmarshalPayload([]byte{0x80})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated outer details field", func(t *testing.T) {
|
||||
_, err := UnmarshalPayload([]byte{0x0a, 0x64, 0x01, 0x02, 0x03, 0x04, 0x05})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated outer unknown field", func(t *testing.T) {
|
||||
// Valid tag for unknown field 99 varint, but no value follows
|
||||
var data []byte
|
||||
data = protowire.AppendTag(data, 99, protowire.VarintType)
|
||||
_, err := UnmarshalPayload(data)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated details tag", func(t *testing.T) {
|
||||
_, err := UnmarshalPayload(wrapDetails([]byte{0x80}))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated cert bytes", func(t *testing.T) {
|
||||
// Field 1 (cert), bytes type, length 10 but only 2 bytes
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldCert, protowire.BytesType)
|
||||
details = append(details, 0x0a, 0x01, 0x02) // length 10, only 2 bytes
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated initiator index varint", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||
details = append(details, 0x80) // incomplete varint
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated responder index varint", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType)
|
||||
details = append(details, 0x80)
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated time varint", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldTime, protowire.VarintType)
|
||||
details = append(details, 0x80)
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated cert version varint", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
|
||||
details = append(details, 0x80)
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated unknown field in details", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, 50, protowire.VarintType)
|
||||
details = append(details, 0x80) // incomplete varint
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("cert with wrong wire type rejected", func(t *testing.T) {
|
||||
// fieldCert as Varint instead of Bytes.
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldCert, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 42)
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("initiator index with wrong wire type rejected", func(t *testing.T) {
|
||||
// fieldInitiatorIndex as Bytes instead of Varint.
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.BytesType)
|
||||
details = protowire.AppendBytes(details, []byte{1, 2, 3})
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("time with wrong wire type rejected", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldTime, protowire.BytesType)
|
||||
details = protowire.AppendBytes(details, []byte{1, 2, 3})
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("cert version with wrong wire type rejected", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldCertVersion, protowire.BytesType)
|
||||
details = protowire.AppendBytes(details, []byte{1, 2, 3})
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("repeated singular field follows proto3 last-wins", func(t *testing.T) {
|
||||
// Per proto3, multiple instances of a singular field are accepted and
|
||||
// the last value wins. We keep this behavior so that peers using
|
||||
// alternative encoders aren't rejected.
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 1)
|
||||
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 42)
|
||||
got, err := UnmarshalPayload(wrapDetails(details))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(42), got.InitiatorIndex)
|
||||
})
|
||||
|
||||
t.Run("initiator index varint overflow rejected", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, math.MaxUint32+1)
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("cert version varint overflow rejected", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, math.MaxUint32+1)
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// FuzzPayload feeds arbitrary bytes through UnmarshalPayload to confirm it
|
||||
// never panics, and for any input that parses cleanly, that re-marshal +
|
||||
// re-parse is a fix-point. Inputs come from an authenticated peer (post-
|
||||
// noise-decrypt), so the threat model is "valid peer behaving arbitrarily,"
|
||||
// not "unauthenticated injection."
|
||||
func FuzzPayload(f *testing.F) {
|
||||
// Seed corpus with a handful of known-good shapes.
|
||||
f.Add(MarshalPayload(nil, Payload{}))
|
||||
f.Add(MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2}))
|
||||
f.Add(MarshalPayload(nil, Payload{InitiatorIndex: 42, Time: 1}))
|
||||
f.Add(MarshalPayload(nil, Payload{
|
||||
Cert: []byte("seed-cert"),
|
||||
InitiatorIndex: 1,
|
||||
ResponderIndex: 2,
|
||||
Time: 3,
|
||||
CertVersion: 2,
|
||||
}))
|
||||
f.Add([]byte{})
|
||||
f.Add([]byte{0xff})
|
||||
|
||||
f.Fuzz(func(t *testing.T, data []byte) {
|
||||
p1, err := UnmarshalPayload(data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// For any input that parses, re-marshaling and re-parsing must
|
||||
// yield an equivalent Payload. This catches dispatch bugs (e.g.
|
||||
// emitting a field on marshal that we don't accept on parse) and
|
||||
// any non-idempotent parsing behavior.
|
||||
b2 := MarshalPayload(nil, p1)
|
||||
p2, err := UnmarshalPayload(b2)
|
||||
if err != nil {
|
||||
t.Fatalf("re-parse of self-marshaled payload failed: %v\nintermediate: %x\n", err, b2)
|
||||
}
|
||||
if !payloadsEqual(p1, p2) {
|
||||
t.Fatalf("re-marshal not idempotent\nfirst: %+v\nsecond: %+v", p1, p2)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func payloadsEqual(a, b Payload) bool {
|
||||
return bytes.Equal(a.Cert, b.Cert) &&
|
||||
a.InitiatorIndex == b.InitiatorIndex &&
|
||||
a.ResponderIndex == b.ResponderIndex &&
|
||||
a.Time == b.Time &&
|
||||
a.CertVersion == b.CertVersion
|
||||
}
|
||||
Reference in New Issue
Block a user