Merge remote-tracking branch 'origin/master' into multiport

This commit is contained in:
Wade Simmons
2026-05-06 14:26:49 -04:00
138 changed files with 10562 additions and 4541 deletions

57
handshake/credential.go Normal file
View File

@@ -0,0 +1,57 @@
package handshake
import (
"crypto/rand"
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert"
)
// Credential holds everything needed to participate in a handshake
// at a given cert version. Version and Curve are read from Cert; the public
// half of the static keypair likewise comes from Cert.PublicKey().
type Credential struct {
Cert cert.Certificate // the certificate
Bytes []byte // pre-marshaled certificate bytes
privateKey []byte // static private key (public half lives in Cert)
cipherSuite noise.CipherSuite // pre-built cipher suite (DH + cipher + hash)
}
// NewCredential creates a Credential with all material needed for handshake
// participation. The cipherSuite should be pre-built by the caller with the
// appropriate DH function, cipher, and hash.
func NewCredential(
c cert.Certificate,
hsBytes []byte,
privateKey []byte,
cipherSuite noise.CipherSuite,
) *Credential {
return &Credential{
Cert: c,
Bytes: hsBytes,
privateKey: privateKey,
cipherSuite: cipherSuite,
}
}
// buildHandshakeState creates a noise.HandshakeState from this credential.
func (hc *Credential) buildHandshakeState(initiator bool, pattern noise.HandshakePattern) (*noise.HandshakeState, error) {
return noise.NewHandshakeState(noise.Config{
CipherSuite: hc.cipherSuite,
Random: rand.Reader,
Pattern: pattern,
Initiator: initiator,
StaticKeypair: noise.DHKey{Private: hc.privateKey, Public: hc.Cert.PublicKey()},
PresharedKey: []byte{},
PresharedKeyPlacement: 0,
})
}
// GetCredentialFunc returns the handshake credential for the given version,
// or nil if that version is not available.
//
// Implementations must return credentials drawn from a snapshot stable for
// the lifetime of any single Machine. The Machine may call this multiple
// times during a handshake (e.g. when negotiating to the peer's version)
// and assumes the underlying static keypair is consistent across calls.
type GetCredentialFunc func(v cert.Version) *Credential

21
handshake/errors.go Normal file
View File

@@ -0,0 +1,21 @@
package handshake
import "errors"
var (
ErrInitiateOnResponder = errors.New("initiate called on responder")
ErrInitiateAlreadyCalled = errors.New("initiate already called")
ErrInitiateNotCalled = errors.New("initiate must be called before ProcessPacket for initiators")
ErrPacketTooShort = errors.New("packet too short")
ErrPublicKeyMismatch = errors.New("public key mismatch between certificate and handshake")
ErrIncompleteHandshake = errors.New("handshake completed without receiving required content")
ErrMachineFailed = errors.New("handshake machine has failed")
ErrUnknownSubtype = errors.New("unknown handshake subtype")
ErrMissingContent = errors.New("expected handshake content but message was empty")
ErrUnexpectedContent = errors.New("received unexpected handshake content")
ErrIndexAllocation = errors.New("failed to allocate local index")
ErrNoCredential = errors.New("no handshake credential available for cert version")
ErrAsymmetricCipherKeys = errors.New("noise produced only one cipher key")
ErrMultiMessageUnsupported = errors.New("multi-message handshake patterns are not yet supported by the manager")
ErrSubtypeMismatch = errors.New("packet subtype does not match handshake machine subtype")
)

37
handshake/handshake.proto Normal file
View File

@@ -0,0 +1,37 @@
// This file documents the wire format the nebula handshake speaks. It is
// not run through protoc; the encoder/decoder in payload.go is hand-written
// against this shape directly to keep the parser narrow and panic-free.
//
// Any change to the wire format must be reflected here, and adding a new
// field requires updating MarshalPayload / unmarshalPayloadDetails together
// with the field-uniqueness and wire-type checks in those functions.
syntax = "proto3";
package nebula.handshake;
message NebulaHandshake {
NebulaHandshakeDetails Details = 1;
bytes Hmac = 2;
}
message NebulaHandshakeDetails {
bytes Cert = 1;
uint32 InitiatorIndex = 2;
uint32 ResponderIndex = 3;
// Cookie was reserved for an anti-DoS mechanism that was never
// implemented. No released version of nebula has ever populated it; the
// hand-written parser silently skips it on read.
uint64 Cookie = 4 [deprecated = true];
uint64 Time = 5;
uint32 CertVersion = 8;
MultiPortDetails InitiatorMultiPort = 6;
MultiPortDetails ResponderMultiPort = 7;
}
message MultiPortDetails {
bool RxSupported = 1;
bool TxSupported = 2;
uint32 BasePort = 3;
uint32 TotalPorts = 4;
}

116
handshake/helpers_test.go Normal file
View File

@@ -0,0 +1,116 @@
package handshake
import (
"net/netip"
"testing"
"time"
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert"
ct "github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/header"
"github.com/stretchr/testify/require"
)
// testCertState holds cert material for a test peer.
type testCertState struct {
version cert.Version
creds map[cert.Version]*Credential
}
func (s *testCertState) getCredential(v cert.Version) *Credential {
return s.creds[v]
}
func newTestCertState(
t *testing.T, ca cert.Certificate, caKey []byte, name string, networks []netip.Prefix,
) *testCertState {
return newTestCertStateWithCipher(t, ca, caKey, name, networks, noise.CipherChaChaPoly)
}
func newTestCertStateWithCipher(
t *testing.T, ca cert.Certificate, caKey []byte, name string, networks []netip.Prefix,
cipher noise.CipherFunc,
) *testCertState {
t.Helper()
c, _, rawPrivKey, _ := ct.NewTestCert(
cert.Version2, cert.Curve_CURVE25519, ca, caKey,
name, ca.NotBefore(), ca.NotAfter(), networks, nil, nil,
)
priv, _, _, err := cert.UnmarshalPrivateKeyFromPEM(rawPrivKey)
require.NoError(t, err)
hsBytes, err := c.MarshalForHandshakes()
require.NoError(t, err)
ncs := noise.NewCipherSuite(noise.DH25519, cipher, noise.HashSHA256)
return &testCertState{
version: cert.Version2,
creds: map[cert.Version]*Credential{
cert.Version2: NewCredential(c, hsBytes, priv, ncs),
},
}
}
func testVerifier(pool *cert.CAPool) CertVerifier {
return func(c cert.Certificate) (*cert.CachedCertificate, error) {
return pool.VerifyCertificate(time.Now(), c)
}
}
func newTestMachine(
t *testing.T,
cs *testCertState,
verifier CertVerifier,
initiator bool,
localIndex uint32,
) *Machine {
t.Helper()
m, err := NewMachine(
cs.version, cs.getCredential,
verifier, func() (uint32, error) { return localIndex, nil },
initiator, header.HandshakeIXPSK0,
)
require.NoError(t, err)
return m
}
func initiateHandshake(
t *testing.T,
initCS *testCertState, initVerifier CertVerifier,
respCS *testCertState, respVerifier CertVerifier,
) (initM, respM *Machine, respResult *Result, resp []byte, err error) {
t.Helper()
initM = newTestMachine(t, initCS, initVerifier, true, 100)
msg1, merr := initM.Initiate(nil)
require.NoError(t, merr)
respM = newTestMachine(t, respCS, respVerifier, false, 200)
resp, respResult, err = respM.ProcessPacket(nil, msg1)
return
}
func doFullHandshake(
t *testing.T, initCS, respCS *testCertState, caPool *cert.CAPool,
) (initResult, respResult *Result) {
t.Helper()
v := testVerifier(caPool)
initM := newTestMachine(t, initCS, v, true, 1000)
respM := newTestMachine(t, respCS, v, false, 2000)
msg1, err := initM.Initiate(nil)
require.NoError(t, err)
resp, respResult, err := respM.ProcessPacket(nil, msg1)
require.NoError(t, err)
require.NotNil(t, respResult)
require.NotEmpty(t, resp)
_, initResult, err = initM.ProcessPacket(nil, resp)
require.NoError(t, err)
require.NotNil(t, initResult)
return initResult, respResult
}

444
handshake/machine.go Normal file
View File

@@ -0,0 +1,444 @@
package handshake
import (
"bytes"
"fmt"
"slices"
"time"
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
)
// IndexAllocator is called by the Machine to allocate a local index for the
// handshake. It is called at most once, when the first outgoing message that
// carries a payload is built.
//
// Implementations MUST NOT return 0. Zero is reserved as a sentinel meaning
// "no index assigned" on the wire and in the payload-presence checks. If an
// allocator ever returned 0, a legitimate handshake's payload could be
// indistinguishable from an empty one and would be rejected.
type IndexAllocator func() (uint32, error)
// CertVerifier is called by the Machine after reconstructing the peer's
// certificate from the handshake. The verifier performs all validation
// (CA trust, expiry, policy checks, allow lists).
type CertVerifier func(cert.Certificate) (*cert.CachedCertificate, error)
// Result contains the results of a successful handshake.
// Returned by ProcessPacket when the handshake is complete.
type Result struct {
EKey *noise.CipherState
DKey *noise.CipherState
MyCert cert.Certificate
RemoteCert *cert.CachedCertificate
RemoteIndex uint32
LocalIndex uint32
HandshakeTime uint64
MessageIndex uint64 // number of messages exchanged during the handshake
Initiator bool
}
// Machine drives a Noise handshake through N messages. It handles Noise
// protocol operations, certificate reconstruction, and payload encoding.
// Certificate validation is delegated to the caller via CertVerifier.
//
// A Machine is not safe for concurrent use. The caller must ensure that
// Initiate and ProcessPacket are not called concurrently.
//
// Error contract: when ProcessPacket or Initiate returns an error, callers
// must check Failed() to decide what to do next. If Failed() is false the
// underlying noise state was not advanced (the packet was rejected before
// ReadMessage took effect, or the rejection is non-fatal like a stale
// retransmit) and the Machine can accept another packet. If Failed() is
// true the Machine is unrecoverable and the caller must abandon it.
type Machine struct {
hs *noise.HandshakeState
getCred GetCredentialFunc
allocIndex IndexAllocator
verifier CertVerifier
result *Result
msgs []msgFlags
myVersion cert.Version
subtype header.MessageSubType
indexAllocated bool
remoteCertSet bool
payloadSet bool
failed bool
}
// NewMachine creates a handshake state machine. The subtype determines both
// the noise pattern and the per-message content layout. The credential for
// `version` is fetched via getCred and used to seed the noise.HandshakeState.
// IndexAllocator is called lazily when the first outgoing payload is built.
func NewMachine(
version cert.Version,
getCred GetCredentialFunc,
verifier CertVerifier,
allocIndex IndexAllocator,
initiator bool,
subtype header.MessageSubType,
) (*Machine, error) {
info, err := subtypeInfoFor(subtype)
if err != nil {
return nil, err
}
cred := getCred(version)
if cred == nil {
return nil, fmt.Errorf("%w: %v", ErrNoCredential, version)
}
hs, err := cred.buildHandshakeState(initiator, info.pattern)
if err != nil {
return nil, fmt.Errorf("build noise state: %w", err)
}
return &Machine{
hs: hs,
subtype: subtype,
msgs: info.msgs,
getCred: getCred,
allocIndex: allocIndex,
verifier: verifier,
myVersion: version,
result: &Result{
Initiator: initiator,
},
}, nil
}
// Failed returns true if the Machine is in an unrecoverable state.
func (m *Machine) Failed() bool {
return m.failed
}
// Subtype returns the handshake subtype this Machine was built for.
func (m *Machine) Subtype() header.MessageSubType {
return m.subtype
}
// MessageIndex returns the noise handshake message index, which equals the
// wire counter of the most recently sent or received message.
func (m *Machine) MessageIndex() int {
return m.hs.MessageIndex()
}
// requireComplete checks that both a peer cert and payload have been received.
// Marks the machine as failed if not.
func (m *Machine) requireComplete() error {
if !m.payloadSet || !m.remoteCertSet {
m.failed = true
return ErrIncompleteHandshake
}
return nil
}
// myMsgFlags returns the flags for the current outgoing message.
func (m *Machine) myMsgFlags() msgFlags {
idx := m.hs.MessageIndex()
if idx < len(m.msgs) {
return m.msgs[idx]
}
return msgFlags{}
}
// peerMsgFlags returns the flags for the message we just read.
func (m *Machine) peerMsgFlags() msgFlags {
idx := m.hs.MessageIndex() - 1
if idx >= 0 && idx < len(m.msgs) {
return m.msgs[idx]
}
return msgFlags{}
}
// Initiate produces the first handshake message. Only valid for initiators,
// and must be called exactly once before ProcessPacket.
//
// out is a destination buffer the message is appended to and returned. Pass
// nil to allocate fresh, or pass a re-used buffer sliced to length 0 (e.g.
// buf[:0]) with sufficient capacity to avoid allocation.
//
// An error return may not indicate a fatal condition, check Failed() to
// determine if the Machine can still be used.
func (m *Machine) Initiate(out []byte) ([]byte, error) {
if m.failed {
return nil, ErrMachineFailed
}
if !m.result.Initiator {
m.failed = true
return nil, ErrInitiateOnResponder
}
if m.hs.MessageIndex() != 0 {
m.failed = true
return nil, ErrInitiateAlreadyCalled
}
// At MessageIndex=0 with RemoteIndex still zero, buildResponse produces
// header counter 1 and remote index 0, which is what the initial message needs.
out, _, _, err := m.buildResponse(out)
if err != nil {
m.failed = true
return nil, err
}
return out, nil
}
// ProcessPacket handles an incoming handshake message. It advances the Noise
// state, validates the peer certificate via the verifier, and optionally
// produces a response.
//
// out is a destination buffer the response is appended to and returned. Pass
// nil to allocate fresh, or pass a re-used buffer sliced to length 0 (e.g.
// buf[:0]) with sufficient capacity to avoid allocation. The returned slice
// is nil when no outgoing message is produced (handshake complete on this
// side, or final message of a multi-message pattern).
//
// Returns a non-nil Result when the handshake is complete.
// An error return may not indicate a fatal condition, check Failed() to
// determine if the Machine can still be used.
func (m *Machine) ProcessPacket(out, packet []byte) ([]byte, *Result, error) {
if m.failed {
return nil, nil, ErrMachineFailed
}
if len(packet) < header.Len {
return nil, nil, ErrPacketTooShort
}
// Reject packets whose subtype doesn't match the one this Machine was
// built for. A pending handshake that suddenly receives a different
// subtype on its index is either a stray packet that matched by chance
// or a peer protocol violation; drop it without failing the Machine so
// the legitimate retransmit can still complete.
if header.MessageSubType(packet[1]) != m.subtype {
return nil, nil, ErrSubtypeMismatch
}
if m.result.Initiator && m.hs.MessageIndex() == 0 {
m.failed = true
return nil, nil, ErrInitiateNotCalled
}
// The (eKey, dKey) ordering here is correct for IX, where the initiator
// completes the handshake by reading the responder's stage-2 message.
// noise returns (cs1, cs2) where cs1 is the initiator->responder cipher.
// For 3-message patterns where a responder finishes by reading the final
// message, this ordering would be wrong; revisit when XX/pqIX lands.
msg, eKey, dKey, err := m.hs.ReadMessage(nil, packet[header.Len:])
if err != nil {
// Noise ReadMessage failed. The noise library checkpoints and rolls back
// on failure, so the Machine is still alive. The caller can retry with
// a different packet.
return nil, nil, fmt.Errorf("noise ReadMessage: %w", err)
}
// From here on, noise state has advanced. Any error is fatal.
flags := m.peerMsgFlags()
if err := m.processPayload(msg, flags); err != nil {
return nil, nil, err
}
// If ReadMessage derived keys, the handshake is complete. Noise should
// always produce both keys together; asymmetry is a protocol invariant
// violation.
if eKey != nil || dKey != nil {
if eKey == nil || dKey == nil {
m.failed = true
return nil, nil, ErrAsymmetricCipherKeys
}
if err := m.requireComplete(); err != nil {
return nil, nil, err
}
return nil, m.completed(eKey, dKey), nil
}
// ReadMessage didn't complete, produce the next outgoing message
out, dk, ek, err := m.buildResponse(out)
if err != nil {
m.failed = true
return nil, nil, err
}
if ek != nil || dk != nil {
if ek == nil || dk == nil {
m.failed = true
return nil, nil, ErrAsymmetricCipherKeys
}
if err := m.requireComplete(); err != nil {
return nil, nil, err
}
return out, m.completed(ek, dk), nil
}
return out, nil, nil
}
func (m *Machine) completed(eKey, dKey *noise.CipherState) *Result {
m.result.EKey = eKey
m.result.DKey = dKey
m.result.MessageIndex = uint64(m.hs.MessageIndex())
return m.result
}
func (m *Machine) processPayload(msg []byte, flags msgFlags) error {
if len(msg) == 0 {
if flags.expectsPayload || flags.expectsCert {
m.failed = true
return ErrMissingContent
}
return nil
}
payload, err := UnmarshalPayload(msg)
if err != nil {
m.failed = true
return fmt.Errorf("unmarshal handshake: %w", err)
}
// Assert the payload contains exactly what we expect
hasPayloadData := payload.InitiatorIndex != 0 || payload.ResponderIndex != 0 || payload.Time != 0
if hasPayloadData != flags.expectsPayload {
m.failed = true
return ErrUnexpectedContent
}
hasCertData := len(payload.Cert) > 0
if hasCertData != flags.expectsCert {
m.failed = true
return ErrUnexpectedContent
}
// Process payload
if flags.expectsPayload {
if m.result.Initiator {
m.result.RemoteIndex = payload.ResponderIndex
} else {
m.result.RemoteIndex = payload.InitiatorIndex
}
m.result.HandshakeTime = payload.Time
m.payloadSet = true
}
// Process certificate
if flags.expectsCert {
if err := m.validateCert(payload); err != nil {
return err
}
}
return nil
}
func (m *Machine) validateCert(payload Payload) error {
cred := m.getCred(m.myVersion)
if cred == nil {
m.failed = true
return fmt.Errorf("%w: %v", ErrNoCredential, m.myVersion)
}
rc, err := cert.Recombine(
cert.Version(payload.CertVersion),
payload.Cert,
m.hs.PeerStatic(),
cred.Cert.Curve(),
)
if err != nil {
m.failed = true
return fmt.Errorf("recombine cert: %w", err)
}
if !bytes.Equal(rc.PublicKey(), m.hs.PeerStatic()) {
m.failed = true
return ErrPublicKeyMismatch
}
// Version negotiation, if the peer sent a different version and we have it, switch
if rc.Version() != m.myVersion {
if m.getCred(rc.Version()) != nil {
m.myVersion = rc.Version()
}
}
verified, err := m.verifier(rc)
if err != nil {
m.failed = true
return fmt.Errorf("verify cert: %w", err)
}
m.result.RemoteCert = verified
m.remoteCertSet = true
return nil
}
func (m *Machine) marshalOutgoing(flags msgFlags) ([]byte, error) {
if !flags.expectsPayload && !flags.expectsCert {
return nil, nil
}
var p Payload
if flags.expectsPayload {
if !m.indexAllocated {
index, err := m.allocIndex()
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrIndexAllocation, err)
}
m.result.LocalIndex = index
m.indexAllocated = true
}
if m.result.Initiator {
p.InitiatorIndex = m.result.LocalIndex
} else {
p.ResponderIndex = m.result.LocalIndex
p.InitiatorIndex = m.result.RemoteIndex
}
p.Time = uint64(time.Now().UnixNano())
}
if flags.expectsCert {
cred := m.getCred(m.myVersion)
if cred == nil {
return nil, fmt.Errorf("%w: %v", ErrNoCredential, m.myVersion)
}
p.Cert = cred.Bytes
p.CertVersion = uint32(cred.Cert.Version())
m.result.MyCert = cred.Cert
}
return MarshalPayload(nil, p), nil
}
func (m *Machine) buildResponse(out []byte) ([]byte, *noise.CipherState, *noise.CipherState, error) {
flags := m.myMsgFlags()
hsBytes, err := m.marshalOutgoing(flags)
if err != nil {
return nil, nil, nil, err
}
// Extend out by header.Len to make room for the header. slices.Grow is a
// no-op when the cap is already sufficient (the zero-copy case where the
// caller passed a pre-sized buffer). header.Encode overwrites the new
// bytes, so they don't need to be zeroed.
start := len(out)
out = slices.Grow(out, header.Len)[:start+header.Len]
header.Encode(
out[start:],
header.Version, header.Handshake, m.subtype,
m.result.RemoteIndex,
uint64(m.hs.MessageIndex()+1),
)
// noise.WriteMessage appends the encrypted handshake message to out,
// reusing capacity when present.
//
// The (dKey, eKey) ordering here is correct for IX, where the responder
// completes the handshake by writing the stage-2 message. noise returns
// (cs1, cs2) where cs1 is the initiator->responder cipher (which is the
// responder's decrypt key). For 3-message patterns where an initiator
// finishes by writing the final message, this ordering would be wrong;
// revisit when XX/pqIX lands.
out, dKey, eKey, err := m.hs.WriteMessage(out, hsBytes)
if err != nil {
return nil, nil, nil, fmt.Errorf("noise WriteMessage: %w", err)
}
return out, dKey, eKey, nil
}

662
handshake/machine_test.go Normal file
View File

@@ -0,0 +1,662 @@
package handshake
import (
"net/netip"
"testing"
"time"
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert"
ct "github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/noiseutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMachineIXHappyPath(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
initCS := newTestCertState(t, ca, caKey, "initiator", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
respCS := newTestCertState(t, ca, caKey, "responder", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
initR, respR := doFullHandshake(t, initCS, respCS, caPool)
assert.Equal(t, "responder", initR.RemoteCert.Certificate.Name())
assert.Equal(t, "initiator", respR.RemoteCert.Certificate.Name())
assert.Equal(t, uint32(1000), initR.LocalIndex)
assert.Equal(t, uint32(2000), initR.RemoteIndex)
assert.Equal(t, uint32(2000), respR.LocalIndex)
assert.Equal(t, uint32(1000), respR.RemoteIndex)
assert.Equal(t, uint64(2), initR.MessageIndex, "IX has 2 messages")
assert.Equal(t, uint64(2), respR.MessageIndex, "IX has 2 messages")
ct1, err := initR.EKey.Encrypt(nil, nil, []byte("hello"))
require.NoError(t, err)
pt1, err := respR.DKey.Decrypt(nil, nil, ct1)
require.NoError(t, err)
assert.Equal(t, []byte("hello"), pt1)
ct2, err := respR.EKey.Encrypt(nil, nil, []byte("world"))
require.NoError(t, err)
pt2, err := initR.DKey.Decrypt(nil, nil, ct2)
require.NoError(t, err)
assert.Equal(t, []byte("world"), pt2)
}
func TestMachineInitiateErrors(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
v := testVerifier(caPool)
t.Run("initiate on responder", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
_, err := m.Initiate(nil)
require.ErrorIs(t, err, ErrInitiateOnResponder)
assert.True(t, m.Failed())
})
t.Run("initiate called twice", func(t *testing.T) {
m := newTestMachine(t, cs, v, true, 100)
_, err := m.Initiate(nil)
require.NoError(t, err)
_, err = m.Initiate(nil)
require.ErrorIs(t, err, ErrInitiateAlreadyCalled)
assert.True(t, m.Failed())
})
t.Run("process packet before initiate on initiator", func(t *testing.T) {
m := newTestMachine(t, cs, v, true, 100)
_, _, err := m.ProcessPacket(nil, make([]byte, 100))
require.ErrorIs(t, err, ErrInitiateNotCalled)
assert.True(t, m.Failed())
})
t.Run("calling failed machine", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
_, err := m.Initiate(nil) // fails: responder
require.Error(t, err)
_, err = m.Initiate(nil) // fails: already failed
require.ErrorIs(t, err, ErrMachineFailed)
})
}
func TestMachineProcessPacketErrors(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
v := testVerifier(caPool)
t.Run("packet too short", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
_, _, err := m.ProcessPacket(nil, []byte{1, 2, 3})
require.ErrorIs(t, err, ErrPacketTooShort)
assert.False(t, m.Failed(), "short packet should not kill machine")
})
t.Run("noise decryption failure is recoverable", func(t *testing.T) {
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
initM := newTestMachine(t, initCS, v, true, 100)
msg1, err := initM.Initiate(nil)
require.NoError(t, err)
respM := newTestMachine(t, cs, v, false, 200)
resp, _, err := respM.ProcessPacket(nil, msg1)
require.NoError(t, err)
corrupted := make([]byte, len(resp))
copy(corrupted, resp)
for i := header.Len; i < len(corrupted); i++ {
corrupted[i] ^= 0xff
}
_, _, err = initM.ProcessPacket(nil, corrupted)
require.Error(t, err)
assert.False(t, initM.Failed(), "noise failure should be recoverable")
// And the machine should still complete a real handshake afterward.
_, result, err := initM.ProcessPacket(nil, resp)
require.NoError(t, err)
require.NotNil(t, result, "initiator should complete on the legitimate response")
})
t.Run("invalid cert is fatal", func(t *testing.T) {
otherCA, _, otherCAKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
otherCS := newTestCertState(t, otherCA, otherCAKey, "other", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
initM := newTestMachine(t, otherCS, testVerifier(ct.NewTestCAPool(otherCA)), true, 100)
msg1, err := initM.Initiate(nil)
require.NoError(t, err)
respM := newTestMachine(t, cs, v, false, 200)
_, _, err = respM.ProcessPacket(nil, msg1)
require.Error(t, err)
assert.True(t, respM.Failed(), "cert validation failure should kill machine")
})
t.Run("subtype mismatch is recoverable", func(t *testing.T) {
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
initM := newTestMachine(t, initCS, v, true, 100)
msg1, err := initM.Initiate(nil)
require.NoError(t, err)
// Mutate the subtype byte (offset 1 in the header) to a value the
// responder Machine wasn't built for.
bad := make([]byte, len(msg1))
copy(bad, msg1)
bad[1] = 0xff
respM := newTestMachine(t, cs, v, false, 200)
_, _, err = respM.ProcessPacket(nil, bad)
require.ErrorIs(t, err, ErrSubtypeMismatch)
assert.False(t, respM.Failed(), "subtype mismatch should not kill the machine")
// And the machine should still complete a real handshake afterward.
resp, result, err := respM.ProcessPacket(nil, msg1)
require.NoError(t, err)
require.NotNil(t, result, "responder should complete on the legitimate stage-1 packet")
assert.NotEmpty(t, resp, "responder should produce a stage-2 reply")
})
}
// TestMachineProcessPayload exercises processPayload's internal validation
// directly. Most of these failure modes can't be reached black-box once the
// subtype check at the top of ProcessPacket gates external callers, so we
// drive them by hand here for coverage.
func TestMachineProcessPayload(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
v := testVerifier(caPool)
t.Run("empty message with expects fails", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
err := m.processPayload(nil, msgFlags{expectsPayload: true, expectsCert: true})
require.ErrorIs(t, err, ErrMissingContent)
assert.True(t, m.Failed())
})
t.Run("empty message with no expects passes", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
err := m.processPayload(nil, msgFlags{})
require.NoError(t, err)
assert.False(t, m.Failed())
})
t.Run("malformed protobuf is fatal", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
err := m.processPayload([]byte{0xff, 0xff, 0xff}, msgFlags{expectsPayload: true, expectsCert: true})
require.Error(t, err)
assert.True(t, m.Failed())
})
t.Run("unexpected payload data is fatal", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
// A payload with index data when none was expected.
bytes := MarshalPayload(nil, Payload{InitiatorIndex: 42, Time: 1})
err := m.processPayload(bytes, msgFlags{expectsPayload: false, expectsCert: false})
require.ErrorIs(t, err, ErrUnexpectedContent)
assert.True(t, m.Failed())
})
t.Run("unexpected cert data is fatal", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
// A payload with cert when none was expected.
bytes := MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2})
err := m.processPayload(bytes, msgFlags{expectsPayload: false, expectsCert: false})
require.ErrorIs(t, err, ErrUnexpectedContent)
assert.True(t, m.Failed())
})
t.Run("missing payload data when expected is fatal", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
// Cert present, but no index/time fields.
bytes := MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2})
err := m.processPayload(bytes, msgFlags{expectsPayload: true, expectsCert: true})
require.ErrorIs(t, err, ErrUnexpectedContent)
assert.True(t, m.Failed())
})
}
// TestMachineRequireComplete checks the fail-on-incomplete-handshake path
// directly. Like processPayload above this isn't reachable from a normal IX
// flow, so we drive it by hand.
func TestMachineRequireComplete(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
v := testVerifier(caPool)
t.Run("missing both fails", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
err := m.requireComplete()
require.ErrorIs(t, err, ErrIncompleteHandshake)
assert.True(t, m.Failed())
})
t.Run("payload only fails", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
m.payloadSet = true
err := m.requireComplete()
require.ErrorIs(t, err, ErrIncompleteHandshake)
assert.True(t, m.Failed())
})
t.Run("cert only fails", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
m.remoteCertSet = true
err := m.requireComplete()
require.ErrorIs(t, err, ErrIncompleteHandshake)
assert.True(t, m.Failed())
})
t.Run("both set passes", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
m.payloadSet = true
m.remoteCertSet = true
err := m.requireComplete()
require.NoError(t, err)
assert.False(t, m.Failed())
})
}
func TestMachineAESCipher(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
initCS := newTestCertStateWithCipher(
t, ca, caKey, "init",
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
noiseutil.CipherAESGCM,
)
respCS := newTestCertStateWithCipher(
t, ca, caKey, "resp",
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")},
noiseutil.CipherAESGCM,
)
initR, respR := doFullHandshake(t, initCS, respCS, caPool)
ct1, err := initR.EKey.Encrypt(nil, nil, []byte("works"))
require.NoError(t, err)
pt1, err := respR.DKey.Decrypt(nil, nil, ct1)
require.NoError(t, err)
assert.Equal(t, []byte("works"), pt1)
ct2, err := respR.EKey.Encrypt(nil, nil, []byte("back"))
require.NoError(t, err)
pt2, err := initR.DKey.Decrypt(nil, nil, ct2)
require.NoError(t, err)
assert.Equal(t, []byte("back"), pt2)
}
func TestResultFields(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
initR, respR := doFullHandshake(t, initCS, respCS, caPool)
assert.True(t, initR.Initiator)
assert.False(t, respR.Initiator)
assert.NotZero(t, initR.HandshakeTime)
assert.NotZero(t, respR.HandshakeTime)
assert.NotNil(t, initR.RemoteCert)
assert.NotNil(t, respR.RemoteCert)
}
func TestMachineBufferReuse(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
v := testVerifier(caPool)
initM := newTestMachine(t, initCS, v, true, 1000)
respM := newTestMachine(t, respCS, v, false, 2000)
msg1, err := initM.Initiate(nil)
require.NoError(t, err)
t.Run("response writes into provided buffer", func(t *testing.T) {
buf := make([]byte, 0, 4096)
resp, result, err := respM.ProcessPacket(buf, msg1)
require.NoError(t, err)
require.NotNil(t, result)
assert.NotEmpty(t, resp, "response should have content")
assert.Equal(t, &buf[:1][0], &resp[:1][0],
"response should reuse the provided buffer's backing array")
})
t.Run("initiate writes into provided buffer", func(t *testing.T) {
initM2 := newTestMachine(t, initCS, v, true, 3000)
buf := make([]byte, 0, 4096)
msg, err := initM2.Initiate(buf)
require.NoError(t, err)
assert.NotEmpty(t, msg, "initiate should have content")
assert.Equal(t, &buf[:1][0], &msg[:1][0],
"initiate should reuse the provided buffer's backing array")
})
t.Run("nil out still works", func(t *testing.T) {
initM2 := newTestMachine(t, initCS, v, true, 4000)
respM2 := newTestMachine(t, respCS, v, false, 5000)
msg1, err := initM2.Initiate(nil)
require.NoError(t, err)
resp, _, err := respM2.ProcessPacket(nil, msg1)
require.NoError(t, err)
out, result, err := initM2.ProcessPacket(nil, resp)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Nil(t, out, "initiator should have no response for IX msg2")
})
}
func TestMachineMsgIndexTracking(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
v := testVerifier(caPool)
initM := newTestMachine(t, initCS, v, true, 100)
respM := newTestMachine(t, respCS, v, false, 200)
msg1, err := initM.Initiate(nil)
require.NoError(t, err)
resp1, result1, err := respM.ProcessPacket(nil, msg1)
require.NoError(t, err)
assert.NotNil(t, result1)
_, result2, err := initM.ProcessPacket(nil, resp1)
require.NoError(t, err)
assert.NotNil(t, result2)
}
func TestMachineThreeMessagePattern(t *testing.T) {
registerTestXXInfo(t)
// Use HandshakeXX (3 messages) to verify the Machine handles multi-message
// patterns correctly. XX flow:
// msg1 (I->R): [E] - payload only, no cert
// msg2 (R->I): [E, ee, S, es] - payload + cert
// msg3 (I->R): [S, se] - cert only (no payload, not first two)
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
v := testVerifier(caPool)
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
initM, err := NewMachine(
cert.Version2,
initCS.getCredential, v,
func() (uint32, error) { return 1000, nil },
true, header.HandshakeXXPSK0,
)
require.NoError(t, err)
respM, err := NewMachine(
cert.Version2,
respCS.getCredential, v,
func() (uint32, error) { return 2000, nil },
false, header.HandshakeXXPSK0,
)
require.NoError(t, err)
// msg1: initiator -> responder (E only, no cert)
msg1, err := initM.Initiate(nil)
require.NoError(t, err)
assert.NotEmpty(t, msg1)
// Responder processes msg1, should not complete yet, should produce msg2
msg2, result, err := respM.ProcessPacket(nil, msg1)
require.NoError(t, err)
assert.Nil(t, result, "XX should not complete on msg1")
assert.NotEmpty(t, msg2, "responder should produce msg2")
// Initiator processes msg2: gets responder's cert, produces msg3, and
// completes (WriteMessage for msg3 derives keys)
msg3, initResult, err := initM.ProcessPacket(nil, msg2)
require.NoError(t, err)
require.NotNil(t, initResult, "XX initiator should complete after reading msg2 and writing msg3")
assert.NotEmpty(t, msg3, "initiator should produce msg3")
assert.Equal(t, "resp", initResult.RemoteCert.Certificate.Name())
// Responder processes msg3: gets initiator's cert and completes
_, respResult, err := respM.ProcessPacket(nil, msg3)
require.NoError(t, err)
require.NotNil(t, respResult, "XX responder should complete on msg3")
assert.Equal(t, "init", respResult.RemoteCert.Certificate.Name())
assert.Equal(t, uint64(3), initResult.MessageIndex, "XX has 3 messages")
assert.Equal(t, uint64(3), respResult.MessageIndex, "XX has 3 messages")
// Verify keys work
ct1, err := initResult.EKey.Encrypt(nil, nil, []byte("three messages"))
require.NoError(t, err)
pt1, err := respResult.DKey.Decrypt(nil, nil, ct1)
require.NoError(t, err)
assert.Equal(t, []byte("three messages"), pt1)
}
// NOTE: ErrIncompleteHandshake is tested implicitly. It can't be triggered with
// IX since the cert is always in the payload. A 3-message pattern test (HybridIX)
// should exercise the case where cert arrives in msg3 and verify that completing
// without it fails.
func TestMachineExpiredCert(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519,
time.Now().Add(-24*time.Hour), time.Now().Add(24*time.Hour),
nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
expCert, _, expKeyPEM, _ := ct.NewTestCert(
cert.Version2, cert.Curve_CURVE25519, ca, caKey,
"expired", time.Now().Add(-2*time.Hour), time.Now().Add(-1*time.Hour),
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}, nil, nil,
)
expKey, _, _, err := cert.UnmarshalPrivateKeyFromPEM(expKeyPEM)
require.NoError(t, err)
expHsBytes, err := expCert.MarshalForHandshakes()
require.NoError(t, err)
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
expiredCS := &testCertState{
version: cert.Version2,
creds: map[cert.Version]*Credential{
cert.Version2: NewCredential(expCert, expHsBytes, expKey, ncs),
},
}
respCS := newTestCertState(
t, ca, caKey, "responder",
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")},
)
_, respM, _, _, err := initiateHandshake(
t, expiredCS, testVerifier(caPool),
respCS, testVerifier(caPool),
)
require.ErrorContains(t, err, "verify cert")
assert.True(t, respM.Failed())
}
func TestMachineNoCertNetworks(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
caHsBytes, err := ca.MarshalForHandshakes()
require.NoError(t, err)
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
noNetCS := &testCertState{
version: cert.Version2,
creds: map[cert.Version]*Credential{
cert.Version2: NewCredential(ca, caHsBytes, caKey, ncs),
},
}
respCS := newTestCertState(
t, ca, caKey, "responder",
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")},
)
_, respM, _, _, err := initiateHandshake(
t, noNetCS, testVerifier(caPool),
respCS, testVerifier(caPool),
)
require.Error(t, err)
assert.True(t, respM.Failed())
}
func TestMachineDifferentCAs(t *testing.T) {
ca1, _, caKey1, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
ca2, _, caKey2, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
initCS := newTestCertState(
t, ca1, caKey1, "init",
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
)
respCS := newTestCertState(
t, ca2, caKey2, "resp",
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")},
)
_, respM, _, _, err := initiateHandshake(
t, initCS, testVerifier(ct.NewTestCAPool(ca1)),
respCS, testVerifier(ct.NewTestCAPool(ca2)),
)
require.ErrorContains(t, err, "verify cert")
assert.True(t, respM.Failed())
}
func TestMachineVersionNegotiation(t *testing.T) {
ca1, _, caKey1, _ := ct.NewTestCaCert(
cert.Version1, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
ca2, _, caKey2, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca1, ca2)
makeMultiVersionResp := func(t *testing.T) *testCertState {
t.Helper()
respCertV1, _, respKeyPEM, _ := ct.NewTestCert(
cert.Version1, cert.Curve_CURVE25519, ca1, caKey1, "resp",
ca1.NotBefore(), ca1.NotAfter(),
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, nil, nil,
)
respKey, _, _, _ := cert.UnmarshalPrivateKeyFromPEM(respKeyPEM)
respCertV2, _ := ct.NewTestCertDifferentVersion(respCertV1, cert.Version2, ca2, caKey2)
respHsV1, _ := respCertV1.MarshalForHandshakes()
respHsV2, _ := respCertV2.MarshalForHandshakes()
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
return &testCertState{
version: cert.Version1,
creds: map[cert.Version]*Credential{
cert.Version1: NewCredential(respCertV1, respHsV1, respKey, ncs),
cert.Version2: NewCredential(respCertV2, respHsV2, respKey, ncs),
},
}
}
t.Run("responder matches initiator version", func(t *testing.T) {
initCS := newTestCertState(
t, ca2, caKey2, "init",
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
)
respCS := makeMultiVersionResp(t)
v := testVerifier(caPool)
initM, _, respResult, resp, err := initiateHandshake(
t, initCS, v,
respCS, v,
)
require.NoError(t, err)
require.NotNil(t, respResult)
assert.Equal(t, cert.Version2, respResult.MyCert.Version(),
"responder should negotiate to initiator's version")
_, initResult, err := initM.ProcessPacket(nil, resp)
require.NoError(t, err)
require.NotNil(t, initResult)
assert.Equal(t, cert.Version2, initResult.RemoteCert.Certificate.Version(),
"initiator should see V2 cert from responder")
})
t.Run("responder keeps version when no match available", func(t *testing.T) {
initCS := newTestCertState(
t, ca2, caKey2, "init",
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
)
respCert, _, respKeyPEM, _ := ct.NewTestCert(
cert.Version1, cert.Curve_CURVE25519, ca1, caKey1, "resp",
ca1.NotBefore(), ca1.NotAfter(),
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, nil, nil,
)
respKey, _, _, _ := cert.UnmarshalPrivateKeyFromPEM(respKeyPEM)
respHs, _ := respCert.MarshalForHandshakes()
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
respCS := &testCertState{
version: cert.Version1,
creds: map[cert.Version]*Credential{
cert.Version1: NewCredential(respCert, respHs, respKey, ncs),
},
}
v := testVerifier(caPool)
_, _, respResult, _, err := initiateHandshake(
t, initCS, v,
respCS, v,
)
require.NoError(t, err)
require.NotNil(t, respResult)
assert.Equal(t, cert.Version1, respResult.MyCert.Version(),
"responder should keep V1 when V2 not available")
})
}

54
handshake/patterns.go Normal file
View File

@@ -0,0 +1,54 @@
package handshake
import (
"fmt"
"github.com/flynn/noise"
"github.com/slackhq/nebula/header"
)
// msgFlags tracks what application data a handshake message carries.
type msgFlags struct {
expectsPayload bool // message carries indexes and time
expectsCert bool // message carries the certificate
}
// subtypeInfo bundles the noise pattern with the per-message flags for a
// given handshake subtype.
type subtypeInfo struct {
pattern noise.HandshakePattern
msgs []msgFlags
}
// subtypeInfos defines the noise pattern and message content layout for each
// handshake subtype.
var subtypeInfos = map[header.MessageSubType]subtypeInfo{
// IX: 2 messages, both carry payload and cert
header.HandshakeIXPSK0: {
pattern: noise.HandshakeIX,
msgs: []msgFlags{
{expectsPayload: true, expectsCert: true},
{expectsPayload: true, expectsCert: true},
},
},
// XX: 3 messages
// msg1 (I->R): payload only
// msg2 (R->I): payload + cert
// msg3 (I->R): cert only
//header.HandshakeXXPSK0: {
// pattern: noise.HandshakeXX,
// msgs: []msgFlags{
// {expectsPayload: true, expectsCert: false},
// {expectsPayload: true, expectsCert: true},
// {expectsPayload: false, expectsCert: true},
// },
//},
}
func subtypeInfoFor(subtype header.MessageSubType) (subtypeInfo, error) {
if info, ok := subtypeInfos[subtype]; ok {
return info, nil
}
return subtypeInfo{}, fmt.Errorf("%w: %d", ErrUnknownSubtype, subtype)
}

View File

@@ -0,0 +1,63 @@
package handshake
import (
"testing"
"github.com/flynn/noise"
"github.com/slackhq/nebula/header"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSubtypeInfo(t *testing.T) {
t.Run("IX", func(t *testing.T) {
info, err := subtypeInfoFor(header.HandshakeIXPSK0)
require.NoError(t, err)
assert.Equal(t, noise.HandshakeIX.Name, info.pattern.Name)
require.Len(t, info.msgs, 2)
// msg1: payload + cert
assert.True(t, info.msgs[0].expectsPayload)
assert.True(t, info.msgs[0].expectsCert)
// msg2: payload + cert
assert.True(t, info.msgs[1].expectsPayload)
assert.True(t, info.msgs[1].expectsCert)
})
t.Run("XX", func(t *testing.T) {
registerTestXXInfo(t)
info, err := subtypeInfoFor(header.HandshakeXXPSK0)
require.NoError(t, err)
assert.Equal(t, noise.HandshakeXX.Name, info.pattern.Name)
require.Len(t, info.msgs, 3)
// msg1: payload only
assert.True(t, info.msgs[0].expectsPayload)
assert.False(t, info.msgs[0].expectsCert)
// msg2: payload + cert
assert.True(t, info.msgs[1].expectsPayload)
assert.True(t, info.msgs[1].expectsCert)
// msg3: cert only
assert.False(t, info.msgs[2].expectsPayload)
assert.True(t, info.msgs[2].expectsCert)
})
t.Run("unknown subtype returns error", func(t *testing.T) {
_, err := subtypeInfoFor(99)
require.ErrorIs(t, err, ErrUnknownSubtype)
})
}
// registerTestXXInfo temporarily registers XX subtype info for testing.
func registerTestXXInfo(t *testing.T) {
t.Helper()
subtypeInfos[header.HandshakeXXPSK0] = subtypeInfo{
pattern: noise.HandshakeXX,
msgs: []msgFlags{
{expectsPayload: true, expectsCert: false},
{expectsPayload: true, expectsCert: true},
{expectsPayload: false, expectsCert: true},
},
}
t.Cleanup(func() {
delete(subtypeInfos, header.HandshakeXXPSK0)
})
}

173
handshake/payload.go Normal file
View File

@@ -0,0 +1,173 @@
package handshake
import (
"errors"
"math"
"google.golang.org/protobuf/encoding/protowire"
)
var (
errInvalidHandshakeMessage = errors.New("invalid handshake message")
errInvalidHandshakeDetails = errors.New("invalid handshake details")
)
// Payload represents the decoded fields of a handshake message.
// Wire format is protobuf-compatible with NebulaHandshake{Details: NebulaHandshakeDetails{...}}.
type Payload struct {
Cert []byte
InitiatorIndex uint32
ResponderIndex uint32
Time uint64
CertVersion uint32
}
// Proto field numbers for NebulaHandshakeDetails
const (
fieldCert = 1 // bytes
fieldInitiatorIndex = 2 // uint32
fieldResponderIndex = 3 // uint32
fieldTime = 5 // uint64
fieldCertVersion = 8 // uint32
)
// MarshalPayload encodes a handshake payload in protobuf wire format compatible
// with NebulaHandshake{Details: NebulaHandshakeDetails{...}}.
// Returns out (which may be nil), with the marshalled Payload appended to it.
func MarshalPayload(out []byte, p Payload) []byte {
var details []byte
if len(p.Cert) > 0 {
details = protowire.AppendTag(details, fieldCert, protowire.BytesType)
details = protowire.AppendBytes(details, p.Cert)
}
if p.InitiatorIndex != 0 {
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = protowire.AppendVarint(details, uint64(p.InitiatorIndex))
}
if p.ResponderIndex != 0 {
details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType)
details = protowire.AppendVarint(details, uint64(p.ResponderIndex))
}
if p.Time != 0 {
details = protowire.AppendTag(details, fieldTime, protowire.VarintType)
details = protowire.AppendVarint(details, p.Time)
}
if p.CertVersion != 0 {
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
details = protowire.AppendVarint(details, uint64(p.CertVersion))
}
out = protowire.AppendTag(out, 1, protowire.BytesType)
out = protowire.AppendBytes(out, details)
return out
}
// UnmarshalPayload decodes a protobuf-encoded NebulaHandshake message.
func UnmarshalPayload(b []byte) (Payload, error) {
var p Payload
for len(b) > 0 {
num, typ, n := protowire.ConsumeTag(b)
if n < 0 {
return p, errInvalidHandshakeMessage
}
b = b[n:]
switch {
case num == 1 && typ == protowire.BytesType:
details, n := protowire.ConsumeBytes(b)
if n < 0 {
return p, errInvalidHandshakeMessage
}
b = b[n:]
if err := unmarshalPayloadDetails(&p, details); err != nil {
return p, err
}
default:
n := protowire.ConsumeFieldValue(num, typ, b)
if n < 0 {
return p, errInvalidHandshakeMessage
}
b = b[n:]
}
}
return p, nil
}
func unmarshalPayloadDetails(p *Payload, b []byte) error {
for len(b) > 0 {
num, typ, n := protowire.ConsumeTag(b)
if n < 0 {
return errInvalidHandshakeDetails
}
b = b[n:]
// For known field numbers, reject any non-matching wire type as a
// hard error rather than silently skipping. The caller will catch
// missing-field cases downstream, but a wire-type mismatch on a tag
// we know is a peer protocol violation worth flagging here.
// Repeated occurrences of a singular field follow proto3 last-wins.
switch num {
case fieldCert:
if typ != protowire.BytesType {
return errInvalidHandshakeDetails
}
v, n := protowire.ConsumeBytes(b)
if n < 0 {
return errInvalidHandshakeDetails
}
p.Cert = append([]byte(nil), v...)
b = b[n:]
case fieldInitiatorIndex:
if typ != protowire.VarintType {
return errInvalidHandshakeDetails
}
v, n := protowire.ConsumeVarint(b)
if n < 0 || v > math.MaxUint32 {
return errInvalidHandshakeDetails
}
p.InitiatorIndex = uint32(v)
b = b[n:]
case fieldResponderIndex:
if typ != protowire.VarintType {
return errInvalidHandshakeDetails
}
v, n := protowire.ConsumeVarint(b)
if n < 0 || v > math.MaxUint32 {
return errInvalidHandshakeDetails
}
p.ResponderIndex = uint32(v)
b = b[n:]
case fieldTime:
if typ != protowire.VarintType {
return errInvalidHandshakeDetails
}
v, n := protowire.ConsumeVarint(b)
if n < 0 {
return errInvalidHandshakeDetails
}
p.Time = v
b = b[n:]
case fieldCertVersion:
if typ != protowire.VarintType {
return errInvalidHandshakeDetails
}
v, n := protowire.ConsumeVarint(b)
if n < 0 || v > math.MaxUint32 {
return errInvalidHandshakeDetails
}
p.CertVersion = uint32(v)
b = b[n:]
default:
n := protowire.ConsumeFieldValue(num, typ, b)
if n < 0 {
return errInvalidHandshakeDetails
}
b = b[n:]
}
}
return nil
}

361
handshake/payload_test.go Normal file
View File

@@ -0,0 +1,361 @@
package handshake
import (
"bytes"
"math"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/encoding/protowire"
)
func TestPayloadRoundTrip(t *testing.T) {
t.Run("all fields set", func(t *testing.T) {
data := MarshalPayload(nil, Payload{
Cert: []byte("test-cert-bytes"),
CertVersion: 2,
InitiatorIndex: 12345,
ResponderIndex: 67890,
Time: 1234567890,
})
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, []byte("test-cert-bytes"), got.Cert)
assert.Equal(t, uint32(12345), got.InitiatorIndex)
assert.Equal(t, uint32(67890), got.ResponderIndex)
assert.Equal(t, uint64(1234567890), got.Time)
assert.Equal(t, uint32(2), got.CertVersion)
})
t.Run("minimal fields", func(t *testing.T) {
data := MarshalPayload(nil, Payload{InitiatorIndex: 1})
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, uint32(1), got.InitiatorIndex)
assert.Equal(t, uint32(0), got.ResponderIndex)
assert.Equal(t, uint64(0), got.Time)
assert.Nil(t, got.Cert)
})
t.Run("empty payload", func(t *testing.T) {
data := MarshalPayload(nil, Payload{})
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, uint32(0), got.InitiatorIndex)
})
t.Run("large cert bytes", func(t *testing.T) {
bigCert := make([]byte, 4096)
for i := range bigCert {
bigCert[i] = byte(i % 256)
}
data := MarshalPayload(nil, Payload{
Cert: bigCert,
CertVersion: 2,
InitiatorIndex: 999,
})
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, bigCert, got.Cert)
assert.Equal(t, uint32(999), got.InitiatorIndex)
})
t.Run("append to existing buffer", func(t *testing.T) {
prefix := []byte("prefix")
data := MarshalPayload(prefix, Payload{InitiatorIndex: 42})
assert.Equal(t, []byte("prefix"), data[:6])
got, err := UnmarshalPayload(data[6:])
require.NoError(t, err)
assert.Equal(t, uint32(42), got.InitiatorIndex)
})
}
func TestPayloadUnknownFields(t *testing.T) {
t.Run("unknown field in outer message is skipped", func(t *testing.T) {
// Marshal a normal payload then append an unknown field (field 99, varint)
data := MarshalPayload(nil, Payload{InitiatorIndex: 42})
data = protowire.AppendTag(data, 99, protowire.VarintType)
data = protowire.AppendVarint(data, 12345)
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, uint32(42), got.InitiatorIndex)
})
t.Run("unknown field in details is skipped", func(t *testing.T) {
// Build details with a known field + unknown field
var details []byte
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = protowire.AppendVarint(details, 77)
// Unknown field 50, varint
details = protowire.AppendTag(details, 50, protowire.VarintType)
details = protowire.AppendVarint(details, 9999)
// Another known field after the unknown one
details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType)
details = protowire.AppendVarint(details, 88)
// Wrap in outer message
var data []byte
data = protowire.AppendTag(data, 1, protowire.BytesType)
data = protowire.AppendBytes(data, details)
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, uint32(77), got.InitiatorIndex)
assert.Equal(t, uint32(88), got.ResponderIndex)
})
t.Run("reserved fields 6 and 7 are skipped", func(t *testing.T) {
// Fields 6 and 7 are reserved in the proto definition
var details []byte
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = protowire.AppendVarint(details, 100)
details = protowire.AppendTag(details, 6, protowire.VarintType)
details = protowire.AppendVarint(details, 1)
details = protowire.AppendTag(details, 7, protowire.VarintType)
details = protowire.AppendVarint(details, 2)
var data []byte
data = protowire.AppendTag(data, 1, protowire.BytesType)
data = protowire.AppendBytes(data, details)
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, uint32(100), got.InitiatorIndex)
})
}
func TestPayloadBytesConsumed(t *testing.T) {
t.Run("all bytes consumed on valid input", func(t *testing.T) {
original := Payload{
Cert: []byte("cert"),
CertVersion: 2,
InitiatorIndex: 100,
ResponderIndex: 200,
Time: 999,
}
data := MarshalPayload(nil, original)
got, err := UnmarshalPayload(data)
require.NoError(t, err)
// Re-marshal and compare — proves we consumed and reproduced all fields
remarshaled := MarshalPayload(nil, got)
assert.Equal(t, data, remarshaled)
})
}
// wrapDetails wraps raw detail bytes in the outer NebulaHandshake envelope
// so UnmarshalPayload can reach unmarshalPayloadDetails.
func wrapDetails(details []byte) []byte {
var out []byte
out = protowire.AppendTag(out, 1, protowire.BytesType)
out = protowire.AppendBytes(out, details)
return out
}
func TestPayloadUnmarshalErrors(t *testing.T) {
t.Run("nil input", func(t *testing.T) {
got, err := UnmarshalPayload(nil)
require.NoError(t, err)
assert.Equal(t, uint32(0), got.InitiatorIndex)
})
t.Run("truncated outer tag", func(t *testing.T) {
_, err := UnmarshalPayload([]byte{0x80})
assert.Error(t, err)
})
t.Run("truncated outer details field", func(t *testing.T) {
_, err := UnmarshalPayload([]byte{0x0a, 0x64, 0x01, 0x02, 0x03, 0x04, 0x05})
assert.Error(t, err)
})
t.Run("truncated outer unknown field", func(t *testing.T) {
// Valid tag for unknown field 99 varint, but no value follows
var data []byte
data = protowire.AppendTag(data, 99, protowire.VarintType)
_, err := UnmarshalPayload(data)
assert.Error(t, err)
})
t.Run("truncated details tag", func(t *testing.T) {
_, err := UnmarshalPayload(wrapDetails([]byte{0x80}))
assert.Error(t, err)
})
t.Run("truncated cert bytes", func(t *testing.T) {
// Field 1 (cert), bytes type, length 10 but only 2 bytes
var details []byte
details = protowire.AppendTag(details, fieldCert, protowire.BytesType)
details = append(details, 0x0a, 0x01, 0x02) // length 10, only 2 bytes
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("truncated initiator index varint", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = append(details, 0x80) // incomplete varint
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("truncated responder index varint", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType)
details = append(details, 0x80)
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("truncated time varint", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldTime, protowire.VarintType)
details = append(details, 0x80)
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("truncated cert version varint", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
details = append(details, 0x80)
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("truncated unknown field in details", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, 50, protowire.VarintType)
details = append(details, 0x80) // incomplete varint
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("cert with wrong wire type rejected", func(t *testing.T) {
// fieldCert as Varint instead of Bytes.
var details []byte
details = protowire.AppendTag(details, fieldCert, protowire.VarintType)
details = protowire.AppendVarint(details, 42)
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("initiator index with wrong wire type rejected", func(t *testing.T) {
// fieldInitiatorIndex as Bytes instead of Varint.
var details []byte
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.BytesType)
details = protowire.AppendBytes(details, []byte{1, 2, 3})
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("time with wrong wire type rejected", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldTime, protowire.BytesType)
details = protowire.AppendBytes(details, []byte{1, 2, 3})
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("cert version with wrong wire type rejected", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldCertVersion, protowire.BytesType)
details = protowire.AppendBytes(details, []byte{1, 2, 3})
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("repeated singular field follows proto3 last-wins", func(t *testing.T) {
// Per proto3, multiple instances of a singular field are accepted and
// the last value wins. We keep this behavior so that peers using
// alternative encoders aren't rejected.
var details []byte
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = protowire.AppendVarint(details, 1)
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = protowire.AppendVarint(details, 42)
got, err := UnmarshalPayload(wrapDetails(details))
require.NoError(t, err)
assert.Equal(t, uint32(42), got.InitiatorIndex)
})
t.Run("initiator index varint overflow rejected", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = protowire.AppendVarint(details, math.MaxUint32+1)
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("cert version varint overflow rejected", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
details = protowire.AppendVarint(details, math.MaxUint32+1)
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
}
// FuzzPayload feeds arbitrary bytes through UnmarshalPayload to confirm it
// never panics, and for any input that parses cleanly, that re-marshal +
// re-parse is a fix-point. Inputs come from an authenticated peer (post-
// noise-decrypt), so the threat model is "valid peer behaving arbitrarily,"
// not "unauthenticated injection."
func FuzzPayload(f *testing.F) {
// Seed corpus with a handful of known-good shapes.
f.Add(MarshalPayload(nil, Payload{}))
f.Add(MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2}))
f.Add(MarshalPayload(nil, Payload{InitiatorIndex: 42, Time: 1}))
f.Add(MarshalPayload(nil, Payload{
Cert: []byte("seed-cert"),
InitiatorIndex: 1,
ResponderIndex: 2,
Time: 3,
CertVersion: 2,
}))
f.Add([]byte{})
f.Add([]byte{0xff})
f.Fuzz(func(t *testing.T, data []byte) {
p1, err := UnmarshalPayload(data)
if err != nil {
return
}
// For any input that parses, re-marshaling and re-parsing must
// yield an equivalent Payload. This catches dispatch bugs (e.g.
// emitting a field on marshal that we don't accept on parse) and
// any non-idempotent parsing behavior.
b2 := MarshalPayload(nil, p1)
p2, err := UnmarshalPayload(b2)
if err != nil {
t.Fatalf("re-parse of self-marshaled payload failed: %v\nintermediate: %x\n", err, b2)
}
if !payloadsEqual(p1, p2) {
t.Fatalf("re-marshal not idempotent\nfirst: %+v\nsecond: %+v", p1, p2)
}
})
}
func payloadsEqual(a, b Payload) bool {
return bytes.Equal(a.Cert, b.Cert) &&
a.InitiatorIndex == b.InitiatorIndex &&
a.ResponderIndex == b.ResponderIndex &&
a.Time == b.Time &&
a.CertVersion == b.CertVersion
}