Handshake state machine (#1656)

This commit is contained in:
Nate Brown
2026-04-30 21:30:27 -05:00
committed by GitHub
parent 1ab1f71dba
commit 9ec8cf10f3
21 changed files with 3036 additions and 1593 deletions

View File

@@ -163,3 +163,55 @@ func P256Keypair() ([]byte, []byte) {
pubkey := privkey.PublicKey() pubkey := privkey.PublicKey()
return pubkey.Bytes(), privkey.Bytes() return pubkey.Bytes(), privkey.Bytes()
} }
// DummyCert is a minimal cert.Certificate implementation for testing error paths.
type DummyCert struct {
Version_ cert.Version
Curve_ cert.Curve
Groups_ []string
IsCA_ bool
Issuer_ string
Name_ string
Networks_ []netip.Prefix
NotAfter_ time.Time
NotBefore_ time.Time
PublicKey_ []byte
Signature_ []byte
UnsafeNetworks_ []netip.Prefix
}
func (d *DummyCert) Version() cert.Version { return d.Version_ }
func (d *DummyCert) Curve() cert.Curve { return d.Curve_ }
func (d *DummyCert) Groups() []string { return d.Groups_ }
func (d *DummyCert) IsCA() bool { return d.IsCA_ }
func (d *DummyCert) Issuer() string { return d.Issuer_ }
func (d *DummyCert) Name() string { return d.Name_ }
func (d *DummyCert) Networks() []netip.Prefix { return d.Networks_ }
func (d *DummyCert) NotAfter() time.Time { return d.NotAfter_ }
func (d *DummyCert) NotBefore() time.Time { return d.NotBefore_ }
func (d *DummyCert) PublicKey() []byte { return d.PublicKey_ }
func (d *DummyCert) Signature() []byte { return d.Signature_ }
func (d *DummyCert) UnsafeNetworks() []netip.Prefix { return d.UnsafeNetworks_ }
func (d *DummyCert) Fingerprint() (string, error) { return "", nil }
func (d *DummyCert) CheckSignature(key []byte) bool { return false }
func (d *DummyCert) MarshalForHandshakes() ([]byte, error) { return nil, nil }
func (d *DummyCert) MarshalPEM() ([]byte, error) { return nil, nil }
func (d *DummyCert) MarshalJSON() ([]byte, error) { return nil, nil }
func (d *DummyCert) Marshal() ([]byte, error) { return nil, nil }
func (d *DummyCert) String() string { return "dummy" }
func (d *DummyCert) Copy() cert.Certificate { return d }
func (d *DummyCert) VerifyPrivateKey(c cert.Curve, k []byte) error { return nil }
func (d *DummyCert) Expired(time.Time) bool { return false }
func (d *DummyCert) MarshalPublicKeyPEM() []byte { return nil }
func (d *DummyCert) PublicKeyPEM() []byte { return nil }
// NewTestCAPool creates a CAPool from the given CA certificates, panicking on error.
func NewTestCAPool(cas ...cert.Certificate) *cert.CAPool {
pool := cert.NewCAPool()
for _, ca := range cas {
if err := pool.AddCA(ca); err != nil {
panic(err)
}
}
return pool
}

View File

@@ -7,7 +7,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay/overlaytest" "github.com/slackhq/nebula/overlay/overlaytest"
@@ -47,7 +46,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
initiatingVersion: cert.Version1, initiatingVersion: cert.Version1,
privateKey: []byte{}, privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1}, v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{}, v1Credential: nil,
} }
lh := newTestLighthouse() lh := newTestLighthouse()
@@ -80,7 +79,6 @@ func Test_NewConnectionManagerTest(t *testing.T) {
} }
hostinfo.ConnectionState = &ConnectionState{ hostinfo.ConnectionState = &ConnectionState{
myCert: &dummyCert{version: cert.Version1}, myCert: &dummyCert{version: cert.Version1},
H: &noise.HandshakeState{},
} }
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@@ -130,7 +128,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
initiatingVersion: cert.Version1, initiatingVersion: cert.Version1,
privateKey: []byte{}, privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1}, v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{}, v1Credential: nil,
} }
lh := newTestLighthouse() lh := newTestLighthouse()
@@ -163,7 +161,6 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
} }
hostinfo.ConnectionState = &ConnectionState{ hostinfo.ConnectionState = &ConnectionState{
myCert: &dummyCert{version: cert.Version1}, myCert: &dummyCert{version: cert.Version1},
H: &noise.HandshakeState{},
} }
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@@ -215,7 +212,7 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
initiatingVersion: cert.Version1, initiatingVersion: cert.Version1,
privateKey: []byte{}, privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1}, v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{}, v1Credential: nil,
} }
lh := newTestLighthouse() lh := newTestLighthouse()
@@ -249,7 +246,6 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
} }
hostinfo.ConnectionState = &ConnectionState{ hostinfo.ConnectionState = &ConnectionState{
myCert: &dummyCert{version: cert.Version1}, myCert: &dummyCert{version: cert.Version1},
H: &noise.HandshakeState{},
} }
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
@@ -342,7 +338,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
cs := &CertState{ cs := &CertState{
privateKey: []byte{}, privateKey: []byte{},
v1Cert: &dummyCert{}, v1Cert: &dummyCert{},
v1HandshakeBytes: []byte{}, v1Credential: nil,
} }
lh := newTestLighthouse() lh := newTestLighthouse()
@@ -372,7 +368,6 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
ConnectionState: &ConnectionState{ ConnectionState: &ConnectionState{
myCert: &dummyCert{}, myCert: &dummyCert{},
peerCert: cachedPeerCert, peerCert: cachedPeerCert,
H: &noise.HandshakeState{},
}, },
} }
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)

View File

@@ -1,15 +1,12 @@
package nebula package nebula
import ( import (
"crypto/rand"
"encoding/json" "encoding/json"
"fmt"
"sync" "sync"
"sync/atomic" "sync/atomic"
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/noiseutil" "github.com/slackhq/nebula/handshake"
) )
const ReplayWindow = 1024 const ReplayWindow = 1024
@@ -17,7 +14,6 @@ const ReplayWindow = 1024
type ConnectionState struct { type ConnectionState struct {
eKey *NebulaCipherState eKey *NebulaCipherState
dKey *NebulaCipherState dKey *NebulaCipherState
H *noise.HandshakeState
myCert cert.Certificate myCert cert.Certificate
peerCert *cert.CachedCertificate peerCert *cert.CachedCertificate
initiator bool initiator bool
@@ -26,55 +22,24 @@ type ConnectionState struct {
writeLock sync.Mutex writeLock sync.Mutex
} }
func NewConnectionState(cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) { // newConnectionStateFromResult builds a fully-populated ConnectionState from a
var dhFunc noise.DHFunc // completed handshake.Result. It seeds messageCounter and the replay window so
switch crt.Curve() { // that the post-handshake message indices already used on the wire don't count
case cert.Curve_CURVE25519: // as missed traffic in the data plane.
dhFunc = noise.DH25519 func newConnectionStateFromResult(r *handshake.Result) *ConnectionState {
case cert.Curve_P256:
if cs.pkcs11Backed {
dhFunc = noiseutil.DHP256PKCS11
} else {
dhFunc = noiseutil.DHP256
}
default:
return nil, fmt.Errorf("invalid curve: %s", crt.Curve())
}
var ncs noise.CipherSuite
if cs.cipher == "chachapoly" {
ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
} else {
ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
}
static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
hs, err := noise.NewHandshakeState(noise.Config{
CipherSuite: ncs,
Random: rand.Reader,
Pattern: pattern,
Initiator: initiator,
StaticKeypair: static,
//NOTE: These should come from CertState (pki.go) when we finally implement it
PresharedKey: []byte{},
PresharedKeyPlacement: 0,
})
if err != nil {
return nil, fmt.Errorf("NewConnectionState: %s", err)
}
// The queue and ready params prevent a counter race that would happen when
// sending stored packets and simultaneously accepting new traffic.
ci := &ConnectionState{ ci := &ConnectionState{
H: hs, myCert: r.MyCert,
initiator: initiator, initiator: r.Initiator,
peerCert: r.RemoteCert,
eKey: NewNebulaCipherState(r.EKey),
dKey: NewNebulaCipherState(r.DKey),
window: NewBits(ReplayWindow), window: NewBits(ReplayWindow),
myCert: crt,
} }
// always start the counter from 2, as packet 1 and packet 2 are handshake packets. ci.messageCounter.Add(r.MessageIndex)
ci.messageCounter.Add(2) for i := uint64(1); i <= r.MessageIndex; i++ {
ci.window.Update(nil, i)
return ci, nil }
return ci
} }
func (cs *ConnectionState) MarshalJSON() ([]byte, error) { func (cs *ConnectionState) MarshalJSON() ([]byte, error) {

114
connection_state_test.go Normal file
View File

@@ -0,0 +1,114 @@
package nebula
import (
"net/netip"
"testing"
"time"
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert"
ct "github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/handshake"
"github.com/slackhq/nebula/header"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// runTestHandshake runs a complete IX handshake between two freshly-built
// peers and returns the initiator and responder Results. Used to produce
// real cipher states for tests that need to exercise post-handshake glue.
func runTestHandshake(t *testing.T) (initR, respR *handshake.Result) {
t.Helper()
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
makeCreds := func(name string, networks []netip.Prefix) handshake.GetCredentialFunc {
c, _, rawKey, _ := ct.NewTestCert(
cert.Version2, cert.Curve_CURVE25519, ca, caKey,
name, ca.NotBefore(), ca.NotAfter(), networks, nil, nil,
)
priv, _, _, err := cert.UnmarshalPrivateKeyFromPEM(rawKey)
require.NoError(t, err)
hsBytes, err := c.MarshalForHandshakes()
require.NoError(t, err)
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
cred := handshake.NewCredential(c, hsBytes, priv, ncs)
return func(v cert.Version) *handshake.Credential {
if v == cert.Version2 {
return cred
}
return nil
}
}
verifier := func(c cert.Certificate) (*cert.CachedCertificate, error) {
return caPool.VerifyCertificate(time.Now(), c)
}
initCreds := makeCreds("initiator", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
respCreds := makeCreds("responder", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
initM, err := handshake.NewMachine(
cert.Version2, initCreds, verifier,
func() (uint32, error) { return 1000, nil },
true, header.HandshakeIXPSK0,
)
require.NoError(t, err)
respM, err := handshake.NewMachine(
cert.Version2, respCreds, verifier,
func() (uint32, error) { return 2000, nil },
false, header.HandshakeIXPSK0,
)
require.NoError(t, err)
msg1, err := initM.Initiate(nil)
require.NoError(t, err)
resp, respR, err := respM.ProcessPacket(nil, msg1)
require.NoError(t, err)
require.NotNil(t, respR)
_, initR, err = initM.ProcessPacket(nil, resp)
require.NoError(t, err)
require.NotNil(t, initR)
return initR, respR
}
func TestNewConnectionStateFromResult(t *testing.T) {
initR, respR := runTestHandshake(t)
t.Run("initiator", func(t *testing.T) {
ci := newConnectionStateFromResult(initR)
assert.True(t, ci.initiator)
assert.Equal(t, initR.MyCert, ci.myCert)
assert.Equal(t, initR.RemoteCert, ci.peerCert)
assert.NotNil(t, ci.eKey)
assert.NotNil(t, ci.dKey)
// IX has 2 handshake messages; the next data-plane send is counter=3.
assert.Equal(t, uint64(2), ci.messageCounter.Load(),
"messageCounter must equal Result.MessageIndex so the next send is N+1")
// Both handshake counters must be marked seen so they don't appear lost.
// Check returns false if an index has already been recorded.
assert.False(t, ci.window.Check(nil, 1), "counter 1 must already be seen")
assert.False(t, ci.window.Check(nil, 2), "counter 2 must already be seen")
// Counter 3 is the next data-plane message and must NOT be pre-marked.
assert.True(t, ci.window.Check(nil, 3), "counter 3 must not be pre-seeded")
})
t.Run("responder", func(t *testing.T) {
ci := newConnectionStateFromResult(respR)
assert.False(t, ci.initiator)
assert.Equal(t, respR.MyCert, ci.myCert)
assert.Equal(t, respR.RemoteCert, ci.peerCert)
assert.NotNil(t, ci.eKey)
assert.NotNil(t, ci.dKey)
assert.Equal(t, uint64(2), ci.messageCounter.Load())
})
}

View File

@@ -1033,7 +1033,7 @@ func TestNewFirewallFromConfig(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
// Test a bad rule definition // Test a bad rule definition
c := &dummyCert{} c := &dummyCert{}
cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil) cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil, "aes")
require.NoError(t, err) require.NoError(t, err)
conf := config.NewC(test.NewLogger()) conf := config.NewC(test.NewLogger())

57
handshake/credential.go Normal file
View File

@@ -0,0 +1,57 @@
package handshake
import (
"crypto/rand"
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert"
)
// Credential holds everything needed to participate in a handshake
// at a given cert version. Version and Curve are read from Cert; the public
// half of the static keypair likewise comes from Cert.PublicKey().
type Credential struct {
Cert cert.Certificate // the certificate
Bytes []byte // pre-marshaled certificate bytes
privateKey []byte // static private key (public half lives in Cert)
cipherSuite noise.CipherSuite // pre-built cipher suite (DH + cipher + hash)
}
// NewCredential creates a Credential with all material needed for handshake
// participation. The cipherSuite should be pre-built by the caller with the
// appropriate DH function, cipher, and hash.
func NewCredential(
c cert.Certificate,
hsBytes []byte,
privateKey []byte,
cipherSuite noise.CipherSuite,
) *Credential {
return &Credential{
Cert: c,
Bytes: hsBytes,
privateKey: privateKey,
cipherSuite: cipherSuite,
}
}
// buildHandshakeState creates a noise.HandshakeState from this credential.
func (hc *Credential) buildHandshakeState(initiator bool, pattern noise.HandshakePattern) (*noise.HandshakeState, error) {
return noise.NewHandshakeState(noise.Config{
CipherSuite: hc.cipherSuite,
Random: rand.Reader,
Pattern: pattern,
Initiator: initiator,
StaticKeypair: noise.DHKey{Private: hc.privateKey, Public: hc.Cert.PublicKey()},
PresharedKey: []byte{},
PresharedKeyPlacement: 0,
})
}
// GetCredentialFunc returns the handshake credential for the given version,
// or nil if that version is not available.
//
// Implementations must return credentials drawn from a snapshot stable for
// the lifetime of any single Machine. The Machine may call this multiple
// times during a handshake (e.g. when negotiating to the peer's version)
// and assumes the underlying static keypair is consistent across calls.
type GetCredentialFunc func(v cert.Version) *Credential

21
handshake/errors.go Normal file
View File

@@ -0,0 +1,21 @@
package handshake
import "errors"
var (
ErrInitiateOnResponder = errors.New("initiate called on responder")
ErrInitiateAlreadyCalled = errors.New("initiate already called")
ErrInitiateNotCalled = errors.New("initiate must be called before ProcessPacket for initiators")
ErrPacketTooShort = errors.New("packet too short")
ErrPublicKeyMismatch = errors.New("public key mismatch between certificate and handshake")
ErrIncompleteHandshake = errors.New("handshake completed without receiving required content")
ErrMachineFailed = errors.New("handshake machine has failed")
ErrUnknownSubtype = errors.New("unknown handshake subtype")
ErrMissingContent = errors.New("expected handshake content but message was empty")
ErrUnexpectedContent = errors.New("received unexpected handshake content")
ErrIndexAllocation = errors.New("failed to allocate local index")
ErrNoCredential = errors.New("no handshake credential available for cert version")
ErrAsymmetricCipherKeys = errors.New("noise produced only one cipher key")
ErrMultiMessageUnsupported = errors.New("multi-message handshake patterns are not yet supported by the manager")
ErrSubtypeMismatch = errors.New("packet subtype does not match handshake machine subtype")
)

29
handshake/handshake.proto Normal file
View File

@@ -0,0 +1,29 @@
// This file documents the wire format the nebula handshake speaks. It is
// not run through protoc; the encoder/decoder in payload.go is hand-written
// against this shape directly to keep the parser narrow and panic-free.
//
// Any change to the wire format must be reflected here, and adding a new
// field requires updating MarshalPayload / unmarshalPayloadDetails together
// with the field-uniqueness and wire-type checks in those functions.
syntax = "proto3";
package nebula.handshake;
message NebulaHandshake {
NebulaHandshakeDetails Details = 1;
bytes Hmac = 2;
}
message NebulaHandshakeDetails {
bytes Cert = 1;
uint32 InitiatorIndex = 2;
uint32 ResponderIndex = 3;
// Cookie was reserved for an anti-DoS mechanism that was never
// implemented. No released version of nebula has ever populated it; the
// hand-written parser silently skips it on read.
uint64 Cookie = 4 [deprecated = true];
uint64 Time = 5;
uint32 CertVersion = 8;
// reserved for WIP multiport
reserved 6, 7;
}

116
handshake/helpers_test.go Normal file
View File

@@ -0,0 +1,116 @@
package handshake
import (
"net/netip"
"testing"
"time"
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert"
ct "github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/header"
"github.com/stretchr/testify/require"
)
// testCertState holds cert material for a test peer.
type testCertState struct {
version cert.Version
creds map[cert.Version]*Credential
}
func (s *testCertState) getCredential(v cert.Version) *Credential {
return s.creds[v]
}
func newTestCertState(
t *testing.T, ca cert.Certificate, caKey []byte, name string, networks []netip.Prefix,
) *testCertState {
return newTestCertStateWithCipher(t, ca, caKey, name, networks, noise.CipherChaChaPoly)
}
func newTestCertStateWithCipher(
t *testing.T, ca cert.Certificate, caKey []byte, name string, networks []netip.Prefix,
cipher noise.CipherFunc,
) *testCertState {
t.Helper()
c, _, rawPrivKey, _ := ct.NewTestCert(
cert.Version2, cert.Curve_CURVE25519, ca, caKey,
name, ca.NotBefore(), ca.NotAfter(), networks, nil, nil,
)
priv, _, _, err := cert.UnmarshalPrivateKeyFromPEM(rawPrivKey)
require.NoError(t, err)
hsBytes, err := c.MarshalForHandshakes()
require.NoError(t, err)
ncs := noise.NewCipherSuite(noise.DH25519, cipher, noise.HashSHA256)
return &testCertState{
version: cert.Version2,
creds: map[cert.Version]*Credential{
cert.Version2: NewCredential(c, hsBytes, priv, ncs),
},
}
}
func testVerifier(pool *cert.CAPool) CertVerifier {
return func(c cert.Certificate) (*cert.CachedCertificate, error) {
return pool.VerifyCertificate(time.Now(), c)
}
}
func newTestMachine(
t *testing.T,
cs *testCertState,
verifier CertVerifier,
initiator bool,
localIndex uint32,
) *Machine {
t.Helper()
m, err := NewMachine(
cs.version, cs.getCredential,
verifier, func() (uint32, error) { return localIndex, nil },
initiator, header.HandshakeIXPSK0,
)
require.NoError(t, err)
return m
}
func initiateHandshake(
t *testing.T,
initCS *testCertState, initVerifier CertVerifier,
respCS *testCertState, respVerifier CertVerifier,
) (initM, respM *Machine, respResult *Result, resp []byte, err error) {
t.Helper()
initM = newTestMachine(t, initCS, initVerifier, true, 100)
msg1, merr := initM.Initiate(nil)
require.NoError(t, merr)
respM = newTestMachine(t, respCS, respVerifier, false, 200)
resp, respResult, err = respM.ProcessPacket(nil, msg1)
return
}
func doFullHandshake(
t *testing.T, initCS, respCS *testCertState, caPool *cert.CAPool,
) (initResult, respResult *Result) {
t.Helper()
v := testVerifier(caPool)
initM := newTestMachine(t, initCS, v, true, 1000)
respM := newTestMachine(t, respCS, v, false, 2000)
msg1, err := initM.Initiate(nil)
require.NoError(t, err)
resp, respResult, err := respM.ProcessPacket(nil, msg1)
require.NoError(t, err)
require.NotNil(t, respResult)
require.NotEmpty(t, resp)
_, initResult, err = initM.ProcessPacket(nil, resp)
require.NoError(t, err)
require.NotNil(t, initResult)
return initResult, respResult
}

444
handshake/machine.go Normal file
View File

@@ -0,0 +1,444 @@
package handshake
import (
"bytes"
"fmt"
"slices"
"time"
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
)
// IndexAllocator is called by the Machine to allocate a local index for the
// handshake. It is called at most once, when the first outgoing message that
// carries a payload is built.
//
// Implementations MUST NOT return 0. Zero is reserved as a sentinel meaning
// "no index assigned" on the wire and in the payload-presence checks. If an
// allocator ever returned 0, a legitimate handshake's payload could be
// indistinguishable from an empty one and would be rejected.
type IndexAllocator func() (uint32, error)
// CertVerifier is called by the Machine after reconstructing the peer's
// certificate from the handshake. The verifier performs all validation
// (CA trust, expiry, policy checks, allow lists).
type CertVerifier func(cert.Certificate) (*cert.CachedCertificate, error)
// Result contains the results of a successful handshake.
// Returned by ProcessPacket when the handshake is complete.
type Result struct {
EKey *noise.CipherState
DKey *noise.CipherState
MyCert cert.Certificate
RemoteCert *cert.CachedCertificate
RemoteIndex uint32
LocalIndex uint32
HandshakeTime uint64
MessageIndex uint64 // number of messages exchanged during the handshake
Initiator bool
}
// Machine drives a Noise handshake through N messages. It handles Noise
// protocol operations, certificate reconstruction, and payload encoding.
// Certificate validation is delegated to the caller via CertVerifier.
//
// A Machine is not safe for concurrent use. The caller must ensure that
// Initiate and ProcessPacket are not called concurrently.
//
// Error contract: when ProcessPacket or Initiate returns an error, callers
// must check Failed() to decide what to do next. If Failed() is false the
// underlying noise state was not advanced (the packet was rejected before
// ReadMessage took effect, or the rejection is non-fatal like a stale
// retransmit) and the Machine can accept another packet. If Failed() is
// true the Machine is unrecoverable and the caller must abandon it.
type Machine struct {
hs *noise.HandshakeState
getCred GetCredentialFunc
allocIndex IndexAllocator
verifier CertVerifier
result *Result
msgs []msgFlags
myVersion cert.Version
subtype header.MessageSubType
indexAllocated bool
remoteCertSet bool
payloadSet bool
failed bool
}
// NewMachine creates a handshake state machine. The subtype determines both
// the noise pattern and the per-message content layout. The credential for
// `version` is fetched via getCred and used to seed the noise.HandshakeState.
// IndexAllocator is called lazily when the first outgoing payload is built.
func NewMachine(
version cert.Version,
getCred GetCredentialFunc,
verifier CertVerifier,
allocIndex IndexAllocator,
initiator bool,
subtype header.MessageSubType,
) (*Machine, error) {
info, err := subtypeInfoFor(subtype)
if err != nil {
return nil, err
}
cred := getCred(version)
if cred == nil {
return nil, fmt.Errorf("%w: %v", ErrNoCredential, version)
}
hs, err := cred.buildHandshakeState(initiator, info.pattern)
if err != nil {
return nil, fmt.Errorf("build noise state: %w", err)
}
return &Machine{
hs: hs,
subtype: subtype,
msgs: info.msgs,
getCred: getCred,
allocIndex: allocIndex,
verifier: verifier,
myVersion: version,
result: &Result{
Initiator: initiator,
},
}, nil
}
// Failed returns true if the Machine is in an unrecoverable state.
func (m *Machine) Failed() bool {
return m.failed
}
// Subtype returns the handshake subtype this Machine was built for.
func (m *Machine) Subtype() header.MessageSubType {
return m.subtype
}
// MessageIndex returns the noise handshake message index, which equals the
// wire counter of the most recently sent or received message.
func (m *Machine) MessageIndex() int {
return m.hs.MessageIndex()
}
// requireComplete checks that both a peer cert and payload have been received.
// Marks the machine as failed if not.
func (m *Machine) requireComplete() error {
if !m.payloadSet || !m.remoteCertSet {
m.failed = true
return ErrIncompleteHandshake
}
return nil
}
// myMsgFlags returns the flags for the current outgoing message.
func (m *Machine) myMsgFlags() msgFlags {
idx := m.hs.MessageIndex()
if idx < len(m.msgs) {
return m.msgs[idx]
}
return msgFlags{}
}
// peerMsgFlags returns the flags for the message we just read.
func (m *Machine) peerMsgFlags() msgFlags {
idx := m.hs.MessageIndex() - 1
if idx >= 0 && idx < len(m.msgs) {
return m.msgs[idx]
}
return msgFlags{}
}
// Initiate produces the first handshake message. Only valid for initiators,
// and must be called exactly once before ProcessPacket.
//
// out is a destination buffer the message is appended to and returned. Pass
// nil to allocate fresh, or pass a re-used buffer sliced to length 0 (e.g.
// buf[:0]) with sufficient capacity to avoid allocation.
//
// An error return may not indicate a fatal condition, check Failed() to
// determine if the Machine can still be used.
func (m *Machine) Initiate(out []byte) ([]byte, error) {
if m.failed {
return nil, ErrMachineFailed
}
if !m.result.Initiator {
m.failed = true
return nil, ErrInitiateOnResponder
}
if m.hs.MessageIndex() != 0 {
m.failed = true
return nil, ErrInitiateAlreadyCalled
}
// At MessageIndex=0 with RemoteIndex still zero, buildResponse produces
// header counter 1 and remote index 0, which is what the initial message needs.
out, _, _, err := m.buildResponse(out)
if err != nil {
m.failed = true
return nil, err
}
return out, nil
}
// ProcessPacket handles an incoming handshake message. It advances the Noise
// state, validates the peer certificate via the verifier, and optionally
// produces a response.
//
// out is a destination buffer the response is appended to and returned. Pass
// nil to allocate fresh, or pass a re-used buffer sliced to length 0 (e.g.
// buf[:0]) with sufficient capacity to avoid allocation. The returned slice
// is nil when no outgoing message is produced (handshake complete on this
// side, or final message of a multi-message pattern).
//
// Returns a non-nil Result when the handshake is complete.
// An error return may not indicate a fatal condition, check Failed() to
// determine if the Machine can still be used.
func (m *Machine) ProcessPacket(out, packet []byte) ([]byte, *Result, error) {
if m.failed {
return nil, nil, ErrMachineFailed
}
if len(packet) < header.Len {
return nil, nil, ErrPacketTooShort
}
// Reject packets whose subtype doesn't match the one this Machine was
// built for. A pending handshake that suddenly receives a different
// subtype on its index is either a stray packet that matched by chance
// or a peer protocol violation; drop it without failing the Machine so
// the legitimate retransmit can still complete.
if header.MessageSubType(packet[1]) != m.subtype {
return nil, nil, ErrSubtypeMismatch
}
if m.result.Initiator && m.hs.MessageIndex() == 0 {
m.failed = true
return nil, nil, ErrInitiateNotCalled
}
// The (eKey, dKey) ordering here is correct for IX, where the initiator
// completes the handshake by reading the responder's stage-2 message.
// noise returns (cs1, cs2) where cs1 is the initiator->responder cipher.
// For 3-message patterns where a responder finishes by reading the final
// message, this ordering would be wrong; revisit when XX/pqIX lands.
msg, eKey, dKey, err := m.hs.ReadMessage(nil, packet[header.Len:])
if err != nil {
// Noise ReadMessage failed. The noise library checkpoints and rolls back
// on failure, so the Machine is still alive. The caller can retry with
// a different packet.
return nil, nil, fmt.Errorf("noise ReadMessage: %w", err)
}
// From here on, noise state has advanced. Any error is fatal.
flags := m.peerMsgFlags()
if err := m.processPayload(msg, flags); err != nil {
return nil, nil, err
}
// If ReadMessage derived keys, the handshake is complete. Noise should
// always produce both keys together; asymmetry is a protocol invariant
// violation.
if eKey != nil || dKey != nil {
if eKey == nil || dKey == nil {
m.failed = true
return nil, nil, ErrAsymmetricCipherKeys
}
if err := m.requireComplete(); err != nil {
return nil, nil, err
}
return nil, m.completed(eKey, dKey), nil
}
// ReadMessage didn't complete, produce the next outgoing message
out, dk, ek, err := m.buildResponse(out)
if err != nil {
m.failed = true
return nil, nil, err
}
if ek != nil || dk != nil {
if ek == nil || dk == nil {
m.failed = true
return nil, nil, ErrAsymmetricCipherKeys
}
if err := m.requireComplete(); err != nil {
return nil, nil, err
}
return out, m.completed(ek, dk), nil
}
return out, nil, nil
}
func (m *Machine) completed(eKey, dKey *noise.CipherState) *Result {
m.result.EKey = eKey
m.result.DKey = dKey
m.result.MessageIndex = uint64(m.hs.MessageIndex())
return m.result
}
func (m *Machine) processPayload(msg []byte, flags msgFlags) error {
if len(msg) == 0 {
if flags.expectsPayload || flags.expectsCert {
m.failed = true
return ErrMissingContent
}
return nil
}
payload, err := UnmarshalPayload(msg)
if err != nil {
m.failed = true
return fmt.Errorf("unmarshal handshake: %w", err)
}
// Assert the payload contains exactly what we expect
hasPayloadData := payload.InitiatorIndex != 0 || payload.ResponderIndex != 0 || payload.Time != 0
if hasPayloadData != flags.expectsPayload {
m.failed = true
return ErrUnexpectedContent
}
hasCertData := len(payload.Cert) > 0
if hasCertData != flags.expectsCert {
m.failed = true
return ErrUnexpectedContent
}
// Process payload
if flags.expectsPayload {
if m.result.Initiator {
m.result.RemoteIndex = payload.ResponderIndex
} else {
m.result.RemoteIndex = payload.InitiatorIndex
}
m.result.HandshakeTime = payload.Time
m.payloadSet = true
}
// Process certificate
if flags.expectsCert {
if err := m.validateCert(payload); err != nil {
return err
}
}
return nil
}
func (m *Machine) validateCert(payload Payload) error {
cred := m.getCred(m.myVersion)
if cred == nil {
m.failed = true
return fmt.Errorf("%w: %v", ErrNoCredential, m.myVersion)
}
rc, err := cert.Recombine(
cert.Version(payload.CertVersion),
payload.Cert,
m.hs.PeerStatic(),
cred.Cert.Curve(),
)
if err != nil {
m.failed = true
return fmt.Errorf("recombine cert: %w", err)
}
if !bytes.Equal(rc.PublicKey(), m.hs.PeerStatic()) {
m.failed = true
return ErrPublicKeyMismatch
}
// Version negotiation, if the peer sent a different version and we have it, switch
if rc.Version() != m.myVersion {
if m.getCred(rc.Version()) != nil {
m.myVersion = rc.Version()
}
}
verified, err := m.verifier(rc)
if err != nil {
m.failed = true
return fmt.Errorf("verify cert: %w", err)
}
m.result.RemoteCert = verified
m.remoteCertSet = true
return nil
}
func (m *Machine) marshalOutgoing(flags msgFlags) ([]byte, error) {
if !flags.expectsPayload && !flags.expectsCert {
return nil, nil
}
var p Payload
if flags.expectsPayload {
if !m.indexAllocated {
index, err := m.allocIndex()
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrIndexAllocation, err)
}
m.result.LocalIndex = index
m.indexAllocated = true
}
if m.result.Initiator {
p.InitiatorIndex = m.result.LocalIndex
} else {
p.ResponderIndex = m.result.LocalIndex
p.InitiatorIndex = m.result.RemoteIndex
}
p.Time = uint64(time.Now().UnixNano())
}
if flags.expectsCert {
cred := m.getCred(m.myVersion)
if cred == nil {
return nil, fmt.Errorf("%w: %v", ErrNoCredential, m.myVersion)
}
p.Cert = cred.Bytes
p.CertVersion = uint32(cred.Cert.Version())
m.result.MyCert = cred.Cert
}
return MarshalPayload(nil, p), nil
}
func (m *Machine) buildResponse(out []byte) ([]byte, *noise.CipherState, *noise.CipherState, error) {
flags := m.myMsgFlags()
hsBytes, err := m.marshalOutgoing(flags)
if err != nil {
return nil, nil, nil, err
}
// Extend out by header.Len to make room for the header. slices.Grow is a
// no-op when the cap is already sufficient (the zero-copy case where the
// caller passed a pre-sized buffer). header.Encode overwrites the new
// bytes, so they don't need to be zeroed.
start := len(out)
out = slices.Grow(out, header.Len)[:start+header.Len]
header.Encode(
out[start:],
header.Version, header.Handshake, m.subtype,
m.result.RemoteIndex,
uint64(m.hs.MessageIndex()+1),
)
// noise.WriteMessage appends the encrypted handshake message to out,
// reusing capacity when present.
//
// The (dKey, eKey) ordering here is correct for IX, where the responder
// completes the handshake by writing the stage-2 message. noise returns
// (cs1, cs2) where cs1 is the initiator->responder cipher (which is the
// responder's decrypt key). For 3-message patterns where an initiator
// finishes by writing the final message, this ordering would be wrong;
// revisit when XX/pqIX lands.
out, dKey, eKey, err := m.hs.WriteMessage(out, hsBytes)
if err != nil {
return nil, nil, nil, fmt.Errorf("noise WriteMessage: %w", err)
}
return out, dKey, eKey, nil
}

662
handshake/machine_test.go Normal file
View File

@@ -0,0 +1,662 @@
package handshake
import (
"net/netip"
"testing"
"time"
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert"
ct "github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/noiseutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMachineIXHappyPath(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
initCS := newTestCertState(t, ca, caKey, "initiator", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
respCS := newTestCertState(t, ca, caKey, "responder", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
initR, respR := doFullHandshake(t, initCS, respCS, caPool)
assert.Equal(t, "responder", initR.RemoteCert.Certificate.Name())
assert.Equal(t, "initiator", respR.RemoteCert.Certificate.Name())
assert.Equal(t, uint32(1000), initR.LocalIndex)
assert.Equal(t, uint32(2000), initR.RemoteIndex)
assert.Equal(t, uint32(2000), respR.LocalIndex)
assert.Equal(t, uint32(1000), respR.RemoteIndex)
assert.Equal(t, uint64(2), initR.MessageIndex, "IX has 2 messages")
assert.Equal(t, uint64(2), respR.MessageIndex, "IX has 2 messages")
ct1, err := initR.EKey.Encrypt(nil, nil, []byte("hello"))
require.NoError(t, err)
pt1, err := respR.DKey.Decrypt(nil, nil, ct1)
require.NoError(t, err)
assert.Equal(t, []byte("hello"), pt1)
ct2, err := respR.EKey.Encrypt(nil, nil, []byte("world"))
require.NoError(t, err)
pt2, err := initR.DKey.Decrypt(nil, nil, ct2)
require.NoError(t, err)
assert.Equal(t, []byte("world"), pt2)
}
func TestMachineInitiateErrors(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
v := testVerifier(caPool)
t.Run("initiate on responder", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
_, err := m.Initiate(nil)
require.ErrorIs(t, err, ErrInitiateOnResponder)
assert.True(t, m.Failed())
})
t.Run("initiate called twice", func(t *testing.T) {
m := newTestMachine(t, cs, v, true, 100)
_, err := m.Initiate(nil)
require.NoError(t, err)
_, err = m.Initiate(nil)
require.ErrorIs(t, err, ErrInitiateAlreadyCalled)
assert.True(t, m.Failed())
})
t.Run("process packet before initiate on initiator", func(t *testing.T) {
m := newTestMachine(t, cs, v, true, 100)
_, _, err := m.ProcessPacket(nil, make([]byte, 100))
require.ErrorIs(t, err, ErrInitiateNotCalled)
assert.True(t, m.Failed())
})
t.Run("calling failed machine", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
_, err := m.Initiate(nil) // fails: responder
require.Error(t, err)
_, err = m.Initiate(nil) // fails: already failed
require.ErrorIs(t, err, ErrMachineFailed)
})
}
func TestMachineProcessPacketErrors(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
v := testVerifier(caPool)
t.Run("packet too short", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
_, _, err := m.ProcessPacket(nil, []byte{1, 2, 3})
require.ErrorIs(t, err, ErrPacketTooShort)
assert.False(t, m.Failed(), "short packet should not kill machine")
})
t.Run("noise decryption failure is recoverable", func(t *testing.T) {
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
initM := newTestMachine(t, initCS, v, true, 100)
msg1, err := initM.Initiate(nil)
require.NoError(t, err)
respM := newTestMachine(t, cs, v, false, 200)
resp, _, err := respM.ProcessPacket(nil, msg1)
require.NoError(t, err)
corrupted := make([]byte, len(resp))
copy(corrupted, resp)
for i := header.Len; i < len(corrupted); i++ {
corrupted[i] ^= 0xff
}
_, _, err = initM.ProcessPacket(nil, corrupted)
require.Error(t, err)
assert.False(t, initM.Failed(), "noise failure should be recoverable")
// And the machine should still complete a real handshake afterward.
_, result, err := initM.ProcessPacket(nil, resp)
require.NoError(t, err)
require.NotNil(t, result, "initiator should complete on the legitimate response")
})
t.Run("invalid cert is fatal", func(t *testing.T) {
otherCA, _, otherCAKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
otherCS := newTestCertState(t, otherCA, otherCAKey, "other", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
initM := newTestMachine(t, otherCS, testVerifier(ct.NewTestCAPool(otherCA)), true, 100)
msg1, err := initM.Initiate(nil)
require.NoError(t, err)
respM := newTestMachine(t, cs, v, false, 200)
_, _, err = respM.ProcessPacket(nil, msg1)
require.Error(t, err)
assert.True(t, respM.Failed(), "cert validation failure should kill machine")
})
t.Run("subtype mismatch is recoverable", func(t *testing.T) {
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
initM := newTestMachine(t, initCS, v, true, 100)
msg1, err := initM.Initiate(nil)
require.NoError(t, err)
// Mutate the subtype byte (offset 1 in the header) to a value the
// responder Machine wasn't built for.
bad := make([]byte, len(msg1))
copy(bad, msg1)
bad[1] = 0xff
respM := newTestMachine(t, cs, v, false, 200)
_, _, err = respM.ProcessPacket(nil, bad)
require.ErrorIs(t, err, ErrSubtypeMismatch)
assert.False(t, respM.Failed(), "subtype mismatch should not kill the machine")
// And the machine should still complete a real handshake afterward.
resp, result, err := respM.ProcessPacket(nil, msg1)
require.NoError(t, err)
require.NotNil(t, result, "responder should complete on the legitimate stage-1 packet")
assert.NotEmpty(t, resp, "responder should produce a stage-2 reply")
})
}
// TestMachineProcessPayload exercises processPayload's internal validation
// directly. Most of these failure modes can't be reached black-box once the
// subtype check at the top of ProcessPacket gates external callers, so we
// drive them by hand here for coverage.
func TestMachineProcessPayload(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
v := testVerifier(caPool)
t.Run("empty message with expects fails", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
err := m.processPayload(nil, msgFlags{expectsPayload: true, expectsCert: true})
require.ErrorIs(t, err, ErrMissingContent)
assert.True(t, m.Failed())
})
t.Run("empty message with no expects passes", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
err := m.processPayload(nil, msgFlags{})
require.NoError(t, err)
assert.False(t, m.Failed())
})
t.Run("malformed protobuf is fatal", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
err := m.processPayload([]byte{0xff, 0xff, 0xff}, msgFlags{expectsPayload: true, expectsCert: true})
require.Error(t, err)
assert.True(t, m.Failed())
})
t.Run("unexpected payload data is fatal", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
// A payload with index data when none was expected.
bytes := MarshalPayload(nil, Payload{InitiatorIndex: 42, Time: 1})
err := m.processPayload(bytes, msgFlags{expectsPayload: false, expectsCert: false})
require.ErrorIs(t, err, ErrUnexpectedContent)
assert.True(t, m.Failed())
})
t.Run("unexpected cert data is fatal", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
// A payload with cert when none was expected.
bytes := MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2})
err := m.processPayload(bytes, msgFlags{expectsPayload: false, expectsCert: false})
require.ErrorIs(t, err, ErrUnexpectedContent)
assert.True(t, m.Failed())
})
t.Run("missing payload data when expected is fatal", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
// Cert present, but no index/time fields.
bytes := MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2})
err := m.processPayload(bytes, msgFlags{expectsPayload: true, expectsCert: true})
require.ErrorIs(t, err, ErrUnexpectedContent)
assert.True(t, m.Failed())
})
}
// TestMachineRequireComplete checks the fail-on-incomplete-handshake path
// directly. Like processPayload above this isn't reachable from a normal IX
// flow, so we drive it by hand.
func TestMachineRequireComplete(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
v := testVerifier(caPool)
t.Run("missing both fails", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
err := m.requireComplete()
require.ErrorIs(t, err, ErrIncompleteHandshake)
assert.True(t, m.Failed())
})
t.Run("payload only fails", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
m.payloadSet = true
err := m.requireComplete()
require.ErrorIs(t, err, ErrIncompleteHandshake)
assert.True(t, m.Failed())
})
t.Run("cert only fails", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
m.remoteCertSet = true
err := m.requireComplete()
require.ErrorIs(t, err, ErrIncompleteHandshake)
assert.True(t, m.Failed())
})
t.Run("both set passes", func(t *testing.T) {
m := newTestMachine(t, cs, v, false, 100)
m.payloadSet = true
m.remoteCertSet = true
err := m.requireComplete()
require.NoError(t, err)
assert.False(t, m.Failed())
})
}
func TestMachineAESCipher(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
initCS := newTestCertStateWithCipher(
t, ca, caKey, "init",
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
noiseutil.CipherAESGCM,
)
respCS := newTestCertStateWithCipher(
t, ca, caKey, "resp",
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")},
noiseutil.CipherAESGCM,
)
initR, respR := doFullHandshake(t, initCS, respCS, caPool)
ct1, err := initR.EKey.Encrypt(nil, nil, []byte("works"))
require.NoError(t, err)
pt1, err := respR.DKey.Decrypt(nil, nil, ct1)
require.NoError(t, err)
assert.Equal(t, []byte("works"), pt1)
ct2, err := respR.EKey.Encrypt(nil, nil, []byte("back"))
require.NoError(t, err)
pt2, err := initR.DKey.Decrypt(nil, nil, ct2)
require.NoError(t, err)
assert.Equal(t, []byte("back"), pt2)
}
func TestResultFields(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
initR, respR := doFullHandshake(t, initCS, respCS, caPool)
assert.True(t, initR.Initiator)
assert.False(t, respR.Initiator)
assert.NotZero(t, initR.HandshakeTime)
assert.NotZero(t, respR.HandshakeTime)
assert.NotNil(t, initR.RemoteCert)
assert.NotNil(t, respR.RemoteCert)
}
func TestMachineBufferReuse(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
v := testVerifier(caPool)
initM := newTestMachine(t, initCS, v, true, 1000)
respM := newTestMachine(t, respCS, v, false, 2000)
msg1, err := initM.Initiate(nil)
require.NoError(t, err)
t.Run("response writes into provided buffer", func(t *testing.T) {
buf := make([]byte, 0, 4096)
resp, result, err := respM.ProcessPacket(buf, msg1)
require.NoError(t, err)
require.NotNil(t, result)
assert.NotEmpty(t, resp, "response should have content")
assert.Equal(t, &buf[:1][0], &resp[:1][0],
"response should reuse the provided buffer's backing array")
})
t.Run("initiate writes into provided buffer", func(t *testing.T) {
initM2 := newTestMachine(t, initCS, v, true, 3000)
buf := make([]byte, 0, 4096)
msg, err := initM2.Initiate(buf)
require.NoError(t, err)
assert.NotEmpty(t, msg, "initiate should have content")
assert.Equal(t, &buf[:1][0], &msg[:1][0],
"initiate should reuse the provided buffer's backing array")
})
t.Run("nil out still works", func(t *testing.T) {
initM2 := newTestMachine(t, initCS, v, true, 4000)
respM2 := newTestMachine(t, respCS, v, false, 5000)
msg1, err := initM2.Initiate(nil)
require.NoError(t, err)
resp, _, err := respM2.ProcessPacket(nil, msg1)
require.NoError(t, err)
out, result, err := initM2.ProcessPacket(nil, resp)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Nil(t, out, "initiator should have no response for IX msg2")
})
}
func TestMachineMsgIndexTracking(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
v := testVerifier(caPool)
initM := newTestMachine(t, initCS, v, true, 100)
respM := newTestMachine(t, respCS, v, false, 200)
msg1, err := initM.Initiate(nil)
require.NoError(t, err)
resp1, result1, err := respM.ProcessPacket(nil, msg1)
require.NoError(t, err)
assert.NotNil(t, result1)
_, result2, err := initM.ProcessPacket(nil, resp1)
require.NoError(t, err)
assert.NotNil(t, result2)
}
func TestMachineThreeMessagePattern(t *testing.T) {
registerTestXXInfo(t)
// Use HandshakeXX (3 messages) to verify the Machine handles multi-message
// patterns correctly. XX flow:
// msg1 (I->R): [E] - payload only, no cert
// msg2 (R->I): [E, ee, S, es] - payload + cert
// msg3 (I->R): [S, se] - cert only (no payload, not first two)
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
v := testVerifier(caPool)
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
initM, err := NewMachine(
cert.Version2,
initCS.getCredential, v,
func() (uint32, error) { return 1000, nil },
true, header.HandshakeXXPSK0,
)
require.NoError(t, err)
respM, err := NewMachine(
cert.Version2,
respCS.getCredential, v,
func() (uint32, error) { return 2000, nil },
false, header.HandshakeXXPSK0,
)
require.NoError(t, err)
// msg1: initiator -> responder (E only, no cert)
msg1, err := initM.Initiate(nil)
require.NoError(t, err)
assert.NotEmpty(t, msg1)
// Responder processes msg1, should not complete yet, should produce msg2
msg2, result, err := respM.ProcessPacket(nil, msg1)
require.NoError(t, err)
assert.Nil(t, result, "XX should not complete on msg1")
assert.NotEmpty(t, msg2, "responder should produce msg2")
// Initiator processes msg2: gets responder's cert, produces msg3, and
// completes (WriteMessage for msg3 derives keys)
msg3, initResult, err := initM.ProcessPacket(nil, msg2)
require.NoError(t, err)
require.NotNil(t, initResult, "XX initiator should complete after reading msg2 and writing msg3")
assert.NotEmpty(t, msg3, "initiator should produce msg3")
assert.Equal(t, "resp", initResult.RemoteCert.Certificate.Name())
// Responder processes msg3: gets initiator's cert and completes
_, respResult, err := respM.ProcessPacket(nil, msg3)
require.NoError(t, err)
require.NotNil(t, respResult, "XX responder should complete on msg3")
assert.Equal(t, "init", respResult.RemoteCert.Certificate.Name())
assert.Equal(t, uint64(3), initResult.MessageIndex, "XX has 3 messages")
assert.Equal(t, uint64(3), respResult.MessageIndex, "XX has 3 messages")
// Verify keys work
ct1, err := initResult.EKey.Encrypt(nil, nil, []byte("three messages"))
require.NoError(t, err)
pt1, err := respResult.DKey.Decrypt(nil, nil, ct1)
require.NoError(t, err)
assert.Equal(t, []byte("three messages"), pt1)
}
// NOTE: ErrIncompleteHandshake is tested implicitly. It can't be triggered with
// IX since the cert is always in the payload. A 3-message pattern test (HybridIX)
// should exercise the case where cert arrives in msg3 and verify that completing
// without it fails.
func TestMachineExpiredCert(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519,
time.Now().Add(-24*time.Hour), time.Now().Add(24*time.Hour),
nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
expCert, _, expKeyPEM, _ := ct.NewTestCert(
cert.Version2, cert.Curve_CURVE25519, ca, caKey,
"expired", time.Now().Add(-2*time.Hour), time.Now().Add(-1*time.Hour),
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}, nil, nil,
)
expKey, _, _, err := cert.UnmarshalPrivateKeyFromPEM(expKeyPEM)
require.NoError(t, err)
expHsBytes, err := expCert.MarshalForHandshakes()
require.NoError(t, err)
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
expiredCS := &testCertState{
version: cert.Version2,
creds: map[cert.Version]*Credential{
cert.Version2: NewCredential(expCert, expHsBytes, expKey, ncs),
},
}
respCS := newTestCertState(
t, ca, caKey, "responder",
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")},
)
_, respM, _, _, err := initiateHandshake(
t, expiredCS, testVerifier(caPool),
respCS, testVerifier(caPool),
)
require.ErrorContains(t, err, "verify cert")
assert.True(t, respM.Failed())
}
func TestMachineNoCertNetworks(t *testing.T) {
ca, _, caKey, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca)
caHsBytes, err := ca.MarshalForHandshakes()
require.NoError(t, err)
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
noNetCS := &testCertState{
version: cert.Version2,
creds: map[cert.Version]*Credential{
cert.Version2: NewCredential(ca, caHsBytes, caKey, ncs),
},
}
respCS := newTestCertState(
t, ca, caKey, "responder",
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")},
)
_, respM, _, _, err := initiateHandshake(
t, noNetCS, testVerifier(caPool),
respCS, testVerifier(caPool),
)
require.Error(t, err)
assert.True(t, respM.Failed())
}
func TestMachineDifferentCAs(t *testing.T) {
ca1, _, caKey1, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
ca2, _, caKey2, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
initCS := newTestCertState(
t, ca1, caKey1, "init",
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
)
respCS := newTestCertState(
t, ca2, caKey2, "resp",
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")},
)
_, respM, _, _, err := initiateHandshake(
t, initCS, testVerifier(ct.NewTestCAPool(ca1)),
respCS, testVerifier(ct.NewTestCAPool(ca2)),
)
require.ErrorContains(t, err, "verify cert")
assert.True(t, respM.Failed())
}
func TestMachineVersionNegotiation(t *testing.T) {
ca1, _, caKey1, _ := ct.NewTestCaCert(
cert.Version1, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
ca2, _, caKey2, _ := ct.NewTestCaCert(
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
)
caPool := ct.NewTestCAPool(ca1, ca2)
makeMultiVersionResp := func(t *testing.T) *testCertState {
t.Helper()
respCertV1, _, respKeyPEM, _ := ct.NewTestCert(
cert.Version1, cert.Curve_CURVE25519, ca1, caKey1, "resp",
ca1.NotBefore(), ca1.NotAfter(),
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, nil, nil,
)
respKey, _, _, _ := cert.UnmarshalPrivateKeyFromPEM(respKeyPEM)
respCertV2, _ := ct.NewTestCertDifferentVersion(respCertV1, cert.Version2, ca2, caKey2)
respHsV1, _ := respCertV1.MarshalForHandshakes()
respHsV2, _ := respCertV2.MarshalForHandshakes()
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
return &testCertState{
version: cert.Version1,
creds: map[cert.Version]*Credential{
cert.Version1: NewCredential(respCertV1, respHsV1, respKey, ncs),
cert.Version2: NewCredential(respCertV2, respHsV2, respKey, ncs),
},
}
}
t.Run("responder matches initiator version", func(t *testing.T) {
initCS := newTestCertState(
t, ca2, caKey2, "init",
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
)
respCS := makeMultiVersionResp(t)
v := testVerifier(caPool)
initM, _, respResult, resp, err := initiateHandshake(
t, initCS, v,
respCS, v,
)
require.NoError(t, err)
require.NotNil(t, respResult)
assert.Equal(t, cert.Version2, respResult.MyCert.Version(),
"responder should negotiate to initiator's version")
_, initResult, err := initM.ProcessPacket(nil, resp)
require.NoError(t, err)
require.NotNil(t, initResult)
assert.Equal(t, cert.Version2, initResult.RemoteCert.Certificate.Version(),
"initiator should see V2 cert from responder")
})
t.Run("responder keeps version when no match available", func(t *testing.T) {
initCS := newTestCertState(
t, ca2, caKey2, "init",
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
)
respCert, _, respKeyPEM, _ := ct.NewTestCert(
cert.Version1, cert.Curve_CURVE25519, ca1, caKey1, "resp",
ca1.NotBefore(), ca1.NotAfter(),
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, nil, nil,
)
respKey, _, _, _ := cert.UnmarshalPrivateKeyFromPEM(respKeyPEM)
respHs, _ := respCert.MarshalForHandshakes()
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
respCS := &testCertState{
version: cert.Version1,
creds: map[cert.Version]*Credential{
cert.Version1: NewCredential(respCert, respHs, respKey, ncs),
},
}
v := testVerifier(caPool)
_, _, respResult, _, err := initiateHandshake(
t, initCS, v,
respCS, v,
)
require.NoError(t, err)
require.NotNil(t, respResult)
assert.Equal(t, cert.Version1, respResult.MyCert.Version(),
"responder should keep V1 when V2 not available")
})
}

54
handshake/patterns.go Normal file
View File

@@ -0,0 +1,54 @@
package handshake
import (
"fmt"
"github.com/flynn/noise"
"github.com/slackhq/nebula/header"
)
// msgFlags tracks what application data a handshake message carries.
type msgFlags struct {
expectsPayload bool // message carries indexes and time
expectsCert bool // message carries the certificate
}
// subtypeInfo bundles the noise pattern with the per-message flags for a
// given handshake subtype.
type subtypeInfo struct {
pattern noise.HandshakePattern
msgs []msgFlags
}
// subtypeInfos defines the noise pattern and message content layout for each
// handshake subtype.
var subtypeInfos = map[header.MessageSubType]subtypeInfo{
// IX: 2 messages, both carry payload and cert
header.HandshakeIXPSK0: {
pattern: noise.HandshakeIX,
msgs: []msgFlags{
{expectsPayload: true, expectsCert: true},
{expectsPayload: true, expectsCert: true},
},
},
// XX: 3 messages
// msg1 (I->R): payload only
// msg2 (R->I): payload + cert
// msg3 (I->R): cert only
//header.HandshakeXXPSK0: {
// pattern: noise.HandshakeXX,
// msgs: []msgFlags{
// {expectsPayload: true, expectsCert: false},
// {expectsPayload: true, expectsCert: true},
// {expectsPayload: false, expectsCert: true},
// },
//},
}
func subtypeInfoFor(subtype header.MessageSubType) (subtypeInfo, error) {
if info, ok := subtypeInfos[subtype]; ok {
return info, nil
}
return subtypeInfo{}, fmt.Errorf("%w: %d", ErrUnknownSubtype, subtype)
}

View File

@@ -0,0 +1,63 @@
package handshake
import (
"testing"
"github.com/flynn/noise"
"github.com/slackhq/nebula/header"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSubtypeInfo(t *testing.T) {
t.Run("IX", func(t *testing.T) {
info, err := subtypeInfoFor(header.HandshakeIXPSK0)
require.NoError(t, err)
assert.Equal(t, noise.HandshakeIX.Name, info.pattern.Name)
require.Len(t, info.msgs, 2)
// msg1: payload + cert
assert.True(t, info.msgs[0].expectsPayload)
assert.True(t, info.msgs[0].expectsCert)
// msg2: payload + cert
assert.True(t, info.msgs[1].expectsPayload)
assert.True(t, info.msgs[1].expectsCert)
})
t.Run("XX", func(t *testing.T) {
registerTestXXInfo(t)
info, err := subtypeInfoFor(header.HandshakeXXPSK0)
require.NoError(t, err)
assert.Equal(t, noise.HandshakeXX.Name, info.pattern.Name)
require.Len(t, info.msgs, 3)
// msg1: payload only
assert.True(t, info.msgs[0].expectsPayload)
assert.False(t, info.msgs[0].expectsCert)
// msg2: payload + cert
assert.True(t, info.msgs[1].expectsPayload)
assert.True(t, info.msgs[1].expectsCert)
// msg3: cert only
assert.False(t, info.msgs[2].expectsPayload)
assert.True(t, info.msgs[2].expectsCert)
})
t.Run("unknown subtype returns error", func(t *testing.T) {
_, err := subtypeInfoFor(99)
require.ErrorIs(t, err, ErrUnknownSubtype)
})
}
// registerTestXXInfo temporarily registers XX subtype info for testing.
func registerTestXXInfo(t *testing.T) {
t.Helper()
subtypeInfos[header.HandshakeXXPSK0] = subtypeInfo{
pattern: noise.HandshakeXX,
msgs: []msgFlags{
{expectsPayload: true, expectsCert: false},
{expectsPayload: true, expectsCert: true},
{expectsPayload: false, expectsCert: true},
},
}
t.Cleanup(func() {
delete(subtypeInfos, header.HandshakeXXPSK0)
})
}

173
handshake/payload.go Normal file
View File

@@ -0,0 +1,173 @@
package handshake
import (
"errors"
"math"
"google.golang.org/protobuf/encoding/protowire"
)
var (
errInvalidHandshakeMessage = errors.New("invalid handshake message")
errInvalidHandshakeDetails = errors.New("invalid handshake details")
)
// Payload represents the decoded fields of a handshake message.
// Wire format is protobuf-compatible with NebulaHandshake{Details: NebulaHandshakeDetails{...}}.
type Payload struct {
Cert []byte
InitiatorIndex uint32
ResponderIndex uint32
Time uint64
CertVersion uint32
}
// Proto field numbers for NebulaHandshakeDetails
const (
fieldCert = 1 // bytes
fieldInitiatorIndex = 2 // uint32
fieldResponderIndex = 3 // uint32
fieldTime = 5 // uint64
fieldCertVersion = 8 // uint32
)
// MarshalPayload encodes a handshake payload in protobuf wire format compatible
// with NebulaHandshake{Details: NebulaHandshakeDetails{...}}.
// Returns out (which may be nil), with the marshalled Payload appended to it.
func MarshalPayload(out []byte, p Payload) []byte {
var details []byte
if len(p.Cert) > 0 {
details = protowire.AppendTag(details, fieldCert, protowire.BytesType)
details = protowire.AppendBytes(details, p.Cert)
}
if p.InitiatorIndex != 0 {
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = protowire.AppendVarint(details, uint64(p.InitiatorIndex))
}
if p.ResponderIndex != 0 {
details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType)
details = protowire.AppendVarint(details, uint64(p.ResponderIndex))
}
if p.Time != 0 {
details = protowire.AppendTag(details, fieldTime, protowire.VarintType)
details = protowire.AppendVarint(details, p.Time)
}
if p.CertVersion != 0 {
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
details = protowire.AppendVarint(details, uint64(p.CertVersion))
}
out = protowire.AppendTag(out, 1, protowire.BytesType)
out = protowire.AppendBytes(out, details)
return out
}
// UnmarshalPayload decodes a protobuf-encoded NebulaHandshake message.
func UnmarshalPayload(b []byte) (Payload, error) {
var p Payload
for len(b) > 0 {
num, typ, n := protowire.ConsumeTag(b)
if n < 0 {
return p, errInvalidHandshakeMessage
}
b = b[n:]
switch {
case num == 1 && typ == protowire.BytesType:
details, n := protowire.ConsumeBytes(b)
if n < 0 {
return p, errInvalidHandshakeMessage
}
b = b[n:]
if err := unmarshalPayloadDetails(&p, details); err != nil {
return p, err
}
default:
n := protowire.ConsumeFieldValue(num, typ, b)
if n < 0 {
return p, errInvalidHandshakeMessage
}
b = b[n:]
}
}
return p, nil
}
func unmarshalPayloadDetails(p *Payload, b []byte) error {
for len(b) > 0 {
num, typ, n := protowire.ConsumeTag(b)
if n < 0 {
return errInvalidHandshakeDetails
}
b = b[n:]
// For known field numbers, reject any non-matching wire type as a
// hard error rather than silently skipping. The caller will catch
// missing-field cases downstream, but a wire-type mismatch on a tag
// we know is a peer protocol violation worth flagging here.
// Repeated occurrences of a singular field follow proto3 last-wins.
switch num {
case fieldCert:
if typ != protowire.BytesType {
return errInvalidHandshakeDetails
}
v, n := protowire.ConsumeBytes(b)
if n < 0 {
return errInvalidHandshakeDetails
}
p.Cert = append([]byte(nil), v...)
b = b[n:]
case fieldInitiatorIndex:
if typ != protowire.VarintType {
return errInvalidHandshakeDetails
}
v, n := protowire.ConsumeVarint(b)
if n < 0 || v > math.MaxUint32 {
return errInvalidHandshakeDetails
}
p.InitiatorIndex = uint32(v)
b = b[n:]
case fieldResponderIndex:
if typ != protowire.VarintType {
return errInvalidHandshakeDetails
}
v, n := protowire.ConsumeVarint(b)
if n < 0 || v > math.MaxUint32 {
return errInvalidHandshakeDetails
}
p.ResponderIndex = uint32(v)
b = b[n:]
case fieldTime:
if typ != protowire.VarintType {
return errInvalidHandshakeDetails
}
v, n := protowire.ConsumeVarint(b)
if n < 0 {
return errInvalidHandshakeDetails
}
p.Time = v
b = b[n:]
case fieldCertVersion:
if typ != protowire.VarintType {
return errInvalidHandshakeDetails
}
v, n := protowire.ConsumeVarint(b)
if n < 0 || v > math.MaxUint32 {
return errInvalidHandshakeDetails
}
p.CertVersion = uint32(v)
b = b[n:]
default:
n := protowire.ConsumeFieldValue(num, typ, b)
if n < 0 {
return errInvalidHandshakeDetails
}
b = b[n:]
}
}
return nil
}

361
handshake/payload_test.go Normal file
View File

@@ -0,0 +1,361 @@
package handshake
import (
"bytes"
"math"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/encoding/protowire"
)
func TestPayloadRoundTrip(t *testing.T) {
t.Run("all fields set", func(t *testing.T) {
data := MarshalPayload(nil, Payload{
Cert: []byte("test-cert-bytes"),
CertVersion: 2,
InitiatorIndex: 12345,
ResponderIndex: 67890,
Time: 1234567890,
})
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, []byte("test-cert-bytes"), got.Cert)
assert.Equal(t, uint32(12345), got.InitiatorIndex)
assert.Equal(t, uint32(67890), got.ResponderIndex)
assert.Equal(t, uint64(1234567890), got.Time)
assert.Equal(t, uint32(2), got.CertVersion)
})
t.Run("minimal fields", func(t *testing.T) {
data := MarshalPayload(nil, Payload{InitiatorIndex: 1})
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, uint32(1), got.InitiatorIndex)
assert.Equal(t, uint32(0), got.ResponderIndex)
assert.Equal(t, uint64(0), got.Time)
assert.Nil(t, got.Cert)
})
t.Run("empty payload", func(t *testing.T) {
data := MarshalPayload(nil, Payload{})
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, uint32(0), got.InitiatorIndex)
})
t.Run("large cert bytes", func(t *testing.T) {
bigCert := make([]byte, 4096)
for i := range bigCert {
bigCert[i] = byte(i % 256)
}
data := MarshalPayload(nil, Payload{
Cert: bigCert,
CertVersion: 2,
InitiatorIndex: 999,
})
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, bigCert, got.Cert)
assert.Equal(t, uint32(999), got.InitiatorIndex)
})
t.Run("append to existing buffer", func(t *testing.T) {
prefix := []byte("prefix")
data := MarshalPayload(prefix, Payload{InitiatorIndex: 42})
assert.Equal(t, []byte("prefix"), data[:6])
got, err := UnmarshalPayload(data[6:])
require.NoError(t, err)
assert.Equal(t, uint32(42), got.InitiatorIndex)
})
}
func TestPayloadUnknownFields(t *testing.T) {
t.Run("unknown field in outer message is skipped", func(t *testing.T) {
// Marshal a normal payload then append an unknown field (field 99, varint)
data := MarshalPayload(nil, Payload{InitiatorIndex: 42})
data = protowire.AppendTag(data, 99, protowire.VarintType)
data = protowire.AppendVarint(data, 12345)
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, uint32(42), got.InitiatorIndex)
})
t.Run("unknown field in details is skipped", func(t *testing.T) {
// Build details with a known field + unknown field
var details []byte
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = protowire.AppendVarint(details, 77)
// Unknown field 50, varint
details = protowire.AppendTag(details, 50, protowire.VarintType)
details = protowire.AppendVarint(details, 9999)
// Another known field after the unknown one
details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType)
details = protowire.AppendVarint(details, 88)
// Wrap in outer message
var data []byte
data = protowire.AppendTag(data, 1, protowire.BytesType)
data = protowire.AppendBytes(data, details)
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, uint32(77), got.InitiatorIndex)
assert.Equal(t, uint32(88), got.ResponderIndex)
})
t.Run("reserved fields 6 and 7 are skipped", func(t *testing.T) {
// Fields 6 and 7 are reserved in the proto definition
var details []byte
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = protowire.AppendVarint(details, 100)
details = protowire.AppendTag(details, 6, protowire.VarintType)
details = protowire.AppendVarint(details, 1)
details = protowire.AppendTag(details, 7, protowire.VarintType)
details = protowire.AppendVarint(details, 2)
var data []byte
data = protowire.AppendTag(data, 1, protowire.BytesType)
data = protowire.AppendBytes(data, details)
got, err := UnmarshalPayload(data)
require.NoError(t, err)
assert.Equal(t, uint32(100), got.InitiatorIndex)
})
}
func TestPayloadBytesConsumed(t *testing.T) {
t.Run("all bytes consumed on valid input", func(t *testing.T) {
original := Payload{
Cert: []byte("cert"),
CertVersion: 2,
InitiatorIndex: 100,
ResponderIndex: 200,
Time: 999,
}
data := MarshalPayload(nil, original)
got, err := UnmarshalPayload(data)
require.NoError(t, err)
// Re-marshal and compare — proves we consumed and reproduced all fields
remarshaled := MarshalPayload(nil, got)
assert.Equal(t, data, remarshaled)
})
}
// wrapDetails wraps raw detail bytes in the outer NebulaHandshake envelope
// so UnmarshalPayload can reach unmarshalPayloadDetails.
func wrapDetails(details []byte) []byte {
var out []byte
out = protowire.AppendTag(out, 1, protowire.BytesType)
out = protowire.AppendBytes(out, details)
return out
}
func TestPayloadUnmarshalErrors(t *testing.T) {
t.Run("nil input", func(t *testing.T) {
got, err := UnmarshalPayload(nil)
require.NoError(t, err)
assert.Equal(t, uint32(0), got.InitiatorIndex)
})
t.Run("truncated outer tag", func(t *testing.T) {
_, err := UnmarshalPayload([]byte{0x80})
assert.Error(t, err)
})
t.Run("truncated outer details field", func(t *testing.T) {
_, err := UnmarshalPayload([]byte{0x0a, 0x64, 0x01, 0x02, 0x03, 0x04, 0x05})
assert.Error(t, err)
})
t.Run("truncated outer unknown field", func(t *testing.T) {
// Valid tag for unknown field 99 varint, but no value follows
var data []byte
data = protowire.AppendTag(data, 99, protowire.VarintType)
_, err := UnmarshalPayload(data)
assert.Error(t, err)
})
t.Run("truncated details tag", func(t *testing.T) {
_, err := UnmarshalPayload(wrapDetails([]byte{0x80}))
assert.Error(t, err)
})
t.Run("truncated cert bytes", func(t *testing.T) {
// Field 1 (cert), bytes type, length 10 but only 2 bytes
var details []byte
details = protowire.AppendTag(details, fieldCert, protowire.BytesType)
details = append(details, 0x0a, 0x01, 0x02) // length 10, only 2 bytes
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("truncated initiator index varint", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = append(details, 0x80) // incomplete varint
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("truncated responder index varint", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType)
details = append(details, 0x80)
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("truncated time varint", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldTime, protowire.VarintType)
details = append(details, 0x80)
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("truncated cert version varint", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
details = append(details, 0x80)
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("truncated unknown field in details", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, 50, protowire.VarintType)
details = append(details, 0x80) // incomplete varint
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("cert with wrong wire type rejected", func(t *testing.T) {
// fieldCert as Varint instead of Bytes.
var details []byte
details = protowire.AppendTag(details, fieldCert, protowire.VarintType)
details = protowire.AppendVarint(details, 42)
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("initiator index with wrong wire type rejected", func(t *testing.T) {
// fieldInitiatorIndex as Bytes instead of Varint.
var details []byte
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.BytesType)
details = protowire.AppendBytes(details, []byte{1, 2, 3})
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("time with wrong wire type rejected", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldTime, protowire.BytesType)
details = protowire.AppendBytes(details, []byte{1, 2, 3})
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("cert version with wrong wire type rejected", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldCertVersion, protowire.BytesType)
details = protowire.AppendBytes(details, []byte{1, 2, 3})
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("repeated singular field follows proto3 last-wins", func(t *testing.T) {
// Per proto3, multiple instances of a singular field are accepted and
// the last value wins. We keep this behavior so that peers using
// alternative encoders aren't rejected.
var details []byte
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = protowire.AppendVarint(details, 1)
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = protowire.AppendVarint(details, 42)
got, err := UnmarshalPayload(wrapDetails(details))
require.NoError(t, err)
assert.Equal(t, uint32(42), got.InitiatorIndex)
})
t.Run("initiator index varint overflow rejected", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
details = protowire.AppendVarint(details, math.MaxUint32+1)
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
t.Run("cert version varint overflow rejected", func(t *testing.T) {
var details []byte
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
details = protowire.AppendVarint(details, math.MaxUint32+1)
_, err := UnmarshalPayload(wrapDetails(details))
assert.Error(t, err)
})
}
// FuzzPayload feeds arbitrary bytes through UnmarshalPayload to confirm it
// never panics, and for any input that parses cleanly, that re-marshal +
// re-parse is a fix-point. Inputs come from an authenticated peer (post-
// noise-decrypt), so the threat model is "valid peer behaving arbitrarily,"
// not "unauthenticated injection."
func FuzzPayload(f *testing.F) {
// Seed corpus with a handful of known-good shapes.
f.Add(MarshalPayload(nil, Payload{}))
f.Add(MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2}))
f.Add(MarshalPayload(nil, Payload{InitiatorIndex: 42, Time: 1}))
f.Add(MarshalPayload(nil, Payload{
Cert: []byte("seed-cert"),
InitiatorIndex: 1,
ResponderIndex: 2,
Time: 3,
CertVersion: 2,
}))
f.Add([]byte{})
f.Add([]byte{0xff})
f.Fuzz(func(t *testing.T, data []byte) {
p1, err := UnmarshalPayload(data)
if err != nil {
return
}
// For any input that parses, re-marshaling and re-parsing must
// yield an equivalent Payload. This catches dispatch bugs (e.g.
// emitting a field on marshal that we don't accept on parse) and
// any non-idempotent parsing behavior.
b2 := MarshalPayload(nil, p1)
p2, err := UnmarshalPayload(b2)
if err != nil {
t.Fatalf("re-parse of self-marshaled payload failed: %v\nintermediate: %x\n", err, b2)
}
if !payloadsEqual(p1, p2) {
t.Fatalf("re-marshal not idempotent\nfirst: %+v\nsecond: %+v", p1, p2)
}
})
}
func payloadsEqual(a, b Payload) bool {
return bytes.Equal(a.Cert, b.Cert) &&
a.InitiatorIndex == b.InitiatorIndex &&
a.ResponderIndex == b.ResponderIndex &&
a.Time == b.Time &&
a.CertVersion == b.CertVersion
}

View File

@@ -1,813 +0,0 @@
package nebula
import (
"bytes"
"context"
"log/slog"
"net/netip"
"time"
"github.com/flynn/noise"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header"
)
// NOISE IX Handshakes
// This function constructs a handshake packet, but does not actually send it
// Sending is done by the handshake manager
func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
err := f.handshakeManager.allocateIndex(hh)
if err != nil {
f.l.Error("Failed to generate index",
"error", err,
"vpnAddrs", hh.hostinfo.vpnAddrs,
"handshake", m{"stage": 0, "style": "ix_psk0"},
)
return false
}
cs := f.pki.getCertState()
v := cs.initiatingVersion
if hh.initiatingVersionOverride != cert.VersionPre1 {
v = hh.initiatingVersionOverride
} else if v < cert.Version2 {
// If we're connecting to a v6 address we should encourage use of a V2 cert
for _, a := range hh.hostinfo.vpnAddrs {
if a.Is6() {
v = cert.Version2
break
}
}
}
crt := cs.getCertificate(v)
if crt == nil {
f.l.Error("Unable to handshake with host because no certificate is available",
"vpnAddrs", hh.hostinfo.vpnAddrs,
"handshake", m{"stage": 0, "style": "ix_psk0"},
"certVersion", v,
)
return false
}
crtHs := cs.getHandshakeBytes(v)
if crtHs == nil {
f.l.Error("Unable to handshake with host because no certificate handshake bytes is available",
"vpnAddrs", hh.hostinfo.vpnAddrs,
"handshake", m{"stage": 0, "style": "ix_psk0"},
"certVersion", v,
)
return false
}
ci, err := NewConnectionState(cs, crt, true, noise.HandshakeIX)
if err != nil {
f.l.Error("Failed to create connection state",
"error", err,
"vpnAddrs", hh.hostinfo.vpnAddrs,
"handshake", m{"stage": 0, "style": "ix_psk0"},
"certVersion", v,
)
return false
}
hh.hostinfo.ConnectionState = ci
hs := &NebulaHandshake{
Details: &NebulaHandshakeDetails{
InitiatorIndex: hh.hostinfo.localIndexId,
Time: uint64(time.Now().UnixNano()),
Cert: crtHs,
CertVersion: uint32(v),
},
}
hsBytes, err := hs.Marshal()
if err != nil {
f.l.Error("Failed to marshal handshake message",
"error", err,
"vpnAddrs", hh.hostinfo.vpnAddrs,
"certVersion", v,
"handshake", m{"stage": 0, "style": "ix_psk0"},
)
return false
}
h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1)
msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
if err != nil {
f.l.Error("Failed to call noise.WriteMessage",
"error", err,
"vpnAddrs", hh.hostinfo.vpnAddrs,
"handshake", m{"stage": 0, "style": "ix_psk0"},
)
return false
}
// We are sending handshake packet 1, so we don't expect to receive
// handshake packet 1 from the responder
ci.window.Update(f.l, 1)
hh.hostinfo.HandshakePacket[0] = msg
hh.ready = true
return true
}
func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) {
cs := f.pki.getCertState()
crt := cs.GetDefaultCertificate()
if crt == nil {
f.l.Error("Unable to handshake with host because no certificate is available",
"from", via,
"handshake", m{"stage": 0, "style": "ix_psk0"},
"certVersion", cs.initiatingVersion,
)
return
}
ci, err := NewConnectionState(cs, crt, false, noise.HandshakeIX)
if err != nil {
f.l.Error("Failed to create connection state",
"error", err,
"from", via,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
// Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(f.l, 1)
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
f.l.Error("Failed to call noise.ReadMessage",
"error", err,
"from", via,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
hs := &NebulaHandshake{}
err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil {
f.l.Error("Failed unmarshal handshake message",
"error", err,
"from", via,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
if err != nil {
f.l.Info("Handshake did not contain a certificate",
"error", err,
"from", via,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
if err != nil {
fp, fperr := rc.Fingerprint()
if fperr != nil {
fp = "<error generating certificate fingerprint>"
}
attrs := []slog.Attr{
slog.Any("error", err),
slog.Any("from", via),
slog.Any("handshake", m{"stage": 1, "style": "ix_psk0"}),
slog.Any("certVpnNetworks", rc.Networks()),
slog.String("certFingerprint", fp),
}
if f.l.Enabled(context.Background(), slog.LevelDebug) {
attrs = append(attrs, slog.Any("cert", rc))
}
// LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that
// callers grow conditionally, which has no pair-form equivalent.
//nolint:sloglint
f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...)
return
}
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
f.l.Info("public key mismatch between certificate and handshake",
"from", via,
"handshake", m{"stage": 1, "style": "ix_psk0"},
"cert", remoteCert,
)
return
}
if remoteCert.Certificate.Version() != ci.myCert.Version() {
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
if myCertOtherVersion == nil {
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("Might be unable to handshake with host due to missing certificate version",
"error", err,
"from", via,
"handshake", m{"stage": 1, "style": "ix_psk0"},
"cert", remoteCert,
)
}
} else {
// Record the certificate we are actually using
ci.myCert = myCertOtherVersion
}
}
if len(remoteCert.Certificate.Networks()) == 0 {
f.l.Info("No networks in certificate",
"error", err,
"from", via,
"cert", remoteCert,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
certName := remoteCert.Certificate.Name()
certVersion := remoteCert.Certificate.Version()
fingerprint := remoteCert.Fingerprint
issuer := remoteCert.Certificate.Issuer()
vpnNetworks := remoteCert.Certificate.Networks()
anyVpnAddrsInCommon := false
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
for i, network := range vpnNetworks {
if f.myVpnAddrsTable.Contains(network.Addr()) {
f.l.Error("Refusing to handshake with myself",
"vpnNetworks", vpnNetworks,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
vpnAddrs[i] = network.Addr()
if f.myVpnNetworksTable.Contains(network.Addr()) {
anyVpnAddrsInCommon = true
}
}
if !via.IsRelayed {
// We only want to apply the remote allow list for direct tunnels here
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) {
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("lighthouse.remote_allow_list denied incoming handshake",
"vpnAddrs", vpnAddrs,
"from", via,
)
}
return
}
}
myIndex, err := generateIndex(f.l)
if err != nil {
f.l.Error("Failed to generate index",
"error", err,
"vpnAddrs", vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
hostinfo := &HostInfo{
ConnectionState: ci,
localIndexId: myIndex,
remoteIndexId: hs.Details.InitiatorIndex,
vpnAddrs: vpnAddrs,
HandshakePacket: make(map[uint8][]byte, 0),
lastHandshakeTime: hs.Details.Time,
relayState: RelayState{
relays: nil,
relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
}
msgRxL := f.l.With(
"vpnAddrs", vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
if anyVpnAddrsInCommon {
msgRxL.Info("Handshake message received")
} else {
//todo warn if not lighthouse or relay?
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
}
hs.Details.ResponderIndex = myIndex
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
if hs.Details.Cert == nil {
msgRxL.Error("Unable to handshake with host because no certificate handshake bytes is available",
"myCertVersion", ci.myCert.Version(),
)
return
}
hs.Details.CertVersion = uint32(ci.myCert.Version())
// Update the time in case their clock is way off from ours
hs.Details.Time = uint64(time.Now().UnixNano())
hsBytes, err := hs.Marshal()
if err != nil {
f.l.Error("Failed to marshal handshake message",
"error", err,
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
if err != nil {
f.l.Error("Failed to call noise.WriteMessage",
"error", err,
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
} else if dKey == nil || eKey == nil {
f.l.Error("Noise did not arrive at a key",
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
hostinfo.HandshakePacket[0] = make([]byte, len(packet[header.Len:]))
copy(hostinfo.HandshakePacket[0], packet[header.Len:])
// Regardless of whether you are the sender or receiver, you should arrive here
// and complete standing up the connection.
hostinfo.HandshakePacket[2] = make([]byte, len(msg))
copy(hostinfo.HandshakePacket[2], msg)
// We are sending handshake packet 2, so we don't expect to receive
// handshake packet 2 from the initiator.
ci.window.Update(f.l, 2)
ci.peerCert = remoteCert
ci.dKey = NewNebulaCipherState(dKey)
ci.eKey = NewNebulaCipherState(eKey)
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
if !via.IsRelayed {
hostinfo.SetRemote(via.UdpAddr)
}
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
if err != nil {
switch err {
case ErrAlreadySeen:
// Update remote if preferred
if existing.SetRemoteIfPreferred(f.hostMap, via) {
// Send a test packet to ensure the other side has also switched to
// the preferred remote
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}
msg = existing.HandshakePacket[2]
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
if !via.IsRelayed {
err := f.outside.WriteTo(msg, via.UdpAddr)
if err != nil {
f.l.Error("Failed to send handshake message",
"vpnAddrs", existing.vpnAddrs,
"from", via,
"handshake", m{"stage": 2, "style": "ix_psk0"},
"cached", true,
"error", err,
)
} else {
f.l.Info("Handshake message sent",
"vpnAddrs", existing.vpnAddrs,
"from", via,
"handshake", m{"stage": 2, "style": "ix_psk0"},
"cached", true,
)
}
return
} else {
if via.relay == nil {
f.l.Error("Handshake send failed: both addr and via.relay are nil.")
return
}
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.Info("Handshake message sent",
"vpnAddrs", existing.vpnAddrs,
"relay", via.relayHI.vpnAddrs[0],
"handshake", m{"stage": 2, "style": "ix_psk0"},
"cached", true,
)
return
}
case ErrExistingHostInfo:
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
f.l.Info("Handshake too old",
"vpnAddrs", vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"oldHandshakeTime", existing.lastHandshakeTime,
"newHandshakeTime", hostinfo.lastHandshakeTime,
"fingerprint", fingerprint,
"issuer", issuer,
"initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
return
case ErrLocalIndexCollision:
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
f.l.Error("Failed to add HostInfo due to localIndex collision",
"vpnAddrs", vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 1, "style": "ix_psk0"},
"localIndex", hostinfo.localIndexId,
"collision", existing.vpnAddrs,
)
return
default:
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
// And we forget to update it here
f.l.Error("Failed to add HostInfo to HostMap",
"error", err,
"vpnAddrs", vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 1, "style": "ix_psk0"},
)
return
}
}
// Do the send
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
if !via.IsRelayed {
err = f.outside.WriteTo(msg, via.UdpAddr)
log := f.l.With(
"vpnAddrs", vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
if err != nil {
log.Error("Failed to send handshake", "error", err)
} else {
log.Info("Handshake message sent")
}
} else {
if via.relay == nil {
f.l.Error("Handshake send failed: both addr and via.relay are nil.")
return
}
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
// I successfully received a handshake. Just in case I marked this tunnel as 'Disestablished', ensure
// it's correctly marked as working.
via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established)
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.Info("Handshake message sent",
"vpnAddrs", vpnAddrs,
"relay", via.relayHI.vpnAddrs[0],
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
}
f.connectionManager.AddTrafficWatch(hostinfo)
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
// Don't wait for UpdateWorker
if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) {
f.lightHouse.TriggerUpdate()
}
return
}
func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
if hh == nil {
// Nothing here to tear down, got a bogus stage 2 packet
return true
}
hh.Lock()
defer hh.Unlock()
hostinfo := hh.hostinfo
if !via.IsRelayed {
// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
if f.l.Enabled(context.Background(), slog.LevelDebug) {
f.l.Debug("lighthouse.remote_allow_list denied incoming handshake",
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
)
}
return false
}
}
ci := hostinfo.ConnectionState
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
f.l.Error("Failed to call noise.ReadMessage",
"error", err,
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
"handshake", m{"stage": 2, "style": "ix_psk0"},
"header", h,
)
// We don't want to tear down the connection on a bad ReadMessage because it could be an attacker trying
// to DOS us. Every other error condition after should to allow a possible good handshake to complete in the
// near future
return false
} else if dKey == nil || eKey == nil {
f.l.Error("Noise did not arrive at a key",
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
// This should be impossible in IX but just in case, if we get here then there is no chance to recover
// the handshake state machine. Tear it down
return true
}
hs := &NebulaHandshake{}
err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil {
f.l.Error("Failed unmarshal handshake message",
"error", err,
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
return true
}
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
if err != nil {
f.l.Info("Handshake did not contain a certificate",
"error", err,
"from", via,
"vpnAddrs", hostinfo.vpnAddrs,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
return true
}
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
if err != nil {
fp, err := rc.Fingerprint()
if err != nil {
fp = "<error generating certificate fingerprint>"
}
attrs := []slog.Attr{
slog.Any("error", err),
slog.Any("from", via),
slog.Any("vpnAddrs", hostinfo.vpnAddrs),
slog.Any("handshake", m{"stage": 2, "style": "ix_psk0"}),
slog.String("certFingerprint", fp),
slog.Any("certVpnNetworks", rc.Networks()),
}
if f.l.Enabled(context.Background(), slog.LevelDebug) {
attrs = append(attrs, slog.Any("cert", rc))
}
// LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that
// callers grow conditionally, which has no pair-form equivalent.
//nolint:sloglint
f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...)
return true
}
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
f.l.Info("public key mismatch between certificate and handshake",
"from", via,
"handshake", m{"stage": 2, "style": "ix_psk0"},
"cert", remoteCert,
)
return true
}
if len(remoteCert.Certificate.Networks()) == 0 {
f.l.Info("No networks in certificate",
"error", err,
"from", via,
"vpnAddrs", hostinfo.vpnAddrs,
"cert", remoteCert,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
return true
}
vpnNetworks := remoteCert.Certificate.Networks()
certName := remoteCert.Certificate.Name()
certVersion := remoteCert.Certificate.Version()
fingerprint := remoteCert.Fingerprint
issuer := remoteCert.Certificate.Issuer()
hostinfo.remoteIndexId = hs.Details.ResponderIndex
hostinfo.lastHandshakeTime = hs.Details.Time
// Store their cert and our symmetric keys
ci.peerCert = remoteCert
ci.dKey = NewNebulaCipherState(dKey)
ci.eKey = NewNebulaCipherState(eKey)
// Make sure the current udpAddr being used is set for responding
if !via.IsRelayed {
hostinfo.SetRemote(via.UdpAddr)
} else {
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
}
correctHostResponded := false
anyVpnAddrsInCommon := false
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
for i, network := range vpnNetworks {
vpnAddrs[i] = network.Addr()
if f.myVpnNetworksTable.Contains(network.Addr()) {
anyVpnAddrsInCommon = true
}
if hostinfo.vpnAddrs[0] == network.Addr() {
// todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not?
correctHostResponded = true
}
}
// Ensure the right host responded
if !correctHostResponded {
f.l.Info("Incorrect host responded to handshake",
"intendedVpnAddrs", hostinfo.vpnAddrs,
"haveVpnNetworks", vpnNetworks,
"from", via,
"certName", certName,
"certVersion", certVersion,
"handshake", m{"stage": 2, "style": "ix_psk0"},
)
// Release our old handshake from pending, it should not continue
f.handshakeManager.DeleteHostInfo(hostinfo)
// Create a new hostinfo/handshake for the intended vpn ip
//TODO is hostinfo.vpnAddrs[0] always the address to use?
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
// Block the current used address
newHH.hostinfo.remotes = hostinfo.remotes
newHH.hostinfo.remotes.BlockRemote(via)
f.l.Info("Blocked addresses for handshakes",
"blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes(),
"vpnNetworks", vpnNetworks,
"remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges()),
)
// Swap the packet store to benefit the original intended recipient
newHH.packetStore = hh.packetStore
hh.packetStore = []*cachedPacket{}
// Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down
hostinfo.vpnAddrs = vpnAddrs
f.sendCloseTunnel(hostinfo)
})
return true
}
// Mark packet 2 as seen so it doesn't show up as missed
ci.window.Update(f.l, 2)
duration := time.Since(hh.startTime).Nanoseconds()
msgRxL := f.l.With(
"vpnAddrs", vpnAddrs,
"from", via,
"certName", certName,
"certVersion", certVersion,
"fingerprint", fingerprint,
"issuer", issuer,
"initiatorIndex", hs.Details.InitiatorIndex,
"responderIndex", hs.Details.ResponderIndex,
"remoteIndex", h.RemoteIndex,
"handshake", m{"stage": 2, "style": "ix_psk0"},
"durationNs", duration,
"sentCachedPackets", len(hh.packetStore),
)
if anyVpnAddrsInCommon {
msgRxL.Info("Handshake message received")
} else {
//todo warn if not lighthouse or relay?
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
}
// Build up the radix for the firewall if we have subnets in the cert
hostinfo.vpnAddrs = vpnAddrs
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
f.handshakeManager.Complete(hostinfo, f)
f.connectionManager.AddTrafficWatch(hostinfo)
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("Sending stored packets",
"count", len(hh.packetStore),
)
}
if len(hh.packetStore) > 0 {
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
for _, cp := range hh.packetStore {
cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out)
}
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
}
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
f.metricHandshakes.Update(duration)
// Don't wait for UpdateWorker
if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) {
f.lightHouse.TriggerUpdate()
}
return false
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/rcrowley/go-metrics" "github.com/rcrowley/go-metrics"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/handshake"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
) )
@@ -23,6 +24,18 @@ const (
DefaultHandshakeRetries = 10 DefaultHandshakeRetries = 10
DefaultHandshakeTriggerBuffer = 64 DefaultHandshakeTriggerBuffer = 64
DefaultUseRelays = true DefaultUseRelays = true
// maxCachedPackets is how many unsent packets we'll buffer per pending
// handshake before dropping further ones.
maxCachedPackets = 100
// HandshakePacket map keys mirror the IX protocol stage convention:
// stage 0 = the initiator's first message (and what the responder
// receives, stripped of header)
// stage 2 = the responder's reply
// Other handshake patterns will need new keys when added.
handshakePacketStage0 uint8 = 0
handshakePacketStage2 uint8 = 2
) )
var ( var (
@@ -76,10 +89,11 @@ type HandshakeHostInfo struct {
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
hostinfo *HostInfo hostinfo *HostInfo
machine *handshake.Machine // The handshake state machine, set during stage 0 (initiator) or beginHandshake (responder multi-message)
} }
func (hh *HandshakeHostInfo) cachePacket(l *slog.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) { func (hh *HandshakeHostInfo) cachePacket(l *slog.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
if len(hh.packetStore) < 100 { if len(hh.packetStore) < maxCachedPackets {
tempPacket := make([]byte, len(packet)) tempPacket := make([]byte, len(packet))
copy(tempPacket, packet) copy(tempPacket, packet)
@@ -137,6 +151,18 @@ func (hm *HandshakeManager) Run(ctx context.Context) {
} }
func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *header.H) { func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *header.H) {
// Gate on known handshake subtypes. Unknown subtypes (or future ones we
// don't yet support) are dropped here rather than silently routed through
// the IX path. Add a case when introducing a new pattern.
switch h.Subtype {
case header.HandshakeIXPSK0:
// supported
default:
hm.l.Debug("dropping handshake with unsupported subtype",
"from", via, "subtype", h.Subtype)
return
}
// First remote allow list check before we know the vpnIp // First remote allow list check before we know the vpnIp
if !via.IsRelayed { if !via.IsRelayed {
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) { if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) {
@@ -145,19 +171,27 @@ func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *head
} }
} }
switch h.Subtype { // First message of a new handshake. The wire format requires RemoteIndex
case header.HandshakeIXPSK0: // to be zero here (the initiator has no responder index to fill in yet),
switch h.MessageCounter { // and generateIndex never allocates 0, so any non-zero RemoteIndex on a
case 1: // stage-1 packet is malformed or someone probing for an index collision.
ixHandshakeStage1(hm.f, via, packet, h) // Drop without paying the cost of running noise on a pending Machine.
if h.MessageCounter == 1 {
if h.RemoteIndex != 0 {
hm.l.Debug("dropping stage-1 handshake with non-zero RemoteIndex",
"from", via, "remoteIndex", h.RemoteIndex)
return
}
hm.beginHandshake(via, packet, h)
return
}
case 2: // Continuation message must match a pending handshake by index.
newHostinfo := hm.queryIndex(h.RemoteIndex) // Anything else is an orphaned packet (e.g., late retransmit after
tearDown := ixHandshakeStage2(hm.f, via, newHostinfo, packet, h) // timeout) and is dropped.
if tearDown && newHostinfo != nil { if hh := hm.queryIndex(h.RemoteIndex); hh != nil {
hm.DeleteHostInfo(newHostinfo.hostinfo) hm.continueHandshake(via, hh, packet)
} return
}
} }
} }
@@ -183,13 +217,22 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
hostinfo := hh.hostinfo hostinfo := hh.hostinfo
// If we are out of time, clean up // If we are out of time, clean up
if hh.counter >= hm.config.retries { if hh.counter >= hm.config.retries {
hh.hostinfo.logger(hm.l).Info("Handshake timed out", fields := []any{
"udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()), "udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()),
"initiatorIndex", hh.hostinfo.localIndexId, "initiatorIndex", hh.hostinfo.localIndexId,
"remoteIndex", hh.hostinfo.remoteIndexId, "remoteIndex", hh.hostinfo.remoteIndexId,
"handshake", m{"stage": 1, "style": "ix_psk0"},
"durationNs", time.Since(hh.startTime).Nanoseconds(), "durationNs", time.Since(hh.startTime).Nanoseconds(),
) }
// hh.machine can be nil here if buildStage0Packet never succeeded
// (e.g., no certificate available). In that case there's no useful
// handshake metadata to log.
if hh.machine != nil {
fields = append(fields, "handshake", m{
"stage": uint64(hh.machine.MessageIndex()),
"style": header.SubTypeName(header.Handshake, hh.machine.Subtype()),
})
}
hh.hostinfo.logger(hm.l).Info("Handshake timed out", fields...)
hm.metricTimedOut.Inc(1) hm.metricTimedOut.Inc(1)
hm.DeleteHostInfo(hostinfo) hm.DeleteHostInfo(hostinfo)
return return
@@ -200,12 +243,25 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
// Check if we have a handshake packet to transmit yet // Check if we have a handshake packet to transmit yet
if !hh.ready { if !hh.ready {
if !ixHandshakeStage0(hm.f, hh) { if !hm.buildStage0Packet(hh) {
hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval*time.Duration(hh.counter)) hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval*time.Duration(hh.counter))
return return
} }
} }
// TODO: this hardcodes "always retransmit stage 0", which is correct for
// IX (the initiator only ever sends one packet, msg1) but wrong the
// moment a 3+ message pattern lands. The retry loop should resend the
// most recent outgoing message, not always stage 0. That implies
// HandshakeHostInfo tracking a single "currentOutbound" packet (bytes +
// header metadata) that gets replaced as the handshake progresses,
// instead of indexing into HandshakePacket.
stage0 := hostinfo.HandshakePacket[handshakePacketStage0]
hsFields := m{
"stage": uint64(hh.machine.MessageIndex()),
"style": header.SubTypeName(header.Handshake, hh.machine.Subtype()),
}
// Get a remotes object if we don't already have one. // Get a remotes object if we don't already have one.
// This is mainly to protect us as this should never be the case // This is mainly to protect us as this should never be the case
// NB ^ This comment doesn't jive. It's how the thing gets initialized. // NB ^ This comment doesn't jive. It's how the thing gets initialized.
@@ -239,13 +295,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply // Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
var sentTo []netip.AddrPort var sentTo []netip.AddrPort
hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr netip.AddrPort, _ bool) { hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr netip.AddrPort, _ bool) {
hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1) hm.messageMetrics.Tx(header.Handshake, hh.machine.Subtype(), 1)
err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr) err := hm.outside.WriteTo(stage0, addr)
if err != nil { if err != nil {
hostinfo.logger(hm.l).Error("Failed to send handshake message", hostinfo.logger(hm.l).Error("Failed to send handshake message",
"udpAddr", addr, "udpAddr", addr,
"initiatorIndex", hostinfo.localIndexId, "initiatorIndex", hostinfo.localIndexId,
"handshake", m{"stage": 1, "style": "ix_psk0"}, "handshake", hsFields,
"error", err, "error", err,
) )
@@ -260,13 +316,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
hostinfo.logger(hm.l).Info("Handshake message sent", hostinfo.logger(hm.l).Info("Handshake message sent",
"udpAddrs", sentTo, "udpAddrs", sentTo,
"initiatorIndex", hostinfo.localIndexId, "initiatorIndex", hostinfo.localIndexId,
"handshake", m{"stage": 1, "style": "ix_psk0"}, "handshake", hsFields,
) )
} else if hm.l.Enabled(context.Background(), slog.LevelDebug) { } else if hm.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(hm.l).Debug("Handshake message sent", hostinfo.logger(hm.l).Debug("Handshake message sent",
"udpAddrs", sentTo, "udpAddrs", sentTo,
"initiatorIndex", hostinfo.localIndexId, "initiatorIndex", hostinfo.localIndexId,
"handshake", m{"stage": 1, "style": "ix_psk0"}, "handshake", hsFields,
) )
} }
@@ -348,7 +404,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
switch existingRelay.State { switch existingRelay.State {
case Established: case Established:
hostinfo.logger(hm.l).Info("Send handshake via relay", "relay", relay.String()) hostinfo.logger(hm.l).Info("Send handshake via relay", "relay", relay.String())
hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[handshakePacketStage0], make([]byte, 12), make([]byte, mtu), false)
case Disestablished: case Disestablished:
// Mark this relay as 'requested' // Mark this relay as 'requested'
relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested) relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested)
@@ -587,7 +643,7 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
// allocateIndex generates a unique localIndexId for this HostInfo // allocateIndex generates a unique localIndexId for this HostInfo
// and adds it to the pendingHostMap. Will error if we are unable to generate // and adds it to the pendingHostMap. Will error if we are unable to generate
// a unique localIndexId // a unique localIndexId
func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error { func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) (uint32, error) {
hm.mainHostMap.RLock() hm.mainHostMap.RLock()
defer hm.mainHostMap.RUnlock() defer hm.mainHostMap.RUnlock()
hm.Lock() hm.Lock()
@@ -596,7 +652,7 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error {
for range 32 { for range 32 {
index, err := generateIndex(hm.l) index, err := generateIndex(hm.l)
if err != nil { if err != nil {
return err return 0, err
} }
_, inPending := hm.indexes[index] _, inPending := hm.indexes[index]
@@ -605,11 +661,11 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error {
if !inMain && !inPending { if !inMain && !inPending {
hh.hostinfo.localIndexId = index hh.hostinfo.localIndexId = index
hm.indexes[index] = hh hm.indexes[index] = hh
return nil return index, nil
} }
} }
return errors.New("failed to generate unique localIndexId") return 0, errors.New("failed to generate unique localIndexId")
} }
func (hm *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { func (hm *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
@@ -728,3 +784,524 @@ func generateIndex(l *slog.Logger) (uint32, error) {
func hsTimeout(tries int64, interval time.Duration) time.Duration { func hsTimeout(tries int64, interval time.Duration) time.Duration {
return time.Duration(tries / 2 * ((2 * int64(interval)) + (tries-1)*int64(interval))) return time.Duration(tries / 2 * ((2 * int64(interval)) + (tries-1)*int64(interval)))
} }
// buildStage0Packet creates the initial handshake packet for the initiator.
func (hm *HandshakeManager) buildStage0Packet(hh *HandshakeHostInfo) bool {
cs := hm.f.pki.getCertState()
v := cs.DefaultVersion()
if hh.initiatingVersionOverride != cert.VersionPre1 {
v = hh.initiatingVersionOverride
} else if v < cert.Version2 {
for _, a := range hh.hostinfo.vpnAddrs {
if a.Is6() {
v = cert.Version2
break
}
}
}
cred := cs.GetCredential(v)
if cred == nil {
hm.f.l.Error("Unable to handshake with host because no certificate is available",
"vpnAddrs", hh.hostinfo.vpnAddrs, "certVersion", v)
return false
}
machine, err := handshake.NewMachine(
v, cs.GetCredential,
hm.certVerifier(), func() (uint32, error) { return hm.allocateIndex(hh) },
true, header.HandshakeIXPSK0,
)
if err != nil {
hm.f.l.Error("Failed to create handshake machine",
"vpnAddrs", hh.hostinfo.vpnAddrs, "error", err)
return false
}
msg, err := machine.Initiate(nil)
if err != nil {
hm.f.l.Error("Failed to initiate handshake",
"vpnAddrs", hh.hostinfo.vpnAddrs, "error", err)
return false
}
// hostinfo.ConnectionState stays nil until the handshake completes in
// continueHandshake. Pre-completion control surfaces guard with nil
// checks; the data plane never observes a pending hostinfo.
hh.hostinfo.HandshakePacket[handshakePacketStage0] = msg
hh.machine = machine
hh.ready = true
return true
}
// beginHandshake handles an incoming handshake packet that doesn't match any
// existing pending handshake. It creates a new responder Machine and processes
// the first message.
func (hm *HandshakeManager) beginHandshake(via ViaSender, packet []byte, h *header.H) {
f := hm.f
cs := f.pki.getCertState()
v := cs.DefaultVersion()
if cs.GetCredential(v) == nil {
f.l.Error("Unable to handshake with host because no certificate is available",
"from", via, "certVersion", v)
return
}
machine, err := handshake.NewMachine(
v, cs.GetCredential,
hm.certVerifier(), func() (uint32, error) { return generateIndex(f.l) },
false, header.HandshakeIXPSK0,
)
if err != nil {
f.l.Error("Failed to create handshake machine", "from", via, "error", err)
return
}
response, result, err := machine.ProcessPacket(nil, packet)
if err != nil {
f.l.Error("Failed to process handshake packet", "from", via, "error", err)
return
}
if result == nil {
// Multi-message pattern: the responder Machine would need to be
// registered in hm.indexes so a future inbound packet finds it via
// continueHandshake. The current manager doesn't do that yet, so
// fail loudly rather than silently dropping the in-flight handshake.
// TODO: support multi-message responder flows (XX, pqIX, etc.).
// See also the IX-shaped cipher key assignment in handshake.Machine.
f.l.Error("multi-message handshake responder is not supported",
"from", via, "error", handshake.ErrMultiMessageUnsupported)
return
}
remoteCert := result.RemoteCert
if remoteCert == nil {
f.l.Error("Handshake did not produce a peer certificate", "from", via)
return
}
// Validate peer identity
vpnAddrs, anyVpnAddrsInCommon, ok := hm.validatePeerCert(via, remoteCert)
if !ok {
return
}
hostinfo := &HostInfo{
ConnectionState: newConnectionStateFromResult(result),
localIndexId: result.LocalIndex,
remoteIndexId: result.RemoteIndex,
vpnAddrs: vpnAddrs,
HandshakePacket: make(map[uint8][]byte, 0),
lastHandshakeTime: result.HandshakeTime,
relayState: RelayState{
relays: nil,
relayForByAddr: map[netip.Addr]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
}
msg := "Handshake message received"
if !anyVpnAddrsInCommon {
msg = "Handshake message received, but no vpnNetworks in common."
}
f.l.Info(msg,
"vpnAddrs", vpnAddrs,
"from", via,
"certName", remoteCert.Certificate.Name(),
"certVersion", remoteCert.Certificate.Version(),
"fingerprint", remoteCert.Fingerprint,
"issuer", remoteCert.Certificate.Issuer(),
"initiatorIndex", result.RemoteIndex,
"responderIndex", result.LocalIndex,
"handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())},
)
// packet aliases the listener's incoming buffer, so this copy must stay.
hostinfo.HandshakePacket[handshakePacketStage0] = make([]byte, len(packet[header.Len:]))
copy(hostinfo.HandshakePacket[handshakePacketStage0], packet[header.Len:])
// response was freshly allocated by ProcessPacket; safe to retain directly.
if response != nil {
hostinfo.HandshakePacket[handshakePacketStage2] = response
}
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
if !via.IsRelayed {
hostinfo.SetRemote(via.UdpAddr)
}
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
existing, err := hm.CheckAndComplete(hostinfo, handshakePacketStage0, f)
if err != nil {
hm.handleCheckAndCompleteError(err, existing, hostinfo, via)
return
}
hm.sendHandshakeResponse(via, response, hostinfo, false)
f.connectionManager.AddTrafficWatch(hostinfo)
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
// Don't wait for UpdateWorker
if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) {
f.lightHouse.TriggerUpdate()
}
}
// continueHandshake feeds an incoming packet to an existing pending handshake Machine.
func (hm *HandshakeManager) continueHandshake(via ViaSender, hh *HandshakeHostInfo, packet []byte) {
f := hm.f
hh.Lock()
defer hh.Unlock()
// Re-verify hh is still tracked. Between queryIndex returning and us taking
// hh.Lock, handleOutbound may have timed out and deleted it. Once we hold
// hh.Lock no other deleter can race our index: handleOutbound also takes
// hh.Lock first, and handleRecvError targets a main-hostmap entry with a
// different localIndexId.
hm.RLock()
cur, ok := hm.indexes[hh.hostinfo.localIndexId]
hm.RUnlock()
if !ok || cur != hh {
return
}
hostinfo := hh.hostinfo
if !via.IsRelayed {
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
f.l.Debug("lighthouse.remote_allow_list denied incoming handshake",
"vpnAddrs", hostinfo.vpnAddrs, "from", via)
return
}
}
machine := hh.machine
if machine == nil {
f.l.Error("No handshake machine available for continuation",
"vpnAddrs", hostinfo.vpnAddrs, "from", via)
hm.DeleteHostInfo(hostinfo)
return
}
response, result, err := machine.ProcessPacket(nil, packet)
if err != nil {
// Recoverable errors are routine noise, log at Debug. Fatal errors get a Warn.
if machine.Failed() {
f.l.Warn("Failed to process handshake packet, abandoning",
"vpnAddrs", hostinfo.vpnAddrs, "from", via, "error", err)
hm.DeleteHostInfo(hostinfo)
} else {
f.l.Debug("Failed to process handshake packet",
"vpnAddrs", hostinfo.vpnAddrs, "from", via, "error", err)
}
return
}
if response != nil {
hm.sendHandshakeResponse(via, response, hostinfo, false)
}
if result == nil {
return
}
// Handshake complete; build the ConnectionState now that we have keys and a verified peer cert.
hostinfo.ConnectionState = newConnectionStateFromResult(result)
remoteCert := result.RemoteCert
if remoteCert == nil {
f.l.Error("Handshake completed without peer certificate",
"vpnAddrs", hostinfo.vpnAddrs, "from", via)
hm.DeleteHostInfo(hostinfo)
return
}
vpnNetworks := remoteCert.Certificate.Networks()
hostinfo.remoteIndexId = result.RemoteIndex
hostinfo.lastHandshakeTime = result.HandshakeTime
if !via.IsRelayed {
hostinfo.SetRemote(via.UdpAddr)
} else {
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
}
// Verify correct host responded (initiator check)
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
correctHostResponded := false
anyVpnAddrsInCommon := false
for i, network := range vpnNetworks {
// inside.go drops self-routed packets at the firewall stage, but we'd
// rather not let a self-handshake complete in the first place: it
// wastes a hostmap slot, suppresses no log, and obscures routing
// misconfig. Explicit refusal here mirrors the responder-side check
// in validatePeerCert.
if f.myVpnAddrsTable.Contains(network.Addr()) {
f.l.Error("Refusing to handshake with myself",
"vpnNetworks", vpnNetworks,
"from", via,
"certName", remoteCert.Certificate.Name(),
"certVersion", remoteCert.Certificate.Version(),
"fingerprint", remoteCert.Fingerprint,
"issuer", remoteCert.Certificate.Issuer(),
"handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())},
)
hm.DeleteHostInfo(hostinfo)
return
}
vpnAddrs[i] = network.Addr()
if hostinfo.vpnAddrs[0] == network.Addr() {
correctHostResponded = true
}
if f.myVpnNetworksTable.Contains(network.Addr()) {
anyVpnAddrsInCommon = true
}
}
if !correctHostResponded {
f.l.Info("Incorrect host responded to handshake",
"intendedVpnAddrs", hostinfo.vpnAddrs,
"haveVpnNetworks", vpnNetworks,
"from", via,
"certName", remoteCert.Certificate.Name(),
"certVersion", remoteCert.Certificate.Version(),
"fingerprint", remoteCert.Fingerprint,
"issuer", remoteCert.Certificate.Issuer(),
"handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())},
)
hm.DeleteHostInfo(hostinfo)
hm.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
newHH.hostinfo.remotes = hostinfo.remotes
newHH.hostinfo.remotes.BlockRemote(via)
newHH.packetStore = hh.packetStore
hh.packetStore = []*cachedPacket{}
hostinfo.vpnAddrs = vpnAddrs
f.sendCloseTunnel(hostinfo)
})
return
}
duration := time.Since(hh.startTime).Nanoseconds()
msg := "Handshake message received"
if !anyVpnAddrsInCommon {
msg = "Handshake message received, but no vpnNetworks in common."
}
f.l.Info(msg,
"vpnAddrs", vpnAddrs,
"from", via,
"certName", remoteCert.Certificate.Name(),
"certVersion", remoteCert.Certificate.Version(),
"fingerprint", remoteCert.Fingerprint,
"issuer", remoteCert.Certificate.Issuer(),
"initiatorIndex", result.LocalIndex,
"responderIndex", result.RemoteIndex,
"handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())},
"durationNs", duration,
"sentCachedPackets", len(hh.packetStore),
)
hostinfo.vpnAddrs = vpnAddrs
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
hm.Complete(hostinfo, f)
f.connectionManager.AddTrafficWatch(hostinfo)
if len(hh.packetStore) > 0 {
if f.l.Enabled(context.Background(), slog.LevelDebug) {
hostinfo.logger(f.l).Debug("Sending stored packets", "count", len(hh.packetStore))
}
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
for _, cp := range hh.packetStore {
cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out)
}
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
}
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
f.metricHandshakes.Update(duration)
// Don't wait for UpdateWorker
if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) {
f.lightHouse.TriggerUpdate()
}
}
// validatePeerCert checks the peer certificate for self-connection and remote allow list.
// Returns the VPN addrs, whether any of them fall within one of our own VPN
// networks, and true if valid; false if rejected.
func (hm *HandshakeManager) validatePeerCert(via ViaSender, remoteCert *cert.CachedCertificate) ([]netip.Addr, bool, bool) {
f := hm.f
vpnNetworks := remoteCert.Certificate.Networks()
// The cert package rejects host certs with no networks at parse time, so
// reaching this state would mean an invariant was bypassed elsewhere.
// Refuse explicitly so downstream code (which indexes vpnAddrs[0]) can't
// panic if that invariant ever changes.
if len(vpnNetworks) == 0 {
f.l.Info("No networks in certificate",
"from", via, "cert", remoteCert)
return nil, false, false
}
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
anyVpnAddrsInCommon := false
for i, network := range vpnNetworks {
if f.myVpnAddrsTable.Contains(network.Addr()) {
f.l.Error("Refusing to handshake with myself",
"vpnNetworks", vpnNetworks,
"from", via,
"certName", remoteCert.Certificate.Name(),
"certVersion", remoteCert.Certificate.Version(),
"fingerprint", remoteCert.Fingerprint,
"issuer", remoteCert.Certificate.Issuer(),
)
return nil, false, false
}
vpnAddrs[i] = network.Addr()
if f.myVpnNetworksTable.Contains(network.Addr()) {
anyVpnAddrsInCommon = true
}
}
if !via.IsRelayed {
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) {
f.l.Debug("lighthouse.remote_allow_list denied incoming handshake",
"vpnAddrs", vpnAddrs, "from", via)
return nil, false, false
}
}
return vpnAddrs, anyVpnAddrsInCommon, true
}
// sendHandshakeResponse sends a handshake response via the appropriate transport.
// cached is true when msg is a stored response being retransmitted because
// the peer's stage-1 retransmit landed (the ErrAlreadySeen path); false on a
// fresh response.
func (hm *HandshakeManager) sendHandshakeResponse(via ViaSender, msg []byte, hostinfo *HostInfo, cached bool) {
if msg == nil {
return
}
f := hm.f
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
// Common log fields. peerCert may be nil during intermediate
// multi-message flows (handshake hasn't completed yet); skip the cert
// block if so.
logFields := []any{
"vpnAddrs", hostinfo.vpnAddrs,
"handshake", m{"stage": uint64(2), "style": header.SubTypeName(header.Handshake, header.HandshakeIXPSK0)},
"cached", cached,
"initiatorIndex", hostinfo.remoteIndexId,
"responderIndex", hostinfo.localIndexId,
}
if peerCert := hostinfo.ConnectionState.peerCert; peerCert != nil {
logFields = append(logFields,
"certName", peerCert.Certificate.Name(),
"certVersion", peerCert.Certificate.Version(),
"fingerprint", peerCert.Fingerprint,
"issuer", peerCert.Certificate.Issuer(),
)
}
if !via.IsRelayed {
fields := append(logFields, "from", via)
err := f.outside.WriteTo(msg, via.UdpAddr)
if err != nil {
f.l.Error("Failed to send handshake message", append(fields, "error", err)...)
} else {
f.l.Info("Handshake message sent", fields...)
}
} else {
if via.relay == nil {
f.l.Error("Handshake send failed: both addr and via.relay are nil.")
return
}
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
// We received a valid handshake on this relay, so make sure the relay
// state reflects that, in case it had been marked Disestablished.
via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established)
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
f.l.Info("Handshake message sent", append(logFields, "relay", via.relayHI.vpnAddrs[0])...)
}
}
// handleCheckAndCompleteError handles errors from CheckAndComplete.
// This only fires from the responder-side beginHandshake path, after the
// peer cert has been validated and ConnectionState populated, so peerCert
// is always non-nil for the cases that log it.
func (hm *HandshakeManager) handleCheckAndCompleteError(err error, existing, hostinfo *HostInfo, via ViaSender) {
f := hm.f
peerCert := hostinfo.ConnectionState.peerCert
hsFields := m{"stage": uint64(1), "style": header.SubTypeName(header.Handshake, header.HandshakeIXPSK0)}
switch err {
case ErrAlreadySeen:
if existing.SetRemoteIfPreferred(f.hostMap, via) {
f.SendMessageToVpnAddr(header.Test, header.TestRequest, hostinfo.vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}
// Resend the original response. The peer is committed to that response's
// ephemeral keys; a freshly-built one would have different keys and break
// the tunnel even though both sides "completed" the handshake.
if msg := existing.HandshakePacket[handshakePacketStage2]; msg != nil {
hm.sendHandshakeResponse(via, msg, existing, true)
}
case ErrExistingHostInfo:
f.l.Info("Handshake too old",
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
"certName", peerCert.Certificate.Name(),
"certVersion", peerCert.Certificate.Version(),
"fingerprint", peerCert.Fingerprint,
"issuer", peerCert.Certificate.Issuer(),
"oldHandshakeTime", existing.lastHandshakeTime,
"newHandshakeTime", hostinfo.lastHandshakeTime,
"initiatorIndex", hostinfo.remoteIndexId,
"responderIndex", hostinfo.localIndexId,
"handshake", hsFields,
)
f.SendMessageToVpnAddr(header.Test, header.TestRequest, hostinfo.vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
case ErrLocalIndexCollision:
f.l.Error("Failed to add HostInfo due to localIndex collision",
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
"certName", peerCert.Certificate.Name(),
"certVersion", peerCert.Certificate.Version(),
"fingerprint", peerCert.Fingerprint,
"issuer", peerCert.Certificate.Issuer(),
"localIndex", hostinfo.localIndexId,
"initiatorIndex", hostinfo.remoteIndexId,
"responderIndex", hostinfo.localIndexId,
"handshake", hsFields,
)
default:
f.l.Error("Failed to add HostInfo to HostMap",
"vpnAddrs", hostinfo.vpnAddrs,
"from", via,
"error", err,
"certName", peerCert.Certificate.Name(),
"certVersion", peerCert.Certificate.Version(),
"fingerprint", peerCert.Fingerprint,
"issuer", peerCert.Certificate.Issuer(),
"initiatorIndex", hostinfo.remoteIndexId,
"responderIndex", hostinfo.localIndexId,
"handshake", hsFields,
)
}
}
// certVerifier returns a CertVerifier that validates certs against the current CA pool.
func (hm *HandshakeManager) certVerifier() handshake.CertVerifier {
return func(c cert.Certificate) (*cert.CachedCertificate, error) {
return hm.f.pki.GetCAPool().VerifyCertificate(time.Now(), c)
}
}

View File

@@ -5,6 +5,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/gaissmai/bart"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
@@ -27,7 +28,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
initiatingVersion: cert.Version1, initiatingVersion: cert.Version1,
privateKey: []byte{}, privateKey: []byte{},
v1Cert: &dummyCert{version: cert.Version1}, v1Cert: &dummyCert{version: cert.Version1},
v1HandshakeBytes: []byte{}, v1Credential: nil,
} }
blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig) blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
@@ -100,3 +101,137 @@ func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo {
func (mw *mockEncWriter) GetCertState() *CertState { func (mw *mockEncWriter) GetCertState() *CertState {
return &CertState{initiatingVersion: cert.Version2} return &CertState{initiatingVersion: cert.Version2}
} }
func TestValidatePeerCert(t *testing.T) {
l := test.NewLogger()
myNetwork := netip.MustParsePrefix("10.0.0.1/24")
myAddrTable := new(bart.Lite)
myAddrTable.Insert(netip.PrefixFrom(myNetwork.Addr(), myNetwork.Addr().BitLen()))
myNetTable := new(bart.Lite)
myNetTable.Insert(myNetwork.Masked())
newHM := func() *HandshakeManager {
hm := NewHandshakeManager(l, newHostMap(l), newTestLighthouse(), &udp.NoopConn{}, defaultHandshakeConfig)
hm.f = &Interface{
handshakeManager: hm,
pki: &PKI{},
l: l,
myVpnAddrsTable: myAddrTable,
myVpnNetworksTable: myNetTable,
lightHouse: hm.lightHouse,
}
return hm
}
cached := func(networks ...netip.Prefix) *cert.CachedCertificate {
return &cert.CachedCertificate{
Certificate: &dummyCert{name: "peer", networks: networks},
}
}
via := ViaSender{
UdpAddr: netip.MustParseAddrPort("198.51.100.7:4242"),
IsRelayed: true, // skip the remote allow list (covered separately)
}
t.Run("addr inside our networks sets anyVpnAddrsInCommon", func(t *testing.T) {
hm := newHM()
// 10.0.0.2 falls inside our 10.0.0.0/24
addrs, common, ok := hm.validatePeerCert(via, cached(netip.MustParsePrefix("10.0.0.2/24")))
assert.True(t, ok)
assert.True(t, common)
assert.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.2")}, addrs)
})
t.Run("addr outside our networks leaves anyVpnAddrsInCommon false", func(t *testing.T) {
hm := newHM()
addrs, common, ok := hm.validatePeerCert(via, cached(netip.MustParsePrefix("192.168.1.5/24")))
assert.True(t, ok)
assert.False(t, common)
assert.Equal(t, []netip.Addr{netip.MustParseAddr("192.168.1.5")}, addrs)
})
t.Run("any matching network is enough", func(t *testing.T) {
hm := newHM()
addrs, common, ok := hm.validatePeerCert(via, cached(
netip.MustParsePrefix("192.168.1.5/24"),
netip.MustParsePrefix("10.0.0.42/24"),
))
assert.True(t, ok)
assert.True(t, common)
assert.Len(t, addrs, 2)
})
t.Run("self-handshake is rejected", func(t *testing.T) {
hm := newHM()
// 10.0.0.1 is in myVpnAddrsTable
addrs, common, ok := hm.validatePeerCert(via, cached(netip.MustParsePrefix("10.0.0.1/24")))
assert.False(t, ok)
assert.False(t, common)
assert.Nil(t, addrs)
})
t.Run("cert with no networks is rejected", func(t *testing.T) {
hm := newHM()
addrs, common, ok := hm.validatePeerCert(via, cached())
assert.False(t, ok)
assert.False(t, common)
assert.Nil(t, addrs)
})
}
func TestHandleIncomingDispatch(t *testing.T) {
l := test.NewLogger()
newHM := func() *HandshakeManager {
hm := NewHandshakeManager(l, newHostMap(l), newTestLighthouse(), &udp.NoopConn{}, defaultHandshakeConfig)
hm.f = &Interface{
handshakeManager: hm,
pki: &PKI{},
l: l,
}
return hm
}
via := ViaSender{
UdpAddr: netip.MustParseAddrPort("198.51.100.7:4242"),
IsRelayed: true, // bypass remote allow list
}
// A packet body of zero length is fine for these tests: dispatch is
// gated on header fields, and we assert that we never reach noise/cert
// processing for any of the malformed shapes here.
pkt := make([]byte, header.Len)
t.Run("unsupported subtype dropped", func(t *testing.T) {
hm := newHM()
h := &header.H{Type: header.Handshake, Subtype: header.MessageSubType(99), MessageCounter: 1}
hm.HandleIncoming(via, pkt, h)
assert.Empty(t, hm.indexes, "no pending handshake should be created")
})
t.Run("stage-1 with non-zero RemoteIndex dropped", func(t *testing.T) {
hm := newHM()
h := &header.H{
Type: header.Handshake,
Subtype: header.HandshakeIXPSK0,
RemoteIndex: 0xdeadbeef,
MessageCounter: 1,
}
hm.HandleIncoming(via, pkt, h)
assert.Empty(t, hm.indexes, "spoofed stage-1 must not create a pending machine")
})
t.Run("continuation with no matching pending index dropped", func(t *testing.T) {
hm := newHM()
h := &header.H{
Type: header.Handshake,
Subtype: header.HandshakeIXPSK0,
RemoteIndex: 0xcafef00d,
MessageCounter: 2,
}
hm.HandleIncoming(via, pkt, h)
assert.Empty(t, hm.indexes, "orphan stage-2 must not create state")
})
}

View File

@@ -124,7 +124,7 @@ func (x NebulaControl_MessageType) String() string {
} }
func (NebulaControl_MessageType) EnumDescriptor() ([]byte, []int) { func (NebulaControl_MessageType) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_2d65afa7693df5ef, []int{8, 0} return fileDescriptor_2d65afa7693df5ef, []int{6, 0}
} }
type NebulaMeta struct { type NebulaMeta struct {
@@ -489,142 +489,6 @@ func (m *NebulaPing) GetTime() uint64 {
return 0 return 0
} }
type NebulaHandshake struct {
Details *NebulaHandshakeDetails `protobuf:"bytes,1,opt,name=Details,proto3" json:"Details,omitempty"`
Hmac []byte `protobuf:"bytes,2,opt,name=Hmac,proto3" json:"Hmac,omitempty"`
}
func (m *NebulaHandshake) Reset() { *m = NebulaHandshake{} }
func (m *NebulaHandshake) String() string { return proto.CompactTextString(m) }
func (*NebulaHandshake) ProtoMessage() {}
func (*NebulaHandshake) Descriptor() ([]byte, []int) {
return fileDescriptor_2d65afa7693df5ef, []int{6}
}
func (m *NebulaHandshake) XXX_Unmarshal(b []byte) error {
return m.Unmarshal(b)
}
func (m *NebulaHandshake) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
if deterministic {
return xxx_messageInfo_NebulaHandshake.Marshal(b, m, deterministic)
} else {
b = b[:cap(b)]
n, err := m.MarshalToSizedBuffer(b)
if err != nil {
return nil, err
}
return b[:n], nil
}
}
func (m *NebulaHandshake) XXX_Merge(src proto.Message) {
xxx_messageInfo_NebulaHandshake.Merge(m, src)
}
func (m *NebulaHandshake) XXX_Size() int {
return m.Size()
}
func (m *NebulaHandshake) XXX_DiscardUnknown() {
xxx_messageInfo_NebulaHandshake.DiscardUnknown(m)
}
var xxx_messageInfo_NebulaHandshake proto.InternalMessageInfo
func (m *NebulaHandshake) GetDetails() *NebulaHandshakeDetails {
if m != nil {
return m.Details
}
return nil
}
func (m *NebulaHandshake) GetHmac() []byte {
if m != nil {
return m.Hmac
}
return nil
}
type NebulaHandshakeDetails struct {
Cert []byte `protobuf:"bytes,1,opt,name=Cert,proto3" json:"Cert,omitempty"`
InitiatorIndex uint32 `protobuf:"varint,2,opt,name=InitiatorIndex,proto3" json:"InitiatorIndex,omitempty"`
ResponderIndex uint32 `protobuf:"varint,3,opt,name=ResponderIndex,proto3" json:"ResponderIndex,omitempty"`
Cookie uint64 `protobuf:"varint,4,opt,name=Cookie,proto3" json:"Cookie,omitempty"`
Time uint64 `protobuf:"varint,5,opt,name=Time,proto3" json:"Time,omitempty"`
CertVersion uint32 `protobuf:"varint,8,opt,name=CertVersion,proto3" json:"CertVersion,omitempty"`
}
func (m *NebulaHandshakeDetails) Reset() { *m = NebulaHandshakeDetails{} }
func (m *NebulaHandshakeDetails) String() string { return proto.CompactTextString(m) }
func (*NebulaHandshakeDetails) ProtoMessage() {}
func (*NebulaHandshakeDetails) Descriptor() ([]byte, []int) {
return fileDescriptor_2d65afa7693df5ef, []int{7}
}
func (m *NebulaHandshakeDetails) XXX_Unmarshal(b []byte) error {
return m.Unmarshal(b)
}
func (m *NebulaHandshakeDetails) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
if deterministic {
return xxx_messageInfo_NebulaHandshakeDetails.Marshal(b, m, deterministic)
} else {
b = b[:cap(b)]
n, err := m.MarshalToSizedBuffer(b)
if err != nil {
return nil, err
}
return b[:n], nil
}
}
func (m *NebulaHandshakeDetails) XXX_Merge(src proto.Message) {
xxx_messageInfo_NebulaHandshakeDetails.Merge(m, src)
}
func (m *NebulaHandshakeDetails) XXX_Size() int {
return m.Size()
}
func (m *NebulaHandshakeDetails) XXX_DiscardUnknown() {
xxx_messageInfo_NebulaHandshakeDetails.DiscardUnknown(m)
}
var xxx_messageInfo_NebulaHandshakeDetails proto.InternalMessageInfo
func (m *NebulaHandshakeDetails) GetCert() []byte {
if m != nil {
return m.Cert
}
return nil
}
func (m *NebulaHandshakeDetails) GetInitiatorIndex() uint32 {
if m != nil {
return m.InitiatorIndex
}
return 0
}
func (m *NebulaHandshakeDetails) GetResponderIndex() uint32 {
if m != nil {
return m.ResponderIndex
}
return 0
}
func (m *NebulaHandshakeDetails) GetCookie() uint64 {
if m != nil {
return m.Cookie
}
return 0
}
func (m *NebulaHandshakeDetails) GetTime() uint64 {
if m != nil {
return m.Time
}
return 0
}
func (m *NebulaHandshakeDetails) GetCertVersion() uint32 {
if m != nil {
return m.CertVersion
}
return 0
}
type NebulaControl struct { type NebulaControl struct {
Type NebulaControl_MessageType `protobuf:"varint,1,opt,name=Type,proto3,enum=nebula.NebulaControl_MessageType" json:"Type,omitempty"` Type NebulaControl_MessageType `protobuf:"varint,1,opt,name=Type,proto3,enum=nebula.NebulaControl_MessageType" json:"Type,omitempty"`
InitiatorRelayIndex uint32 `protobuf:"varint,2,opt,name=InitiatorRelayIndex,proto3" json:"InitiatorRelayIndex,omitempty"` InitiatorRelayIndex uint32 `protobuf:"varint,2,opt,name=InitiatorRelayIndex,proto3" json:"InitiatorRelayIndex,omitempty"`
@@ -639,7 +503,7 @@ func (m *NebulaControl) Reset() { *m = NebulaControl{} }
func (m *NebulaControl) String() string { return proto.CompactTextString(m) } func (m *NebulaControl) String() string { return proto.CompactTextString(m) }
func (*NebulaControl) ProtoMessage() {} func (*NebulaControl) ProtoMessage() {}
func (*NebulaControl) Descriptor() ([]byte, []int) { func (*NebulaControl) Descriptor() ([]byte, []int) {
return fileDescriptor_2d65afa7693df5ef, []int{8} return fileDescriptor_2d65afa7693df5ef, []int{6}
} }
func (m *NebulaControl) XXX_Unmarshal(b []byte) error { func (m *NebulaControl) XXX_Unmarshal(b []byte) error {
return m.Unmarshal(b) return m.Unmarshal(b)
@@ -729,65 +593,55 @@ func init() {
proto.RegisterType((*V4AddrPort)(nil), "nebula.V4AddrPort") proto.RegisterType((*V4AddrPort)(nil), "nebula.V4AddrPort")
proto.RegisterType((*V6AddrPort)(nil), "nebula.V6AddrPort") proto.RegisterType((*V6AddrPort)(nil), "nebula.V6AddrPort")
proto.RegisterType((*NebulaPing)(nil), "nebula.NebulaPing") proto.RegisterType((*NebulaPing)(nil), "nebula.NebulaPing")
proto.RegisterType((*NebulaHandshake)(nil), "nebula.NebulaHandshake")
proto.RegisterType((*NebulaHandshakeDetails)(nil), "nebula.NebulaHandshakeDetails")
proto.RegisterType((*NebulaControl)(nil), "nebula.NebulaControl") proto.RegisterType((*NebulaControl)(nil), "nebula.NebulaControl")
} }
func init() { proto.RegisterFile("nebula.proto", fileDescriptor_2d65afa7693df5ef) } func init() { proto.RegisterFile("nebula.proto", fileDescriptor_2d65afa7693df5ef) }
var fileDescriptor_2d65afa7693df5ef = []byte{ var fileDescriptor_2d65afa7693df5ef = []byte{
// 785 bytes of a gzipped FileDescriptorProto // 665 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x55, 0xcd, 0x6e, 0xeb, 0x44, 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x54, 0xcd, 0x6e, 0xd3, 0x5c,
0x14, 0x8e, 0x1d, 0x27, 0x4e, 0x4f, 0x7e, 0xae, 0x39, 0x15, 0xc1, 0x41, 0x22, 0x0a, 0x5e, 0x54, 0x10, 0x8d, 0x1d, 0x27, 0x69, 0x27, 0x4d, 0x3e, 0x7f, 0x53, 0x51, 0x12, 0x24, 0xac, 0xe0, 0x45,
0x57, 0x2c, 0x72, 0x51, 0x5a, 0xae, 0x58, 0x72, 0x1b, 0x84, 0xd2, 0xaa, 0x3f, 0x61, 0x54, 0x8a, 0x55, 0xb1, 0x48, 0x51, 0x5a, 0xba, 0xa6, 0x2d, 0x42, 0xa9, 0xd4, 0x9f, 0x70, 0x55, 0x8a, 0xc4,
0xc4, 0x06, 0xb9, 0xf6, 0xd0, 0x58, 0x71, 0x3c, 0xa9, 0x3d, 0x41, 0xcd, 0x5b, 0xf0, 0x30, 0x3c, 0xce, 0xb5, 0x2f, 0x8d, 0x55, 0xc7, 0x37, 0xb5, 0x6f, 0x50, 0xf3, 0x16, 0x3c, 0x0c, 0x0f, 0x01,
0x04, 0xec, 0xba, 0x42, 0x2c, 0x51, 0xbb, 0x64, 0xc9, 0x0b, 0xa0, 0x19, 0xff, 0x27, 0x86, 0xbb, 0xbb, 0x2e, 0x59, 0xa2, 0x66, 0xc9, 0x92, 0x17, 0x40, 0xf7, 0xfa, 0xbf, 0x31, 0xb0, 0xbb, 0x33,
0x9b, 0x73, 0xbe, 0xef, 0x3b, 0x73, 0xe6, 0xf3, 0x9c, 0x31, 0x74, 0x02, 0x7a, 0xb7, 0xf1, 0xed, 0xe7, 0x9c, 0x99, 0xc9, 0xc9, 0x8c, 0x61, 0xcd, 0xa7, 0x97, 0x33, 0xcf, 0xea, 0x4f, 0x03, 0xc6,
0xf1, 0x3a, 0x64, 0x9c, 0x61, 0x33, 0x8e, 0xac, 0xbf, 0x55, 0x80, 0x2b, 0xb9, 0xbc, 0xa4, 0xdc, 0x19, 0xd6, 0xa3, 0xc8, 0xfc, 0xa9, 0x02, 0x9c, 0xca, 0xe7, 0x09, 0xe5, 0x16, 0x0e, 0x40, 0x3b,
0xc6, 0x09, 0x68, 0x37, 0xdb, 0x35, 0x35, 0x95, 0x91, 0xf2, 0xba, 0x37, 0x19, 0x8e, 0x13, 0x4d, 0x9f, 0x4f, 0x69, 0x47, 0xe9, 0x29, 0x5b, 0xed, 0x81, 0xd1, 0x8f, 0x35, 0x19, 0xa3, 0x7f, 0x42,
0xce, 0x18, 0x5f, 0xd2, 0x28, 0xb2, 0xef, 0xa9, 0x60, 0x11, 0xc9, 0xc5, 0x63, 0xd0, 0xbf, 0xa6, 0xc3, 0xd0, 0xba, 0xa2, 0x82, 0x45, 0x24, 0x17, 0x77, 0xa0, 0xf1, 0x9a, 0x72, 0xcb, 0xf5, 0xc2,
0xdc, 0xf6, 0xfc, 0xc8, 0x54, 0x47, 0xca, 0xeb, 0xf6, 0x64, 0xb0, 0x2f, 0x4b, 0x08, 0x24, 0x65, 0x8e, 0xda, 0x53, 0xb6, 0x9a, 0x83, 0xee, 0xb2, 0x2c, 0x26, 0x90, 0x84, 0x69, 0xfe, 0x52, 0xa0,
0x5a, 0xff, 0x28, 0xd0, 0x2e, 0x94, 0xc2, 0x16, 0x68, 0x57, 0x2c, 0xa0, 0x46, 0x0d, 0xbb, 0x70, 0x99, 0x2b, 0x85, 0x2b, 0xa0, 0x9d, 0x32, 0x9f, 0xea, 0x15, 0x6c, 0xc1, 0xea, 0x90, 0x85, 0xfc,
0x30, 0x63, 0x11, 0xff, 0x76, 0x43, 0xc3, 0xad, 0xa1, 0x20, 0x42, 0x2f, 0x0b, 0x09, 0x5d, 0xfb, 0xed, 0x8c, 0x06, 0x73, 0x5d, 0x41, 0x84, 0x76, 0x1a, 0x12, 0x3a, 0xf5, 0xe6, 0xba, 0x8a, 0x4f,
0x5b, 0x43, 0xc5, 0x8f, 0xa1, 0x2f, 0x72, 0xdf, 0xad, 0x5d, 0x9b, 0xd3, 0x2b, 0xc6, 0xbd, 0x9f, 0x60, 0x43, 0xe4, 0xde, 0x4d, 0x1d, 0x8b, 0xd3, 0x53, 0xc6, 0xdd, 0x8f, 0xae, 0x6d, 0x71, 0x97,
0x3c, 0xc7, 0xe6, 0x1e, 0x0b, 0x8c, 0x3a, 0x0e, 0xe0, 0x43, 0x81, 0x5d, 0xb2, 0x9f, 0xa9, 0x5b, 0xf9, 0x7a, 0x15, 0xbb, 0xf0, 0x48, 0x60, 0x27, 0xec, 0x13, 0x75, 0x0a, 0x90, 0x96, 0x40, 0xa3,
0x82, 0xb4, 0x14, 0x9a, 0x6f, 0x02, 0x67, 0x51, 0x82, 0x1a, 0xd8, 0x03, 0x10, 0xd0, 0xf7, 0x0b, 0x99, 0x6f, 0x8f, 0x0b, 0x50, 0x0d, 0xdb, 0x00, 0x02, 0x7a, 0x3f, 0x66, 0xd6, 0xc4, 0xd5, 0xeb,
0x66, 0xaf, 0x3c, 0xa3, 0x89, 0x87, 0xf0, 0x2a, 0x8f, 0xe3, 0x6d, 0x75, 0xd1, 0xd9, 0xdc, 0xe6, 0xb8, 0x0e, 0xff, 0x65, 0x71, 0xd4, 0xb6, 0x21, 0x26, 0x1b, 0x59, 0x7c, 0x7c, 0x38, 0xa6, 0xf6,
0x8b, 0xe9, 0x82, 0x3a, 0x4b, 0xa3, 0x25, 0x3a, 0xcb, 0xc2, 0x98, 0x72, 0x80, 0x9f, 0xc0, 0xa0, 0xb5, 0xbe, 0x22, 0x26, 0x4b, 0xc3, 0x88, 0xb2, 0x8a, 0x4f, 0xa1, 0x5b, 0x3e, 0xd9, 0xbe, 0x7d,
0xba, 0xb3, 0x77, 0xce, 0xd2, 0x00, 0xeb, 0x77, 0x15, 0x3e, 0xd8, 0x33, 0x05, 0x2d, 0x80, 0x6b, 0xad, 0x83, 0xf9, 0x4d, 0x85, 0xff, 0x97, 0x4c, 0x41, 0x13, 0xe0, 0xcc, 0x73, 0x2e, 0xa6, 0xfe,
0xdf, 0xbd, 0x5d, 0x07, 0xef, 0x5c, 0x37, 0x94, 0xd6, 0x77, 0x4f, 0x55, 0x53, 0x21, 0x85, 0x2c, 0xbe, 0xe3, 0x04, 0xd2, 0xfa, 0xd6, 0x81, 0xda, 0x51, 0x48, 0x2e, 0x8b, 0x9b, 0xd0, 0x48, 0x08,
0x1e, 0x81, 0x9e, 0x12, 0x9a, 0xd2, 0xe4, 0x4e, 0x6a, 0xb2, 0xc8, 0x91, 0x14, 0xc4, 0x31, 0x18, 0x75, 0x69, 0xf2, 0x5a, 0x62, 0xb2, 0xc8, 0x91, 0x04, 0xc4, 0x3e, 0xe8, 0x67, 0x9e, 0x43, 0xa8,
0xd7, 0xbe, 0x4b, 0xa8, 0x6f, 0x6f, 0x93, 0x54, 0x64, 0x36, 0x46, 0xf5, 0xa4, 0xe2, 0x1e, 0x86, 0x67, 0xcd, 0xe3, 0x54, 0xd8, 0xa9, 0xf5, 0xaa, 0x71, 0xc5, 0x25, 0x0c, 0x07, 0xd0, 0x2a, 0x92,
0x13, 0xe8, 0x96, 0xc9, 0xfa, 0xa8, 0xbe, 0x57, 0xbd, 0x4c, 0xc1, 0x13, 0x68, 0xdf, 0x9e, 0x88, 0x1b, 0xbd, 0xea, 0x52, 0xf5, 0x22, 0x05, 0x77, 0xa1, 0x79, 0xb1, 0x2b, 0x9e, 0x23, 0x16, 0x70,
0xe5, 0x9c, 0x85, 0x5c, 0x7c, 0x74, 0xa1, 0xc0, 0x54, 0x91, 0x43, 0xa4, 0x48, 0x93, 0xaa, 0xb7, 0xf1, 0xa7, 0x0b, 0x05, 0x26, 0x8a, 0x0c, 0x22, 0x79, 0x9a, 0x54, 0xed, 0x65, 0x2a, 0xed, 0x81,
0xb9, 0x4a, 0xdb, 0x51, 0xbd, 0x2d, 0xa8, 0x72, 0x1a, 0x9a, 0xa0, 0x3b, 0x6c, 0x13, 0x70, 0x1a, 0x6a, 0x2f, 0xa7, 0xca, 0x68, 0xd8, 0x81, 0x86, 0xcd, 0x66, 0x3e, 0xa7, 0x41, 0xa7, 0x2a, 0x8c,
0x9a, 0x75, 0x61, 0x0c, 0x49, 0x43, 0xeb, 0x08, 0x34, 0x79, 0xe2, 0x1e, 0xa8, 0x33, 0x4f, 0xba, 0x21, 0x49, 0x68, 0x6e, 0x82, 0x26, 0x7f, 0x71, 0x1b, 0xd4, 0xa1, 0x2b, 0x5d, 0xd3, 0x88, 0x3a,
0xa6, 0x11, 0x75, 0xe6, 0x89, 0xf8, 0x82, 0xc9, 0x9b, 0xa8, 0x11, 0xf5, 0x82, 0x59, 0x27, 0x00, 0x74, 0x45, 0x7c, 0xcc, 0xe4, 0x26, 0x6a, 0x44, 0x3d, 0x66, 0xe6, 0x2e, 0x40, 0x36, 0x06, 0x62,
0x79, 0x1b, 0x88, 0xb1, 0x2a, 0x76, 0x99, 0xc4, 0x15, 0x10, 0x34, 0x81, 0x49, 0x4d, 0x97, 0xc8, 0xa4, 0x8a, 0x5c, 0x26, 0x51, 0x05, 0x04, 0x4d, 0x60, 0x52, 0xd3, 0x22, 0xf2, 0x6d, 0xbe, 0x02,
0xb5, 0xf5, 0x15, 0x40, 0xde, 0xc6, 0xfb, 0xf6, 0xc8, 0x2a, 0xd4, 0x0b, 0x15, 0x1e, 0xd3, 0xc1, 0xc8, 0xc6, 0xf8, 0x57, 0x8f, 0xb4, 0x42, 0x35, 0x57, 0xe1, 0x36, 0x39, 0xac, 0x91, 0xeb, 0x5f,
0x9a, 0x7b, 0xc1, 0xfd, 0xff, 0x0f, 0x96, 0x60, 0x54, 0x0c, 0x16, 0x82, 0x76, 0xe3, 0xad, 0x68, 0xfd, 0xfd, 0xb0, 0x04, 0xa3, 0xe4, 0xb0, 0x10, 0xb4, 0x73, 0x77, 0x42, 0xe3, 0x3e, 0xf2, 0x6d,
0xb2, 0x8f, 0x5c, 0x5b, 0xd6, 0xde, 0xd8, 0x08, 0xb1, 0x51, 0xc3, 0x03, 0x68, 0xc4, 0x97, 0x50, 0x9a, 0x4b, 0x67, 0x23, 0xc4, 0x7a, 0x05, 0x57, 0xa1, 0x16, 0x2d, 0xa1, 0x62, 0x7e, 0xa9, 0x42,
0xb1, 0x7e, 0x84, 0x57, 0x71, 0xdd, 0x99, 0x1d, 0xb8, 0xd1, 0xc2, 0x5e, 0x52, 0xfc, 0x32, 0x9f, 0x2b, 0x2a, 0x7c, 0xc8, 0x7c, 0x1e, 0x30, 0x0f, 0x5f, 0x16, 0xba, 0x3f, 0x2b, 0x76, 0x8f, 0x49,
0x51, 0x45, 0x5e, 0x9f, 0x9d, 0x0e, 0x32, 0xe6, 0xee, 0xa0, 0x8a, 0x26, 0x66, 0x2b, 0xdb, 0x91, 0x25, 0x03, 0xbc, 0x80, 0xf5, 0x23, 0xdf, 0xe5, 0xae, 0xc5, 0x59, 0x20, 0x57, 0xe0, 0xc8, 0x77,
0x4d, 0x74, 0x88, 0x5c, 0x5b, 0x7f, 0x28, 0xd0, 0xaf, 0xd6, 0x09, 0xfa, 0x94, 0x86, 0x5c, 0xee, 0xe8, 0x6d, 0xec, 0x53, 0x19, 0x24, 0x14, 0x84, 0x86, 0x53, 0xe6, 0x3b, 0x34, 0xaf, 0x88, 0x7c,
0xd2, 0x21, 0x72, 0x8d, 0x47, 0xd0, 0x3b, 0x0b, 0x3c, 0xee, 0xd9, 0x9c, 0x85, 0x67, 0x81, 0x4b, 0x29, 0x83, 0xf0, 0x39, 0xb4, 0x93, 0xa5, 0x3c, 0x67, 0xf2, 0xaf, 0xd1, 0xd2, 0x03, 0x78, 0x80,
0x1f, 0x13, 0xa7, 0x77, 0xb2, 0x82, 0x47, 0x68, 0xb4, 0x66, 0x81, 0x4b, 0x13, 0x5e, 0xec, 0xe7, 0xe4, 0x97, 0xfb, 0x4d, 0xc0, 0x26, 0x92, 0x5d, 0x4b, 0xd9, 0x4b, 0x18, 0xf6, 0xa1, 0x99, 0x2f,
0x4e, 0x16, 0xfb, 0xd0, 0x9c, 0x32, 0xb6, 0xf4, 0xa8, 0xa9, 0x49, 0x67, 0x92, 0x28, 0xf3, 0xab, 0x5c, 0x76, 0x38, 0x79, 0x42, 0x7a, 0x0c, 0x69, 0xf1, 0x46, 0x89, 0xa2, 0x48, 0x31, 0x87, 0x7f,
0x91, 0xfb, 0x85, 0x23, 0x68, 0x8b, 0x1e, 0x6e, 0x69, 0x18, 0x79, 0x2c, 0x30, 0x5b, 0xb2, 0x60, 0xfa, 0x8e, 0x6d, 0x00, 0x1e, 0x06, 0xd4, 0xe2, 0x54, 0xf2, 0x09, 0xbd, 0x99, 0xd1, 0x90, 0xeb,
0x31, 0x75, 0xae, 0xb5, 0x9a, 0x86, 0x7e, 0xae, 0xb5, 0x74, 0xa3, 0x65, 0xfd, 0x5a, 0x87, 0x6e, 0x0a, 0x3e, 0x86, 0xf5, 0x42, 0x5e, 0x58, 0x12, 0x52, 0x5d, 0x3d, 0xd8, 0xf9, 0x7a, 0x6f, 0x28,
0x7c, 0xb0, 0x29, 0x0b, 0x78, 0xc8, 0x7c, 0xfc, 0xa2, 0xf4, 0xdd, 0x3e, 0x2d, 0xbb, 0x96, 0x90, 0x77, 0xf7, 0x86, 0xf2, 0xe3, 0xde, 0x50, 0x3e, 0x2f, 0x8c, 0xca, 0xdd, 0xc2, 0xa8, 0x7c, 0x5f,
0x2a, 0x3e, 0xdd, 0xe7, 0x70, 0x98, 0x1d, 0x4e, 0x0e, 0x4f, 0xf1, 0xdc, 0x55, 0x90, 0x50, 0x64, 0x18, 0x95, 0x0f, 0xdd, 0x2b, 0x97, 0x8f, 0x67, 0x97, 0x7d, 0x9b, 0x4d, 0xb6, 0x43, 0xcf, 0xb2,
0xc7, 0x2c, 0x28, 0x62, 0x07, 0xaa, 0x20, 0xfc, 0x0c, 0x7a, 0xe9, 0x38, 0xdf, 0x30, 0x79, 0xa9, 0xaf, 0xc7, 0x37, 0xdb, 0xd1, 0x48, 0x97, 0x75, 0xf9, 0x39, 0xdf, 0xf9, 0x1d, 0x00, 0x00, 0xff,
0xb5, 0xec, 0xe9, 0xd8, 0x41, 0x8a, 0xcf, 0xc2, 0x37, 0x21, 0x5b, 0x49, 0x76, 0x23, 0x63, 0xef, 0xff, 0x51, 0x0a, 0xe3, 0xd7, 0xde, 0x05, 0x00, 0x00,
0x61, 0x38, 0x86, 0x76, 0xb1, 0x70, 0xd5, 0x93, 0x53, 0x24, 0x64, 0xcf, 0x48, 0x56, 0x5c, 0xaf,
0x50, 0x94, 0x29, 0xd6, 0xec, 0xbf, 0xfe, 0x00, 0x7d, 0xc0, 0x69, 0x48, 0x6d, 0x4e, 0x25, 0x9f,
0xd0, 0x87, 0x0d, 0x8d, 0xb8, 0xa1, 0xe0, 0x47, 0x70, 0x58, 0xca, 0x0b, 0x4b, 0x22, 0x6a, 0xa8,
0xa7, 0xc7, 0xbf, 0x3d, 0x0f, 0x95, 0xa7, 0xe7, 0xa1, 0xf2, 0xd7, 0xf3, 0x50, 0xf9, 0xe5, 0x65,
0x58, 0x7b, 0x7a, 0x19, 0xd6, 0xfe, 0x7c, 0x19, 0xd6, 0x7e, 0x18, 0xdc, 0x7b, 0x7c, 0xb1, 0xb9,
0x1b, 0x3b, 0x6c, 0xf5, 0x26, 0xf2, 0x6d, 0x67, 0xb9, 0x78, 0x78, 0x13, 0xb7, 0x74, 0xd7, 0x94,
0x3f, 0xc2, 0xe3, 0x7f, 0x03, 0x00, 0x00, 0xff, 0xff, 0xea, 0x6f, 0xbc, 0x50, 0x18, 0x07, 0x00,
0x00,
} }
func (m *NebulaMeta) Marshal() (dAtA []byte, err error) { func (m *NebulaMeta) Marshal() (dAtA []byte, err error) {
@@ -1072,103 +926,6 @@ func (m *NebulaPing) MarshalToSizedBuffer(dAtA []byte) (int, error) {
return len(dAtA) - i, nil return len(dAtA) - i, nil
} }
func (m *NebulaHandshake) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalToSizedBuffer(dAtA[:size])
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *NebulaHandshake) MarshalTo(dAtA []byte) (int, error) {
size := m.Size()
return m.MarshalToSizedBuffer(dAtA[:size])
}
func (m *NebulaHandshake) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i := len(dAtA)
_ = i
var l int
_ = l
if len(m.Hmac) > 0 {
i -= len(m.Hmac)
copy(dAtA[i:], m.Hmac)
i = encodeVarintNebula(dAtA, i, uint64(len(m.Hmac)))
i--
dAtA[i] = 0x12
}
if m.Details != nil {
{
size, err := m.Details.MarshalToSizedBuffer(dAtA[:i])
if err != nil {
return 0, err
}
i -= size
i = encodeVarintNebula(dAtA, i, uint64(size))
}
i--
dAtA[i] = 0xa
}
return len(dAtA) - i, nil
}
func (m *NebulaHandshakeDetails) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalToSizedBuffer(dAtA[:size])
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *NebulaHandshakeDetails) MarshalTo(dAtA []byte) (int, error) {
size := m.Size()
return m.MarshalToSizedBuffer(dAtA[:size])
}
func (m *NebulaHandshakeDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i := len(dAtA)
_ = i
var l int
_ = l
if m.CertVersion != 0 {
i = encodeVarintNebula(dAtA, i, uint64(m.CertVersion))
i--
dAtA[i] = 0x40
}
if m.Time != 0 {
i = encodeVarintNebula(dAtA, i, uint64(m.Time))
i--
dAtA[i] = 0x28
}
if m.Cookie != 0 {
i = encodeVarintNebula(dAtA, i, uint64(m.Cookie))
i--
dAtA[i] = 0x20
}
if m.ResponderIndex != 0 {
i = encodeVarintNebula(dAtA, i, uint64(m.ResponderIndex))
i--
dAtA[i] = 0x18
}
if m.InitiatorIndex != 0 {
i = encodeVarintNebula(dAtA, i, uint64(m.InitiatorIndex))
i--
dAtA[i] = 0x10
}
if len(m.Cert) > 0 {
i -= len(m.Cert)
copy(dAtA[i:], m.Cert)
i = encodeVarintNebula(dAtA, i, uint64(len(m.Cert)))
i--
dAtA[i] = 0xa
}
return len(dAtA) - i, nil
}
func (m *NebulaControl) Marshal() (dAtA []byte, err error) { func (m *NebulaControl) Marshal() (dAtA []byte, err error) {
size := m.Size() size := m.Size()
dAtA = make([]byte, size) dAtA = make([]byte, size)
@@ -1375,51 +1132,6 @@ func (m *NebulaPing) Size() (n int) {
return n return n
} }
func (m *NebulaHandshake) Size() (n int) {
if m == nil {
return 0
}
var l int
_ = l
if m.Details != nil {
l = m.Details.Size()
n += 1 + l + sovNebula(uint64(l))
}
l = len(m.Hmac)
if l > 0 {
n += 1 + l + sovNebula(uint64(l))
}
return n
}
func (m *NebulaHandshakeDetails) Size() (n int) {
if m == nil {
return 0
}
var l int
_ = l
l = len(m.Cert)
if l > 0 {
n += 1 + l + sovNebula(uint64(l))
}
if m.InitiatorIndex != 0 {
n += 1 + sovNebula(uint64(m.InitiatorIndex))
}
if m.ResponderIndex != 0 {
n += 1 + sovNebula(uint64(m.ResponderIndex))
}
if m.Cookie != 0 {
n += 1 + sovNebula(uint64(m.Cookie))
}
if m.Time != 0 {
n += 1 + sovNebula(uint64(m.Time))
}
if m.CertVersion != 0 {
n += 1 + sovNebula(uint64(m.CertVersion))
}
return n
}
func (m *NebulaControl) Size() (n int) { func (m *NebulaControl) Size() (n int) {
if m == nil { if m == nil {
return 0 return 0
@@ -2236,305 +1948,6 @@ func (m *NebulaPing) Unmarshal(dAtA []byte) error {
} }
return nil return nil
} }
func (m *NebulaHandshake) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowNebula
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: NebulaHandshake: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: NebulaHandshake: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Details", wireType)
}
var msglen int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowNebula
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
msglen |= int(b&0x7F) << shift
if b < 0x80 {
break
}
}
if msglen < 0 {
return ErrInvalidLengthNebula
}
postIndex := iNdEx + msglen
if postIndex < 0 {
return ErrInvalidLengthNebula
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
if m.Details == nil {
m.Details = &NebulaHandshakeDetails{}
}
if err := m.Details.Unmarshal(dAtA[iNdEx:postIndex]); err != nil {
return err
}
iNdEx = postIndex
case 2:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Hmac", wireType)
}
var byteLen int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowNebula
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
byteLen |= int(b&0x7F) << shift
if b < 0x80 {
break
}
}
if byteLen < 0 {
return ErrInvalidLengthNebula
}
postIndex := iNdEx + byteLen
if postIndex < 0 {
return ErrInvalidLengthNebula
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Hmac = append(m.Hmac[:0], dAtA[iNdEx:postIndex]...)
if m.Hmac == nil {
m.Hmac = []byte{}
}
iNdEx = postIndex
default:
iNdEx = preIndex
skippy, err := skipNebula(dAtA[iNdEx:])
if err != nil {
return err
}
if (skippy < 0) || (iNdEx+skippy) < 0 {
return ErrInvalidLengthNebula
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func (m *NebulaHandshakeDetails) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowNebula
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: NebulaHandshakeDetails: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: NebulaHandshakeDetails: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Cert", wireType)
}
var byteLen int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowNebula
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
byteLen |= int(b&0x7F) << shift
if b < 0x80 {
break
}
}
if byteLen < 0 {
return ErrInvalidLengthNebula
}
postIndex := iNdEx + byteLen
if postIndex < 0 {
return ErrInvalidLengthNebula
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Cert = append(m.Cert[:0], dAtA[iNdEx:postIndex]...)
if m.Cert == nil {
m.Cert = []byte{}
}
iNdEx = postIndex
case 2:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field InitiatorIndex", wireType)
}
m.InitiatorIndex = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowNebula
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
m.InitiatorIndex |= uint32(b&0x7F) << shift
if b < 0x80 {
break
}
}
case 3:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field ResponderIndex", wireType)
}
m.ResponderIndex = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowNebula
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
m.ResponderIndex |= uint32(b&0x7F) << shift
if b < 0x80 {
break
}
}
case 4:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field Cookie", wireType)
}
m.Cookie = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowNebula
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
m.Cookie |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
case 5:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field Time", wireType)
}
m.Time = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowNebula
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
m.Time |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
case 8:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field CertVersion", wireType)
}
m.CertVersion = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowNebula
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
m.CertVersion |= uint32(b&0x7F) << shift
if b < 0x80 {
break
}
}
default:
iNdEx = preIndex
skippy, err := skipNebula(dAtA[iNdEx:])
if err != nil {
return err
}
if (skippy < 0) || (iNdEx+skippy) < 0 {
return ErrInvalidLengthNebula
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func (m *NebulaControl) Unmarshal(dAtA []byte) error { func (m *NebulaControl) Unmarshal(dAtA []byte) error {
l := len(dAtA) l := len(dAtA)
iNdEx := 0 iNdEx := 0

View File

@@ -60,21 +60,9 @@ message NebulaPing {
uint64 Time = 2; uint64 Time = 2;
} }
message NebulaHandshake { // NebulaHandshake / NebulaHandshakeDetails moved to
NebulaHandshakeDetails Details = 1; // handshake/handshake.proto. The handshake package speaks that wire format
bytes Hmac = 2; // directly via a hand-written encoder/decoder.
}
message NebulaHandshakeDetails {
bytes Cert = 1;
uint32 InitiatorIndex = 2;
uint32 ResponderIndex = 3;
uint64 Cookie = 4;
uint64 Time = 5;
uint32 CertVersion = 8;
// reserved for WIP multiport
reserved 6, 7;
}
message NebulaControl { message NebulaControl {
enum MessageType { enum MessageType {

119
pki.go
View File

@@ -15,9 +15,12 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/flynn/noise"
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/handshake"
"github.com/slackhq/nebula/noiseutil"
"github.com/slackhq/nebula/util" "github.com/slackhq/nebula/util"
) )
@@ -29,10 +32,10 @@ type PKI struct {
type CertState struct { type CertState struct {
v1Cert cert.Certificate v1Cert cert.Certificate
v1HandshakeBytes []byte v1Credential *handshake.Credential
v2Cert cert.Certificate v2Cert cert.Certificate
v2HandshakeBytes []byte v2Credential *handshake.Credential
initiatingVersion cert.Version initiatingVersion cert.Version
privateKey []byte privateKey []byte
@@ -92,13 +95,35 @@ func (p *PKI) reload(c *config.C, initial bool) error {
} }
func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
newState, err := newCertStateFromConfig(c) var cipher string
var currentState *CertState
if initial {
cipher = c.GetString("cipher", "aes")
//TODO: this sucks and we should make it not a global
switch cipher {
case "aes":
noiseEndianness = binary.BigEndian
case "chachapoly":
noiseEndianness = binary.LittleEndian
default:
return util.NewContextualError(
"unknown cipher",
m{"cipher": cipher},
nil,
)
}
} else {
// Cipher cant be hot swapped so just leave it at what it was before
currentState = p.cs.Load()
cipher = currentState.cipher
}
newState, err := newCertStateFromConfig(c, cipher)
if err != nil { if err != nil {
return util.NewContextualError("Could not load client cert", nil, err) return util.NewContextualError("Could not load client cert", nil, err)
} }
if !initial { if currentState != nil {
currentState := p.cs.Load()
if newState.v1Cert != nil { if newState.v1Cert != nil {
if currentState.v1Cert == nil { if currentState.v1Cert == nil {
//adding certs is fine, actually. Networks-in-common confirmed in newCertState(). //adding certs is fine, actually. Networks-in-common confirmed in newCertState().
@@ -158,25 +183,6 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
) )
} }
} }
// Cipher cant be hot swapped so just leave it at what it was before
newState.cipher = currentState.cipher
} else {
newState.cipher = c.GetString("cipher", "aes")
//TODO: this sucks and we should make it not a global
switch newState.cipher {
case "aes":
noiseEndianness = binary.BigEndian
case "chachapoly":
noiseEndianness = binary.LittleEndian
default:
return util.NewContextualError(
"unknown cipher",
m{"cipher": newState.cipher},
nil,
)
}
} }
p.cs.Store(newState) p.cs.Store(newState)
@@ -208,6 +214,20 @@ func (cs *CertState) GetDefaultCertificate() cert.Certificate {
return c return c
} }
// DefaultVersion returns the preferred cert version for initiating handshakes.
func (cs *CertState) DefaultVersion() cert.Version { return cs.initiatingVersion }
// GetCredential returns the pre-computed handshake credential for the given version, or nil.
func (cs *CertState) GetCredential(v cert.Version) *handshake.Credential {
switch v {
case cert.Version1:
return cs.v1Credential
case cert.Version2:
return cs.v2Credential
}
return nil
}
func (cs *CertState) getCertificate(v cert.Version) cert.Certificate { func (cs *CertState) getCertificate(v cert.Version) cert.Certificate {
switch v { switch v {
case cert.Version1: case cert.Version1:
@@ -219,17 +239,25 @@ func (cs *CertState) getCertificate(v cert.Version) cert.Certificate {
return nil return nil
} }
// getHandshakeBytes returns the cached bytes to be used in a handshake message for the requested version. func newCipherSuite(curve cert.Curve, pkcs11backed bool, cipher string) (noise.CipherSuite, error) {
// Callers must check if the return []byte is nil. var dhFunc noise.DHFunc
func (cs *CertState) getHandshakeBytes(v cert.Version) []byte { switch curve {
switch v { case cert.Curve_CURVE25519:
case cert.Version1: dhFunc = noise.DH25519
return cs.v1HandshakeBytes case cert.Curve_P256:
case cert.Version2: if pkcs11backed {
return cs.v2HandshakeBytes dhFunc = noiseutil.DHP256PKCS11
default: } else {
return nil dhFunc = noiseutil.DHP256
} }
default:
return nil, fmt.Errorf("unsupported curve: %s", curve)
}
if cipher == "chachapoly" {
return noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256), nil
}
return noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256), nil
} }
func (cs *CertState) String() string { func (cs *CertState) String() string {
@@ -261,7 +289,7 @@ func (cs *CertState) MarshalJSON() ([]byte, error) {
return json.Marshal(msg) return json.Marshal(msg)
} }
func newCertStateFromConfig(c *config.C) (*CertState, error) { func newCertStateFromConfig(c *config.C, cipher string) (*CertState, error) {
var err error var err error
privPathOrPEM := c.GetString("pki.key", "") privPathOrPEM := c.GetString("pki.key", "")
@@ -345,13 +373,14 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
return nil, fmt.Errorf("unknown pki.initiating_version: %v", rawInitiatingVersion) return nil, fmt.Errorf("unknown pki.initiating_version: %v", rawInitiatingVersion)
} }
return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey) return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey, cipher)
} }
func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) { func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte, cipher string) (*CertState, error) {
cs := CertState{ cs := CertState{
privateKey: privateKey, privateKey: privateKey,
pkcs11Backed: pkcs11backed, pkcs11Backed: pkcs11backed,
cipher: cipher,
myVpnNetworksTable: new(bart.Lite), myVpnNetworksTable: new(bart.Lite),
myVpnAddrsTable: new(bart.Lite), myVpnAddrsTable: new(bart.Lite),
myVpnBroadcastAddrsTable: new(bart.Lite), myVpnBroadcastAddrsTable: new(bart.Lite),
@@ -384,10 +413,14 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
v1hs, err := v1.MarshalForHandshakes() v1hs, err := v1.MarshalForHandshakes()
if err != nil { if err != nil {
return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err) return nil, fmt.Errorf("error marshalling v1 certificate for handshake: %w", err)
}
ncs, err := newCipherSuite(v1.Curve(), pkcs11backed, cipher)
if err != nil {
return nil, err
} }
cs.v1Cert = v1 cs.v1Cert = v1
cs.v1HandshakeBytes = v1hs cs.v1Credential = handshake.NewCredential(v1, v1hs, privateKey, ncs)
if cs.initiatingVersion == 0 { if cs.initiatingVersion == 0 {
cs.initiatingVersion = cert.Version1 cs.initiatingVersion = cert.Version1
@@ -405,10 +438,14 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
v2hs, err := v2.MarshalForHandshakes() v2hs, err := v2.MarshalForHandshakes()
if err != nil { if err != nil {
return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err) return nil, fmt.Errorf("error marshalling v2 certificate for handshake: %w", err)
}
ncs, err := newCipherSuite(v2.Curve(), pkcs11backed, cipher)
if err != nil {
return nil, err
} }
cs.v2Cert = v2 cs.v2Cert = v2
cs.v2HandshakeBytes = v2hs cs.v2Credential = handshake.NewCredential(v2, v2hs, privateKey, ncs)
if cs.initiatingVersion == 0 { if cs.initiatingVersion == 0 {
cs.initiatingVersion = cert.Version2 cs.initiatingVersion = cert.Version2