mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-15 20:37:36 +02:00
Handshake state machine (#1656)
This commit is contained in:
@@ -163,3 +163,55 @@ func P256Keypair() ([]byte, []byte) {
|
||||
pubkey := privkey.PublicKey()
|
||||
return pubkey.Bytes(), privkey.Bytes()
|
||||
}
|
||||
|
||||
// DummyCert is a minimal cert.Certificate implementation for testing error paths.
|
||||
type DummyCert struct {
|
||||
Version_ cert.Version
|
||||
Curve_ cert.Curve
|
||||
Groups_ []string
|
||||
IsCA_ bool
|
||||
Issuer_ string
|
||||
Name_ string
|
||||
Networks_ []netip.Prefix
|
||||
NotAfter_ time.Time
|
||||
NotBefore_ time.Time
|
||||
PublicKey_ []byte
|
||||
Signature_ []byte
|
||||
UnsafeNetworks_ []netip.Prefix
|
||||
}
|
||||
|
||||
func (d *DummyCert) Version() cert.Version { return d.Version_ }
|
||||
func (d *DummyCert) Curve() cert.Curve { return d.Curve_ }
|
||||
func (d *DummyCert) Groups() []string { return d.Groups_ }
|
||||
func (d *DummyCert) IsCA() bool { return d.IsCA_ }
|
||||
func (d *DummyCert) Issuer() string { return d.Issuer_ }
|
||||
func (d *DummyCert) Name() string { return d.Name_ }
|
||||
func (d *DummyCert) Networks() []netip.Prefix { return d.Networks_ }
|
||||
func (d *DummyCert) NotAfter() time.Time { return d.NotAfter_ }
|
||||
func (d *DummyCert) NotBefore() time.Time { return d.NotBefore_ }
|
||||
func (d *DummyCert) PublicKey() []byte { return d.PublicKey_ }
|
||||
func (d *DummyCert) Signature() []byte { return d.Signature_ }
|
||||
func (d *DummyCert) UnsafeNetworks() []netip.Prefix { return d.UnsafeNetworks_ }
|
||||
func (d *DummyCert) Fingerprint() (string, error) { return "", nil }
|
||||
func (d *DummyCert) CheckSignature(key []byte) bool { return false }
|
||||
func (d *DummyCert) MarshalForHandshakes() ([]byte, error) { return nil, nil }
|
||||
func (d *DummyCert) MarshalPEM() ([]byte, error) { return nil, nil }
|
||||
func (d *DummyCert) MarshalJSON() ([]byte, error) { return nil, nil }
|
||||
func (d *DummyCert) Marshal() ([]byte, error) { return nil, nil }
|
||||
func (d *DummyCert) String() string { return "dummy" }
|
||||
func (d *DummyCert) Copy() cert.Certificate { return d }
|
||||
func (d *DummyCert) VerifyPrivateKey(c cert.Curve, k []byte) error { return nil }
|
||||
func (d *DummyCert) Expired(time.Time) bool { return false }
|
||||
func (d *DummyCert) MarshalPublicKeyPEM() []byte { return nil }
|
||||
func (d *DummyCert) PublicKeyPEM() []byte { return nil }
|
||||
|
||||
// NewTestCAPool creates a CAPool from the given CA certificates, panicking on error.
|
||||
func NewTestCAPool(cas ...cert.Certificate) *cert.CAPool {
|
||||
pool := cert.NewCAPool()
|
||||
for _, ca := range cas {
|
||||
if err := pool.AddCA(ca); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
return pool
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/overlay/overlaytest"
|
||||
@@ -47,7 +46,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
||||
initiatingVersion: cert.Version1,
|
||||
privateKey: []byte{},
|
||||
v1Cert: &dummyCert{version: cert.Version1},
|
||||
v1HandshakeBytes: []byte{},
|
||||
v1Credential: nil,
|
||||
}
|
||||
|
||||
lh := newTestLighthouse()
|
||||
@@ -80,7 +79,6 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
||||
}
|
||||
hostinfo.ConnectionState = &ConnectionState{
|
||||
myCert: &dummyCert{version: cert.Version1},
|
||||
H: &noise.HandshakeState{},
|
||||
}
|
||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||
|
||||
@@ -130,7 +128,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||
initiatingVersion: cert.Version1,
|
||||
privateKey: []byte{},
|
||||
v1Cert: &dummyCert{version: cert.Version1},
|
||||
v1HandshakeBytes: []byte{},
|
||||
v1Credential: nil,
|
||||
}
|
||||
|
||||
lh := newTestLighthouse()
|
||||
@@ -163,7 +161,6 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
||||
}
|
||||
hostinfo.ConnectionState = &ConnectionState{
|
||||
myCert: &dummyCert{version: cert.Version1},
|
||||
H: &noise.HandshakeState{},
|
||||
}
|
||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||
|
||||
@@ -215,7 +212,7 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
|
||||
initiatingVersion: cert.Version1,
|
||||
privateKey: []byte{},
|
||||
v1Cert: &dummyCert{version: cert.Version1},
|
||||
v1HandshakeBytes: []byte{},
|
||||
v1Credential: nil,
|
||||
}
|
||||
|
||||
lh := newTestLighthouse()
|
||||
@@ -249,7 +246,6 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
|
||||
}
|
||||
hostinfo.ConnectionState = &ConnectionState{
|
||||
myCert: &dummyCert{version: cert.Version1},
|
||||
H: &noise.HandshakeState{},
|
||||
}
|
||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||
|
||||
@@ -342,7 +338,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
||||
cs := &CertState{
|
||||
privateKey: []byte{},
|
||||
v1Cert: &dummyCert{},
|
||||
v1HandshakeBytes: []byte{},
|
||||
v1Credential: nil,
|
||||
}
|
||||
|
||||
lh := newTestLighthouse()
|
||||
@@ -372,7 +368,6 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
||||
ConnectionState: &ConnectionState{
|
||||
myCert: &dummyCert{},
|
||||
peerCert: cachedPeerCert,
|
||||
H: &noise.HandshakeState{},
|
||||
},
|
||||
}
|
||||
nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/noiseutil"
|
||||
"github.com/slackhq/nebula/handshake"
|
||||
)
|
||||
|
||||
const ReplayWindow = 1024
|
||||
@@ -17,7 +14,6 @@ const ReplayWindow = 1024
|
||||
type ConnectionState struct {
|
||||
eKey *NebulaCipherState
|
||||
dKey *NebulaCipherState
|
||||
H *noise.HandshakeState
|
||||
myCert cert.Certificate
|
||||
peerCert *cert.CachedCertificate
|
||||
initiator bool
|
||||
@@ -26,55 +22,24 @@ type ConnectionState struct {
|
||||
writeLock sync.Mutex
|
||||
}
|
||||
|
||||
func NewConnectionState(cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
|
||||
var dhFunc noise.DHFunc
|
||||
switch crt.Curve() {
|
||||
case cert.Curve_CURVE25519:
|
||||
dhFunc = noise.DH25519
|
||||
case cert.Curve_P256:
|
||||
if cs.pkcs11Backed {
|
||||
dhFunc = noiseutil.DHP256PKCS11
|
||||
} else {
|
||||
dhFunc = noiseutil.DHP256
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid curve: %s", crt.Curve())
|
||||
}
|
||||
|
||||
var ncs noise.CipherSuite
|
||||
if cs.cipher == "chachapoly" {
|
||||
ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||
} else {
|
||||
ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
|
||||
}
|
||||
|
||||
static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}
|
||||
hs, err := noise.NewHandshakeState(noise.Config{
|
||||
CipherSuite: ncs,
|
||||
Random: rand.Reader,
|
||||
Pattern: pattern,
|
||||
Initiator: initiator,
|
||||
StaticKeypair: static,
|
||||
//NOTE: These should come from CertState (pki.go) when we finally implement it
|
||||
PresharedKey: []byte{},
|
||||
PresharedKeyPlacement: 0,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewConnectionState: %s", err)
|
||||
}
|
||||
|
||||
// The queue and ready params prevent a counter race that would happen when
|
||||
// sending stored packets and simultaneously accepting new traffic.
|
||||
// newConnectionStateFromResult builds a fully-populated ConnectionState from a
|
||||
// completed handshake.Result. It seeds messageCounter and the replay window so
|
||||
// that the post-handshake message indices already used on the wire don't count
|
||||
// as missed traffic in the data plane.
|
||||
func newConnectionStateFromResult(r *handshake.Result) *ConnectionState {
|
||||
ci := &ConnectionState{
|
||||
H: hs,
|
||||
initiator: initiator,
|
||||
myCert: r.MyCert,
|
||||
initiator: r.Initiator,
|
||||
peerCert: r.RemoteCert,
|
||||
eKey: NewNebulaCipherState(r.EKey),
|
||||
dKey: NewNebulaCipherState(r.DKey),
|
||||
window: NewBits(ReplayWindow),
|
||||
myCert: crt,
|
||||
}
|
||||
// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
|
||||
ci.messageCounter.Add(2)
|
||||
|
||||
return ci, nil
|
||||
ci.messageCounter.Add(r.MessageIndex)
|
||||
for i := uint64(1); i <= r.MessageIndex; i++ {
|
||||
ci.window.Update(nil, i)
|
||||
}
|
||||
return ci
|
||||
}
|
||||
|
||||
func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
|
||||
|
||||
114
connection_state_test.go
Normal file
114
connection_state_test.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
ct "github.com/slackhq/nebula/cert_test"
|
||||
"github.com/slackhq/nebula/handshake"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// runTestHandshake runs a complete IX handshake between two freshly-built
|
||||
// peers and returns the initiator and responder Results. Used to produce
|
||||
// real cipher states for tests that need to exercise post-handshake glue.
|
||||
func runTestHandshake(t *testing.T) (initR, respR *handshake.Result) {
|
||||
t.Helper()
|
||||
|
||||
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca)
|
||||
|
||||
makeCreds := func(name string, networks []netip.Prefix) handshake.GetCredentialFunc {
|
||||
c, _, rawKey, _ := ct.NewTestCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, ca, caKey,
|
||||
name, ca.NotBefore(), ca.NotAfter(), networks, nil, nil,
|
||||
)
|
||||
priv, _, _, err := cert.UnmarshalPrivateKeyFromPEM(rawKey)
|
||||
require.NoError(t, err)
|
||||
hsBytes, err := c.MarshalForHandshakes()
|
||||
require.NoError(t, err)
|
||||
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||
cred := handshake.NewCredential(c, hsBytes, priv, ncs)
|
||||
return func(v cert.Version) *handshake.Credential {
|
||||
if v == cert.Version2 {
|
||||
return cred
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
verifier := func(c cert.Certificate) (*cert.CachedCertificate, error) {
|
||||
return caPool.VerifyCertificate(time.Now(), c)
|
||||
}
|
||||
|
||||
initCreds := makeCreds("initiator", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||
respCreds := makeCreds("responder", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
|
||||
|
||||
initM, err := handshake.NewMachine(
|
||||
cert.Version2, initCreds, verifier,
|
||||
func() (uint32, error) { return 1000, nil },
|
||||
true, header.HandshakeIXPSK0,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
respM, err := handshake.NewMachine(
|
||||
cert.Version2, respCreds, verifier,
|
||||
func() (uint32, error) { return 2000, nil },
|
||||
false, header.HandshakeIXPSK0,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
msg1, err := initM.Initiate(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, respR, err := respM.ProcessPacket(nil, msg1)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, respR)
|
||||
|
||||
_, initR, err = initM.ProcessPacket(nil, resp)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, initR)
|
||||
|
||||
return initR, respR
|
||||
}
|
||||
|
||||
func TestNewConnectionStateFromResult(t *testing.T) {
|
||||
initR, respR := runTestHandshake(t)
|
||||
|
||||
t.Run("initiator", func(t *testing.T) {
|
||||
ci := newConnectionStateFromResult(initR)
|
||||
assert.True(t, ci.initiator)
|
||||
assert.Equal(t, initR.MyCert, ci.myCert)
|
||||
assert.Equal(t, initR.RemoteCert, ci.peerCert)
|
||||
assert.NotNil(t, ci.eKey)
|
||||
assert.NotNil(t, ci.dKey)
|
||||
|
||||
// IX has 2 handshake messages; the next data-plane send is counter=3.
|
||||
assert.Equal(t, uint64(2), ci.messageCounter.Load(),
|
||||
"messageCounter must equal Result.MessageIndex so the next send is N+1")
|
||||
|
||||
// Both handshake counters must be marked seen so they don't appear lost.
|
||||
// Check returns false if an index has already been recorded.
|
||||
assert.False(t, ci.window.Check(nil, 1), "counter 1 must already be seen")
|
||||
assert.False(t, ci.window.Check(nil, 2), "counter 2 must already be seen")
|
||||
// Counter 3 is the next data-plane message and must NOT be pre-marked.
|
||||
assert.True(t, ci.window.Check(nil, 3), "counter 3 must not be pre-seeded")
|
||||
})
|
||||
|
||||
t.Run("responder", func(t *testing.T) {
|
||||
ci := newConnectionStateFromResult(respR)
|
||||
assert.False(t, ci.initiator)
|
||||
assert.Equal(t, respR.MyCert, ci.myCert)
|
||||
assert.Equal(t, respR.RemoteCert, ci.peerCert)
|
||||
assert.NotNil(t, ci.eKey)
|
||||
assert.NotNil(t, ci.dKey)
|
||||
assert.Equal(t, uint64(2), ci.messageCounter.Load())
|
||||
})
|
||||
}
|
||||
@@ -1033,7 +1033,7 @@ func TestNewFirewallFromConfig(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
// Test a bad rule definition
|
||||
c := &dummyCert{}
|
||||
cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
|
||||
cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil, "aes")
|
||||
require.NoError(t, err)
|
||||
|
||||
conf := config.NewC(test.NewLogger())
|
||||
|
||||
57
handshake/credential.go
Normal file
57
handshake/credential.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
)
|
||||
|
||||
// Credential holds everything needed to participate in a handshake
|
||||
// at a given cert version. Version and Curve are read from Cert; the public
|
||||
// half of the static keypair likewise comes from Cert.PublicKey().
|
||||
type Credential struct {
|
||||
Cert cert.Certificate // the certificate
|
||||
Bytes []byte // pre-marshaled certificate bytes
|
||||
privateKey []byte // static private key (public half lives in Cert)
|
||||
cipherSuite noise.CipherSuite // pre-built cipher suite (DH + cipher + hash)
|
||||
}
|
||||
|
||||
// NewCredential creates a Credential with all material needed for handshake
|
||||
// participation. The cipherSuite should be pre-built by the caller with the
|
||||
// appropriate DH function, cipher, and hash.
|
||||
func NewCredential(
|
||||
c cert.Certificate,
|
||||
hsBytes []byte,
|
||||
privateKey []byte,
|
||||
cipherSuite noise.CipherSuite,
|
||||
) *Credential {
|
||||
return &Credential{
|
||||
Cert: c,
|
||||
Bytes: hsBytes,
|
||||
privateKey: privateKey,
|
||||
cipherSuite: cipherSuite,
|
||||
}
|
||||
}
|
||||
|
||||
// buildHandshakeState creates a noise.HandshakeState from this credential.
|
||||
func (hc *Credential) buildHandshakeState(initiator bool, pattern noise.HandshakePattern) (*noise.HandshakeState, error) {
|
||||
return noise.NewHandshakeState(noise.Config{
|
||||
CipherSuite: hc.cipherSuite,
|
||||
Random: rand.Reader,
|
||||
Pattern: pattern,
|
||||
Initiator: initiator,
|
||||
StaticKeypair: noise.DHKey{Private: hc.privateKey, Public: hc.Cert.PublicKey()},
|
||||
PresharedKey: []byte{},
|
||||
PresharedKeyPlacement: 0,
|
||||
})
|
||||
}
|
||||
|
||||
// GetCredentialFunc returns the handshake credential for the given version,
|
||||
// or nil if that version is not available.
|
||||
//
|
||||
// Implementations must return credentials drawn from a snapshot stable for
|
||||
// the lifetime of any single Machine. The Machine may call this multiple
|
||||
// times during a handshake (e.g. when negotiating to the peer's version)
|
||||
// and assumes the underlying static keypair is consistent across calls.
|
||||
type GetCredentialFunc func(v cert.Version) *Credential
|
||||
21
handshake/errors.go
Normal file
21
handshake/errors.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package handshake
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrInitiateOnResponder = errors.New("initiate called on responder")
|
||||
ErrInitiateAlreadyCalled = errors.New("initiate already called")
|
||||
ErrInitiateNotCalled = errors.New("initiate must be called before ProcessPacket for initiators")
|
||||
ErrPacketTooShort = errors.New("packet too short")
|
||||
ErrPublicKeyMismatch = errors.New("public key mismatch between certificate and handshake")
|
||||
ErrIncompleteHandshake = errors.New("handshake completed without receiving required content")
|
||||
ErrMachineFailed = errors.New("handshake machine has failed")
|
||||
ErrUnknownSubtype = errors.New("unknown handshake subtype")
|
||||
ErrMissingContent = errors.New("expected handshake content but message was empty")
|
||||
ErrUnexpectedContent = errors.New("received unexpected handshake content")
|
||||
ErrIndexAllocation = errors.New("failed to allocate local index")
|
||||
ErrNoCredential = errors.New("no handshake credential available for cert version")
|
||||
ErrAsymmetricCipherKeys = errors.New("noise produced only one cipher key")
|
||||
ErrMultiMessageUnsupported = errors.New("multi-message handshake patterns are not yet supported by the manager")
|
||||
ErrSubtypeMismatch = errors.New("packet subtype does not match handshake machine subtype")
|
||||
)
|
||||
29
handshake/handshake.proto
Normal file
29
handshake/handshake.proto
Normal file
@@ -0,0 +1,29 @@
|
||||
// This file documents the wire format the nebula handshake speaks. It is
|
||||
// not run through protoc; the encoder/decoder in payload.go is hand-written
|
||||
// against this shape directly to keep the parser narrow and panic-free.
|
||||
//
|
||||
// Any change to the wire format must be reflected here, and adding a new
|
||||
// field requires updating MarshalPayload / unmarshalPayloadDetails together
|
||||
// with the field-uniqueness and wire-type checks in those functions.
|
||||
|
||||
syntax = "proto3";
|
||||
package nebula.handshake;
|
||||
|
||||
message NebulaHandshake {
|
||||
NebulaHandshakeDetails Details = 1;
|
||||
bytes Hmac = 2;
|
||||
}
|
||||
|
||||
message NebulaHandshakeDetails {
|
||||
bytes Cert = 1;
|
||||
uint32 InitiatorIndex = 2;
|
||||
uint32 ResponderIndex = 3;
|
||||
// Cookie was reserved for an anti-DoS mechanism that was never
|
||||
// implemented. No released version of nebula has ever populated it; the
|
||||
// hand-written parser silently skips it on read.
|
||||
uint64 Cookie = 4 [deprecated = true];
|
||||
uint64 Time = 5;
|
||||
uint32 CertVersion = 8;
|
||||
// reserved for WIP multiport
|
||||
reserved 6, 7;
|
||||
}
|
||||
116
handshake/helpers_test.go
Normal file
116
handshake/helpers_test.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
ct "github.com/slackhq/nebula/cert_test"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testCertState holds cert material for a test peer.
|
||||
type testCertState struct {
|
||||
version cert.Version
|
||||
creds map[cert.Version]*Credential
|
||||
}
|
||||
|
||||
func (s *testCertState) getCredential(v cert.Version) *Credential {
|
||||
return s.creds[v]
|
||||
}
|
||||
|
||||
func newTestCertState(
|
||||
t *testing.T, ca cert.Certificate, caKey []byte, name string, networks []netip.Prefix,
|
||||
) *testCertState {
|
||||
return newTestCertStateWithCipher(t, ca, caKey, name, networks, noise.CipherChaChaPoly)
|
||||
}
|
||||
|
||||
func newTestCertStateWithCipher(
|
||||
t *testing.T, ca cert.Certificate, caKey []byte, name string, networks []netip.Prefix,
|
||||
cipher noise.CipherFunc,
|
||||
) *testCertState {
|
||||
t.Helper()
|
||||
c, _, rawPrivKey, _ := ct.NewTestCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, ca, caKey,
|
||||
name, ca.NotBefore(), ca.NotAfter(), networks, nil, nil,
|
||||
)
|
||||
|
||||
priv, _, _, err := cert.UnmarshalPrivateKeyFromPEM(rawPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
hsBytes, err := c.MarshalForHandshakes()
|
||||
require.NoError(t, err)
|
||||
|
||||
ncs := noise.NewCipherSuite(noise.DH25519, cipher, noise.HashSHA256)
|
||||
return &testCertState{
|
||||
version: cert.Version2,
|
||||
creds: map[cert.Version]*Credential{
|
||||
cert.Version2: NewCredential(c, hsBytes, priv, ncs),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func testVerifier(pool *cert.CAPool) CertVerifier {
|
||||
return func(c cert.Certificate) (*cert.CachedCertificate, error) {
|
||||
return pool.VerifyCertificate(time.Now(), c)
|
||||
}
|
||||
}
|
||||
|
||||
func newTestMachine(
|
||||
t *testing.T,
|
||||
cs *testCertState,
|
||||
verifier CertVerifier,
|
||||
initiator bool,
|
||||
localIndex uint32,
|
||||
) *Machine {
|
||||
t.Helper()
|
||||
m, err := NewMachine(
|
||||
cs.version, cs.getCredential,
|
||||
verifier, func() (uint32, error) { return localIndex, nil },
|
||||
initiator, header.HandshakeIXPSK0,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
return m
|
||||
}
|
||||
|
||||
func initiateHandshake(
|
||||
t *testing.T,
|
||||
initCS *testCertState, initVerifier CertVerifier,
|
||||
respCS *testCertState, respVerifier CertVerifier,
|
||||
) (initM, respM *Machine, respResult *Result, resp []byte, err error) {
|
||||
t.Helper()
|
||||
initM = newTestMachine(t, initCS, initVerifier, true, 100)
|
||||
msg1, merr := initM.Initiate(nil)
|
||||
require.NoError(t, merr)
|
||||
|
||||
respM = newTestMachine(t, respCS, respVerifier, false, 200)
|
||||
resp, respResult, err = respM.ProcessPacket(nil, msg1)
|
||||
return
|
||||
}
|
||||
|
||||
func doFullHandshake(
|
||||
t *testing.T, initCS, respCS *testCertState, caPool *cert.CAPool,
|
||||
) (initResult, respResult *Result) {
|
||||
t.Helper()
|
||||
v := testVerifier(caPool)
|
||||
|
||||
initM := newTestMachine(t, initCS, v, true, 1000)
|
||||
respM := newTestMachine(t, respCS, v, false, 2000)
|
||||
|
||||
msg1, err := initM.Initiate(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, respResult, err := respM.ProcessPacket(nil, msg1)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, respResult)
|
||||
require.NotEmpty(t, resp)
|
||||
|
||||
_, initResult, err = initM.ProcessPacket(nil, resp)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, initResult)
|
||||
|
||||
return initResult, respResult
|
||||
}
|
||||
444
handshake/machine.go
Normal file
444
handshake/machine.go
Normal file
@@ -0,0 +1,444 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/header"
|
||||
)
|
||||
|
||||
// IndexAllocator is called by the Machine to allocate a local index for the
|
||||
// handshake. It is called at most once, when the first outgoing message that
|
||||
// carries a payload is built.
|
||||
//
|
||||
// Implementations MUST NOT return 0. Zero is reserved as a sentinel meaning
|
||||
// "no index assigned" on the wire and in the payload-presence checks. If an
|
||||
// allocator ever returned 0, a legitimate handshake's payload could be
|
||||
// indistinguishable from an empty one and would be rejected.
|
||||
type IndexAllocator func() (uint32, error)
|
||||
|
||||
// CertVerifier is called by the Machine after reconstructing the peer's
|
||||
// certificate from the handshake. The verifier performs all validation
|
||||
// (CA trust, expiry, policy checks, allow lists).
|
||||
type CertVerifier func(cert.Certificate) (*cert.CachedCertificate, error)
|
||||
|
||||
// Result contains the results of a successful handshake.
|
||||
// Returned by ProcessPacket when the handshake is complete.
|
||||
type Result struct {
|
||||
EKey *noise.CipherState
|
||||
DKey *noise.CipherState
|
||||
MyCert cert.Certificate
|
||||
RemoteCert *cert.CachedCertificate
|
||||
RemoteIndex uint32
|
||||
LocalIndex uint32
|
||||
HandshakeTime uint64
|
||||
MessageIndex uint64 // number of messages exchanged during the handshake
|
||||
Initiator bool
|
||||
}
|
||||
|
||||
// Machine drives a Noise handshake through N messages. It handles Noise
|
||||
// protocol operations, certificate reconstruction, and payload encoding.
|
||||
// Certificate validation is delegated to the caller via CertVerifier.
|
||||
//
|
||||
// A Machine is not safe for concurrent use. The caller must ensure that
|
||||
// Initiate and ProcessPacket are not called concurrently.
|
||||
//
|
||||
// Error contract: when ProcessPacket or Initiate returns an error, callers
|
||||
// must check Failed() to decide what to do next. If Failed() is false the
|
||||
// underlying noise state was not advanced (the packet was rejected before
|
||||
// ReadMessage took effect, or the rejection is non-fatal like a stale
|
||||
// retransmit) and the Machine can accept another packet. If Failed() is
|
||||
// true the Machine is unrecoverable and the caller must abandon it.
|
||||
type Machine struct {
|
||||
hs *noise.HandshakeState
|
||||
getCred GetCredentialFunc
|
||||
allocIndex IndexAllocator
|
||||
verifier CertVerifier
|
||||
result *Result
|
||||
msgs []msgFlags
|
||||
myVersion cert.Version
|
||||
subtype header.MessageSubType
|
||||
indexAllocated bool
|
||||
remoteCertSet bool
|
||||
payloadSet bool
|
||||
failed bool
|
||||
}
|
||||
|
||||
// NewMachine creates a handshake state machine. The subtype determines both
|
||||
// the noise pattern and the per-message content layout. The credential for
|
||||
// `version` is fetched via getCred and used to seed the noise.HandshakeState.
|
||||
// IndexAllocator is called lazily when the first outgoing payload is built.
|
||||
func NewMachine(
|
||||
version cert.Version,
|
||||
getCred GetCredentialFunc,
|
||||
verifier CertVerifier,
|
||||
allocIndex IndexAllocator,
|
||||
initiator bool,
|
||||
subtype header.MessageSubType,
|
||||
) (*Machine, error) {
|
||||
info, err := subtypeInfoFor(subtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cred := getCred(version)
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("%w: %v", ErrNoCredential, version)
|
||||
}
|
||||
|
||||
hs, err := cred.buildHandshakeState(initiator, info.pattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build noise state: %w", err)
|
||||
}
|
||||
|
||||
return &Machine{
|
||||
hs: hs,
|
||||
subtype: subtype,
|
||||
msgs: info.msgs,
|
||||
getCred: getCred,
|
||||
allocIndex: allocIndex,
|
||||
verifier: verifier,
|
||||
myVersion: version,
|
||||
result: &Result{
|
||||
Initiator: initiator,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Failed returns true if the Machine is in an unrecoverable state.
|
||||
func (m *Machine) Failed() bool {
|
||||
return m.failed
|
||||
}
|
||||
|
||||
// Subtype returns the handshake subtype this Machine was built for.
|
||||
func (m *Machine) Subtype() header.MessageSubType {
|
||||
return m.subtype
|
||||
}
|
||||
|
||||
// MessageIndex returns the noise handshake message index, which equals the
|
||||
// wire counter of the most recently sent or received message.
|
||||
func (m *Machine) MessageIndex() int {
|
||||
return m.hs.MessageIndex()
|
||||
}
|
||||
|
||||
// requireComplete checks that both a peer cert and payload have been received.
|
||||
// Marks the machine as failed if not.
|
||||
func (m *Machine) requireComplete() error {
|
||||
if !m.payloadSet || !m.remoteCertSet {
|
||||
m.failed = true
|
||||
return ErrIncompleteHandshake
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// myMsgFlags returns the flags for the current outgoing message.
|
||||
func (m *Machine) myMsgFlags() msgFlags {
|
||||
idx := m.hs.MessageIndex()
|
||||
if idx < len(m.msgs) {
|
||||
return m.msgs[idx]
|
||||
}
|
||||
return msgFlags{}
|
||||
}
|
||||
|
||||
// peerMsgFlags returns the flags for the message we just read.
|
||||
func (m *Machine) peerMsgFlags() msgFlags {
|
||||
idx := m.hs.MessageIndex() - 1
|
||||
if idx >= 0 && idx < len(m.msgs) {
|
||||
return m.msgs[idx]
|
||||
}
|
||||
return msgFlags{}
|
||||
}
|
||||
|
||||
// Initiate produces the first handshake message. Only valid for initiators,
|
||||
// and must be called exactly once before ProcessPacket.
|
||||
//
|
||||
// out is a destination buffer the message is appended to and returned. Pass
|
||||
// nil to allocate fresh, or pass a re-used buffer sliced to length 0 (e.g.
|
||||
// buf[:0]) with sufficient capacity to avoid allocation.
|
||||
//
|
||||
// An error return may not indicate a fatal condition, check Failed() to
|
||||
// determine if the Machine can still be used.
|
||||
func (m *Machine) Initiate(out []byte) ([]byte, error) {
|
||||
if m.failed {
|
||||
return nil, ErrMachineFailed
|
||||
}
|
||||
if !m.result.Initiator {
|
||||
m.failed = true
|
||||
return nil, ErrInitiateOnResponder
|
||||
}
|
||||
if m.hs.MessageIndex() != 0 {
|
||||
m.failed = true
|
||||
return nil, ErrInitiateAlreadyCalled
|
||||
}
|
||||
|
||||
// At MessageIndex=0 with RemoteIndex still zero, buildResponse produces
|
||||
// header counter 1 and remote index 0, which is what the initial message needs.
|
||||
out, _, _, err := m.buildResponse(out)
|
||||
if err != nil {
|
||||
m.failed = true
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// ProcessPacket handles an incoming handshake message. It advances the Noise
|
||||
// state, validates the peer certificate via the verifier, and optionally
|
||||
// produces a response.
|
||||
//
|
||||
// out is a destination buffer the response is appended to and returned. Pass
|
||||
// nil to allocate fresh, or pass a re-used buffer sliced to length 0 (e.g.
|
||||
// buf[:0]) with sufficient capacity to avoid allocation. The returned slice
|
||||
// is nil when no outgoing message is produced (handshake complete on this
|
||||
// side, or final message of a multi-message pattern).
|
||||
//
|
||||
// Returns a non-nil Result when the handshake is complete.
|
||||
// An error return may not indicate a fatal condition, check Failed() to
|
||||
// determine if the Machine can still be used.
|
||||
func (m *Machine) ProcessPacket(out, packet []byte) ([]byte, *Result, error) {
|
||||
if m.failed {
|
||||
return nil, nil, ErrMachineFailed
|
||||
}
|
||||
if len(packet) < header.Len {
|
||||
return nil, nil, ErrPacketTooShort
|
||||
}
|
||||
// Reject packets whose subtype doesn't match the one this Machine was
|
||||
// built for. A pending handshake that suddenly receives a different
|
||||
// subtype on its index is either a stray packet that matched by chance
|
||||
// or a peer protocol violation; drop it without failing the Machine so
|
||||
// the legitimate retransmit can still complete.
|
||||
if header.MessageSubType(packet[1]) != m.subtype {
|
||||
return nil, nil, ErrSubtypeMismatch
|
||||
}
|
||||
if m.result.Initiator && m.hs.MessageIndex() == 0 {
|
||||
m.failed = true
|
||||
return nil, nil, ErrInitiateNotCalled
|
||||
}
|
||||
|
||||
// The (eKey, dKey) ordering here is correct for IX, where the initiator
|
||||
// completes the handshake by reading the responder's stage-2 message.
|
||||
// noise returns (cs1, cs2) where cs1 is the initiator->responder cipher.
|
||||
// For 3-message patterns where a responder finishes by reading the final
|
||||
// message, this ordering would be wrong; revisit when XX/pqIX lands.
|
||||
msg, eKey, dKey, err := m.hs.ReadMessage(nil, packet[header.Len:])
|
||||
if err != nil {
|
||||
// Noise ReadMessage failed. The noise library checkpoints and rolls back
|
||||
// on failure, so the Machine is still alive. The caller can retry with
|
||||
// a different packet.
|
||||
return nil, nil, fmt.Errorf("noise ReadMessage: %w", err)
|
||||
}
|
||||
|
||||
// From here on, noise state has advanced. Any error is fatal.
|
||||
flags := m.peerMsgFlags()
|
||||
|
||||
if err := m.processPayload(msg, flags); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// If ReadMessage derived keys, the handshake is complete. Noise should
|
||||
// always produce both keys together; asymmetry is a protocol invariant
|
||||
// violation.
|
||||
if eKey != nil || dKey != nil {
|
||||
if eKey == nil || dKey == nil {
|
||||
m.failed = true
|
||||
return nil, nil, ErrAsymmetricCipherKeys
|
||||
}
|
||||
if err := m.requireComplete(); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return nil, m.completed(eKey, dKey), nil
|
||||
}
|
||||
|
||||
// ReadMessage didn't complete, produce the next outgoing message
|
||||
out, dk, ek, err := m.buildResponse(out)
|
||||
if err != nil {
|
||||
m.failed = true
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if ek != nil || dk != nil {
|
||||
if ek == nil || dk == nil {
|
||||
m.failed = true
|
||||
return nil, nil, ErrAsymmetricCipherKeys
|
||||
}
|
||||
if err := m.requireComplete(); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return out, m.completed(ek, dk), nil
|
||||
}
|
||||
|
||||
return out, nil, nil
|
||||
}
|
||||
|
||||
func (m *Machine) completed(eKey, dKey *noise.CipherState) *Result {
|
||||
m.result.EKey = eKey
|
||||
m.result.DKey = dKey
|
||||
m.result.MessageIndex = uint64(m.hs.MessageIndex())
|
||||
return m.result
|
||||
}
|
||||
|
||||
func (m *Machine) processPayload(msg []byte, flags msgFlags) error {
|
||||
if len(msg) == 0 {
|
||||
if flags.expectsPayload || flags.expectsCert {
|
||||
m.failed = true
|
||||
return ErrMissingContent
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
payload, err := UnmarshalPayload(msg)
|
||||
if err != nil {
|
||||
m.failed = true
|
||||
return fmt.Errorf("unmarshal handshake: %w", err)
|
||||
}
|
||||
|
||||
// Assert the payload contains exactly what we expect
|
||||
hasPayloadData := payload.InitiatorIndex != 0 || payload.ResponderIndex != 0 || payload.Time != 0
|
||||
if hasPayloadData != flags.expectsPayload {
|
||||
m.failed = true
|
||||
return ErrUnexpectedContent
|
||||
}
|
||||
|
||||
hasCertData := len(payload.Cert) > 0
|
||||
if hasCertData != flags.expectsCert {
|
||||
m.failed = true
|
||||
return ErrUnexpectedContent
|
||||
}
|
||||
|
||||
// Process payload
|
||||
if flags.expectsPayload {
|
||||
if m.result.Initiator {
|
||||
m.result.RemoteIndex = payload.ResponderIndex
|
||||
} else {
|
||||
m.result.RemoteIndex = payload.InitiatorIndex
|
||||
}
|
||||
m.result.HandshakeTime = payload.Time
|
||||
m.payloadSet = true
|
||||
}
|
||||
|
||||
// Process certificate
|
||||
if flags.expectsCert {
|
||||
if err := m.validateCert(payload); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Machine) validateCert(payload Payload) error {
|
||||
cred := m.getCred(m.myVersion)
|
||||
if cred == nil {
|
||||
m.failed = true
|
||||
return fmt.Errorf("%w: %v", ErrNoCredential, m.myVersion)
|
||||
}
|
||||
rc, err := cert.Recombine(
|
||||
cert.Version(payload.CertVersion),
|
||||
payload.Cert,
|
||||
m.hs.PeerStatic(),
|
||||
cred.Cert.Curve(),
|
||||
)
|
||||
if err != nil {
|
||||
m.failed = true
|
||||
return fmt.Errorf("recombine cert: %w", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(rc.PublicKey(), m.hs.PeerStatic()) {
|
||||
m.failed = true
|
||||
return ErrPublicKeyMismatch
|
||||
}
|
||||
|
||||
// Version negotiation, if the peer sent a different version and we have it, switch
|
||||
if rc.Version() != m.myVersion {
|
||||
if m.getCred(rc.Version()) != nil {
|
||||
m.myVersion = rc.Version()
|
||||
}
|
||||
}
|
||||
|
||||
verified, err := m.verifier(rc)
|
||||
if err != nil {
|
||||
m.failed = true
|
||||
return fmt.Errorf("verify cert: %w", err)
|
||||
}
|
||||
|
||||
m.result.RemoteCert = verified
|
||||
m.remoteCertSet = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Machine) marshalOutgoing(flags msgFlags) ([]byte, error) {
|
||||
if !flags.expectsPayload && !flags.expectsCert {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var p Payload
|
||||
if flags.expectsPayload {
|
||||
if !m.indexAllocated {
|
||||
index, err := m.allocIndex()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %w", ErrIndexAllocation, err)
|
||||
}
|
||||
m.result.LocalIndex = index
|
||||
m.indexAllocated = true
|
||||
}
|
||||
|
||||
if m.result.Initiator {
|
||||
p.InitiatorIndex = m.result.LocalIndex
|
||||
} else {
|
||||
p.ResponderIndex = m.result.LocalIndex
|
||||
p.InitiatorIndex = m.result.RemoteIndex
|
||||
}
|
||||
p.Time = uint64(time.Now().UnixNano())
|
||||
}
|
||||
if flags.expectsCert {
|
||||
cred := m.getCred(m.myVersion)
|
||||
if cred == nil {
|
||||
return nil, fmt.Errorf("%w: %v", ErrNoCredential, m.myVersion)
|
||||
}
|
||||
p.Cert = cred.Bytes
|
||||
p.CertVersion = uint32(cred.Cert.Version())
|
||||
m.result.MyCert = cred.Cert
|
||||
}
|
||||
|
||||
return MarshalPayload(nil, p), nil
|
||||
}
|
||||
|
||||
func (m *Machine) buildResponse(out []byte) ([]byte, *noise.CipherState, *noise.CipherState, error) {
|
||||
flags := m.myMsgFlags()
|
||||
hsBytes, err := m.marshalOutgoing(flags)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
// Extend out by header.Len to make room for the header. slices.Grow is a
|
||||
// no-op when the cap is already sufficient (the zero-copy case where the
|
||||
// caller passed a pre-sized buffer). header.Encode overwrites the new
|
||||
// bytes, so they don't need to be zeroed.
|
||||
start := len(out)
|
||||
out = slices.Grow(out, header.Len)[:start+header.Len]
|
||||
header.Encode(
|
||||
out[start:],
|
||||
header.Version, header.Handshake, m.subtype,
|
||||
m.result.RemoteIndex,
|
||||
uint64(m.hs.MessageIndex()+1),
|
||||
)
|
||||
|
||||
// noise.WriteMessage appends the encrypted handshake message to out,
|
||||
// reusing capacity when present.
|
||||
//
|
||||
// The (dKey, eKey) ordering here is correct for IX, where the responder
|
||||
// completes the handshake by writing the stage-2 message. noise returns
|
||||
// (cs1, cs2) where cs1 is the initiator->responder cipher (which is the
|
||||
// responder's decrypt key). For 3-message patterns where an initiator
|
||||
// finishes by writing the final message, this ordering would be wrong;
|
||||
// revisit when XX/pqIX lands.
|
||||
out, dKey, eKey, err := m.hs.WriteMessage(out, hsBytes)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("noise WriteMessage: %w", err)
|
||||
}
|
||||
|
||||
return out, dKey, eKey, nil
|
||||
}
|
||||
662
handshake/machine_test.go
Normal file
662
handshake/machine_test.go
Normal file
@@ -0,0 +1,662 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
ct "github.com/slackhq/nebula/cert_test"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/noiseutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMachineIXHappyPath(t *testing.T) {
|
||||
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca)
|
||||
|
||||
initCS := newTestCertState(t, ca, caKey, "initiator", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||
respCS := newTestCertState(t, ca, caKey, "responder", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
|
||||
|
||||
initR, respR := doFullHandshake(t, initCS, respCS, caPool)
|
||||
|
||||
assert.Equal(t, "responder", initR.RemoteCert.Certificate.Name())
|
||||
assert.Equal(t, "initiator", respR.RemoteCert.Certificate.Name())
|
||||
|
||||
assert.Equal(t, uint32(1000), initR.LocalIndex)
|
||||
assert.Equal(t, uint32(2000), initR.RemoteIndex)
|
||||
assert.Equal(t, uint32(2000), respR.LocalIndex)
|
||||
assert.Equal(t, uint32(1000), respR.RemoteIndex)
|
||||
|
||||
assert.Equal(t, uint64(2), initR.MessageIndex, "IX has 2 messages")
|
||||
assert.Equal(t, uint64(2), respR.MessageIndex, "IX has 2 messages")
|
||||
|
||||
ct1, err := initR.EKey.Encrypt(nil, nil, []byte("hello"))
|
||||
require.NoError(t, err)
|
||||
pt1, err := respR.DKey.Decrypt(nil, nil, ct1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("hello"), pt1)
|
||||
|
||||
ct2, err := respR.EKey.Encrypt(nil, nil, []byte("world"))
|
||||
require.NoError(t, err)
|
||||
pt2, err := initR.DKey.Decrypt(nil, nil, ct2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("world"), pt2)
|
||||
}
|
||||
|
||||
func TestMachineInitiateErrors(t *testing.T) {
|
||||
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca)
|
||||
cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||
v := testVerifier(caPool)
|
||||
|
||||
t.Run("initiate on responder", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
_, err := m.Initiate(nil)
|
||||
require.ErrorIs(t, err, ErrInitiateOnResponder)
|
||||
assert.True(t, m.Failed())
|
||||
})
|
||||
|
||||
t.Run("initiate called twice", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, true, 100)
|
||||
_, err := m.Initiate(nil)
|
||||
require.NoError(t, err)
|
||||
_, err = m.Initiate(nil)
|
||||
require.ErrorIs(t, err, ErrInitiateAlreadyCalled)
|
||||
assert.True(t, m.Failed())
|
||||
})
|
||||
|
||||
t.Run("process packet before initiate on initiator", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, true, 100)
|
||||
_, _, err := m.ProcessPacket(nil, make([]byte, 100))
|
||||
require.ErrorIs(t, err, ErrInitiateNotCalled)
|
||||
assert.True(t, m.Failed())
|
||||
})
|
||||
|
||||
t.Run("calling failed machine", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
_, err := m.Initiate(nil) // fails: responder
|
||||
require.Error(t, err)
|
||||
_, err = m.Initiate(nil) // fails: already failed
|
||||
require.ErrorIs(t, err, ErrMachineFailed)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMachineProcessPacketErrors(t *testing.T) {
|
||||
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca)
|
||||
cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||
v := testVerifier(caPool)
|
||||
|
||||
t.Run("packet too short", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
_, _, err := m.ProcessPacket(nil, []byte{1, 2, 3})
|
||||
require.ErrorIs(t, err, ErrPacketTooShort)
|
||||
assert.False(t, m.Failed(), "short packet should not kill machine")
|
||||
})
|
||||
|
||||
t.Run("noise decryption failure is recoverable", func(t *testing.T) {
|
||||
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||
initM := newTestMachine(t, initCS, v, true, 100)
|
||||
msg1, err := initM.Initiate(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
respM := newTestMachine(t, cs, v, false, 200)
|
||||
resp, _, err := respM.ProcessPacket(nil, msg1)
|
||||
require.NoError(t, err)
|
||||
|
||||
corrupted := make([]byte, len(resp))
|
||||
copy(corrupted, resp)
|
||||
for i := header.Len; i < len(corrupted); i++ {
|
||||
corrupted[i] ^= 0xff
|
||||
}
|
||||
_, _, err = initM.ProcessPacket(nil, corrupted)
|
||||
require.Error(t, err)
|
||||
assert.False(t, initM.Failed(), "noise failure should be recoverable")
|
||||
|
||||
// And the machine should still complete a real handshake afterward.
|
||||
_, result, err := initM.ProcessPacket(nil, resp)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result, "initiator should complete on the legitimate response")
|
||||
})
|
||||
|
||||
t.Run("invalid cert is fatal", func(t *testing.T) {
|
||||
otherCA, _, otherCAKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
otherCS := newTestCertState(t, otherCA, otherCAKey, "other", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
|
||||
|
||||
initM := newTestMachine(t, otherCS, testVerifier(ct.NewTestCAPool(otherCA)), true, 100)
|
||||
msg1, err := initM.Initiate(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
respM := newTestMachine(t, cs, v, false, 200)
|
||||
_, _, err = respM.ProcessPacket(nil, msg1)
|
||||
require.Error(t, err)
|
||||
assert.True(t, respM.Failed(), "cert validation failure should kill machine")
|
||||
})
|
||||
|
||||
t.Run("subtype mismatch is recoverable", func(t *testing.T) {
|
||||
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||
initM := newTestMachine(t, initCS, v, true, 100)
|
||||
msg1, err := initM.Initiate(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Mutate the subtype byte (offset 1 in the header) to a value the
|
||||
// responder Machine wasn't built for.
|
||||
bad := make([]byte, len(msg1))
|
||||
copy(bad, msg1)
|
||||
bad[1] = 0xff
|
||||
|
||||
respM := newTestMachine(t, cs, v, false, 200)
|
||||
_, _, err = respM.ProcessPacket(nil, bad)
|
||||
require.ErrorIs(t, err, ErrSubtypeMismatch)
|
||||
assert.False(t, respM.Failed(), "subtype mismatch should not kill the machine")
|
||||
|
||||
// And the machine should still complete a real handshake afterward.
|
||||
resp, result, err := respM.ProcessPacket(nil, msg1)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result, "responder should complete on the legitimate stage-1 packet")
|
||||
assert.NotEmpty(t, resp, "responder should produce a stage-2 reply")
|
||||
})
|
||||
}
|
||||
|
||||
// TestMachineProcessPayload exercises processPayload's internal validation
|
||||
// directly. Most of these failure modes can't be reached black-box once the
|
||||
// subtype check at the top of ProcessPacket gates external callers, so we
|
||||
// drive them by hand here for coverage.
|
||||
func TestMachineProcessPayload(t *testing.T) {
|
||||
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca)
|
||||
cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||
v := testVerifier(caPool)
|
||||
|
||||
t.Run("empty message with expects fails", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
err := m.processPayload(nil, msgFlags{expectsPayload: true, expectsCert: true})
|
||||
require.ErrorIs(t, err, ErrMissingContent)
|
||||
assert.True(t, m.Failed())
|
||||
})
|
||||
|
||||
t.Run("empty message with no expects passes", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
err := m.processPayload(nil, msgFlags{})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, m.Failed())
|
||||
})
|
||||
|
||||
t.Run("malformed protobuf is fatal", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
err := m.processPayload([]byte{0xff, 0xff, 0xff}, msgFlags{expectsPayload: true, expectsCert: true})
|
||||
require.Error(t, err)
|
||||
assert.True(t, m.Failed())
|
||||
})
|
||||
|
||||
t.Run("unexpected payload data is fatal", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
// A payload with index data when none was expected.
|
||||
bytes := MarshalPayload(nil, Payload{InitiatorIndex: 42, Time: 1})
|
||||
err := m.processPayload(bytes, msgFlags{expectsPayload: false, expectsCert: false})
|
||||
require.ErrorIs(t, err, ErrUnexpectedContent)
|
||||
assert.True(t, m.Failed())
|
||||
})
|
||||
|
||||
t.Run("unexpected cert data is fatal", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
// A payload with cert when none was expected.
|
||||
bytes := MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2})
|
||||
err := m.processPayload(bytes, msgFlags{expectsPayload: false, expectsCert: false})
|
||||
require.ErrorIs(t, err, ErrUnexpectedContent)
|
||||
assert.True(t, m.Failed())
|
||||
})
|
||||
|
||||
t.Run("missing payload data when expected is fatal", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
// Cert present, but no index/time fields.
|
||||
bytes := MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2})
|
||||
err := m.processPayload(bytes, msgFlags{expectsPayload: true, expectsCert: true})
|
||||
require.ErrorIs(t, err, ErrUnexpectedContent)
|
||||
assert.True(t, m.Failed())
|
||||
})
|
||||
}
|
||||
|
||||
// TestMachineRequireComplete checks the fail-on-incomplete-handshake path
|
||||
// directly. Like processPayload above this isn't reachable from a normal IX
|
||||
// flow, so we drive it by hand.
|
||||
func TestMachineRequireComplete(t *testing.T) {
|
||||
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca)
|
||||
cs := newTestCertState(t, ca, caKey, "test", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||
v := testVerifier(caPool)
|
||||
|
||||
t.Run("missing both fails", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
err := m.requireComplete()
|
||||
require.ErrorIs(t, err, ErrIncompleteHandshake)
|
||||
assert.True(t, m.Failed())
|
||||
})
|
||||
|
||||
t.Run("payload only fails", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
m.payloadSet = true
|
||||
err := m.requireComplete()
|
||||
require.ErrorIs(t, err, ErrIncompleteHandshake)
|
||||
assert.True(t, m.Failed())
|
||||
})
|
||||
|
||||
t.Run("cert only fails", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
m.remoteCertSet = true
|
||||
err := m.requireComplete()
|
||||
require.ErrorIs(t, err, ErrIncompleteHandshake)
|
||||
assert.True(t, m.Failed())
|
||||
})
|
||||
|
||||
t.Run("both set passes", func(t *testing.T) {
|
||||
m := newTestMachine(t, cs, v, false, 100)
|
||||
m.payloadSet = true
|
||||
m.remoteCertSet = true
|
||||
err := m.requireComplete()
|
||||
require.NoError(t, err)
|
||||
assert.False(t, m.Failed())
|
||||
})
|
||||
}
|
||||
|
||||
func TestMachineAESCipher(t *testing.T) {
|
||||
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca)
|
||||
|
||||
initCS := newTestCertStateWithCipher(
|
||||
t, ca, caKey, "init",
|
||||
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
|
||||
noiseutil.CipherAESGCM,
|
||||
)
|
||||
respCS := newTestCertStateWithCipher(
|
||||
t, ca, caKey, "resp",
|
||||
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")},
|
||||
noiseutil.CipherAESGCM,
|
||||
)
|
||||
|
||||
initR, respR := doFullHandshake(t, initCS, respCS, caPool)
|
||||
|
||||
ct1, err := initR.EKey.Encrypt(nil, nil, []byte("works"))
|
||||
require.NoError(t, err)
|
||||
pt1, err := respR.DKey.Decrypt(nil, nil, ct1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("works"), pt1)
|
||||
|
||||
ct2, err := respR.EKey.Encrypt(nil, nil, []byte("back"))
|
||||
require.NoError(t, err)
|
||||
pt2, err := initR.DKey.Decrypt(nil, nil, ct2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("back"), pt2)
|
||||
}
|
||||
|
||||
func TestResultFields(t *testing.T) {
|
||||
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca)
|
||||
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||
respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
|
||||
|
||||
initR, respR := doFullHandshake(t, initCS, respCS, caPool)
|
||||
|
||||
assert.True(t, initR.Initiator)
|
||||
assert.False(t, respR.Initiator)
|
||||
assert.NotZero(t, initR.HandshakeTime)
|
||||
assert.NotZero(t, respR.HandshakeTime)
|
||||
assert.NotNil(t, initR.RemoteCert)
|
||||
assert.NotNil(t, respR.RemoteCert)
|
||||
}
|
||||
|
||||
func TestMachineBufferReuse(t *testing.T) {
|
||||
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca)
|
||||
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||
respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
|
||||
v := testVerifier(caPool)
|
||||
|
||||
initM := newTestMachine(t, initCS, v, true, 1000)
|
||||
respM := newTestMachine(t, respCS, v, false, 2000)
|
||||
|
||||
msg1, err := initM.Initiate(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("response writes into provided buffer", func(t *testing.T) {
|
||||
buf := make([]byte, 0, 4096)
|
||||
resp, result, err := respM.ProcessPacket(buf, msg1)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
assert.NotEmpty(t, resp, "response should have content")
|
||||
assert.Equal(t, &buf[:1][0], &resp[:1][0],
|
||||
"response should reuse the provided buffer's backing array")
|
||||
})
|
||||
|
||||
t.Run("initiate writes into provided buffer", func(t *testing.T) {
|
||||
initM2 := newTestMachine(t, initCS, v, true, 3000)
|
||||
buf := make([]byte, 0, 4096)
|
||||
msg, err := initM2.Initiate(buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEmpty(t, msg, "initiate should have content")
|
||||
assert.Equal(t, &buf[:1][0], &msg[:1][0],
|
||||
"initiate should reuse the provided buffer's backing array")
|
||||
})
|
||||
|
||||
t.Run("nil out still works", func(t *testing.T) {
|
||||
initM2 := newTestMachine(t, initCS, v, true, 4000)
|
||||
respM2 := newTestMachine(t, respCS, v, false, 5000)
|
||||
|
||||
msg1, err := initM2.Initiate(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, _, err := respM2.ProcessPacket(nil, msg1)
|
||||
require.NoError(t, err)
|
||||
|
||||
out, result, err := initM2.ProcessPacket(nil, resp)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Nil(t, out, "initiator should have no response for IX msg2")
|
||||
})
|
||||
}
|
||||
|
||||
func TestMachineMsgIndexTracking(t *testing.T) {
|
||||
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca)
|
||||
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||
respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
|
||||
v := testVerifier(caPool)
|
||||
|
||||
initM := newTestMachine(t, initCS, v, true, 100)
|
||||
respM := newTestMachine(t, respCS, v, false, 200)
|
||||
|
||||
msg1, err := initM.Initiate(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp1, result1, err := respM.ProcessPacket(nil, msg1)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result1)
|
||||
|
||||
_, result2, err := initM.ProcessPacket(nil, resp1)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result2)
|
||||
}
|
||||
|
||||
func TestMachineThreeMessagePattern(t *testing.T) {
|
||||
registerTestXXInfo(t)
|
||||
|
||||
// Use HandshakeXX (3 messages) to verify the Machine handles multi-message
|
||||
// patterns correctly. XX flow:
|
||||
// msg1 (I->R): [E] - payload only, no cert
|
||||
// msg2 (R->I): [E, ee, S, es] - payload + cert
|
||||
// msg3 (I->R): [S, se] - cert only (no payload, not first two)
|
||||
|
||||
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca)
|
||||
v := testVerifier(caPool)
|
||||
|
||||
initCS := newTestCertState(t, ca, caKey, "init", []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")})
|
||||
respCS := newTestCertState(t, ca, caKey, "resp", []netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")})
|
||||
|
||||
initM, err := NewMachine(
|
||||
cert.Version2,
|
||||
initCS.getCredential, v,
|
||||
func() (uint32, error) { return 1000, nil },
|
||||
true, header.HandshakeXXPSK0,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
respM, err := NewMachine(
|
||||
cert.Version2,
|
||||
respCS.getCredential, v,
|
||||
func() (uint32, error) { return 2000, nil },
|
||||
false, header.HandshakeXXPSK0,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// msg1: initiator -> responder (E only, no cert)
|
||||
msg1, err := initM.Initiate(nil)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, msg1)
|
||||
|
||||
// Responder processes msg1, should not complete yet, should produce msg2
|
||||
msg2, result, err := respM.ProcessPacket(nil, msg1)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, result, "XX should not complete on msg1")
|
||||
assert.NotEmpty(t, msg2, "responder should produce msg2")
|
||||
|
||||
// Initiator processes msg2: gets responder's cert, produces msg3, and
|
||||
// completes (WriteMessage for msg3 derives keys)
|
||||
msg3, initResult, err := initM.ProcessPacket(nil, msg2)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, initResult, "XX initiator should complete after reading msg2 and writing msg3")
|
||||
assert.NotEmpty(t, msg3, "initiator should produce msg3")
|
||||
assert.Equal(t, "resp", initResult.RemoteCert.Certificate.Name())
|
||||
|
||||
// Responder processes msg3: gets initiator's cert and completes
|
||||
_, respResult, err := respM.ProcessPacket(nil, msg3)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, respResult, "XX responder should complete on msg3")
|
||||
assert.Equal(t, "init", respResult.RemoteCert.Certificate.Name())
|
||||
|
||||
assert.Equal(t, uint64(3), initResult.MessageIndex, "XX has 3 messages")
|
||||
assert.Equal(t, uint64(3), respResult.MessageIndex, "XX has 3 messages")
|
||||
|
||||
// Verify keys work
|
||||
ct1, err := initResult.EKey.Encrypt(nil, nil, []byte("three messages"))
|
||||
require.NoError(t, err)
|
||||
pt1, err := respResult.DKey.Decrypt(nil, nil, ct1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("three messages"), pt1)
|
||||
}
|
||||
|
||||
// NOTE: ErrIncompleteHandshake is tested implicitly. It can't be triggered with
|
||||
// IX since the cert is always in the payload. A 3-message pattern test (HybridIX)
|
||||
// should exercise the case where cert arrives in msg3 and verify that completing
|
||||
// without it fails.
|
||||
|
||||
func TestMachineExpiredCert(t *testing.T) {
|
||||
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519,
|
||||
time.Now().Add(-24*time.Hour), time.Now().Add(24*time.Hour),
|
||||
nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca)
|
||||
|
||||
expCert, _, expKeyPEM, _ := ct.NewTestCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, ca, caKey,
|
||||
"expired", time.Now().Add(-2*time.Hour), time.Now().Add(-1*time.Hour),
|
||||
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}, nil, nil,
|
||||
)
|
||||
expKey, _, _, err := cert.UnmarshalPrivateKeyFromPEM(expKeyPEM)
|
||||
require.NoError(t, err)
|
||||
expHsBytes, err := expCert.MarshalForHandshakes()
|
||||
require.NoError(t, err)
|
||||
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||
|
||||
expiredCS := &testCertState{
|
||||
version: cert.Version2,
|
||||
creds: map[cert.Version]*Credential{
|
||||
cert.Version2: NewCredential(expCert, expHsBytes, expKey, ncs),
|
||||
},
|
||||
}
|
||||
|
||||
respCS := newTestCertState(
|
||||
t, ca, caKey, "responder",
|
||||
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")},
|
||||
)
|
||||
|
||||
_, respM, _, _, err := initiateHandshake(
|
||||
t, expiredCS, testVerifier(caPool),
|
||||
respCS, testVerifier(caPool),
|
||||
)
|
||||
require.ErrorContains(t, err, "verify cert")
|
||||
assert.True(t, respM.Failed())
|
||||
}
|
||||
|
||||
func TestMachineNoCertNetworks(t *testing.T) {
|
||||
ca, _, caKey, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca)
|
||||
|
||||
caHsBytes, err := ca.MarshalForHandshakes()
|
||||
require.NoError(t, err)
|
||||
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||
|
||||
noNetCS := &testCertState{
|
||||
version: cert.Version2,
|
||||
creds: map[cert.Version]*Credential{
|
||||
cert.Version2: NewCredential(ca, caHsBytes, caKey, ncs),
|
||||
},
|
||||
}
|
||||
|
||||
respCS := newTestCertState(
|
||||
t, ca, caKey, "responder",
|
||||
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")},
|
||||
)
|
||||
|
||||
_, respM, _, _, err := initiateHandshake(
|
||||
t, noNetCS, testVerifier(caPool),
|
||||
respCS, testVerifier(caPool),
|
||||
)
|
||||
require.Error(t, err)
|
||||
assert.True(t, respM.Failed())
|
||||
}
|
||||
|
||||
func TestMachineDifferentCAs(t *testing.T) {
|
||||
ca1, _, caKey1, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
ca2, _, caKey2, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
|
||||
initCS := newTestCertState(
|
||||
t, ca1, caKey1, "init",
|
||||
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
|
||||
)
|
||||
respCS := newTestCertState(
|
||||
t, ca2, caKey2, "resp",
|
||||
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")},
|
||||
)
|
||||
|
||||
_, respM, _, _, err := initiateHandshake(
|
||||
t, initCS, testVerifier(ct.NewTestCAPool(ca1)),
|
||||
respCS, testVerifier(ct.NewTestCAPool(ca2)),
|
||||
)
|
||||
require.ErrorContains(t, err, "verify cert")
|
||||
assert.True(t, respM.Failed())
|
||||
}
|
||||
|
||||
func TestMachineVersionNegotiation(t *testing.T) {
|
||||
ca1, _, caKey1, _ := ct.NewTestCaCert(
|
||||
cert.Version1, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
ca2, _, caKey2, _ := ct.NewTestCaCert(
|
||||
cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil,
|
||||
)
|
||||
caPool := ct.NewTestCAPool(ca1, ca2)
|
||||
|
||||
makeMultiVersionResp := func(t *testing.T) *testCertState {
|
||||
t.Helper()
|
||||
respCertV1, _, respKeyPEM, _ := ct.NewTestCert(
|
||||
cert.Version1, cert.Curve_CURVE25519, ca1, caKey1, "resp",
|
||||
ca1.NotBefore(), ca1.NotAfter(),
|
||||
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, nil, nil,
|
||||
)
|
||||
respKey, _, _, _ := cert.UnmarshalPrivateKeyFromPEM(respKeyPEM)
|
||||
respCertV2, _ := ct.NewTestCertDifferentVersion(respCertV1, cert.Version2, ca2, caKey2)
|
||||
respHsV1, _ := respCertV1.MarshalForHandshakes()
|
||||
respHsV2, _ := respCertV2.MarshalForHandshakes()
|
||||
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||
return &testCertState{
|
||||
version: cert.Version1,
|
||||
creds: map[cert.Version]*Credential{
|
||||
cert.Version1: NewCredential(respCertV1, respHsV1, respKey, ncs),
|
||||
cert.Version2: NewCredential(respCertV2, respHsV2, respKey, ncs),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("responder matches initiator version", func(t *testing.T) {
|
||||
initCS := newTestCertState(
|
||||
t, ca2, caKey2, "init",
|
||||
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
|
||||
)
|
||||
respCS := makeMultiVersionResp(t)
|
||||
v := testVerifier(caPool)
|
||||
|
||||
initM, _, respResult, resp, err := initiateHandshake(
|
||||
t, initCS, v,
|
||||
respCS, v,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, respResult)
|
||||
|
||||
assert.Equal(t, cert.Version2, respResult.MyCert.Version(),
|
||||
"responder should negotiate to initiator's version")
|
||||
|
||||
_, initResult, err := initM.ProcessPacket(nil, resp)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, initResult)
|
||||
assert.Equal(t, cert.Version2, initResult.RemoteCert.Certificate.Version(),
|
||||
"initiator should see V2 cert from responder")
|
||||
})
|
||||
|
||||
t.Run("responder keeps version when no match available", func(t *testing.T) {
|
||||
initCS := newTestCertState(
|
||||
t, ca2, caKey2, "init",
|
||||
[]netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")},
|
||||
)
|
||||
|
||||
respCert, _, respKeyPEM, _ := ct.NewTestCert(
|
||||
cert.Version1, cert.Curve_CURVE25519, ca1, caKey1, "resp",
|
||||
ca1.NotBefore(), ca1.NotAfter(),
|
||||
[]netip.Prefix{netip.MustParsePrefix("10.0.0.2/24")}, nil, nil,
|
||||
)
|
||||
respKey, _, _, _ := cert.UnmarshalPrivateKeyFromPEM(respKeyPEM)
|
||||
respHs, _ := respCert.MarshalForHandshakes()
|
||||
ncs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||
respCS := &testCertState{
|
||||
version: cert.Version1,
|
||||
creds: map[cert.Version]*Credential{
|
||||
cert.Version1: NewCredential(respCert, respHs, respKey, ncs),
|
||||
},
|
||||
}
|
||||
|
||||
v := testVerifier(caPool)
|
||||
_, _, respResult, _, err := initiateHandshake(
|
||||
t, initCS, v,
|
||||
respCS, v,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, respResult)
|
||||
|
||||
assert.Equal(t, cert.Version1, respResult.MyCert.Version(),
|
||||
"responder should keep V1 when V2 not available")
|
||||
})
|
||||
}
|
||||
54
handshake/patterns.go
Normal file
54
handshake/patterns.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/slackhq/nebula/header"
|
||||
)
|
||||
|
||||
// msgFlags tracks what application data a handshake message carries.
|
||||
type msgFlags struct {
|
||||
expectsPayload bool // message carries indexes and time
|
||||
expectsCert bool // message carries the certificate
|
||||
}
|
||||
|
||||
// subtypeInfo bundles the noise pattern with the per-message flags for a
|
||||
// given handshake subtype.
|
||||
type subtypeInfo struct {
|
||||
pattern noise.HandshakePattern
|
||||
msgs []msgFlags
|
||||
}
|
||||
|
||||
// subtypeInfos defines the noise pattern and message content layout for each
|
||||
// handshake subtype.
|
||||
var subtypeInfos = map[header.MessageSubType]subtypeInfo{
|
||||
// IX: 2 messages, both carry payload and cert
|
||||
header.HandshakeIXPSK0: {
|
||||
pattern: noise.HandshakeIX,
|
||||
msgs: []msgFlags{
|
||||
{expectsPayload: true, expectsCert: true},
|
||||
{expectsPayload: true, expectsCert: true},
|
||||
},
|
||||
},
|
||||
|
||||
// XX: 3 messages
|
||||
// msg1 (I->R): payload only
|
||||
// msg2 (R->I): payload + cert
|
||||
// msg3 (I->R): cert only
|
||||
//header.HandshakeXXPSK0: {
|
||||
// pattern: noise.HandshakeXX,
|
||||
// msgs: []msgFlags{
|
||||
// {expectsPayload: true, expectsCert: false},
|
||||
// {expectsPayload: true, expectsCert: true},
|
||||
// {expectsPayload: false, expectsCert: true},
|
||||
// },
|
||||
//},
|
||||
}
|
||||
|
||||
func subtypeInfoFor(subtype header.MessageSubType) (subtypeInfo, error) {
|
||||
if info, ok := subtypeInfos[subtype]; ok {
|
||||
return info, nil
|
||||
}
|
||||
return subtypeInfo{}, fmt.Errorf("%w: %d", ErrUnknownSubtype, subtype)
|
||||
}
|
||||
63
handshake/patterns_test.go
Normal file
63
handshake/patterns_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSubtypeInfo(t *testing.T) {
|
||||
t.Run("IX", func(t *testing.T) {
|
||||
info, err := subtypeInfoFor(header.HandshakeIXPSK0)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, noise.HandshakeIX.Name, info.pattern.Name)
|
||||
require.Len(t, info.msgs, 2)
|
||||
// msg1: payload + cert
|
||||
assert.True(t, info.msgs[0].expectsPayload)
|
||||
assert.True(t, info.msgs[0].expectsCert)
|
||||
// msg2: payload + cert
|
||||
assert.True(t, info.msgs[1].expectsPayload)
|
||||
assert.True(t, info.msgs[1].expectsCert)
|
||||
})
|
||||
|
||||
t.Run("XX", func(t *testing.T) {
|
||||
registerTestXXInfo(t)
|
||||
info, err := subtypeInfoFor(header.HandshakeXXPSK0)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, noise.HandshakeXX.Name, info.pattern.Name)
|
||||
require.Len(t, info.msgs, 3)
|
||||
// msg1: payload only
|
||||
assert.True(t, info.msgs[0].expectsPayload)
|
||||
assert.False(t, info.msgs[0].expectsCert)
|
||||
// msg2: payload + cert
|
||||
assert.True(t, info.msgs[1].expectsPayload)
|
||||
assert.True(t, info.msgs[1].expectsCert)
|
||||
// msg3: cert only
|
||||
assert.False(t, info.msgs[2].expectsPayload)
|
||||
assert.True(t, info.msgs[2].expectsCert)
|
||||
})
|
||||
|
||||
t.Run("unknown subtype returns error", func(t *testing.T) {
|
||||
_, err := subtypeInfoFor(99)
|
||||
require.ErrorIs(t, err, ErrUnknownSubtype)
|
||||
})
|
||||
}
|
||||
|
||||
// registerTestXXInfo temporarily registers XX subtype info for testing.
|
||||
func registerTestXXInfo(t *testing.T) {
|
||||
t.Helper()
|
||||
subtypeInfos[header.HandshakeXXPSK0] = subtypeInfo{
|
||||
pattern: noise.HandshakeXX,
|
||||
msgs: []msgFlags{
|
||||
{expectsPayload: true, expectsCert: false},
|
||||
{expectsPayload: true, expectsCert: true},
|
||||
{expectsPayload: false, expectsCert: true},
|
||||
},
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
delete(subtypeInfos, header.HandshakeXXPSK0)
|
||||
})
|
||||
}
|
||||
173
handshake/payload.go
Normal file
173
handshake/payload.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"google.golang.org/protobuf/encoding/protowire"
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidHandshakeMessage = errors.New("invalid handshake message")
|
||||
errInvalidHandshakeDetails = errors.New("invalid handshake details")
|
||||
)
|
||||
|
||||
// Payload represents the decoded fields of a handshake message.
|
||||
// Wire format is protobuf-compatible with NebulaHandshake{Details: NebulaHandshakeDetails{...}}.
|
||||
type Payload struct {
|
||||
Cert []byte
|
||||
InitiatorIndex uint32
|
||||
ResponderIndex uint32
|
||||
Time uint64
|
||||
CertVersion uint32
|
||||
}
|
||||
|
||||
// Proto field numbers for NebulaHandshakeDetails
|
||||
const (
|
||||
fieldCert = 1 // bytes
|
||||
fieldInitiatorIndex = 2 // uint32
|
||||
fieldResponderIndex = 3 // uint32
|
||||
fieldTime = 5 // uint64
|
||||
fieldCertVersion = 8 // uint32
|
||||
)
|
||||
|
||||
// MarshalPayload encodes a handshake payload in protobuf wire format compatible
|
||||
// with NebulaHandshake{Details: NebulaHandshakeDetails{...}}.
|
||||
// Returns out (which may be nil), with the marshalled Payload appended to it.
|
||||
func MarshalPayload(out []byte, p Payload) []byte {
|
||||
var details []byte
|
||||
|
||||
if len(p.Cert) > 0 {
|
||||
details = protowire.AppendTag(details, fieldCert, protowire.BytesType)
|
||||
details = protowire.AppendBytes(details, p.Cert)
|
||||
}
|
||||
if p.InitiatorIndex != 0 {
|
||||
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, uint64(p.InitiatorIndex))
|
||||
}
|
||||
if p.ResponderIndex != 0 {
|
||||
details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, uint64(p.ResponderIndex))
|
||||
}
|
||||
if p.Time != 0 {
|
||||
details = protowire.AppendTag(details, fieldTime, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, p.Time)
|
||||
}
|
||||
if p.CertVersion != 0 {
|
||||
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, uint64(p.CertVersion))
|
||||
}
|
||||
|
||||
out = protowire.AppendTag(out, 1, protowire.BytesType)
|
||||
out = protowire.AppendBytes(out, details)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// UnmarshalPayload decodes a protobuf-encoded NebulaHandshake message.
|
||||
func UnmarshalPayload(b []byte) (Payload, error) {
|
||||
var p Payload
|
||||
|
||||
for len(b) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(b)
|
||||
if n < 0 {
|
||||
return p, errInvalidHandshakeMessage
|
||||
}
|
||||
b = b[n:]
|
||||
|
||||
switch {
|
||||
case num == 1 && typ == protowire.BytesType:
|
||||
details, n := protowire.ConsumeBytes(b)
|
||||
if n < 0 {
|
||||
return p, errInvalidHandshakeMessage
|
||||
}
|
||||
b = b[n:]
|
||||
if err := unmarshalPayloadDetails(&p, details); err != nil {
|
||||
return p, err
|
||||
}
|
||||
default:
|
||||
n := protowire.ConsumeFieldValue(num, typ, b)
|
||||
if n < 0 {
|
||||
return p, errInvalidHandshakeMessage
|
||||
}
|
||||
b = b[n:]
|
||||
}
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func unmarshalPayloadDetails(p *Payload, b []byte) error {
|
||||
for len(b) > 0 {
|
||||
num, typ, n := protowire.ConsumeTag(b)
|
||||
if n < 0 {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
b = b[n:]
|
||||
|
||||
// For known field numbers, reject any non-matching wire type as a
|
||||
// hard error rather than silently skipping. The caller will catch
|
||||
// missing-field cases downstream, but a wire-type mismatch on a tag
|
||||
// we know is a peer protocol violation worth flagging here.
|
||||
// Repeated occurrences of a singular field follow proto3 last-wins.
|
||||
switch num {
|
||||
case fieldCert:
|
||||
if typ != protowire.BytesType {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
v, n := protowire.ConsumeBytes(b)
|
||||
if n < 0 {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
p.Cert = append([]byte(nil), v...)
|
||||
b = b[n:]
|
||||
case fieldInitiatorIndex:
|
||||
if typ != protowire.VarintType {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
v, n := protowire.ConsumeVarint(b)
|
||||
if n < 0 || v > math.MaxUint32 {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
p.InitiatorIndex = uint32(v)
|
||||
b = b[n:]
|
||||
case fieldResponderIndex:
|
||||
if typ != protowire.VarintType {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
v, n := protowire.ConsumeVarint(b)
|
||||
if n < 0 || v > math.MaxUint32 {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
p.ResponderIndex = uint32(v)
|
||||
b = b[n:]
|
||||
case fieldTime:
|
||||
if typ != protowire.VarintType {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
v, n := protowire.ConsumeVarint(b)
|
||||
if n < 0 {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
p.Time = v
|
||||
b = b[n:]
|
||||
case fieldCertVersion:
|
||||
if typ != protowire.VarintType {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
v, n := protowire.ConsumeVarint(b)
|
||||
if n < 0 || v > math.MaxUint32 {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
p.CertVersion = uint32(v)
|
||||
b = b[n:]
|
||||
default:
|
||||
n := protowire.ConsumeFieldValue(num, typ, b)
|
||||
if n < 0 {
|
||||
return errInvalidHandshakeDetails
|
||||
}
|
||||
b = b[n:]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
361
handshake/payload_test.go
Normal file
361
handshake/payload_test.go
Normal file
@@ -0,0 +1,361 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/encoding/protowire"
|
||||
)
|
||||
|
||||
func TestPayloadRoundTrip(t *testing.T) {
|
||||
t.Run("all fields set", func(t *testing.T) {
|
||||
data := MarshalPayload(nil, Payload{
|
||||
Cert: []byte("test-cert-bytes"),
|
||||
CertVersion: 2,
|
||||
InitiatorIndex: 12345,
|
||||
ResponderIndex: 67890,
|
||||
Time: 1234567890,
|
||||
})
|
||||
|
||||
got, err := UnmarshalPayload(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, []byte("test-cert-bytes"), got.Cert)
|
||||
assert.Equal(t, uint32(12345), got.InitiatorIndex)
|
||||
assert.Equal(t, uint32(67890), got.ResponderIndex)
|
||||
assert.Equal(t, uint64(1234567890), got.Time)
|
||||
assert.Equal(t, uint32(2), got.CertVersion)
|
||||
})
|
||||
|
||||
t.Run("minimal fields", func(t *testing.T) {
|
||||
data := MarshalPayload(nil, Payload{InitiatorIndex: 1})
|
||||
|
||||
got, err := UnmarshalPayload(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, uint32(1), got.InitiatorIndex)
|
||||
assert.Equal(t, uint32(0), got.ResponderIndex)
|
||||
assert.Equal(t, uint64(0), got.Time)
|
||||
assert.Nil(t, got.Cert)
|
||||
})
|
||||
|
||||
t.Run("empty payload", func(t *testing.T) {
|
||||
data := MarshalPayload(nil, Payload{})
|
||||
|
||||
got, err := UnmarshalPayload(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, uint32(0), got.InitiatorIndex)
|
||||
})
|
||||
|
||||
t.Run("large cert bytes", func(t *testing.T) {
|
||||
bigCert := make([]byte, 4096)
|
||||
for i := range bigCert {
|
||||
bigCert[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
data := MarshalPayload(nil, Payload{
|
||||
Cert: bigCert,
|
||||
CertVersion: 2,
|
||||
InitiatorIndex: 999,
|
||||
})
|
||||
|
||||
got, err := UnmarshalPayload(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, bigCert, got.Cert)
|
||||
assert.Equal(t, uint32(999), got.InitiatorIndex)
|
||||
})
|
||||
|
||||
t.Run("append to existing buffer", func(t *testing.T) {
|
||||
prefix := []byte("prefix")
|
||||
data := MarshalPayload(prefix, Payload{InitiatorIndex: 42})
|
||||
|
||||
assert.Equal(t, []byte("prefix"), data[:6])
|
||||
|
||||
got, err := UnmarshalPayload(data[6:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(42), got.InitiatorIndex)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPayloadUnknownFields(t *testing.T) {
|
||||
t.Run("unknown field in outer message is skipped", func(t *testing.T) {
|
||||
// Marshal a normal payload then append an unknown field (field 99, varint)
|
||||
data := MarshalPayload(nil, Payload{InitiatorIndex: 42})
|
||||
data = protowire.AppendTag(data, 99, protowire.VarintType)
|
||||
data = protowire.AppendVarint(data, 12345)
|
||||
|
||||
got, err := UnmarshalPayload(data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(42), got.InitiatorIndex)
|
||||
})
|
||||
|
||||
t.Run("unknown field in details is skipped", func(t *testing.T) {
|
||||
// Build details with a known field + unknown field
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 77)
|
||||
// Unknown field 50, varint
|
||||
details = protowire.AppendTag(details, 50, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 9999)
|
||||
// Another known field after the unknown one
|
||||
details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 88)
|
||||
|
||||
// Wrap in outer message
|
||||
var data []byte
|
||||
data = protowire.AppendTag(data, 1, protowire.BytesType)
|
||||
data = protowire.AppendBytes(data, details)
|
||||
|
||||
got, err := UnmarshalPayload(data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(77), got.InitiatorIndex)
|
||||
assert.Equal(t, uint32(88), got.ResponderIndex)
|
||||
})
|
||||
|
||||
t.Run("reserved fields 6 and 7 are skipped", func(t *testing.T) {
|
||||
// Fields 6 and 7 are reserved in the proto definition
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 100)
|
||||
details = protowire.AppendTag(details, 6, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 1)
|
||||
details = protowire.AppendTag(details, 7, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 2)
|
||||
|
||||
var data []byte
|
||||
data = protowire.AppendTag(data, 1, protowire.BytesType)
|
||||
data = protowire.AppendBytes(data, details)
|
||||
|
||||
got, err := UnmarshalPayload(data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(100), got.InitiatorIndex)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPayloadBytesConsumed(t *testing.T) {
|
||||
t.Run("all bytes consumed on valid input", func(t *testing.T) {
|
||||
original := Payload{
|
||||
Cert: []byte("cert"),
|
||||
CertVersion: 2,
|
||||
InitiatorIndex: 100,
|
||||
ResponderIndex: 200,
|
||||
Time: 999,
|
||||
}
|
||||
data := MarshalPayload(nil, original)
|
||||
|
||||
got, err := UnmarshalPayload(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Re-marshal and compare — proves we consumed and reproduced all fields
|
||||
remarshaled := MarshalPayload(nil, got)
|
||||
assert.Equal(t, data, remarshaled)
|
||||
})
|
||||
}
|
||||
|
||||
// wrapDetails wraps raw detail bytes in the outer NebulaHandshake envelope
|
||||
// so UnmarshalPayload can reach unmarshalPayloadDetails.
|
||||
func wrapDetails(details []byte) []byte {
|
||||
var out []byte
|
||||
out = protowire.AppendTag(out, 1, protowire.BytesType)
|
||||
out = protowire.AppendBytes(out, details)
|
||||
return out
|
||||
}
|
||||
|
||||
func TestPayloadUnmarshalErrors(t *testing.T) {
|
||||
t.Run("nil input", func(t *testing.T) {
|
||||
got, err := UnmarshalPayload(nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(0), got.InitiatorIndex)
|
||||
})
|
||||
|
||||
t.Run("truncated outer tag", func(t *testing.T) {
|
||||
_, err := UnmarshalPayload([]byte{0x80})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated outer details field", func(t *testing.T) {
|
||||
_, err := UnmarshalPayload([]byte{0x0a, 0x64, 0x01, 0x02, 0x03, 0x04, 0x05})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated outer unknown field", func(t *testing.T) {
|
||||
// Valid tag for unknown field 99 varint, but no value follows
|
||||
var data []byte
|
||||
data = protowire.AppendTag(data, 99, protowire.VarintType)
|
||||
_, err := UnmarshalPayload(data)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated details tag", func(t *testing.T) {
|
||||
_, err := UnmarshalPayload(wrapDetails([]byte{0x80}))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated cert bytes", func(t *testing.T) {
|
||||
// Field 1 (cert), bytes type, length 10 but only 2 bytes
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldCert, protowire.BytesType)
|
||||
details = append(details, 0x0a, 0x01, 0x02) // length 10, only 2 bytes
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated initiator index varint", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||
details = append(details, 0x80) // incomplete varint
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated responder index varint", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldResponderIndex, protowire.VarintType)
|
||||
details = append(details, 0x80)
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated time varint", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldTime, protowire.VarintType)
|
||||
details = append(details, 0x80)
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated cert version varint", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
|
||||
details = append(details, 0x80)
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("truncated unknown field in details", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, 50, protowire.VarintType)
|
||||
details = append(details, 0x80) // incomplete varint
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("cert with wrong wire type rejected", func(t *testing.T) {
|
||||
// fieldCert as Varint instead of Bytes.
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldCert, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 42)
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("initiator index with wrong wire type rejected", func(t *testing.T) {
|
||||
// fieldInitiatorIndex as Bytes instead of Varint.
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.BytesType)
|
||||
details = protowire.AppendBytes(details, []byte{1, 2, 3})
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("time with wrong wire type rejected", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldTime, protowire.BytesType)
|
||||
details = protowire.AppendBytes(details, []byte{1, 2, 3})
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("cert version with wrong wire type rejected", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldCertVersion, protowire.BytesType)
|
||||
details = protowire.AppendBytes(details, []byte{1, 2, 3})
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("repeated singular field follows proto3 last-wins", func(t *testing.T) {
|
||||
// Per proto3, multiple instances of a singular field are accepted and
|
||||
// the last value wins. We keep this behavior so that peers using
|
||||
// alternative encoders aren't rejected.
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 1)
|
||||
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, 42)
|
||||
got, err := UnmarshalPayload(wrapDetails(details))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(42), got.InitiatorIndex)
|
||||
})
|
||||
|
||||
t.Run("initiator index varint overflow rejected", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldInitiatorIndex, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, math.MaxUint32+1)
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("cert version varint overflow rejected", func(t *testing.T) {
|
||||
var details []byte
|
||||
details = protowire.AppendTag(details, fieldCertVersion, protowire.VarintType)
|
||||
details = protowire.AppendVarint(details, math.MaxUint32+1)
|
||||
_, err := UnmarshalPayload(wrapDetails(details))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// FuzzPayload feeds arbitrary bytes through UnmarshalPayload to confirm it
|
||||
// never panics, and for any input that parses cleanly, that re-marshal +
|
||||
// re-parse is a fix-point. Inputs come from an authenticated peer (post-
|
||||
// noise-decrypt), so the threat model is "valid peer behaving arbitrarily,"
|
||||
// not "unauthenticated injection."
|
||||
func FuzzPayload(f *testing.F) {
|
||||
// Seed corpus with a handful of known-good shapes.
|
||||
f.Add(MarshalPayload(nil, Payload{}))
|
||||
f.Add(MarshalPayload(nil, Payload{Cert: []byte{1, 2, 3}, CertVersion: 2}))
|
||||
f.Add(MarshalPayload(nil, Payload{InitiatorIndex: 42, Time: 1}))
|
||||
f.Add(MarshalPayload(nil, Payload{
|
||||
Cert: []byte("seed-cert"),
|
||||
InitiatorIndex: 1,
|
||||
ResponderIndex: 2,
|
||||
Time: 3,
|
||||
CertVersion: 2,
|
||||
}))
|
||||
f.Add([]byte{})
|
||||
f.Add([]byte{0xff})
|
||||
|
||||
f.Fuzz(func(t *testing.T, data []byte) {
|
||||
p1, err := UnmarshalPayload(data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// For any input that parses, re-marshaling and re-parsing must
|
||||
// yield an equivalent Payload. This catches dispatch bugs (e.g.
|
||||
// emitting a field on marshal that we don't accept on parse) and
|
||||
// any non-idempotent parsing behavior.
|
||||
b2 := MarshalPayload(nil, p1)
|
||||
p2, err := UnmarshalPayload(b2)
|
||||
if err != nil {
|
||||
t.Fatalf("re-parse of self-marshaled payload failed: %v\nintermediate: %x\n", err, b2)
|
||||
}
|
||||
if !payloadsEqual(p1, p2) {
|
||||
t.Fatalf("re-marshal not idempotent\nfirst: %+v\nsecond: %+v", p1, p2)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func payloadsEqual(a, b Payload) bool {
|
||||
return bytes.Equal(a.Cert, b.Cert) &&
|
||||
a.InitiatorIndex == b.InitiatorIndex &&
|
||||
a.ResponderIndex == b.ResponderIndex &&
|
||||
a.Time == b.Time &&
|
||||
a.CertVersion == b.CertVersion
|
||||
}
|
||||
813
handshake_ix.go
813
handshake_ix.go
@@ -1,813 +0,0 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/header"
|
||||
)
|
||||
|
||||
// NOISE IX Handshakes
|
||||
|
||||
// This function constructs a handshake packet, but does not actually send it
|
||||
// Sending is done by the handshake manager
|
||||
func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||
err := f.handshakeManager.allocateIndex(hh)
|
||||
if err != nil {
|
||||
f.l.Error("Failed to generate index",
|
||||
"error", err,
|
||||
"vpnAddrs", hh.hostinfo.vpnAddrs,
|
||||
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||
)
|
||||
return false
|
||||
}
|
||||
|
||||
cs := f.pki.getCertState()
|
||||
v := cs.initiatingVersion
|
||||
if hh.initiatingVersionOverride != cert.VersionPre1 {
|
||||
v = hh.initiatingVersionOverride
|
||||
} else if v < cert.Version2 {
|
||||
// If we're connecting to a v6 address we should encourage use of a V2 cert
|
||||
for _, a := range hh.hostinfo.vpnAddrs {
|
||||
if a.Is6() {
|
||||
v = cert.Version2
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
crt := cs.getCertificate(v)
|
||||
if crt == nil {
|
||||
f.l.Error("Unable to handshake with host because no certificate is available",
|
||||
"vpnAddrs", hh.hostinfo.vpnAddrs,
|
||||
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||
"certVersion", v,
|
||||
)
|
||||
return false
|
||||
}
|
||||
|
||||
crtHs := cs.getHandshakeBytes(v)
|
||||
if crtHs == nil {
|
||||
f.l.Error("Unable to handshake with host because no certificate handshake bytes is available",
|
||||
"vpnAddrs", hh.hostinfo.vpnAddrs,
|
||||
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||
"certVersion", v,
|
||||
)
|
||||
return false
|
||||
}
|
||||
|
||||
ci, err := NewConnectionState(cs, crt, true, noise.HandshakeIX)
|
||||
if err != nil {
|
||||
f.l.Error("Failed to create connection state",
|
||||
"error", err,
|
||||
"vpnAddrs", hh.hostinfo.vpnAddrs,
|
||||
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||
"certVersion", v,
|
||||
)
|
||||
return false
|
||||
}
|
||||
hh.hostinfo.ConnectionState = ci
|
||||
|
||||
hs := &NebulaHandshake{
|
||||
Details: &NebulaHandshakeDetails{
|
||||
InitiatorIndex: hh.hostinfo.localIndexId,
|
||||
Time: uint64(time.Now().UnixNano()),
|
||||
Cert: crtHs,
|
||||
CertVersion: uint32(v),
|
||||
},
|
||||
}
|
||||
|
||||
hsBytes, err := hs.Marshal()
|
||||
if err != nil {
|
||||
f.l.Error("Failed to marshal handshake message",
|
||||
"error", err,
|
||||
"vpnAddrs", hh.hostinfo.vpnAddrs,
|
||||
"certVersion", v,
|
||||
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||
)
|
||||
return false
|
||||
}
|
||||
|
||||
h := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, 0, 1)
|
||||
|
||||
msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
|
||||
if err != nil {
|
||||
f.l.Error("Failed to call noise.WriteMessage",
|
||||
"error", err,
|
||||
"vpnAddrs", hh.hostinfo.vpnAddrs,
|
||||
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||
)
|
||||
return false
|
||||
}
|
||||
|
||||
// We are sending handshake packet 1, so we don't expect to receive
|
||||
// handshake packet 1 from the responder
|
||||
ci.window.Update(f.l, 1)
|
||||
|
||||
hh.hostinfo.HandshakePacket[0] = msg
|
||||
hh.ready = true
|
||||
return true
|
||||
}
|
||||
|
||||
func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) {
|
||||
cs := f.pki.getCertState()
|
||||
crt := cs.GetDefaultCertificate()
|
||||
if crt == nil {
|
||||
f.l.Error("Unable to handshake with host because no certificate is available",
|
||||
"from", via,
|
||||
"handshake", m{"stage": 0, "style": "ix_psk0"},
|
||||
"certVersion", cs.initiatingVersion,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
ci, err := NewConnectionState(cs, crt, false, noise.HandshakeIX)
|
||||
if err != nil {
|
||||
f.l.Error("Failed to create connection state",
|
||||
"error", err,
|
||||
"from", via,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Mark packet 1 as seen so it doesn't show up as missed
|
||||
ci.window.Update(f.l, 1)
|
||||
|
||||
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
||||
if err != nil {
|
||||
f.l.Error("Failed to call noise.ReadMessage",
|
||||
"error", err,
|
||||
"from", via,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
hs := &NebulaHandshake{}
|
||||
err = hs.Unmarshal(msg)
|
||||
if err != nil || hs.Details == nil {
|
||||
f.l.Error("Failed unmarshal handshake message",
|
||||
"error", err,
|
||||
"from", via,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
||||
if err != nil {
|
||||
f.l.Info("Handshake did not contain a certificate",
|
||||
"error", err,
|
||||
"from", via,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
|
||||
if err != nil {
|
||||
fp, fperr := rc.Fingerprint()
|
||||
if fperr != nil {
|
||||
fp = "<error generating certificate fingerprint>"
|
||||
}
|
||||
|
||||
attrs := []slog.Attr{
|
||||
slog.Any("error", err),
|
||||
slog.Any("from", via),
|
||||
slog.Any("handshake", m{"stage": 1, "style": "ix_psk0"}),
|
||||
slog.Any("certVpnNetworks", rc.Networks()),
|
||||
slog.String("certFingerprint", fp),
|
||||
}
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
attrs = append(attrs, slog.Any("cert", rc))
|
||||
}
|
||||
|
||||
// LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that
|
||||
// callers grow conditionally, which has no pair-form equivalent.
|
||||
//nolint:sloglint
|
||||
f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...)
|
||||
return
|
||||
}
|
||||
|
||||
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
|
||||
f.l.Info("public key mismatch between certificate and handshake",
|
||||
"from", via,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
"cert", remoteCert,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if remoteCert.Certificate.Version() != ci.myCert.Version() {
|
||||
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
|
||||
myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version())
|
||||
if myCertOtherVersion == nil {
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
f.l.Debug("Might be unable to handshake with host due to missing certificate version",
|
||||
"error", err,
|
||||
"from", via,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
"cert", remoteCert,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// Record the certificate we are actually using
|
||||
ci.myCert = myCertOtherVersion
|
||||
}
|
||||
}
|
||||
|
||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||
f.l.Info("No networks in certificate",
|
||||
"error", err,
|
||||
"from", via,
|
||||
"cert", remoteCert,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
certName := remoteCert.Certificate.Name()
|
||||
certVersion := remoteCert.Certificate.Version()
|
||||
fingerprint := remoteCert.Fingerprint
|
||||
issuer := remoteCert.Certificate.Issuer()
|
||||
vpnNetworks := remoteCert.Certificate.Networks()
|
||||
|
||||
anyVpnAddrsInCommon := false
|
||||
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
||||
for i, network := range vpnNetworks {
|
||||
if f.myVpnAddrsTable.Contains(network.Addr()) {
|
||||
f.l.Error("Refusing to handshake with myself",
|
||||
"vpnNetworks", vpnNetworks,
|
||||
"from", via,
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"fingerprint", fingerprint,
|
||||
"issuer", issuer,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
return
|
||||
}
|
||||
vpnAddrs[i] = network.Addr()
|
||||
if f.myVpnNetworksTable.Contains(network.Addr()) {
|
||||
anyVpnAddrsInCommon = true
|
||||
}
|
||||
}
|
||||
|
||||
if !via.IsRelayed {
|
||||
// We only want to apply the remote allow list for direct tunnels here
|
||||
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) {
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
f.l.Debug("lighthouse.remote_allow_list denied incoming handshake",
|
||||
"vpnAddrs", vpnAddrs,
|
||||
"from", via,
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
myIndex, err := generateIndex(f.l)
|
||||
if err != nil {
|
||||
f.l.Error("Failed to generate index",
|
||||
"error", err,
|
||||
"vpnAddrs", vpnAddrs,
|
||||
"from", via,
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"fingerprint", fingerprint,
|
||||
"issuer", issuer,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
hostinfo := &HostInfo{
|
||||
ConnectionState: ci,
|
||||
localIndexId: myIndex,
|
||||
remoteIndexId: hs.Details.InitiatorIndex,
|
||||
vpnAddrs: vpnAddrs,
|
||||
HandshakePacket: make(map[uint8][]byte, 0),
|
||||
lastHandshakeTime: hs.Details.Time,
|
||||
relayState: RelayState{
|
||||
relays: nil,
|
||||
relayForByAddr: map[netip.Addr]*Relay{},
|
||||
relayForByIdx: map[uint32]*Relay{},
|
||||
},
|
||||
}
|
||||
|
||||
msgRxL := f.l.With(
|
||||
"vpnAddrs", vpnAddrs,
|
||||
"from", via,
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"fingerprint", fingerprint,
|
||||
"issuer", issuer,
|
||||
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||
"responderIndex", hs.Details.ResponderIndex,
|
||||
"remoteIndex", h.RemoteIndex,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
|
||||
if anyVpnAddrsInCommon {
|
||||
msgRxL.Info("Handshake message received")
|
||||
} else {
|
||||
//todo warn if not lighthouse or relay?
|
||||
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
|
||||
}
|
||||
|
||||
hs.Details.ResponderIndex = myIndex
|
||||
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
|
||||
if hs.Details.Cert == nil {
|
||||
msgRxL.Error("Unable to handshake with host because no certificate handshake bytes is available",
|
||||
"myCertVersion", ci.myCert.Version(),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
hs.Details.CertVersion = uint32(ci.myCert.Version())
|
||||
// Update the time in case their clock is way off from ours
|
||||
hs.Details.Time = uint64(time.Now().UnixNano())
|
||||
|
||||
hsBytes, err := hs.Marshal()
|
||||
if err != nil {
|
||||
f.l.Error("Failed to marshal handshake message",
|
||||
"error", err,
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"from", via,
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"fingerprint", fingerprint,
|
||||
"issuer", issuer,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
|
||||
msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
|
||||
if err != nil {
|
||||
f.l.Error("Failed to call noise.WriteMessage",
|
||||
"error", err,
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"from", via,
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"fingerprint", fingerprint,
|
||||
"issuer", issuer,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
return
|
||||
} else if dKey == nil || eKey == nil {
|
||||
f.l.Error("Noise did not arrive at a key",
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"from", via,
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"fingerprint", fingerprint,
|
||||
"issuer", issuer,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
hostinfo.HandshakePacket[0] = make([]byte, len(packet[header.Len:]))
|
||||
copy(hostinfo.HandshakePacket[0], packet[header.Len:])
|
||||
|
||||
// Regardless of whether you are the sender or receiver, you should arrive here
|
||||
// and complete standing up the connection.
|
||||
hostinfo.HandshakePacket[2] = make([]byte, len(msg))
|
||||
copy(hostinfo.HandshakePacket[2], msg)
|
||||
|
||||
// We are sending handshake packet 2, so we don't expect to receive
|
||||
// handshake packet 2 from the initiator.
|
||||
ci.window.Update(f.l, 2)
|
||||
|
||||
ci.peerCert = remoteCert
|
||||
ci.dKey = NewNebulaCipherState(dKey)
|
||||
ci.eKey = NewNebulaCipherState(eKey)
|
||||
|
||||
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
|
||||
if !via.IsRelayed {
|
||||
hostinfo.SetRemote(via.UdpAddr)
|
||||
}
|
||||
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
|
||||
|
||||
existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case ErrAlreadySeen:
|
||||
// Update remote if preferred
|
||||
if existing.SetRemoteIfPreferred(f.hostMap, via) {
|
||||
// Send a test packet to ensure the other side has also switched to
|
||||
// the preferred remote
|
||||
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||
}
|
||||
|
||||
msg = existing.HandshakePacket[2]
|
||||
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
|
||||
if !via.IsRelayed {
|
||||
err := f.outside.WriteTo(msg, via.UdpAddr)
|
||||
if err != nil {
|
||||
f.l.Error("Failed to send handshake message",
|
||||
"vpnAddrs", existing.vpnAddrs,
|
||||
"from", via,
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
"cached", true,
|
||||
"error", err,
|
||||
)
|
||||
} else {
|
||||
f.l.Info("Handshake message sent",
|
||||
"vpnAddrs", existing.vpnAddrs,
|
||||
"from", via,
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
"cached", true,
|
||||
)
|
||||
}
|
||||
return
|
||||
} else {
|
||||
if via.relay == nil {
|
||||
f.l.Error("Handshake send failed: both addr and via.relay are nil.")
|
||||
return
|
||||
}
|
||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
||||
f.l.Info("Handshake message sent",
|
||||
"vpnAddrs", existing.vpnAddrs,
|
||||
"relay", via.relayHI.vpnAddrs[0],
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
"cached", true,
|
||||
)
|
||||
return
|
||||
}
|
||||
case ErrExistingHostInfo:
|
||||
// This means there was an existing tunnel and this handshake was older than the one we are currently based on
|
||||
f.l.Info("Handshake too old",
|
||||
"vpnAddrs", vpnAddrs,
|
||||
"from", via,
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"oldHandshakeTime", existing.lastHandshakeTime,
|
||||
"newHandshakeTime", hostinfo.lastHandshakeTime,
|
||||
"fingerprint", fingerprint,
|
||||
"issuer", issuer,
|
||||
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||
"responderIndex", hs.Details.ResponderIndex,
|
||||
"remoteIndex", h.RemoteIndex,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
|
||||
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
|
||||
f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||
return
|
||||
case ErrLocalIndexCollision:
|
||||
// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
|
||||
f.l.Error("Failed to add HostInfo due to localIndex collision",
|
||||
"vpnAddrs", vpnAddrs,
|
||||
"from", via,
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"fingerprint", fingerprint,
|
||||
"issuer", issuer,
|
||||
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||
"responderIndex", hs.Details.ResponderIndex,
|
||||
"remoteIndex", h.RemoteIndex,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
"localIndex", hostinfo.localIndexId,
|
||||
"collision", existing.vpnAddrs,
|
||||
)
|
||||
return
|
||||
default:
|
||||
// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
|
||||
// And we forget to update it here
|
||||
f.l.Error("Failed to add HostInfo to HostMap",
|
||||
"error", err,
|
||||
"vpnAddrs", vpnAddrs,
|
||||
"from", via,
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"fingerprint", fingerprint,
|
||||
"issuer", issuer,
|
||||
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||
"responderIndex", hs.Details.ResponderIndex,
|
||||
"remoteIndex", h.RemoteIndex,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Do the send
|
||||
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
|
||||
if !via.IsRelayed {
|
||||
err = f.outside.WriteTo(msg, via.UdpAddr)
|
||||
log := f.l.With(
|
||||
"vpnAddrs", vpnAddrs,
|
||||
"from", via,
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"fingerprint", fingerprint,
|
||||
"issuer", issuer,
|
||||
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||
"responderIndex", hs.Details.ResponderIndex,
|
||||
"remoteIndex", h.RemoteIndex,
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
)
|
||||
if err != nil {
|
||||
log.Error("Failed to send handshake", "error", err)
|
||||
} else {
|
||||
log.Info("Handshake message sent")
|
||||
}
|
||||
} else {
|
||||
if via.relay == nil {
|
||||
f.l.Error("Handshake send failed: both addr and via.relay are nil.")
|
||||
return
|
||||
}
|
||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||
// I successfully received a handshake. Just in case I marked this tunnel as 'Disestablished', ensure
|
||||
// it's correctly marked as working.
|
||||
via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established)
|
||||
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
||||
f.l.Info("Handshake message sent",
|
||||
"vpnAddrs", vpnAddrs,
|
||||
"relay", via.relayHI.vpnAddrs[0],
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"fingerprint", fingerprint,
|
||||
"issuer", issuer,
|
||||
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||
"responderIndex", hs.Details.ResponderIndex,
|
||||
"remoteIndex", h.RemoteIndex,
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
)
|
||||
}
|
||||
|
||||
f.connectionManager.AddTrafficWatch(hostinfo)
|
||||
|
||||
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
|
||||
|
||||
// Don't wait for UpdateWorker
|
||||
if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) {
|
||||
f.lightHouse.TriggerUpdate()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
|
||||
if hh == nil {
|
||||
// Nothing here to tear down, got a bogus stage 2 packet
|
||||
return true
|
||||
}
|
||||
|
||||
hh.Lock()
|
||||
defer hh.Unlock()
|
||||
|
||||
hostinfo := hh.hostinfo
|
||||
if !via.IsRelayed {
|
||||
// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
|
||||
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
f.l.Debug("lighthouse.remote_allow_list denied incoming handshake",
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"from", via,
|
||||
)
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
ci := hostinfo.ConnectionState
|
||||
msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
||||
if err != nil {
|
||||
f.l.Error("Failed to call noise.ReadMessage",
|
||||
"error", err,
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"from", via,
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
"header", h,
|
||||
)
|
||||
|
||||
// We don't want to tear down the connection on a bad ReadMessage because it could be an attacker trying
|
||||
// to DOS us. Every other error condition after should to allow a possible good handshake to complete in the
|
||||
// near future
|
||||
return false
|
||||
} else if dKey == nil || eKey == nil {
|
||||
f.l.Error("Noise did not arrive at a key",
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"from", via,
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
)
|
||||
|
||||
// This should be impossible in IX but just in case, if we get here then there is no chance to recover
|
||||
// the handshake state machine. Tear it down
|
||||
return true
|
||||
}
|
||||
|
||||
hs := &NebulaHandshake{}
|
||||
err = hs.Unmarshal(msg)
|
||||
if err != nil || hs.Details == nil {
|
||||
f.l.Error("Failed unmarshal handshake message",
|
||||
"error", err,
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"from", via,
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
)
|
||||
|
||||
// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
|
||||
return true
|
||||
}
|
||||
|
||||
rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
|
||||
if err != nil {
|
||||
f.l.Info("Handshake did not contain a certificate",
|
||||
"error", err,
|
||||
"from", via,
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
)
|
||||
return true
|
||||
}
|
||||
|
||||
remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc)
|
||||
if err != nil {
|
||||
fp, err := rc.Fingerprint()
|
||||
if err != nil {
|
||||
fp = "<error generating certificate fingerprint>"
|
||||
}
|
||||
|
||||
attrs := []slog.Attr{
|
||||
slog.Any("error", err),
|
||||
slog.Any("from", via),
|
||||
slog.Any("vpnAddrs", hostinfo.vpnAddrs),
|
||||
slog.Any("handshake", m{"stage": 2, "style": "ix_psk0"}),
|
||||
slog.String("certFingerprint", fp),
|
||||
slog.Any("certVpnNetworks", rc.Networks()),
|
||||
}
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
attrs = append(attrs, slog.Any("cert", rc))
|
||||
}
|
||||
|
||||
// LogAttrs is intentional: attrs is a pre-built []slog.Attr slice that
|
||||
// callers grow conditionally, which has no pair-form equivalent.
|
||||
//nolint:sloglint
|
||||
f.l.LogAttrs(context.Background(), slog.LevelInfo, "Invalid certificate from host", attrs...)
|
||||
return true
|
||||
}
|
||||
if !bytes.Equal(remoteCert.Certificate.PublicKey(), ci.H.PeerStatic()) {
|
||||
f.l.Info("public key mismatch between certificate and handshake",
|
||||
"from", via,
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
"cert", remoteCert,
|
||||
)
|
||||
return true
|
||||
}
|
||||
|
||||
if len(remoteCert.Certificate.Networks()) == 0 {
|
||||
f.l.Info("No networks in certificate",
|
||||
"error", err,
|
||||
"from", via,
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"cert", remoteCert,
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
)
|
||||
return true
|
||||
}
|
||||
|
||||
vpnNetworks := remoteCert.Certificate.Networks()
|
||||
certName := remoteCert.Certificate.Name()
|
||||
certVersion := remoteCert.Certificate.Version()
|
||||
fingerprint := remoteCert.Fingerprint
|
||||
issuer := remoteCert.Certificate.Issuer()
|
||||
|
||||
hostinfo.remoteIndexId = hs.Details.ResponderIndex
|
||||
hostinfo.lastHandshakeTime = hs.Details.Time
|
||||
|
||||
// Store their cert and our symmetric keys
|
||||
ci.peerCert = remoteCert
|
||||
ci.dKey = NewNebulaCipherState(dKey)
|
||||
ci.eKey = NewNebulaCipherState(eKey)
|
||||
|
||||
// Make sure the current udpAddr being used is set for responding
|
||||
if !via.IsRelayed {
|
||||
hostinfo.SetRemote(via.UdpAddr)
|
||||
} else {
|
||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||
}
|
||||
|
||||
correctHostResponded := false
|
||||
anyVpnAddrsInCommon := false
|
||||
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
||||
for i, network := range vpnNetworks {
|
||||
vpnAddrs[i] = network.Addr()
|
||||
if f.myVpnNetworksTable.Contains(network.Addr()) {
|
||||
anyVpnAddrsInCommon = true
|
||||
}
|
||||
if hostinfo.vpnAddrs[0] == network.Addr() {
|
||||
// todo is it more correct to see if any of hostinfo.vpnAddrs are in the cert? it should have len==1, but one day it might not?
|
||||
correctHostResponded = true
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure the right host responded
|
||||
if !correctHostResponded {
|
||||
f.l.Info("Incorrect host responded to handshake",
|
||||
"intendedVpnAddrs", hostinfo.vpnAddrs,
|
||||
"haveVpnNetworks", vpnNetworks,
|
||||
"from", via,
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
)
|
||||
|
||||
// Release our old handshake from pending, it should not continue
|
||||
f.handshakeManager.DeleteHostInfo(hostinfo)
|
||||
|
||||
// Create a new hostinfo/handshake for the intended vpn ip
|
||||
//TODO is hostinfo.vpnAddrs[0] always the address to use?
|
||||
f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
|
||||
// Block the current used address
|
||||
newHH.hostinfo.remotes = hostinfo.remotes
|
||||
newHH.hostinfo.remotes.BlockRemote(via)
|
||||
|
||||
f.l.Info("Blocked addresses for handshakes",
|
||||
"blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes(),
|
||||
"vpnNetworks", vpnNetworks,
|
||||
"remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges()),
|
||||
)
|
||||
|
||||
// Swap the packet store to benefit the original intended recipient
|
||||
newHH.packetStore = hh.packetStore
|
||||
hh.packetStore = []*cachedPacket{}
|
||||
|
||||
// Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down
|
||||
hostinfo.vpnAddrs = vpnAddrs
|
||||
f.sendCloseTunnel(hostinfo)
|
||||
})
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Mark packet 2 as seen so it doesn't show up as missed
|
||||
ci.window.Update(f.l, 2)
|
||||
|
||||
duration := time.Since(hh.startTime).Nanoseconds()
|
||||
msgRxL := f.l.With(
|
||||
"vpnAddrs", vpnAddrs,
|
||||
"from", via,
|
||||
"certName", certName,
|
||||
"certVersion", certVersion,
|
||||
"fingerprint", fingerprint,
|
||||
"issuer", issuer,
|
||||
"initiatorIndex", hs.Details.InitiatorIndex,
|
||||
"responderIndex", hs.Details.ResponderIndex,
|
||||
"remoteIndex", h.RemoteIndex,
|
||||
"handshake", m{"stage": 2, "style": "ix_psk0"},
|
||||
"durationNs", duration,
|
||||
"sentCachedPackets", len(hh.packetStore),
|
||||
)
|
||||
if anyVpnAddrsInCommon {
|
||||
msgRxL.Info("Handshake message received")
|
||||
} else {
|
||||
//todo warn if not lighthouse or relay?
|
||||
msgRxL.Info("Handshake message received, but no vpnNetworks in common.")
|
||||
}
|
||||
|
||||
// Build up the radix for the firewall if we have subnets in the cert
|
||||
hostinfo.vpnAddrs = vpnAddrs
|
||||
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
|
||||
|
||||
// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
|
||||
f.handshakeManager.Complete(hostinfo, f)
|
||||
f.connectionManager.AddTrafficWatch(hostinfo)
|
||||
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
hostinfo.logger(f.l).Debug("Sending stored packets",
|
||||
"count", len(hh.packetStore),
|
||||
)
|
||||
}
|
||||
|
||||
if len(hh.packetStore) > 0 {
|
||||
nb := make([]byte, 12, 12)
|
||||
out := make([]byte, mtu)
|
||||
for _, cp := range hh.packetStore {
|
||||
cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out)
|
||||
}
|
||||
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
|
||||
}
|
||||
|
||||
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
|
||||
f.metricHandshakes.Update(duration)
|
||||
|
||||
// Don't wait for UpdateWorker
|
||||
if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) {
|
||||
f.lightHouse.TriggerUpdate()
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/handshake"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/udp"
|
||||
)
|
||||
@@ -23,6 +24,18 @@ const (
|
||||
DefaultHandshakeRetries = 10
|
||||
DefaultHandshakeTriggerBuffer = 64
|
||||
DefaultUseRelays = true
|
||||
|
||||
// maxCachedPackets is how many unsent packets we'll buffer per pending
|
||||
// handshake before dropping further ones.
|
||||
maxCachedPackets = 100
|
||||
|
||||
// HandshakePacket map keys mirror the IX protocol stage convention:
|
||||
// stage 0 = the initiator's first message (and what the responder
|
||||
// receives, stripped of header)
|
||||
// stage 2 = the responder's reply
|
||||
// Other handshake patterns will need new keys when added.
|
||||
handshakePacketStage0 uint8 = 0
|
||||
handshakePacketStage2 uint8 = 2
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -76,10 +89,11 @@ type HandshakeHostInfo struct {
|
||||
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
||||
|
||||
hostinfo *HostInfo
|
||||
machine *handshake.Machine // The handshake state machine, set during stage 0 (initiator) or beginHandshake (responder multi-message)
|
||||
}
|
||||
|
||||
func (hh *HandshakeHostInfo) cachePacket(l *slog.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
|
||||
if len(hh.packetStore) < 100 {
|
||||
if len(hh.packetStore) < maxCachedPackets {
|
||||
tempPacket := make([]byte, len(packet))
|
||||
copy(tempPacket, packet)
|
||||
|
||||
@@ -137,6 +151,18 @@ func (hm *HandshakeManager) Run(ctx context.Context) {
|
||||
}
|
||||
|
||||
func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *header.H) {
|
||||
// Gate on known handshake subtypes. Unknown subtypes (or future ones we
|
||||
// don't yet support) are dropped here rather than silently routed through
|
||||
// the IX path. Add a case when introducing a new pattern.
|
||||
switch h.Subtype {
|
||||
case header.HandshakeIXPSK0:
|
||||
// supported
|
||||
default:
|
||||
hm.l.Debug("dropping handshake with unsupported subtype",
|
||||
"from", via, "subtype", h.Subtype)
|
||||
return
|
||||
}
|
||||
|
||||
// First remote allow list check before we know the vpnIp
|
||||
if !via.IsRelayed {
|
||||
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) {
|
||||
@@ -145,19 +171,27 @@ func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *head
|
||||
}
|
||||
}
|
||||
|
||||
switch h.Subtype {
|
||||
case header.HandshakeIXPSK0:
|
||||
switch h.MessageCounter {
|
||||
case 1:
|
||||
ixHandshakeStage1(hm.f, via, packet, h)
|
||||
// First message of a new handshake. The wire format requires RemoteIndex
|
||||
// to be zero here (the initiator has no responder index to fill in yet),
|
||||
// and generateIndex never allocates 0, so any non-zero RemoteIndex on a
|
||||
// stage-1 packet is malformed or someone probing for an index collision.
|
||||
// Drop without paying the cost of running noise on a pending Machine.
|
||||
if h.MessageCounter == 1 {
|
||||
if h.RemoteIndex != 0 {
|
||||
hm.l.Debug("dropping stage-1 handshake with non-zero RemoteIndex",
|
||||
"from", via, "remoteIndex", h.RemoteIndex)
|
||||
return
|
||||
}
|
||||
hm.beginHandshake(via, packet, h)
|
||||
return
|
||||
}
|
||||
|
||||
case 2:
|
||||
newHostinfo := hm.queryIndex(h.RemoteIndex)
|
||||
tearDown := ixHandshakeStage2(hm.f, via, newHostinfo, packet, h)
|
||||
if tearDown && newHostinfo != nil {
|
||||
hm.DeleteHostInfo(newHostinfo.hostinfo)
|
||||
}
|
||||
}
|
||||
// Continuation message must match a pending handshake by index.
|
||||
// Anything else is an orphaned packet (e.g., late retransmit after
|
||||
// timeout) and is dropped.
|
||||
if hh := hm.queryIndex(h.RemoteIndex); hh != nil {
|
||||
hm.continueHandshake(via, hh, packet)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,13 +217,22 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
hostinfo := hh.hostinfo
|
||||
// If we are out of time, clean up
|
||||
if hh.counter >= hm.config.retries {
|
||||
hh.hostinfo.logger(hm.l).Info("Handshake timed out",
|
||||
fields := []any{
|
||||
"udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()),
|
||||
"initiatorIndex", hh.hostinfo.localIndexId,
|
||||
"remoteIndex", hh.hostinfo.remoteIndexId,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
"durationNs", time.Since(hh.startTime).Nanoseconds(),
|
||||
)
|
||||
}
|
||||
// hh.machine can be nil here if buildStage0Packet never succeeded
|
||||
// (e.g., no certificate available). In that case there's no useful
|
||||
// handshake metadata to log.
|
||||
if hh.machine != nil {
|
||||
fields = append(fields, "handshake", m{
|
||||
"stage": uint64(hh.machine.MessageIndex()),
|
||||
"style": header.SubTypeName(header.Handshake, hh.machine.Subtype()),
|
||||
})
|
||||
}
|
||||
hh.hostinfo.logger(hm.l).Info("Handshake timed out", fields...)
|
||||
hm.metricTimedOut.Inc(1)
|
||||
hm.DeleteHostInfo(hostinfo)
|
||||
return
|
||||
@@ -200,12 +243,25 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
|
||||
// Check if we have a handshake packet to transmit yet
|
||||
if !hh.ready {
|
||||
if !ixHandshakeStage0(hm.f, hh) {
|
||||
if !hm.buildStage0Packet(hh) {
|
||||
hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval*time.Duration(hh.counter))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: this hardcodes "always retransmit stage 0", which is correct for
|
||||
// IX (the initiator only ever sends one packet, msg1) but wrong the
|
||||
// moment a 3+ message pattern lands. The retry loop should resend the
|
||||
// most recent outgoing message, not always stage 0. That implies
|
||||
// HandshakeHostInfo tracking a single "currentOutbound" packet (bytes +
|
||||
// header metadata) that gets replaced as the handshake progresses,
|
||||
// instead of indexing into HandshakePacket.
|
||||
stage0 := hostinfo.HandshakePacket[handshakePacketStage0]
|
||||
hsFields := m{
|
||||
"stage": uint64(hh.machine.MessageIndex()),
|
||||
"style": header.SubTypeName(header.Handshake, hh.machine.Subtype()),
|
||||
}
|
||||
|
||||
// Get a remotes object if we don't already have one.
|
||||
// This is mainly to protect us as this should never be the case
|
||||
// NB ^ This comment doesn't jive. It's how the thing gets initialized.
|
||||
@@ -239,13 +295,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
|
||||
var sentTo []netip.AddrPort
|
||||
hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr netip.AddrPort, _ bool) {
|
||||
hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
|
||||
err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
|
||||
hm.messageMetrics.Tx(header.Handshake, hh.machine.Subtype(), 1)
|
||||
err := hm.outside.WriteTo(stage0, addr)
|
||||
if err != nil {
|
||||
hostinfo.logger(hm.l).Error("Failed to send handshake message",
|
||||
"udpAddr", addr,
|
||||
"initiatorIndex", hostinfo.localIndexId,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
"handshake", hsFields,
|
||||
"error", err,
|
||||
)
|
||||
|
||||
@@ -260,13 +316,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
hostinfo.logger(hm.l).Info("Handshake message sent",
|
||||
"udpAddrs", sentTo,
|
||||
"initiatorIndex", hostinfo.localIndexId,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
"handshake", hsFields,
|
||||
)
|
||||
} else if hm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
hostinfo.logger(hm.l).Debug("Handshake message sent",
|
||||
"udpAddrs", sentTo,
|
||||
"initiatorIndex", hostinfo.localIndexId,
|
||||
"handshake", m{"stage": 1, "style": "ix_psk0"},
|
||||
"handshake", hsFields,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -348,7 +404,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
|
||||
switch existingRelay.State {
|
||||
case Established:
|
||||
hostinfo.logger(hm.l).Info("Send handshake via relay", "relay", relay.String())
|
||||
hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
|
||||
hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[handshakePacketStage0], make([]byte, 12), make([]byte, mtu), false)
|
||||
case Disestablished:
|
||||
// Mark this relay as 'requested'
|
||||
relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested)
|
||||
@@ -587,7 +643,7 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
|
||||
// allocateIndex generates a unique localIndexId for this HostInfo
|
||||
// and adds it to the pendingHostMap. Will error if we are unable to generate
|
||||
// a unique localIndexId
|
||||
func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error {
|
||||
func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) (uint32, error) {
|
||||
hm.mainHostMap.RLock()
|
||||
defer hm.mainHostMap.RUnlock()
|
||||
hm.Lock()
|
||||
@@ -596,7 +652,7 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error {
|
||||
for range 32 {
|
||||
index, err := generateIndex(hm.l)
|
||||
if err != nil {
|
||||
return err
|
||||
return 0, err
|
||||
}
|
||||
|
||||
_, inPending := hm.indexes[index]
|
||||
@@ -605,11 +661,11 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error {
|
||||
if !inMain && !inPending {
|
||||
hh.hostinfo.localIndexId = index
|
||||
hm.indexes[index] = hh
|
||||
return nil
|
||||
return index, nil
|
||||
}
|
||||
}
|
||||
|
||||
return errors.New("failed to generate unique localIndexId")
|
||||
return 0, errors.New("failed to generate unique localIndexId")
|
||||
}
|
||||
|
||||
func (hm *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
|
||||
@@ -728,3 +784,524 @@ func generateIndex(l *slog.Logger) (uint32, error) {
|
||||
func hsTimeout(tries int64, interval time.Duration) time.Duration {
|
||||
return time.Duration(tries / 2 * ((2 * int64(interval)) + (tries-1)*int64(interval)))
|
||||
}
|
||||
|
||||
// buildStage0Packet creates the initial handshake packet for the initiator.
|
||||
func (hm *HandshakeManager) buildStage0Packet(hh *HandshakeHostInfo) bool {
|
||||
cs := hm.f.pki.getCertState()
|
||||
v := cs.DefaultVersion()
|
||||
if hh.initiatingVersionOverride != cert.VersionPre1 {
|
||||
v = hh.initiatingVersionOverride
|
||||
} else if v < cert.Version2 {
|
||||
for _, a := range hh.hostinfo.vpnAddrs {
|
||||
if a.Is6() {
|
||||
v = cert.Version2
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cred := cs.GetCredential(v)
|
||||
if cred == nil {
|
||||
hm.f.l.Error("Unable to handshake with host because no certificate is available",
|
||||
"vpnAddrs", hh.hostinfo.vpnAddrs, "certVersion", v)
|
||||
return false
|
||||
}
|
||||
|
||||
machine, err := handshake.NewMachine(
|
||||
v, cs.GetCredential,
|
||||
hm.certVerifier(), func() (uint32, error) { return hm.allocateIndex(hh) },
|
||||
true, header.HandshakeIXPSK0,
|
||||
)
|
||||
if err != nil {
|
||||
hm.f.l.Error("Failed to create handshake machine",
|
||||
"vpnAddrs", hh.hostinfo.vpnAddrs, "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
msg, err := machine.Initiate(nil)
|
||||
if err != nil {
|
||||
hm.f.l.Error("Failed to initiate handshake",
|
||||
"vpnAddrs", hh.hostinfo.vpnAddrs, "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
// hostinfo.ConnectionState stays nil until the handshake completes in
|
||||
// continueHandshake. Pre-completion control surfaces guard with nil
|
||||
// checks; the data plane never observes a pending hostinfo.
|
||||
hh.hostinfo.HandshakePacket[handshakePacketStage0] = msg
|
||||
hh.machine = machine
|
||||
hh.ready = true
|
||||
return true
|
||||
}
|
||||
|
||||
// beginHandshake handles an incoming handshake packet that doesn't match any
|
||||
// existing pending handshake. It creates a new responder Machine and processes
|
||||
// the first message.
|
||||
func (hm *HandshakeManager) beginHandshake(via ViaSender, packet []byte, h *header.H) {
|
||||
f := hm.f
|
||||
cs := f.pki.getCertState()
|
||||
|
||||
v := cs.DefaultVersion()
|
||||
if cs.GetCredential(v) == nil {
|
||||
f.l.Error("Unable to handshake with host because no certificate is available",
|
||||
"from", via, "certVersion", v)
|
||||
return
|
||||
}
|
||||
|
||||
machine, err := handshake.NewMachine(
|
||||
v, cs.GetCredential,
|
||||
hm.certVerifier(), func() (uint32, error) { return generateIndex(f.l) },
|
||||
false, header.HandshakeIXPSK0,
|
||||
)
|
||||
if err != nil {
|
||||
f.l.Error("Failed to create handshake machine", "from", via, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
response, result, err := machine.ProcessPacket(nil, packet)
|
||||
if err != nil {
|
||||
f.l.Error("Failed to process handshake packet", "from", via, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
// Multi-message pattern: the responder Machine would need to be
|
||||
// registered in hm.indexes so a future inbound packet finds it via
|
||||
// continueHandshake. The current manager doesn't do that yet, so
|
||||
// fail loudly rather than silently dropping the in-flight handshake.
|
||||
// TODO: support multi-message responder flows (XX, pqIX, etc.).
|
||||
// See also the IX-shaped cipher key assignment in handshake.Machine.
|
||||
f.l.Error("multi-message handshake responder is not supported",
|
||||
"from", via, "error", handshake.ErrMultiMessageUnsupported)
|
||||
return
|
||||
}
|
||||
|
||||
remoteCert := result.RemoteCert
|
||||
if remoteCert == nil {
|
||||
f.l.Error("Handshake did not produce a peer certificate", "from", via)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate peer identity
|
||||
vpnAddrs, anyVpnAddrsInCommon, ok := hm.validatePeerCert(via, remoteCert)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
hostinfo := &HostInfo{
|
||||
ConnectionState: newConnectionStateFromResult(result),
|
||||
localIndexId: result.LocalIndex,
|
||||
remoteIndexId: result.RemoteIndex,
|
||||
vpnAddrs: vpnAddrs,
|
||||
HandshakePacket: make(map[uint8][]byte, 0),
|
||||
lastHandshakeTime: result.HandshakeTime,
|
||||
relayState: RelayState{
|
||||
relays: nil,
|
||||
relayForByAddr: map[netip.Addr]*Relay{},
|
||||
relayForByIdx: map[uint32]*Relay{},
|
||||
},
|
||||
}
|
||||
|
||||
msg := "Handshake message received"
|
||||
if !anyVpnAddrsInCommon {
|
||||
msg = "Handshake message received, but no vpnNetworks in common."
|
||||
}
|
||||
f.l.Info(msg,
|
||||
"vpnAddrs", vpnAddrs,
|
||||
"from", via,
|
||||
"certName", remoteCert.Certificate.Name(),
|
||||
"certVersion", remoteCert.Certificate.Version(),
|
||||
"fingerprint", remoteCert.Fingerprint,
|
||||
"issuer", remoteCert.Certificate.Issuer(),
|
||||
"initiatorIndex", result.RemoteIndex,
|
||||
"responderIndex", result.LocalIndex,
|
||||
"handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())},
|
||||
)
|
||||
|
||||
// packet aliases the listener's incoming buffer, so this copy must stay.
|
||||
hostinfo.HandshakePacket[handshakePacketStage0] = make([]byte, len(packet[header.Len:]))
|
||||
copy(hostinfo.HandshakePacket[handshakePacketStage0], packet[header.Len:])
|
||||
|
||||
// response was freshly allocated by ProcessPacket; safe to retain directly.
|
||||
if response != nil {
|
||||
hostinfo.HandshakePacket[handshakePacketStage2] = response
|
||||
}
|
||||
|
||||
hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
|
||||
if !via.IsRelayed {
|
||||
hostinfo.SetRemote(via.UdpAddr)
|
||||
}
|
||||
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
|
||||
|
||||
existing, err := hm.CheckAndComplete(hostinfo, handshakePacketStage0, f)
|
||||
if err != nil {
|
||||
hm.handleCheckAndCompleteError(err, existing, hostinfo, via)
|
||||
return
|
||||
}
|
||||
|
||||
hm.sendHandshakeResponse(via, response, hostinfo, false)
|
||||
f.connectionManager.AddTrafficWatch(hostinfo)
|
||||
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
|
||||
|
||||
// Don't wait for UpdateWorker
|
||||
if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) {
|
||||
f.lightHouse.TriggerUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
// continueHandshake feeds an incoming packet to an existing pending handshake Machine.
|
||||
func (hm *HandshakeManager) continueHandshake(via ViaSender, hh *HandshakeHostInfo, packet []byte) {
|
||||
f := hm.f
|
||||
|
||||
hh.Lock()
|
||||
defer hh.Unlock()
|
||||
|
||||
// Re-verify hh is still tracked. Between queryIndex returning and us taking
|
||||
// hh.Lock, handleOutbound may have timed out and deleted it. Once we hold
|
||||
// hh.Lock no other deleter can race our index: handleOutbound also takes
|
||||
// hh.Lock first, and handleRecvError targets a main-hostmap entry with a
|
||||
// different localIndexId.
|
||||
hm.RLock()
|
||||
cur, ok := hm.indexes[hh.hostinfo.localIndexId]
|
||||
hm.RUnlock()
|
||||
if !ok || cur != hh {
|
||||
return
|
||||
}
|
||||
|
||||
hostinfo := hh.hostinfo
|
||||
if !via.IsRelayed {
|
||||
if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
|
||||
f.l.Debug("lighthouse.remote_allow_list denied incoming handshake",
|
||||
"vpnAddrs", hostinfo.vpnAddrs, "from", via)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
machine := hh.machine
|
||||
if machine == nil {
|
||||
f.l.Error("No handshake machine available for continuation",
|
||||
"vpnAddrs", hostinfo.vpnAddrs, "from", via)
|
||||
hm.DeleteHostInfo(hostinfo)
|
||||
return
|
||||
}
|
||||
|
||||
response, result, err := machine.ProcessPacket(nil, packet)
|
||||
if err != nil {
|
||||
// Recoverable errors are routine noise, log at Debug. Fatal errors get a Warn.
|
||||
if machine.Failed() {
|
||||
f.l.Warn("Failed to process handshake packet, abandoning",
|
||||
"vpnAddrs", hostinfo.vpnAddrs, "from", via, "error", err)
|
||||
hm.DeleteHostInfo(hostinfo)
|
||||
} else {
|
||||
f.l.Debug("Failed to process handshake packet",
|
||||
"vpnAddrs", hostinfo.vpnAddrs, "from", via, "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if response != nil {
|
||||
hm.sendHandshakeResponse(via, response, hostinfo, false)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Handshake complete; build the ConnectionState now that we have keys and a verified peer cert.
|
||||
hostinfo.ConnectionState = newConnectionStateFromResult(result)
|
||||
|
||||
remoteCert := result.RemoteCert
|
||||
if remoteCert == nil {
|
||||
f.l.Error("Handshake completed without peer certificate",
|
||||
"vpnAddrs", hostinfo.vpnAddrs, "from", via)
|
||||
hm.DeleteHostInfo(hostinfo)
|
||||
return
|
||||
}
|
||||
|
||||
vpnNetworks := remoteCert.Certificate.Networks()
|
||||
hostinfo.remoteIndexId = result.RemoteIndex
|
||||
hostinfo.lastHandshakeTime = result.HandshakeTime
|
||||
|
||||
if !via.IsRelayed {
|
||||
hostinfo.SetRemote(via.UdpAddr)
|
||||
} else {
|
||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||
}
|
||||
|
||||
// Verify correct host responded (initiator check)
|
||||
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
||||
correctHostResponded := false
|
||||
anyVpnAddrsInCommon := false
|
||||
for i, network := range vpnNetworks {
|
||||
// inside.go drops self-routed packets at the firewall stage, but we'd
|
||||
// rather not let a self-handshake complete in the first place: it
|
||||
// wastes a hostmap slot, suppresses no log, and obscures routing
|
||||
// misconfig. Explicit refusal here mirrors the responder-side check
|
||||
// in validatePeerCert.
|
||||
if f.myVpnAddrsTable.Contains(network.Addr()) {
|
||||
f.l.Error("Refusing to handshake with myself",
|
||||
"vpnNetworks", vpnNetworks,
|
||||
"from", via,
|
||||
"certName", remoteCert.Certificate.Name(),
|
||||
"certVersion", remoteCert.Certificate.Version(),
|
||||
"fingerprint", remoteCert.Fingerprint,
|
||||
"issuer", remoteCert.Certificate.Issuer(),
|
||||
"handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())},
|
||||
)
|
||||
hm.DeleteHostInfo(hostinfo)
|
||||
return
|
||||
}
|
||||
vpnAddrs[i] = network.Addr()
|
||||
if hostinfo.vpnAddrs[0] == network.Addr() {
|
||||
correctHostResponded = true
|
||||
}
|
||||
if f.myVpnNetworksTable.Contains(network.Addr()) {
|
||||
anyVpnAddrsInCommon = true
|
||||
}
|
||||
}
|
||||
|
||||
if !correctHostResponded {
|
||||
f.l.Info("Incorrect host responded to handshake",
|
||||
"intendedVpnAddrs", hostinfo.vpnAddrs,
|
||||
"haveVpnNetworks", vpnNetworks,
|
||||
"from", via,
|
||||
"certName", remoteCert.Certificate.Name(),
|
||||
"certVersion", remoteCert.Certificate.Version(),
|
||||
"fingerprint", remoteCert.Fingerprint,
|
||||
"issuer", remoteCert.Certificate.Issuer(),
|
||||
"handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())},
|
||||
)
|
||||
|
||||
hm.DeleteHostInfo(hostinfo)
|
||||
hm.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
|
||||
newHH.hostinfo.remotes = hostinfo.remotes
|
||||
newHH.hostinfo.remotes.BlockRemote(via)
|
||||
newHH.packetStore = hh.packetStore
|
||||
hh.packetStore = []*cachedPacket{}
|
||||
hostinfo.vpnAddrs = vpnAddrs
|
||||
f.sendCloseTunnel(hostinfo)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
duration := time.Since(hh.startTime).Nanoseconds()
|
||||
msg := "Handshake message received"
|
||||
if !anyVpnAddrsInCommon {
|
||||
msg = "Handshake message received, but no vpnNetworks in common."
|
||||
}
|
||||
f.l.Info(msg,
|
||||
"vpnAddrs", vpnAddrs,
|
||||
"from", via,
|
||||
"certName", remoteCert.Certificate.Name(),
|
||||
"certVersion", remoteCert.Certificate.Version(),
|
||||
"fingerprint", remoteCert.Fingerprint,
|
||||
"issuer", remoteCert.Certificate.Issuer(),
|
||||
"initiatorIndex", result.LocalIndex,
|
||||
"responderIndex", result.RemoteIndex,
|
||||
"handshake", m{"stage": uint64(machine.MessageIndex()), "style": header.SubTypeName(header.Handshake, machine.Subtype())},
|
||||
"durationNs", duration,
|
||||
"sentCachedPackets", len(hh.packetStore),
|
||||
)
|
||||
|
||||
hostinfo.vpnAddrs = vpnAddrs
|
||||
hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
|
||||
|
||||
hm.Complete(hostinfo, f)
|
||||
f.connectionManager.AddTrafficWatch(hostinfo)
|
||||
|
||||
if len(hh.packetStore) > 0 {
|
||||
if f.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||
hostinfo.logger(f.l).Debug("Sending stored packets", "count", len(hh.packetStore))
|
||||
}
|
||||
nb := make([]byte, 12, 12)
|
||||
out := make([]byte, mtu)
|
||||
for _, cp := range hh.packetStore {
|
||||
cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out)
|
||||
}
|
||||
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
|
||||
}
|
||||
|
||||
hostinfo.remotes.RefreshFromHandshake(vpnAddrs)
|
||||
f.metricHandshakes.Update(duration)
|
||||
|
||||
// Don't wait for UpdateWorker
|
||||
if f.lightHouse.IsAnyLighthouseAddr(vpnAddrs) {
|
||||
f.lightHouse.TriggerUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
// validatePeerCert checks the peer certificate for self-connection and remote allow list.
|
||||
// Returns the VPN addrs, whether any of them fall within one of our own VPN
|
||||
// networks, and true if valid; false if rejected.
|
||||
func (hm *HandshakeManager) validatePeerCert(via ViaSender, remoteCert *cert.CachedCertificate) ([]netip.Addr, bool, bool) {
|
||||
f := hm.f
|
||||
vpnNetworks := remoteCert.Certificate.Networks()
|
||||
|
||||
// The cert package rejects host certs with no networks at parse time, so
|
||||
// reaching this state would mean an invariant was bypassed elsewhere.
|
||||
// Refuse explicitly so downstream code (which indexes vpnAddrs[0]) can't
|
||||
// panic if that invariant ever changes.
|
||||
if len(vpnNetworks) == 0 {
|
||||
f.l.Info("No networks in certificate",
|
||||
"from", via, "cert", remoteCert)
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
vpnAddrs := make([]netip.Addr, len(vpnNetworks))
|
||||
anyVpnAddrsInCommon := false
|
||||
|
||||
for i, network := range vpnNetworks {
|
||||
if f.myVpnAddrsTable.Contains(network.Addr()) {
|
||||
f.l.Error("Refusing to handshake with myself",
|
||||
"vpnNetworks", vpnNetworks,
|
||||
"from", via,
|
||||
"certName", remoteCert.Certificate.Name(),
|
||||
"certVersion", remoteCert.Certificate.Version(),
|
||||
"fingerprint", remoteCert.Fingerprint,
|
||||
"issuer", remoteCert.Certificate.Issuer(),
|
||||
)
|
||||
return nil, false, false
|
||||
}
|
||||
vpnAddrs[i] = network.Addr()
|
||||
if f.myVpnNetworksTable.Contains(network.Addr()) {
|
||||
anyVpnAddrsInCommon = true
|
||||
}
|
||||
}
|
||||
|
||||
if !via.IsRelayed {
|
||||
if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) {
|
||||
f.l.Debug("lighthouse.remote_allow_list denied incoming handshake",
|
||||
"vpnAddrs", vpnAddrs, "from", via)
|
||||
return nil, false, false
|
||||
}
|
||||
}
|
||||
|
||||
return vpnAddrs, anyVpnAddrsInCommon, true
|
||||
}
|
||||
|
||||
// sendHandshakeResponse sends a handshake response via the appropriate transport.
|
||||
// cached is true when msg is a stored response being retransmitted because
|
||||
// the peer's stage-1 retransmit landed (the ErrAlreadySeen path); false on a
|
||||
// fresh response.
|
||||
func (hm *HandshakeManager) sendHandshakeResponse(via ViaSender, msg []byte, hostinfo *HostInfo, cached bool) {
|
||||
if msg == nil {
|
||||
return
|
||||
}
|
||||
|
||||
f := hm.f
|
||||
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
|
||||
|
||||
// Common log fields. peerCert may be nil during intermediate
|
||||
// multi-message flows (handshake hasn't completed yet); skip the cert
|
||||
// block if so.
|
||||
logFields := []any{
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"handshake", m{"stage": uint64(2), "style": header.SubTypeName(header.Handshake, header.HandshakeIXPSK0)},
|
||||
"cached", cached,
|
||||
"initiatorIndex", hostinfo.remoteIndexId,
|
||||
"responderIndex", hostinfo.localIndexId,
|
||||
}
|
||||
if peerCert := hostinfo.ConnectionState.peerCert; peerCert != nil {
|
||||
logFields = append(logFields,
|
||||
"certName", peerCert.Certificate.Name(),
|
||||
"certVersion", peerCert.Certificate.Version(),
|
||||
"fingerprint", peerCert.Fingerprint,
|
||||
"issuer", peerCert.Certificate.Issuer(),
|
||||
)
|
||||
}
|
||||
|
||||
if !via.IsRelayed {
|
||||
fields := append(logFields, "from", via)
|
||||
err := f.outside.WriteTo(msg, via.UdpAddr)
|
||||
if err != nil {
|
||||
f.l.Error("Failed to send handshake message", append(fields, "error", err)...)
|
||||
} else {
|
||||
f.l.Info("Handshake message sent", fields...)
|
||||
}
|
||||
} else {
|
||||
if via.relay == nil {
|
||||
f.l.Error("Handshake send failed: both addr and via.relay are nil.")
|
||||
return
|
||||
}
|
||||
hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
|
||||
// We received a valid handshake on this relay, so make sure the relay
|
||||
// state reflects that, in case it had been marked Disestablished.
|
||||
via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established)
|
||||
f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false)
|
||||
f.l.Info("Handshake message sent", append(logFields, "relay", via.relayHI.vpnAddrs[0])...)
|
||||
}
|
||||
}
|
||||
|
||||
// handleCheckAndCompleteError handles errors from CheckAndComplete.
|
||||
// This only fires from the responder-side beginHandshake path, after the
|
||||
// peer cert has been validated and ConnectionState populated, so peerCert
|
||||
// is always non-nil for the cases that log it.
|
||||
func (hm *HandshakeManager) handleCheckAndCompleteError(err error, existing, hostinfo *HostInfo, via ViaSender) {
|
||||
f := hm.f
|
||||
peerCert := hostinfo.ConnectionState.peerCert
|
||||
hsFields := m{"stage": uint64(1), "style": header.SubTypeName(header.Handshake, header.HandshakeIXPSK0)}
|
||||
|
||||
switch err {
|
||||
case ErrAlreadySeen:
|
||||
if existing.SetRemoteIfPreferred(f.hostMap, via) {
|
||||
f.SendMessageToVpnAddr(header.Test, header.TestRequest, hostinfo.vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||
}
|
||||
// Resend the original response. The peer is committed to that response's
|
||||
// ephemeral keys; a freshly-built one would have different keys and break
|
||||
// the tunnel even though both sides "completed" the handshake.
|
||||
if msg := existing.HandshakePacket[handshakePacketStage2]; msg != nil {
|
||||
hm.sendHandshakeResponse(via, msg, existing, true)
|
||||
}
|
||||
|
||||
case ErrExistingHostInfo:
|
||||
f.l.Info("Handshake too old",
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"from", via,
|
||||
"certName", peerCert.Certificate.Name(),
|
||||
"certVersion", peerCert.Certificate.Version(),
|
||||
"fingerprint", peerCert.Fingerprint,
|
||||
"issuer", peerCert.Certificate.Issuer(),
|
||||
"oldHandshakeTime", existing.lastHandshakeTime,
|
||||
"newHandshakeTime", hostinfo.lastHandshakeTime,
|
||||
"initiatorIndex", hostinfo.remoteIndexId,
|
||||
"responderIndex", hostinfo.localIndexId,
|
||||
"handshake", hsFields,
|
||||
)
|
||||
f.SendMessageToVpnAddr(header.Test, header.TestRequest, hostinfo.vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
||||
|
||||
case ErrLocalIndexCollision:
|
||||
f.l.Error("Failed to add HostInfo due to localIndex collision",
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"from", via,
|
||||
"certName", peerCert.Certificate.Name(),
|
||||
"certVersion", peerCert.Certificate.Version(),
|
||||
"fingerprint", peerCert.Fingerprint,
|
||||
"issuer", peerCert.Certificate.Issuer(),
|
||||
"localIndex", hostinfo.localIndexId,
|
||||
"initiatorIndex", hostinfo.remoteIndexId,
|
||||
"responderIndex", hostinfo.localIndexId,
|
||||
"handshake", hsFields,
|
||||
)
|
||||
|
||||
default:
|
||||
f.l.Error("Failed to add HostInfo to HostMap",
|
||||
"vpnAddrs", hostinfo.vpnAddrs,
|
||||
"from", via,
|
||||
"error", err,
|
||||
"certName", peerCert.Certificate.Name(),
|
||||
"certVersion", peerCert.Certificate.Version(),
|
||||
"fingerprint", peerCert.Fingerprint,
|
||||
"issuer", peerCert.Certificate.Issuer(),
|
||||
"initiatorIndex", hostinfo.remoteIndexId,
|
||||
"responderIndex", hostinfo.localIndexId,
|
||||
"handshake", hsFields,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// certVerifier returns a CertVerifier that validates certs against the current CA pool.
|
||||
func (hm *HandshakeManager) certVerifier() handshake.CertVerifier {
|
||||
return func(c cert.Certificate) (*cert.CachedCertificate, error) {
|
||||
return hm.f.pki.GetCAPool().VerifyCertificate(time.Now(), c)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/test"
|
||||
@@ -27,7 +28,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
|
||||
initiatingVersion: cert.Version1,
|
||||
privateKey: []byte{},
|
||||
v1Cert: &dummyCert{version: cert.Version1},
|
||||
v1HandshakeBytes: []byte{},
|
||||
v1Credential: nil,
|
||||
}
|
||||
|
||||
blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
|
||||
@@ -100,3 +101,137 @@ func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo {
|
||||
func (mw *mockEncWriter) GetCertState() *CertState {
|
||||
return &CertState{initiatingVersion: cert.Version2}
|
||||
}
|
||||
|
||||
func TestValidatePeerCert(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
|
||||
myNetwork := netip.MustParsePrefix("10.0.0.1/24")
|
||||
myAddrTable := new(bart.Lite)
|
||||
myAddrTable.Insert(netip.PrefixFrom(myNetwork.Addr(), myNetwork.Addr().BitLen()))
|
||||
myNetTable := new(bart.Lite)
|
||||
myNetTable.Insert(myNetwork.Masked())
|
||||
|
||||
newHM := func() *HandshakeManager {
|
||||
hm := NewHandshakeManager(l, newHostMap(l), newTestLighthouse(), &udp.NoopConn{}, defaultHandshakeConfig)
|
||||
hm.f = &Interface{
|
||||
handshakeManager: hm,
|
||||
pki: &PKI{},
|
||||
l: l,
|
||||
myVpnAddrsTable: myAddrTable,
|
||||
myVpnNetworksTable: myNetTable,
|
||||
lightHouse: hm.lightHouse,
|
||||
}
|
||||
return hm
|
||||
}
|
||||
|
||||
cached := func(networks ...netip.Prefix) *cert.CachedCertificate {
|
||||
return &cert.CachedCertificate{
|
||||
Certificate: &dummyCert{name: "peer", networks: networks},
|
||||
}
|
||||
}
|
||||
|
||||
via := ViaSender{
|
||||
UdpAddr: netip.MustParseAddrPort("198.51.100.7:4242"),
|
||||
IsRelayed: true, // skip the remote allow list (covered separately)
|
||||
}
|
||||
|
||||
t.Run("addr inside our networks sets anyVpnAddrsInCommon", func(t *testing.T) {
|
||||
hm := newHM()
|
||||
// 10.0.0.2 falls inside our 10.0.0.0/24
|
||||
addrs, common, ok := hm.validatePeerCert(via, cached(netip.MustParsePrefix("10.0.0.2/24")))
|
||||
assert.True(t, ok)
|
||||
assert.True(t, common)
|
||||
assert.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.2")}, addrs)
|
||||
})
|
||||
|
||||
t.Run("addr outside our networks leaves anyVpnAddrsInCommon false", func(t *testing.T) {
|
||||
hm := newHM()
|
||||
addrs, common, ok := hm.validatePeerCert(via, cached(netip.MustParsePrefix("192.168.1.5/24")))
|
||||
assert.True(t, ok)
|
||||
assert.False(t, common)
|
||||
assert.Equal(t, []netip.Addr{netip.MustParseAddr("192.168.1.5")}, addrs)
|
||||
})
|
||||
|
||||
t.Run("any matching network is enough", func(t *testing.T) {
|
||||
hm := newHM()
|
||||
addrs, common, ok := hm.validatePeerCert(via, cached(
|
||||
netip.MustParsePrefix("192.168.1.5/24"),
|
||||
netip.MustParsePrefix("10.0.0.42/24"),
|
||||
))
|
||||
assert.True(t, ok)
|
||||
assert.True(t, common)
|
||||
assert.Len(t, addrs, 2)
|
||||
})
|
||||
|
||||
t.Run("self-handshake is rejected", func(t *testing.T) {
|
||||
hm := newHM()
|
||||
// 10.0.0.1 is in myVpnAddrsTable
|
||||
addrs, common, ok := hm.validatePeerCert(via, cached(netip.MustParsePrefix("10.0.0.1/24")))
|
||||
assert.False(t, ok)
|
||||
assert.False(t, common)
|
||||
assert.Nil(t, addrs)
|
||||
})
|
||||
|
||||
t.Run("cert with no networks is rejected", func(t *testing.T) {
|
||||
hm := newHM()
|
||||
addrs, common, ok := hm.validatePeerCert(via, cached())
|
||||
assert.False(t, ok)
|
||||
assert.False(t, common)
|
||||
assert.Nil(t, addrs)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleIncomingDispatch(t *testing.T) {
|
||||
l := test.NewLogger()
|
||||
|
||||
newHM := func() *HandshakeManager {
|
||||
hm := NewHandshakeManager(l, newHostMap(l), newTestLighthouse(), &udp.NoopConn{}, defaultHandshakeConfig)
|
||||
hm.f = &Interface{
|
||||
handshakeManager: hm,
|
||||
pki: &PKI{},
|
||||
l: l,
|
||||
}
|
||||
return hm
|
||||
}
|
||||
|
||||
via := ViaSender{
|
||||
UdpAddr: netip.MustParseAddrPort("198.51.100.7:4242"),
|
||||
IsRelayed: true, // bypass remote allow list
|
||||
}
|
||||
|
||||
// A packet body of zero length is fine for these tests: dispatch is
|
||||
// gated on header fields, and we assert that we never reach noise/cert
|
||||
// processing for any of the malformed shapes here.
|
||||
pkt := make([]byte, header.Len)
|
||||
|
||||
t.Run("unsupported subtype dropped", func(t *testing.T) {
|
||||
hm := newHM()
|
||||
h := &header.H{Type: header.Handshake, Subtype: header.MessageSubType(99), MessageCounter: 1}
|
||||
hm.HandleIncoming(via, pkt, h)
|
||||
assert.Empty(t, hm.indexes, "no pending handshake should be created")
|
||||
})
|
||||
|
||||
t.Run("stage-1 with non-zero RemoteIndex dropped", func(t *testing.T) {
|
||||
hm := newHM()
|
||||
h := &header.H{
|
||||
Type: header.Handshake,
|
||||
Subtype: header.HandshakeIXPSK0,
|
||||
RemoteIndex: 0xdeadbeef,
|
||||
MessageCounter: 1,
|
||||
}
|
||||
hm.HandleIncoming(via, pkt, h)
|
||||
assert.Empty(t, hm.indexes, "spoofed stage-1 must not create a pending machine")
|
||||
})
|
||||
|
||||
t.Run("continuation with no matching pending index dropped", func(t *testing.T) {
|
||||
hm := newHM()
|
||||
h := &header.H{
|
||||
Type: header.Handshake,
|
||||
Subtype: header.HandshakeIXPSK0,
|
||||
RemoteIndex: 0xcafef00d,
|
||||
MessageCounter: 2,
|
||||
}
|
||||
hm.HandleIncoming(via, pkt, h)
|
||||
assert.Empty(t, hm.indexes, "orphan stage-2 must not create state")
|
||||
})
|
||||
}
|
||||
|
||||
677
nebula.pb.go
677
nebula.pb.go
@@ -124,7 +124,7 @@ func (x NebulaControl_MessageType) String() string {
|
||||
}
|
||||
|
||||
func (NebulaControl_MessageType) EnumDescriptor() ([]byte, []int) {
|
||||
return fileDescriptor_2d65afa7693df5ef, []int{8, 0}
|
||||
return fileDescriptor_2d65afa7693df5ef, []int{6, 0}
|
||||
}
|
||||
|
||||
type NebulaMeta struct {
|
||||
@@ -489,142 +489,6 @@ func (m *NebulaPing) GetTime() uint64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
type NebulaHandshake struct {
|
||||
Details *NebulaHandshakeDetails `protobuf:"bytes,1,opt,name=Details,proto3" json:"Details,omitempty"`
|
||||
Hmac []byte `protobuf:"bytes,2,opt,name=Hmac,proto3" json:"Hmac,omitempty"`
|
||||
}
|
||||
|
||||
func (m *NebulaHandshake) Reset() { *m = NebulaHandshake{} }
|
||||
func (m *NebulaHandshake) String() string { return proto.CompactTextString(m) }
|
||||
func (*NebulaHandshake) ProtoMessage() {}
|
||||
func (*NebulaHandshake) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_2d65afa7693df5ef, []int{6}
|
||||
}
|
||||
func (m *NebulaHandshake) XXX_Unmarshal(b []byte) error {
|
||||
return m.Unmarshal(b)
|
||||
}
|
||||
func (m *NebulaHandshake) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
||||
if deterministic {
|
||||
return xxx_messageInfo_NebulaHandshake.Marshal(b, m, deterministic)
|
||||
} else {
|
||||
b = b[:cap(b)]
|
||||
n, err := m.MarshalToSizedBuffer(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b[:n], nil
|
||||
}
|
||||
}
|
||||
func (m *NebulaHandshake) XXX_Merge(src proto.Message) {
|
||||
xxx_messageInfo_NebulaHandshake.Merge(m, src)
|
||||
}
|
||||
func (m *NebulaHandshake) XXX_Size() int {
|
||||
return m.Size()
|
||||
}
|
||||
func (m *NebulaHandshake) XXX_DiscardUnknown() {
|
||||
xxx_messageInfo_NebulaHandshake.DiscardUnknown(m)
|
||||
}
|
||||
|
||||
var xxx_messageInfo_NebulaHandshake proto.InternalMessageInfo
|
||||
|
||||
func (m *NebulaHandshake) GetDetails() *NebulaHandshakeDetails {
|
||||
if m != nil {
|
||||
return m.Details
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *NebulaHandshake) GetHmac() []byte {
|
||||
if m != nil {
|
||||
return m.Hmac
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type NebulaHandshakeDetails struct {
|
||||
Cert []byte `protobuf:"bytes,1,opt,name=Cert,proto3" json:"Cert,omitempty"`
|
||||
InitiatorIndex uint32 `protobuf:"varint,2,opt,name=InitiatorIndex,proto3" json:"InitiatorIndex,omitempty"`
|
||||
ResponderIndex uint32 `protobuf:"varint,3,opt,name=ResponderIndex,proto3" json:"ResponderIndex,omitempty"`
|
||||
Cookie uint64 `protobuf:"varint,4,opt,name=Cookie,proto3" json:"Cookie,omitempty"`
|
||||
Time uint64 `protobuf:"varint,5,opt,name=Time,proto3" json:"Time,omitempty"`
|
||||
CertVersion uint32 `protobuf:"varint,8,opt,name=CertVersion,proto3" json:"CertVersion,omitempty"`
|
||||
}
|
||||
|
||||
func (m *NebulaHandshakeDetails) Reset() { *m = NebulaHandshakeDetails{} }
|
||||
func (m *NebulaHandshakeDetails) String() string { return proto.CompactTextString(m) }
|
||||
func (*NebulaHandshakeDetails) ProtoMessage() {}
|
||||
func (*NebulaHandshakeDetails) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_2d65afa7693df5ef, []int{7}
|
||||
}
|
||||
func (m *NebulaHandshakeDetails) XXX_Unmarshal(b []byte) error {
|
||||
return m.Unmarshal(b)
|
||||
}
|
||||
func (m *NebulaHandshakeDetails) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
||||
if deterministic {
|
||||
return xxx_messageInfo_NebulaHandshakeDetails.Marshal(b, m, deterministic)
|
||||
} else {
|
||||
b = b[:cap(b)]
|
||||
n, err := m.MarshalToSizedBuffer(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b[:n], nil
|
||||
}
|
||||
}
|
||||
func (m *NebulaHandshakeDetails) XXX_Merge(src proto.Message) {
|
||||
xxx_messageInfo_NebulaHandshakeDetails.Merge(m, src)
|
||||
}
|
||||
func (m *NebulaHandshakeDetails) XXX_Size() int {
|
||||
return m.Size()
|
||||
}
|
||||
func (m *NebulaHandshakeDetails) XXX_DiscardUnknown() {
|
||||
xxx_messageInfo_NebulaHandshakeDetails.DiscardUnknown(m)
|
||||
}
|
||||
|
||||
var xxx_messageInfo_NebulaHandshakeDetails proto.InternalMessageInfo
|
||||
|
||||
func (m *NebulaHandshakeDetails) GetCert() []byte {
|
||||
if m != nil {
|
||||
return m.Cert
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *NebulaHandshakeDetails) GetInitiatorIndex() uint32 {
|
||||
if m != nil {
|
||||
return m.InitiatorIndex
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *NebulaHandshakeDetails) GetResponderIndex() uint32 {
|
||||
if m != nil {
|
||||
return m.ResponderIndex
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *NebulaHandshakeDetails) GetCookie() uint64 {
|
||||
if m != nil {
|
||||
return m.Cookie
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *NebulaHandshakeDetails) GetTime() uint64 {
|
||||
if m != nil {
|
||||
return m.Time
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *NebulaHandshakeDetails) GetCertVersion() uint32 {
|
||||
if m != nil {
|
||||
return m.CertVersion
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
type NebulaControl struct {
|
||||
Type NebulaControl_MessageType `protobuf:"varint,1,opt,name=Type,proto3,enum=nebula.NebulaControl_MessageType" json:"Type,omitempty"`
|
||||
InitiatorRelayIndex uint32 `protobuf:"varint,2,opt,name=InitiatorRelayIndex,proto3" json:"InitiatorRelayIndex,omitempty"`
|
||||
@@ -639,7 +503,7 @@ func (m *NebulaControl) Reset() { *m = NebulaControl{} }
|
||||
func (m *NebulaControl) String() string { return proto.CompactTextString(m) }
|
||||
func (*NebulaControl) ProtoMessage() {}
|
||||
func (*NebulaControl) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_2d65afa7693df5ef, []int{8}
|
||||
return fileDescriptor_2d65afa7693df5ef, []int{6}
|
||||
}
|
||||
func (m *NebulaControl) XXX_Unmarshal(b []byte) error {
|
||||
return m.Unmarshal(b)
|
||||
@@ -729,65 +593,55 @@ func init() {
|
||||
proto.RegisterType((*V4AddrPort)(nil), "nebula.V4AddrPort")
|
||||
proto.RegisterType((*V6AddrPort)(nil), "nebula.V6AddrPort")
|
||||
proto.RegisterType((*NebulaPing)(nil), "nebula.NebulaPing")
|
||||
proto.RegisterType((*NebulaHandshake)(nil), "nebula.NebulaHandshake")
|
||||
proto.RegisterType((*NebulaHandshakeDetails)(nil), "nebula.NebulaHandshakeDetails")
|
||||
proto.RegisterType((*NebulaControl)(nil), "nebula.NebulaControl")
|
||||
}
|
||||
|
||||
func init() { proto.RegisterFile("nebula.proto", fileDescriptor_2d65afa7693df5ef) }
|
||||
|
||||
var fileDescriptor_2d65afa7693df5ef = []byte{
|
||||
// 785 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x55, 0xcd, 0x6e, 0xeb, 0x44,
|
||||
0x14, 0x8e, 0x1d, 0x27, 0x4e, 0x4f, 0x7e, 0xae, 0x39, 0x15, 0xc1, 0x41, 0x22, 0x0a, 0x5e, 0x54,
|
||||
0x57, 0x2c, 0x72, 0x51, 0x5a, 0xae, 0x58, 0x72, 0x1b, 0x84, 0xd2, 0xaa, 0x3f, 0x61, 0x54, 0x8a,
|
||||
0xc4, 0x06, 0xb9, 0xf6, 0xd0, 0x58, 0x71, 0x3c, 0xa9, 0x3d, 0x41, 0xcd, 0x5b, 0xf0, 0x30, 0x3c,
|
||||
0x04, 0xec, 0xba, 0x42, 0x2c, 0x51, 0xbb, 0x64, 0xc9, 0x0b, 0xa0, 0x19, 0xff, 0x27, 0x86, 0xbb,
|
||||
0x9b, 0x73, 0xbe, 0xef, 0x3b, 0x73, 0xe6, 0xf3, 0x9c, 0x31, 0x74, 0x02, 0x7a, 0xb7, 0xf1, 0xed,
|
||||
0xf1, 0x3a, 0x64, 0x9c, 0x61, 0x33, 0x8e, 0xac, 0xbf, 0x55, 0x80, 0x2b, 0xb9, 0xbc, 0xa4, 0xdc,
|
||||
0xc6, 0x09, 0x68, 0x37, 0xdb, 0x35, 0x35, 0x95, 0x91, 0xf2, 0xba, 0x37, 0x19, 0x8e, 0x13, 0x4d,
|
||||
0xce, 0x18, 0x5f, 0xd2, 0x28, 0xb2, 0xef, 0xa9, 0x60, 0x11, 0xc9, 0xc5, 0x63, 0xd0, 0xbf, 0xa6,
|
||||
0xdc, 0xf6, 0xfc, 0xc8, 0x54, 0x47, 0xca, 0xeb, 0xf6, 0x64, 0xb0, 0x2f, 0x4b, 0x08, 0x24, 0x65,
|
||||
0x5a, 0xff, 0x28, 0xd0, 0x2e, 0x94, 0xc2, 0x16, 0x68, 0x57, 0x2c, 0xa0, 0x46, 0x0d, 0xbb, 0x70,
|
||||
0x30, 0x63, 0x11, 0xff, 0x76, 0x43, 0xc3, 0xad, 0xa1, 0x20, 0x42, 0x2f, 0x0b, 0x09, 0x5d, 0xfb,
|
||||
0x5b, 0x43, 0xc5, 0x8f, 0xa1, 0x2f, 0x72, 0xdf, 0xad, 0x5d, 0x9b, 0xd3, 0x2b, 0xc6, 0xbd, 0x9f,
|
||||
0x3c, 0xc7, 0xe6, 0x1e, 0x0b, 0x8c, 0x3a, 0x0e, 0xe0, 0x43, 0x81, 0x5d, 0xb2, 0x9f, 0xa9, 0x5b,
|
||||
0x82, 0xb4, 0x14, 0x9a, 0x6f, 0x02, 0x67, 0x51, 0x82, 0x1a, 0xd8, 0x03, 0x10, 0xd0, 0xf7, 0x0b,
|
||||
0x66, 0xaf, 0x3c, 0xa3, 0x89, 0x87, 0xf0, 0x2a, 0x8f, 0xe3, 0x6d, 0x75, 0xd1, 0xd9, 0xdc, 0xe6,
|
||||
0x8b, 0xe9, 0x82, 0x3a, 0x4b, 0xa3, 0x25, 0x3a, 0xcb, 0xc2, 0x98, 0x72, 0x80, 0x9f, 0xc0, 0xa0,
|
||||
0xba, 0xb3, 0x77, 0xce, 0xd2, 0x00, 0xeb, 0x77, 0x15, 0x3e, 0xd8, 0x33, 0x05, 0x2d, 0x80, 0x6b,
|
||||
0xdf, 0xbd, 0x5d, 0x07, 0xef, 0x5c, 0x37, 0x94, 0xd6, 0x77, 0x4f, 0x55, 0x53, 0x21, 0x85, 0x2c,
|
||||
0x1e, 0x81, 0x9e, 0x12, 0x9a, 0xd2, 0xe4, 0x4e, 0x6a, 0xb2, 0xc8, 0x91, 0x14, 0xc4, 0x31, 0x18,
|
||||
0xd7, 0xbe, 0x4b, 0xa8, 0x6f, 0x6f, 0x93, 0x54, 0x64, 0x36, 0x46, 0xf5, 0xa4, 0xe2, 0x1e, 0x86,
|
||||
0x13, 0xe8, 0x96, 0xc9, 0xfa, 0xa8, 0xbe, 0x57, 0xbd, 0x4c, 0xc1, 0x13, 0x68, 0xdf, 0x9e, 0x88,
|
||||
0xe5, 0x9c, 0x85, 0x5c, 0x7c, 0x74, 0xa1, 0xc0, 0x54, 0x91, 0x43, 0xa4, 0x48, 0x93, 0xaa, 0xb7,
|
||||
0xb9, 0x4a, 0xdb, 0x51, 0xbd, 0x2d, 0xa8, 0x72, 0x1a, 0x9a, 0xa0, 0x3b, 0x6c, 0x13, 0x70, 0x1a,
|
||||
0x9a, 0x75, 0x61, 0x0c, 0x49, 0x43, 0xeb, 0x08, 0x34, 0x79, 0xe2, 0x1e, 0xa8, 0x33, 0x4f, 0xba,
|
||||
0xa6, 0x11, 0x75, 0xe6, 0x89, 0xf8, 0x82, 0xc9, 0x9b, 0xa8, 0x11, 0xf5, 0x82, 0x59, 0x27, 0x00,
|
||||
0x79, 0x1b, 0x88, 0xb1, 0x2a, 0x76, 0x99, 0xc4, 0x15, 0x10, 0x34, 0x81, 0x49, 0x4d, 0x97, 0xc8,
|
||||
0xb5, 0xf5, 0x15, 0x40, 0xde, 0xc6, 0xfb, 0xf6, 0xc8, 0x2a, 0xd4, 0x0b, 0x15, 0x1e, 0xd3, 0xc1,
|
||||
0x9a, 0x7b, 0xc1, 0xfd, 0xff, 0x0f, 0x96, 0x60, 0x54, 0x0c, 0x16, 0x82, 0x76, 0xe3, 0xad, 0x68,
|
||||
0xb2, 0x8f, 0x5c, 0x5b, 0xd6, 0xde, 0xd8, 0x08, 0xb1, 0x51, 0xc3, 0x03, 0x68, 0xc4, 0x97, 0x50,
|
||||
0xb1, 0x7e, 0x84, 0x57, 0x71, 0xdd, 0x99, 0x1d, 0xb8, 0xd1, 0xc2, 0x5e, 0x52, 0xfc, 0x32, 0x9f,
|
||||
0x51, 0x45, 0x5e, 0x9f, 0x9d, 0x0e, 0x32, 0xe6, 0xee, 0xa0, 0x8a, 0x26, 0x66, 0x2b, 0xdb, 0x91,
|
||||
0x4d, 0x74, 0x88, 0x5c, 0x5b, 0x7f, 0x28, 0xd0, 0xaf, 0xd6, 0x09, 0xfa, 0x94, 0x86, 0x5c, 0xee,
|
||||
0xd2, 0x21, 0x72, 0x8d, 0x47, 0xd0, 0x3b, 0x0b, 0x3c, 0xee, 0xd9, 0x9c, 0x85, 0x67, 0x81, 0x4b,
|
||||
0x1f, 0x13, 0xa7, 0x77, 0xb2, 0x82, 0x47, 0x68, 0xb4, 0x66, 0x81, 0x4b, 0x13, 0x5e, 0xec, 0xe7,
|
||||
0x4e, 0x16, 0xfb, 0xd0, 0x9c, 0x32, 0xb6, 0xf4, 0xa8, 0xa9, 0x49, 0x67, 0x92, 0x28, 0xf3, 0xab,
|
||||
0x91, 0xfb, 0x85, 0x23, 0x68, 0x8b, 0x1e, 0x6e, 0x69, 0x18, 0x79, 0x2c, 0x30, 0x5b, 0xb2, 0x60,
|
||||
0x31, 0x75, 0xae, 0xb5, 0x9a, 0x86, 0x7e, 0xae, 0xb5, 0x74, 0xa3, 0x65, 0xfd, 0x5a, 0x87, 0x6e,
|
||||
0x7c, 0xb0, 0x29, 0x0b, 0x78, 0xc8, 0x7c, 0xfc, 0xa2, 0xf4, 0xdd, 0x3e, 0x2d, 0xbb, 0x96, 0x90,
|
||||
0x2a, 0x3e, 0xdd, 0xe7, 0x70, 0x98, 0x1d, 0x4e, 0x0e, 0x4f, 0xf1, 0xdc, 0x55, 0x90, 0x50, 0x64,
|
||||
0xc7, 0x2c, 0x28, 0x62, 0x07, 0xaa, 0x20, 0xfc, 0x0c, 0x7a, 0xe9, 0x38, 0xdf, 0x30, 0x79, 0xa9,
|
||||
0xb5, 0xec, 0xe9, 0xd8, 0x41, 0x8a, 0xcf, 0xc2, 0x37, 0x21, 0x5b, 0x49, 0x76, 0x23, 0x63, 0xef,
|
||||
0x61, 0x38, 0x86, 0x76, 0xb1, 0x70, 0xd5, 0x93, 0x53, 0x24, 0x64, 0xcf, 0x48, 0x56, 0x5c, 0xaf,
|
||||
0x50, 0x94, 0x29, 0xd6, 0xec, 0xbf, 0xfe, 0x00, 0x7d, 0xc0, 0x69, 0x48, 0x6d, 0x4e, 0x25, 0x9f,
|
||||
0xd0, 0x87, 0x0d, 0x8d, 0xb8, 0xa1, 0xe0, 0x47, 0x70, 0x58, 0xca, 0x0b, 0x4b, 0x22, 0x6a, 0xa8,
|
||||
0xa7, 0xc7, 0xbf, 0x3d, 0x0f, 0x95, 0xa7, 0xe7, 0xa1, 0xf2, 0xd7, 0xf3, 0x50, 0xf9, 0xe5, 0x65,
|
||||
0x58, 0x7b, 0x7a, 0x19, 0xd6, 0xfe, 0x7c, 0x19, 0xd6, 0x7e, 0x18, 0xdc, 0x7b, 0x7c, 0xb1, 0xb9,
|
||||
0x1b, 0x3b, 0x6c, 0xf5, 0x26, 0xf2, 0x6d, 0x67, 0xb9, 0x78, 0x78, 0x13, 0xb7, 0x74, 0xd7, 0x94,
|
||||
0x3f, 0xc2, 0xe3, 0x7f, 0x03, 0x00, 0x00, 0xff, 0xff, 0xea, 0x6f, 0xbc, 0x50, 0x18, 0x07, 0x00,
|
||||
0x00,
|
||||
// 665 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x54, 0xcd, 0x6e, 0xd3, 0x5c,
|
||||
0x10, 0x8d, 0x1d, 0x27, 0x69, 0x27, 0x4d, 0x3e, 0x7f, 0x53, 0x51, 0x12, 0x24, 0xac, 0xe0, 0x45,
|
||||
0x55, 0xb1, 0x48, 0x51, 0x5a, 0xba, 0xa6, 0x2d, 0x42, 0xa9, 0xd4, 0x9f, 0x70, 0x55, 0x8a, 0xc4,
|
||||
0xce, 0xb5, 0x2f, 0x8d, 0x55, 0xc7, 0x37, 0xb5, 0x6f, 0x50, 0xf3, 0x16, 0x3c, 0x0c, 0x0f, 0x01,
|
||||
0xbb, 0x2e, 0x59, 0xa2, 0x66, 0xc9, 0x92, 0x17, 0x40, 0xf7, 0xfa, 0xbf, 0x31, 0xb0, 0xbb, 0x33,
|
||||
0xe7, 0x9c, 0x99, 0xc9, 0xc9, 0x8c, 0x61, 0xcd, 0xa7, 0x97, 0x33, 0xcf, 0xea, 0x4f, 0x03, 0xc6,
|
||||
0x19, 0xd6, 0xa3, 0xc8, 0xfc, 0xa9, 0x02, 0x9c, 0xca, 0xe7, 0x09, 0xe5, 0x16, 0x0e, 0x40, 0x3b,
|
||||
0x9f, 0x4f, 0x69, 0x47, 0xe9, 0x29, 0x5b, 0xed, 0x81, 0xd1, 0x8f, 0x35, 0x19, 0xa3, 0x7f, 0x42,
|
||||
0xc3, 0xd0, 0xba, 0xa2, 0x82, 0x45, 0x24, 0x17, 0x77, 0xa0, 0xf1, 0x9a, 0x72, 0xcb, 0xf5, 0xc2,
|
||||
0x8e, 0xda, 0x53, 0xb6, 0x9a, 0x83, 0xee, 0xb2, 0x2c, 0x26, 0x90, 0x84, 0x69, 0xfe, 0x52, 0xa0,
|
||||
0x99, 0x2b, 0x85, 0x2b, 0xa0, 0x9d, 0x32, 0x9f, 0xea, 0x15, 0x6c, 0xc1, 0xea, 0x90, 0x85, 0xfc,
|
||||
0xed, 0x8c, 0x06, 0x73, 0x5d, 0x41, 0x84, 0x76, 0x1a, 0x12, 0x3a, 0xf5, 0xe6, 0xba, 0x8a, 0x4f,
|
||||
0x60, 0x43, 0xe4, 0xde, 0x4d, 0x1d, 0x8b, 0xd3, 0x53, 0xc6, 0xdd, 0x8f, 0xae, 0x6d, 0x71, 0x97,
|
||||
0xf9, 0x7a, 0x15, 0xbb, 0xf0, 0x48, 0x60, 0x27, 0xec, 0x13, 0x75, 0x0a, 0x90, 0x96, 0x40, 0xa3,
|
||||
0x99, 0x6f, 0x8f, 0x0b, 0x50, 0x0d, 0xdb, 0x00, 0x02, 0x7a, 0x3f, 0x66, 0xd6, 0xc4, 0xd5, 0xeb,
|
||||
0xb8, 0x0e, 0xff, 0x65, 0x71, 0xd4, 0xb6, 0x21, 0x26, 0x1b, 0x59, 0x7c, 0x7c, 0x38, 0xa6, 0xf6,
|
||||
0xb5, 0xbe, 0x22, 0x26, 0x4b, 0xc3, 0x88, 0xb2, 0x8a, 0x4f, 0xa1, 0x5b, 0x3e, 0xd9, 0xbe, 0x7d,
|
||||
0xad, 0x83, 0xf9, 0x4d, 0x85, 0xff, 0x97, 0x4c, 0x41, 0x13, 0xe0, 0xcc, 0x73, 0x2e, 0xa6, 0xfe,
|
||||
0xbe, 0xe3, 0x04, 0xd2, 0xfa, 0xd6, 0x81, 0xda, 0x51, 0x48, 0x2e, 0x8b, 0x9b, 0xd0, 0x48, 0x08,
|
||||
0x75, 0x69, 0xf2, 0x5a, 0x62, 0xb2, 0xc8, 0x91, 0x04, 0xc4, 0x3e, 0xe8, 0x67, 0x9e, 0x43, 0xa8,
|
||||
0x67, 0xcd, 0xe3, 0x54, 0xd8, 0xa9, 0xf5, 0xaa, 0x71, 0xc5, 0x25, 0x0c, 0x07, 0xd0, 0x2a, 0x92,
|
||||
0x1b, 0xbd, 0xea, 0x52, 0xf5, 0x22, 0x05, 0x77, 0xa1, 0x79, 0xb1, 0x2b, 0x9e, 0x23, 0x16, 0x70,
|
||||
0xf1, 0xa7, 0x0b, 0x05, 0x26, 0x8a, 0x0c, 0x22, 0x79, 0x9a, 0x54, 0xed, 0x65, 0x2a, 0xed, 0x81,
|
||||
0x6a, 0x2f, 0xa7, 0xca, 0x68, 0xd8, 0x81, 0x86, 0xcd, 0x66, 0x3e, 0xa7, 0x41, 0xa7, 0x2a, 0x8c,
|
||||
0x21, 0x49, 0x68, 0x6e, 0x82, 0x26, 0x7f, 0x71, 0x1b, 0xd4, 0xa1, 0x2b, 0x5d, 0xd3, 0x88, 0x3a,
|
||||
0x74, 0x45, 0x7c, 0xcc, 0xe4, 0x26, 0x6a, 0x44, 0x3d, 0x66, 0xe6, 0x2e, 0x40, 0x36, 0x06, 0x62,
|
||||
0xa4, 0x8a, 0x5c, 0x26, 0x51, 0x05, 0x04, 0x4d, 0x60, 0x52, 0xd3, 0x22, 0xf2, 0x6d, 0xbe, 0x02,
|
||||
0xc8, 0xc6, 0xf8, 0x57, 0x8f, 0xb4, 0x42, 0x35, 0x57, 0xe1, 0x36, 0x39, 0xac, 0x91, 0xeb, 0x5f,
|
||||
0xfd, 0xfd, 0xb0, 0x04, 0xa3, 0xe4, 0xb0, 0x10, 0xb4, 0x73, 0x77, 0x42, 0xe3, 0x3e, 0xf2, 0x6d,
|
||||
0x9a, 0x4b, 0x67, 0x23, 0xc4, 0x7a, 0x05, 0x57, 0xa1, 0x16, 0x2d, 0xa1, 0x62, 0x7e, 0xa9, 0x42,
|
||||
0x2b, 0x2a, 0x7c, 0xc8, 0x7c, 0x1e, 0x30, 0x0f, 0x5f, 0x16, 0xba, 0x3f, 0x2b, 0x76, 0x8f, 0x49,
|
||||
0x25, 0x03, 0xbc, 0x80, 0xf5, 0x23, 0xdf, 0xe5, 0xae, 0xc5, 0x59, 0x20, 0x57, 0xe0, 0xc8, 0x77,
|
||||
0xe8, 0x6d, 0xec, 0x53, 0x19, 0x24, 0x14, 0x84, 0x86, 0x53, 0xe6, 0x3b, 0x34, 0xaf, 0x88, 0x7c,
|
||||
0x29, 0x83, 0xf0, 0x39, 0xb4, 0x93, 0xa5, 0x3c, 0x67, 0xf2, 0xaf, 0xd1, 0xd2, 0x03, 0x78, 0x80,
|
||||
0xe4, 0x97, 0xfb, 0x4d, 0xc0, 0x26, 0x92, 0x5d, 0x4b, 0xd9, 0x4b, 0x18, 0xf6, 0xa1, 0x99, 0x2f,
|
||||
0x5c, 0x76, 0x38, 0x79, 0x42, 0x7a, 0x0c, 0x69, 0xf1, 0x46, 0x89, 0xa2, 0x48, 0x31, 0x87, 0x7f,
|
||||
0xfa, 0x8e, 0x6d, 0x00, 0x1e, 0x06, 0xd4, 0xe2, 0x54, 0xf2, 0x09, 0xbd, 0x99, 0xd1, 0x90, 0xeb,
|
||||
0x0a, 0x3e, 0x86, 0xf5, 0x42, 0x5e, 0x58, 0x12, 0x52, 0x5d, 0x3d, 0xd8, 0xf9, 0x7a, 0x6f, 0x28,
|
||||
0x77, 0xf7, 0x86, 0xf2, 0xe3, 0xde, 0x50, 0x3e, 0x2f, 0x8c, 0xca, 0xdd, 0xc2, 0xa8, 0x7c, 0x5f,
|
||||
0x18, 0x95, 0x0f, 0xdd, 0x2b, 0x97, 0x8f, 0x67, 0x97, 0x7d, 0x9b, 0x4d, 0xb6, 0x43, 0xcf, 0xb2,
|
||||
0xaf, 0xc7, 0x37, 0xdb, 0xd1, 0x48, 0x97, 0x75, 0xf9, 0x39, 0xdf, 0xf9, 0x1d, 0x00, 0x00, 0xff,
|
||||
0xff, 0x51, 0x0a, 0xe3, 0xd7, 0xde, 0x05, 0x00, 0x00,
|
||||
}
|
||||
|
||||
func (m *NebulaMeta) Marshal() (dAtA []byte, err error) {
|
||||
@@ -1072,103 +926,6 @@ func (m *NebulaPing) MarshalToSizedBuffer(dAtA []byte) (int, error) {
|
||||
return len(dAtA) - i, nil
|
||||
}
|
||||
|
||||
func (m *NebulaHandshake) Marshal() (dAtA []byte, err error) {
|
||||
size := m.Size()
|
||||
dAtA = make([]byte, size)
|
||||
n, err := m.MarshalToSizedBuffer(dAtA[:size])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dAtA[:n], nil
|
||||
}
|
||||
|
||||
func (m *NebulaHandshake) MarshalTo(dAtA []byte) (int, error) {
|
||||
size := m.Size()
|
||||
return m.MarshalToSizedBuffer(dAtA[:size])
|
||||
}
|
||||
|
||||
func (m *NebulaHandshake) MarshalToSizedBuffer(dAtA []byte) (int, error) {
|
||||
i := len(dAtA)
|
||||
_ = i
|
||||
var l int
|
||||
_ = l
|
||||
if len(m.Hmac) > 0 {
|
||||
i -= len(m.Hmac)
|
||||
copy(dAtA[i:], m.Hmac)
|
||||
i = encodeVarintNebula(dAtA, i, uint64(len(m.Hmac)))
|
||||
i--
|
||||
dAtA[i] = 0x12
|
||||
}
|
||||
if m.Details != nil {
|
||||
{
|
||||
size, err := m.Details.MarshalToSizedBuffer(dAtA[:i])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
i -= size
|
||||
i = encodeVarintNebula(dAtA, i, uint64(size))
|
||||
}
|
||||
i--
|
||||
dAtA[i] = 0xa
|
||||
}
|
||||
return len(dAtA) - i, nil
|
||||
}
|
||||
|
||||
func (m *NebulaHandshakeDetails) Marshal() (dAtA []byte, err error) {
|
||||
size := m.Size()
|
||||
dAtA = make([]byte, size)
|
||||
n, err := m.MarshalToSizedBuffer(dAtA[:size])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dAtA[:n], nil
|
||||
}
|
||||
|
||||
func (m *NebulaHandshakeDetails) MarshalTo(dAtA []byte) (int, error) {
|
||||
size := m.Size()
|
||||
return m.MarshalToSizedBuffer(dAtA[:size])
|
||||
}
|
||||
|
||||
func (m *NebulaHandshakeDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) {
|
||||
i := len(dAtA)
|
||||
_ = i
|
||||
var l int
|
||||
_ = l
|
||||
if m.CertVersion != 0 {
|
||||
i = encodeVarintNebula(dAtA, i, uint64(m.CertVersion))
|
||||
i--
|
||||
dAtA[i] = 0x40
|
||||
}
|
||||
if m.Time != 0 {
|
||||
i = encodeVarintNebula(dAtA, i, uint64(m.Time))
|
||||
i--
|
||||
dAtA[i] = 0x28
|
||||
}
|
||||
if m.Cookie != 0 {
|
||||
i = encodeVarintNebula(dAtA, i, uint64(m.Cookie))
|
||||
i--
|
||||
dAtA[i] = 0x20
|
||||
}
|
||||
if m.ResponderIndex != 0 {
|
||||
i = encodeVarintNebula(dAtA, i, uint64(m.ResponderIndex))
|
||||
i--
|
||||
dAtA[i] = 0x18
|
||||
}
|
||||
if m.InitiatorIndex != 0 {
|
||||
i = encodeVarintNebula(dAtA, i, uint64(m.InitiatorIndex))
|
||||
i--
|
||||
dAtA[i] = 0x10
|
||||
}
|
||||
if len(m.Cert) > 0 {
|
||||
i -= len(m.Cert)
|
||||
copy(dAtA[i:], m.Cert)
|
||||
i = encodeVarintNebula(dAtA, i, uint64(len(m.Cert)))
|
||||
i--
|
||||
dAtA[i] = 0xa
|
||||
}
|
||||
return len(dAtA) - i, nil
|
||||
}
|
||||
|
||||
func (m *NebulaControl) Marshal() (dAtA []byte, err error) {
|
||||
size := m.Size()
|
||||
dAtA = make([]byte, size)
|
||||
@@ -1375,51 +1132,6 @@ func (m *NebulaPing) Size() (n int) {
|
||||
return n
|
||||
}
|
||||
|
||||
func (m *NebulaHandshake) Size() (n int) {
|
||||
if m == nil {
|
||||
return 0
|
||||
}
|
||||
var l int
|
||||
_ = l
|
||||
if m.Details != nil {
|
||||
l = m.Details.Size()
|
||||
n += 1 + l + sovNebula(uint64(l))
|
||||
}
|
||||
l = len(m.Hmac)
|
||||
if l > 0 {
|
||||
n += 1 + l + sovNebula(uint64(l))
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (m *NebulaHandshakeDetails) Size() (n int) {
|
||||
if m == nil {
|
||||
return 0
|
||||
}
|
||||
var l int
|
||||
_ = l
|
||||
l = len(m.Cert)
|
||||
if l > 0 {
|
||||
n += 1 + l + sovNebula(uint64(l))
|
||||
}
|
||||
if m.InitiatorIndex != 0 {
|
||||
n += 1 + sovNebula(uint64(m.InitiatorIndex))
|
||||
}
|
||||
if m.ResponderIndex != 0 {
|
||||
n += 1 + sovNebula(uint64(m.ResponderIndex))
|
||||
}
|
||||
if m.Cookie != 0 {
|
||||
n += 1 + sovNebula(uint64(m.Cookie))
|
||||
}
|
||||
if m.Time != 0 {
|
||||
n += 1 + sovNebula(uint64(m.Time))
|
||||
}
|
||||
if m.CertVersion != 0 {
|
||||
n += 1 + sovNebula(uint64(m.CertVersion))
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (m *NebulaControl) Size() (n int) {
|
||||
if m == nil {
|
||||
return 0
|
||||
@@ -2236,305 +1948,6 @@ func (m *NebulaPing) Unmarshal(dAtA []byte) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *NebulaHandshake) Unmarshal(dAtA []byte) error {
|
||||
l := len(dAtA)
|
||||
iNdEx := 0
|
||||
for iNdEx < l {
|
||||
preIndex := iNdEx
|
||||
var wire uint64
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowNebula
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
wire |= uint64(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
fieldNum := int32(wire >> 3)
|
||||
wireType := int(wire & 0x7)
|
||||
if wireType == 4 {
|
||||
return fmt.Errorf("proto: NebulaHandshake: wiretype end group for non-group")
|
||||
}
|
||||
if fieldNum <= 0 {
|
||||
return fmt.Errorf("proto: NebulaHandshake: illegal tag %d (wire type %d)", fieldNum, wire)
|
||||
}
|
||||
switch fieldNum {
|
||||
case 1:
|
||||
if wireType != 2 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field Details", wireType)
|
||||
}
|
||||
var msglen int
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowNebula
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
msglen |= int(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
if msglen < 0 {
|
||||
return ErrInvalidLengthNebula
|
||||
}
|
||||
postIndex := iNdEx + msglen
|
||||
if postIndex < 0 {
|
||||
return ErrInvalidLengthNebula
|
||||
}
|
||||
if postIndex > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
if m.Details == nil {
|
||||
m.Details = &NebulaHandshakeDetails{}
|
||||
}
|
||||
if err := m.Details.Unmarshal(dAtA[iNdEx:postIndex]); err != nil {
|
||||
return err
|
||||
}
|
||||
iNdEx = postIndex
|
||||
case 2:
|
||||
if wireType != 2 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field Hmac", wireType)
|
||||
}
|
||||
var byteLen int
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowNebula
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
byteLen |= int(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
if byteLen < 0 {
|
||||
return ErrInvalidLengthNebula
|
||||
}
|
||||
postIndex := iNdEx + byteLen
|
||||
if postIndex < 0 {
|
||||
return ErrInvalidLengthNebula
|
||||
}
|
||||
if postIndex > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
m.Hmac = append(m.Hmac[:0], dAtA[iNdEx:postIndex]...)
|
||||
if m.Hmac == nil {
|
||||
m.Hmac = []byte{}
|
||||
}
|
||||
iNdEx = postIndex
|
||||
default:
|
||||
iNdEx = preIndex
|
||||
skippy, err := skipNebula(dAtA[iNdEx:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if (skippy < 0) || (iNdEx+skippy) < 0 {
|
||||
return ErrInvalidLengthNebula
|
||||
}
|
||||
if (iNdEx + skippy) > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
iNdEx += skippy
|
||||
}
|
||||
}
|
||||
|
||||
if iNdEx > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *NebulaHandshakeDetails) Unmarshal(dAtA []byte) error {
|
||||
l := len(dAtA)
|
||||
iNdEx := 0
|
||||
for iNdEx < l {
|
||||
preIndex := iNdEx
|
||||
var wire uint64
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowNebula
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
wire |= uint64(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
fieldNum := int32(wire >> 3)
|
||||
wireType := int(wire & 0x7)
|
||||
if wireType == 4 {
|
||||
return fmt.Errorf("proto: NebulaHandshakeDetails: wiretype end group for non-group")
|
||||
}
|
||||
if fieldNum <= 0 {
|
||||
return fmt.Errorf("proto: NebulaHandshakeDetails: illegal tag %d (wire type %d)", fieldNum, wire)
|
||||
}
|
||||
switch fieldNum {
|
||||
case 1:
|
||||
if wireType != 2 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field Cert", wireType)
|
||||
}
|
||||
var byteLen int
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowNebula
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
byteLen |= int(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
if byteLen < 0 {
|
||||
return ErrInvalidLengthNebula
|
||||
}
|
||||
postIndex := iNdEx + byteLen
|
||||
if postIndex < 0 {
|
||||
return ErrInvalidLengthNebula
|
||||
}
|
||||
if postIndex > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
m.Cert = append(m.Cert[:0], dAtA[iNdEx:postIndex]...)
|
||||
if m.Cert == nil {
|
||||
m.Cert = []byte{}
|
||||
}
|
||||
iNdEx = postIndex
|
||||
case 2:
|
||||
if wireType != 0 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field InitiatorIndex", wireType)
|
||||
}
|
||||
m.InitiatorIndex = 0
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowNebula
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
m.InitiatorIndex |= uint32(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
case 3:
|
||||
if wireType != 0 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field ResponderIndex", wireType)
|
||||
}
|
||||
m.ResponderIndex = 0
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowNebula
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
m.ResponderIndex |= uint32(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
case 4:
|
||||
if wireType != 0 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field Cookie", wireType)
|
||||
}
|
||||
m.Cookie = 0
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowNebula
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
m.Cookie |= uint64(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
case 5:
|
||||
if wireType != 0 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field Time", wireType)
|
||||
}
|
||||
m.Time = 0
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowNebula
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
m.Time |= uint64(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
case 8:
|
||||
if wireType != 0 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field CertVersion", wireType)
|
||||
}
|
||||
m.CertVersion = 0
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowNebula
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
m.CertVersion |= uint32(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
default:
|
||||
iNdEx = preIndex
|
||||
skippy, err := skipNebula(dAtA[iNdEx:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if (skippy < 0) || (iNdEx+skippy) < 0 {
|
||||
return ErrInvalidLengthNebula
|
||||
}
|
||||
if (iNdEx + skippy) > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
iNdEx += skippy
|
||||
}
|
||||
}
|
||||
|
||||
if iNdEx > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *NebulaControl) Unmarshal(dAtA []byte) error {
|
||||
l := len(dAtA)
|
||||
iNdEx := 0
|
||||
|
||||
18
nebula.proto
18
nebula.proto
@@ -60,21 +60,9 @@ message NebulaPing {
|
||||
uint64 Time = 2;
|
||||
}
|
||||
|
||||
message NebulaHandshake {
|
||||
NebulaHandshakeDetails Details = 1;
|
||||
bytes Hmac = 2;
|
||||
}
|
||||
|
||||
message NebulaHandshakeDetails {
|
||||
bytes Cert = 1;
|
||||
uint32 InitiatorIndex = 2;
|
||||
uint32 ResponderIndex = 3;
|
||||
uint64 Cookie = 4;
|
||||
uint64 Time = 5;
|
||||
uint32 CertVersion = 8;
|
||||
// reserved for WIP multiport
|
||||
reserved 6, 7;
|
||||
}
|
||||
// NebulaHandshake / NebulaHandshakeDetails moved to
|
||||
// handshake/handshake.proto. The handshake package speaks that wire format
|
||||
// directly via a hand-written encoder/decoder.
|
||||
|
||||
message NebulaControl {
|
||||
enum MessageType {
|
||||
|
||||
119
pki.go
119
pki.go
@@ -15,9 +15,12 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/handshake"
|
||||
"github.com/slackhq/nebula/noiseutil"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
@@ -29,10 +32,10 @@ type PKI struct {
|
||||
|
||||
type CertState struct {
|
||||
v1Cert cert.Certificate
|
||||
v1HandshakeBytes []byte
|
||||
v1Credential *handshake.Credential
|
||||
|
||||
v2Cert cert.Certificate
|
||||
v2HandshakeBytes []byte
|
||||
v2Credential *handshake.Credential
|
||||
|
||||
initiatingVersion cert.Version
|
||||
privateKey []byte
|
||||
@@ -92,13 +95,35 @@ func (p *PKI) reload(c *config.C, initial bool) error {
|
||||
}
|
||||
|
||||
func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
||||
newState, err := newCertStateFromConfig(c)
|
||||
var cipher string
|
||||
var currentState *CertState
|
||||
if initial {
|
||||
cipher = c.GetString("cipher", "aes")
|
||||
//TODO: this sucks and we should make it not a global
|
||||
switch cipher {
|
||||
case "aes":
|
||||
noiseEndianness = binary.BigEndian
|
||||
case "chachapoly":
|
||||
noiseEndianness = binary.LittleEndian
|
||||
default:
|
||||
return util.NewContextualError(
|
||||
"unknown cipher",
|
||||
m{"cipher": cipher},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// Cipher cant be hot swapped so just leave it at what it was before
|
||||
currentState = p.cs.Load()
|
||||
cipher = currentState.cipher
|
||||
}
|
||||
|
||||
newState, err := newCertStateFromConfig(c, cipher)
|
||||
if err != nil {
|
||||
return util.NewContextualError("Could not load client cert", nil, err)
|
||||
}
|
||||
|
||||
if !initial {
|
||||
currentState := p.cs.Load()
|
||||
if currentState != nil {
|
||||
if newState.v1Cert != nil {
|
||||
if currentState.v1Cert == nil {
|
||||
//adding certs is fine, actually. Networks-in-common confirmed in newCertState().
|
||||
@@ -158,25 +183,6 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Cipher cant be hot swapped so just leave it at what it was before
|
||||
newState.cipher = currentState.cipher
|
||||
|
||||
} else {
|
||||
newState.cipher = c.GetString("cipher", "aes")
|
||||
//TODO: this sucks and we should make it not a global
|
||||
switch newState.cipher {
|
||||
case "aes":
|
||||
noiseEndianness = binary.BigEndian
|
||||
case "chachapoly":
|
||||
noiseEndianness = binary.LittleEndian
|
||||
default:
|
||||
return util.NewContextualError(
|
||||
"unknown cipher",
|
||||
m{"cipher": newState.cipher},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
p.cs.Store(newState)
|
||||
@@ -208,6 +214,20 @@ func (cs *CertState) GetDefaultCertificate() cert.Certificate {
|
||||
return c
|
||||
}
|
||||
|
||||
// DefaultVersion returns the preferred cert version for initiating handshakes.
|
||||
func (cs *CertState) DefaultVersion() cert.Version { return cs.initiatingVersion }
|
||||
|
||||
// GetCredential returns the pre-computed handshake credential for the given version, or nil.
|
||||
func (cs *CertState) GetCredential(v cert.Version) *handshake.Credential {
|
||||
switch v {
|
||||
case cert.Version1:
|
||||
return cs.v1Credential
|
||||
case cert.Version2:
|
||||
return cs.v2Credential
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cs *CertState) getCertificate(v cert.Version) cert.Certificate {
|
||||
switch v {
|
||||
case cert.Version1:
|
||||
@@ -219,17 +239,25 @@ func (cs *CertState) getCertificate(v cert.Version) cert.Certificate {
|
||||
return nil
|
||||
}
|
||||
|
||||
// getHandshakeBytes returns the cached bytes to be used in a handshake message for the requested version.
|
||||
// Callers must check if the return []byte is nil.
|
||||
func (cs *CertState) getHandshakeBytes(v cert.Version) []byte {
|
||||
switch v {
|
||||
case cert.Version1:
|
||||
return cs.v1HandshakeBytes
|
||||
case cert.Version2:
|
||||
return cs.v2HandshakeBytes
|
||||
default:
|
||||
return nil
|
||||
func newCipherSuite(curve cert.Curve, pkcs11backed bool, cipher string) (noise.CipherSuite, error) {
|
||||
var dhFunc noise.DHFunc
|
||||
switch curve {
|
||||
case cert.Curve_CURVE25519:
|
||||
dhFunc = noise.DH25519
|
||||
case cert.Curve_P256:
|
||||
if pkcs11backed {
|
||||
dhFunc = noiseutil.DHP256PKCS11
|
||||
} else {
|
||||
dhFunc = noiseutil.DHP256
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported curve: %s", curve)
|
||||
}
|
||||
|
||||
if cipher == "chachapoly" {
|
||||
return noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256), nil
|
||||
}
|
||||
return noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256), nil
|
||||
}
|
||||
|
||||
func (cs *CertState) String() string {
|
||||
@@ -261,7 +289,7 @@ func (cs *CertState) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(msg)
|
||||
}
|
||||
|
||||
func newCertStateFromConfig(c *config.C) (*CertState, error) {
|
||||
func newCertStateFromConfig(c *config.C, cipher string) (*CertState, error) {
|
||||
var err error
|
||||
|
||||
privPathOrPEM := c.GetString("pki.key", "")
|
||||
@@ -345,13 +373,14 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
|
||||
return nil, fmt.Errorf("unknown pki.initiating_version: %v", rawInitiatingVersion)
|
||||
}
|
||||
|
||||
return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey)
|
||||
return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey, cipher)
|
||||
}
|
||||
|
||||
func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) {
|
||||
func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte, cipher string) (*CertState, error) {
|
||||
cs := CertState{
|
||||
privateKey: privateKey,
|
||||
pkcs11Backed: pkcs11backed,
|
||||
cipher: cipher,
|
||||
myVpnNetworksTable: new(bart.Lite),
|
||||
myVpnAddrsTable: new(bart.Lite),
|
||||
myVpnBroadcastAddrsTable: new(bart.Lite),
|
||||
@@ -384,10 +413,14 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
|
||||
|
||||
v1hs, err := v1.MarshalForHandshakes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err)
|
||||
return nil, fmt.Errorf("error marshalling v1 certificate for handshake: %w", err)
|
||||
}
|
||||
ncs, err := newCipherSuite(v1.Curve(), pkcs11backed, cipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cs.v1Cert = v1
|
||||
cs.v1HandshakeBytes = v1hs
|
||||
cs.v1Credential = handshake.NewCredential(v1, v1hs, privateKey, ncs)
|
||||
|
||||
if cs.initiatingVersion == 0 {
|
||||
cs.initiatingVersion = cert.Version1
|
||||
@@ -405,10 +438,14 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
|
||||
|
||||
v2hs, err := v2.MarshalForHandshakes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err)
|
||||
return nil, fmt.Errorf("error marshalling v2 certificate for handshake: %w", err)
|
||||
}
|
||||
ncs, err := newCipherSuite(v2.Curve(), pkcs11backed, cipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cs.v2Cert = v2
|
||||
cs.v2HandshakeBytes = v2hs
|
||||
cs.v2Credential = handshake.NewCredential(v2, v2hs, privateKey, ncs)
|
||||
|
||||
if cs.initiatingVersion == 0 {
|
||||
cs.initiatingVersion = cert.Version2
|
||||
|
||||
Reference in New Issue
Block a user