mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-15 20:37:36 +02:00
Handshake state machine (#1656)
This commit is contained in:
121
pki.go
121
pki.go
@@ -15,9 +15,12 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/handshake"
|
||||
"github.com/slackhq/nebula/noiseutil"
|
||||
"github.com/slackhq/nebula/util"
|
||||
)
|
||||
|
||||
@@ -28,11 +31,11 @@ type PKI struct {
|
||||
}
|
||||
|
||||
type CertState struct {
|
||||
v1Cert cert.Certificate
|
||||
v1HandshakeBytes []byte
|
||||
v1Cert cert.Certificate
|
||||
v1Credential *handshake.Credential
|
||||
|
||||
v2Cert cert.Certificate
|
||||
v2HandshakeBytes []byte
|
||||
v2Cert cert.Certificate
|
||||
v2Credential *handshake.Credential
|
||||
|
||||
initiatingVersion cert.Version
|
||||
privateKey []byte
|
||||
@@ -92,13 +95,35 @@ func (p *PKI) reload(c *config.C, initial bool) error {
|
||||
}
|
||||
|
||||
func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
||||
newState, err := newCertStateFromConfig(c)
|
||||
var cipher string
|
||||
var currentState *CertState
|
||||
if initial {
|
||||
cipher = c.GetString("cipher", "aes")
|
||||
//TODO: this sucks and we should make it not a global
|
||||
switch cipher {
|
||||
case "aes":
|
||||
noiseEndianness = binary.BigEndian
|
||||
case "chachapoly":
|
||||
noiseEndianness = binary.LittleEndian
|
||||
default:
|
||||
return util.NewContextualError(
|
||||
"unknown cipher",
|
||||
m{"cipher": cipher},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// Cipher cant be hot swapped so just leave it at what it was before
|
||||
currentState = p.cs.Load()
|
||||
cipher = currentState.cipher
|
||||
}
|
||||
|
||||
newState, err := newCertStateFromConfig(c, cipher)
|
||||
if err != nil {
|
||||
return util.NewContextualError("Could not load client cert", nil, err)
|
||||
}
|
||||
|
||||
if !initial {
|
||||
currentState := p.cs.Load()
|
||||
if currentState != nil {
|
||||
if newState.v1Cert != nil {
|
||||
if currentState.v1Cert == nil {
|
||||
//adding certs is fine, actually. Networks-in-common confirmed in newCertState().
|
||||
@@ -158,25 +183,6 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Cipher cant be hot swapped so just leave it at what it was before
|
||||
newState.cipher = currentState.cipher
|
||||
|
||||
} else {
|
||||
newState.cipher = c.GetString("cipher", "aes")
|
||||
//TODO: this sucks and we should make it not a global
|
||||
switch newState.cipher {
|
||||
case "aes":
|
||||
noiseEndianness = binary.BigEndian
|
||||
case "chachapoly":
|
||||
noiseEndianness = binary.LittleEndian
|
||||
default:
|
||||
return util.NewContextualError(
|
||||
"unknown cipher",
|
||||
m{"cipher": newState.cipher},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
p.cs.Store(newState)
|
||||
@@ -208,6 +214,20 @@ func (cs *CertState) GetDefaultCertificate() cert.Certificate {
|
||||
return c
|
||||
}
|
||||
|
||||
// DefaultVersion returns the preferred cert version for initiating handshakes.
|
||||
func (cs *CertState) DefaultVersion() cert.Version { return cs.initiatingVersion }
|
||||
|
||||
// GetCredential returns the pre-computed handshake credential for the given version, or nil.
|
||||
func (cs *CertState) GetCredential(v cert.Version) *handshake.Credential {
|
||||
switch v {
|
||||
case cert.Version1:
|
||||
return cs.v1Credential
|
||||
case cert.Version2:
|
||||
return cs.v2Credential
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cs *CertState) getCertificate(v cert.Version) cert.Certificate {
|
||||
switch v {
|
||||
case cert.Version1:
|
||||
@@ -219,17 +239,25 @@ func (cs *CertState) getCertificate(v cert.Version) cert.Certificate {
|
||||
return nil
|
||||
}
|
||||
|
||||
// getHandshakeBytes returns the cached bytes to be used in a handshake message for the requested version.
|
||||
// Callers must check if the return []byte is nil.
|
||||
func (cs *CertState) getHandshakeBytes(v cert.Version) []byte {
|
||||
switch v {
|
||||
case cert.Version1:
|
||||
return cs.v1HandshakeBytes
|
||||
case cert.Version2:
|
||||
return cs.v2HandshakeBytes
|
||||
func newCipherSuite(curve cert.Curve, pkcs11backed bool, cipher string) (noise.CipherSuite, error) {
|
||||
var dhFunc noise.DHFunc
|
||||
switch curve {
|
||||
case cert.Curve_CURVE25519:
|
||||
dhFunc = noise.DH25519
|
||||
case cert.Curve_P256:
|
||||
if pkcs11backed {
|
||||
dhFunc = noiseutil.DHP256PKCS11
|
||||
} else {
|
||||
dhFunc = noiseutil.DHP256
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
return nil, fmt.Errorf("unsupported curve: %s", curve)
|
||||
}
|
||||
|
||||
if cipher == "chachapoly" {
|
||||
return noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256), nil
|
||||
}
|
||||
return noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256), nil
|
||||
}
|
||||
|
||||
func (cs *CertState) String() string {
|
||||
@@ -261,7 +289,7 @@ func (cs *CertState) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(msg)
|
||||
}
|
||||
|
||||
func newCertStateFromConfig(c *config.C) (*CertState, error) {
|
||||
func newCertStateFromConfig(c *config.C, cipher string) (*CertState, error) {
|
||||
var err error
|
||||
|
||||
privPathOrPEM := c.GetString("pki.key", "")
|
||||
@@ -345,13 +373,14 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
|
||||
return nil, fmt.Errorf("unknown pki.initiating_version: %v", rawInitiatingVersion)
|
||||
}
|
||||
|
||||
return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey)
|
||||
return newCertState(initiatingVersion, v1, v2, isPkcs11, curve, rawKey, cipher)
|
||||
}
|
||||
|
||||
func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) {
|
||||
func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte, cipher string) (*CertState, error) {
|
||||
cs := CertState{
|
||||
privateKey: privateKey,
|
||||
pkcs11Backed: pkcs11backed,
|
||||
cipher: cipher,
|
||||
myVpnNetworksTable: new(bart.Lite),
|
||||
myVpnAddrsTable: new(bart.Lite),
|
||||
myVpnBroadcastAddrsTable: new(bart.Lite),
|
||||
@@ -384,10 +413,14 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
|
||||
|
||||
v1hs, err := v1.MarshalForHandshakes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err)
|
||||
return nil, fmt.Errorf("error marshalling v1 certificate for handshake: %w", err)
|
||||
}
|
||||
ncs, err := newCipherSuite(v1.Curve(), pkcs11backed, cipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cs.v1Cert = v1
|
||||
cs.v1HandshakeBytes = v1hs
|
||||
cs.v1Credential = handshake.NewCredential(v1, v1hs, privateKey, ncs)
|
||||
|
||||
if cs.initiatingVersion == 0 {
|
||||
cs.initiatingVersion = cert.Version1
|
||||
@@ -405,10 +438,14 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p
|
||||
|
||||
v2hs, err := v2.MarshalForHandshakes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err)
|
||||
return nil, fmt.Errorf("error marshalling v2 certificate for handshake: %w", err)
|
||||
}
|
||||
ncs, err := newCipherSuite(v2.Curve(), pkcs11backed, cipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cs.v2Cert = v2
|
||||
cs.v2HandshakeBytes = v2hs
|
||||
cs.v2Credential = handshake.NewCredential(v2, v2hs, privateKey, ncs)
|
||||
|
||||
if cs.initiatingVersion == 0 {
|
||||
cs.initiatingVersion = cert.Version2
|
||||
|
||||
Reference in New Issue
Block a user