Handshake state machine (#1656)

This commit is contained in:
Nate Brown
2026-04-30 21:30:27 -05:00
committed by GitHub
parent 1ab1f71dba
commit 9ec8cf10f3
21 changed files with 3036 additions and 1593 deletions

121
pki.go
View File

@@ -15,9 +15,12 @@ import (
"sync/atomic"
"time"
"github.com/flynn/noise"
"github.com/gaissmai/bart"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/handshake"
"github.com/slackhq/nebula/noiseutil"
"github.com/slackhq/nebula/util"
)
@@ -28,11 +31,11 @@ type PKI struct {
}
type CertState struct {
v1Cert cert.Certificate
v1HandshakeBytes []byte
v1Cert cert.Certificate
v1Credential *handshake.Credential
v2Cert cert.Certificate
v2HandshakeBytes []byte
v2Cert cert.Certificate
v2Credential *handshake.Credential
initiatingVersion cert.Version
privateKey []byte
@@ -92,13 +95,35 @@ func (p *PKI) reload(c *config.C, initial bool) error {
}
func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
newState, err := newCertStateFromConfig(c)
var cipher string
var currentState *CertState
if initial {
cipher = c.GetString("cipher", "aes")
//TODO: this sucks and we should make it not a global
switch cipher {
case "aes":
noiseEndianness = binary.BigEndian
case "chachapoly":
noiseEndianness = binary.LittleEndian
default:
return util.NewContextualError(
"unknown cipher",
m{"cipher": cipher},
nil,
)
}
} else {
// Cipher cant be hot swapped so just leave it at what it was before
currentState = p.cs.Load()
cipher = currentState.cipher
}
newState, err := newCertStateFromConfig(c, cipher)
if err != nil {
return util.NewContextualError("Could not load client cert", nil, err)
}
if !initial {
currentState := p.cs.Load()
if currentState != nil {
if newState.v1Cert != nil {
if currentState.v1Cert == nil {
//adding certs is fine, actually. Networks-in-common confirmed in newCertState().
@@ -158,25 +183,6 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
)
}
}
// Cipher cant be hot swapped so just leave it at what it was before
newState.cipher = currentState.cipher
} else {
newState.cipher = c.GetString("cipher", "aes")
//TODO: this sucks and we should make it not a global
switch newState.cipher {
case "aes":
noiseEndianness = binary.BigEndian
case "chachapoly":
noiseEndianness = binary.LittleEndian
default:
return util.NewContextualError(
"unknown cipher",
m{"cipher": newState.cipher},
nil,
)
}
}
p.cs.Store(newState)
@@ -208,6 +214,20 @@ func (cs *CertState) GetDefaultCertificate() cert.Certificate {
return c
}
// DefaultVersion returns the preferred cert version for initiating handshakes.
func (cs *CertState) DefaultVersion() cert.Version { return cs.initiatingVersion }
// GetCredential returns the pre-computed handshake credential for the given version, or nil.
func (cs *CertState) GetCredential(v cert.Version) *handshake.Credential {
switch v {
case cert.Version1:
return cs.v1Credential
case cert.Version2:
return cs.v2Credential
}
return nil
}
func (cs *CertState) getCertificate(v cert.Version) cert.Certificate {
switch v {
case cert.Version1:
@@ -219,17 +239,25 @@ func (cs *CertState) getCertificate(v cert.Version) cert.Certificate {
return nil
}
// getHandshakeBytes returns the cached bytes to be used in a handshake message for the requested version.
// Callers must check if the return []byte is nil.
func (cs *CertState) getHandshakeBytes(v cert.Version) []byte {
switch v {
case cert.Version1:
return cs.v1HandshakeBytes
case cert.Version2:
return cs.v2HandshakeBytes
func newCipherSuite(curve cert.Curve, pkcs11backed bool, cipher string) (noise.CipherSuite, error) {
var dhFunc noise.DHFunc
switch curve {
case cert.Curve_CURVE25519:
dhFunc = noise.DH25519
case cert.Curve_P256:
if pkcs11backed {
dhFunc = noiseutil.DHP256PKCS11
} else {
dhFunc = noiseutil.DHP256
}
default:
return nil
return nil, fmt.Errorf("unsupported curve: %s", curve)
}
if cipher == "chachapoly" {
return noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256), nil
}
return noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256), nil
}
func (cs *CertState) String() string {
@@ -261,7 +289,7 @@ func (cs *CertState) MarshalJSON() ([]byte, error) {
return json.Marshal(msg)
}
func newCertStateFromConfig(c *config.C) (*CertState, error) {
func newCertStateFromConfig(c *config.C, cipher string) (*CertState, error) {
var err error
privPathOrPEM := c.GetString("pki.key", "")
@@ -345,13 +373,14 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
return nil, fmt.Errorf("unknown pki.initiating_version: %v", rawInitiatingVersion)
}
return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey)
return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey, cipher)
}
func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) {
func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte, cipher string) (*CertState, error) {
cs := CertState{
privateKey: privateKey,
pkcs11Backed: pkcs11backed,
cipher: cipher,
myVpnNetworksTable: new(bart.Lite),
myVpnAddrsTable: new(bart.Lite),
myVpnBroadcastAddrsTable: new(bart.Lite),
@@ -384,10 +413,14 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
v1hs, err := v1.MarshalForHandshakes()
if err != nil {
return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err)
return nil, fmt.Errorf("error marshalling v1 certificate for handshake: %w", err)
}
ncs, err := newCipherSuite(v1.Curve(), pkcs11backed, cipher)
if err != nil {
return nil, err
}
cs.v1Cert = v1
cs.v1HandshakeBytes = v1hs
cs.v1Credential = handshake.NewCredential(v1, v1hs, privateKey, ncs)
if cs.initiatingVersion == 0 {
cs.initiatingVersion = cert.Version1
@@ -405,10 +438,14 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
v2hs, err := v2.MarshalForHandshakes()
if err != nil {
return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err)
return nil, fmt.Errorf("error marshalling v2 certificate for handshake: %w", err)
}
ncs, err := newCipherSuite(v2.Curve(), pkcs11backed, cipher)
if err != nil {
return nil, err
}
cs.v2Cert = v2
cs.v2HandshakeBytes = v2hs
cs.v2Credential = handshake.NewCredential(v2, v2hs, privateKey, ncs)
if cs.initiatingVersion == 0 {
cs.initiatingVersion = cert.Version2