mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 16:34:25 +01:00
V2 certificate format (#1216)
Co-authored-by: Nate Brown <nbrown.us@gmail.com> Co-authored-by: Jack Doan <jackdoan@rivian.com> Co-authored-by: brad-defined <77982333+brad-defined@users.noreply.github.com> Co-authored-by: Jack Doan <me@jackdoan.com>
This commit is contained in:
414
pki.go
414
pki.go
@@ -1,13 +1,19 @@
|
||||
package nebula
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/slackhq/nebula/config"
|
||||
@@ -21,12 +27,22 @@ type PKI struct {
|
||||
}
|
||||
|
||||
type CertState struct {
|
||||
Certificate cert.Certificate
|
||||
RawCertificate []byte
|
||||
RawCertificateNoKey []byte
|
||||
PublicKey []byte
|
||||
PrivateKey []byte
|
||||
pkcs11Backed bool
|
||||
v1Cert cert.Certificate
|
||||
v1HandshakeBytes []byte
|
||||
|
||||
v2Cert cert.Certificate
|
||||
v2HandshakeBytes []byte
|
||||
|
||||
defaultVersion cert.Version
|
||||
privateKey []byte
|
||||
pkcs11Backed bool
|
||||
cipher string
|
||||
|
||||
myVpnNetworks []netip.Prefix
|
||||
myVpnNetworksTable *bart.Table[struct{}]
|
||||
myVpnAddrs []netip.Addr
|
||||
myVpnAddrsTable *bart.Table[struct{}]
|
||||
myVpnBroadcastAddrsTable *bart.Table[struct{}]
|
||||
}
|
||||
|
||||
func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
|
||||
@@ -46,16 +62,16 @@ func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
|
||||
return pki, nil
|
||||
}
|
||||
|
||||
func (p *PKI) GetCertState() *CertState {
|
||||
return p.cs.Load()
|
||||
}
|
||||
|
||||
func (p *PKI) GetCAPool() *cert.CAPool {
|
||||
return p.caPool.Load()
|
||||
}
|
||||
|
||||
func (p *PKI) getCertState() *CertState {
|
||||
return p.cs.Load()
|
||||
}
|
||||
|
||||
func (p *PKI) reload(c *config.C, initial bool) error {
|
||||
err := p.reloadCert(c, initial)
|
||||
err := p.reloadCerts(c, initial)
|
||||
if err != nil {
|
||||
if initial {
|
||||
return err
|
||||
@@ -74,33 +90,94 @@ func (p *PKI) reload(c *config.C, initial bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError {
|
||||
cs, err := newCertStateFromConfig(c)
|
||||
func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
|
||||
newState, err := newCertStateFromConfig(c)
|
||||
if err != nil {
|
||||
return util.NewContextualError("Could not load client cert", nil, err)
|
||||
}
|
||||
|
||||
if !initial {
|
||||
//TODO: include check for mask equality as well
|
||||
currentState := p.cs.Load()
|
||||
if newState.v1Cert != nil {
|
||||
if currentState.v1Cert == nil {
|
||||
return util.NewContextualError("v1 certificate was added, restart required", nil, err)
|
||||
}
|
||||
|
||||
// did IP in cert change? if so, don't set
|
||||
currentCert := p.cs.Load().Certificate
|
||||
oldIPs := currentCert.Networks()
|
||||
newIPs := cs.Certificate.Networks()
|
||||
if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
|
||||
// did IP in cert change? if so, don't set
|
||||
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
|
||||
return util.NewContextualError(
|
||||
"Networks in new cert was different from old",
|
||||
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
|
||||
return util.NewContextualError(
|
||||
"Curve in new cert was different from old",
|
||||
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
} else if currentState.v1Cert != nil {
|
||||
//TODO: CERT-V2 we should be able to tear this down
|
||||
return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
|
||||
}
|
||||
|
||||
if newState.v2Cert != nil {
|
||||
if currentState.v2Cert == nil {
|
||||
return util.NewContextualError("v2 certificate was added, restart required", nil, err)
|
||||
}
|
||||
|
||||
// did IP in cert change? if so, don't set
|
||||
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
|
||||
return util.NewContextualError(
|
||||
"Networks in new cert was different from old",
|
||||
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
|
||||
return util.NewContextualError(
|
||||
"Curve in new cert was different from old",
|
||||
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
} else if currentState.v2Cert != nil {
|
||||
return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
|
||||
}
|
||||
|
||||
// Cipher cant be hot swapped so just leave it at what it was before
|
||||
newState.cipher = currentState.cipher
|
||||
|
||||
} else {
|
||||
newState.cipher = c.GetString("cipher", "aes")
|
||||
//TODO: this sucks and we should make it not a global
|
||||
switch newState.cipher {
|
||||
case "aes":
|
||||
noiseEndianness = binary.BigEndian
|
||||
case "chachapoly":
|
||||
noiseEndianness = binary.LittleEndian
|
||||
default:
|
||||
return util.NewContextualError(
|
||||
"Networks in new cert was different from old",
|
||||
m{"new_network": newIPs[0], "old_network": oldIPs[0]},
|
||||
"unknown cipher",
|
||||
m{"cipher": newState.cipher},
|
||||
nil,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
p.cs.Store(cs)
|
||||
p.cs.Store(newState)
|
||||
|
||||
//TODO: CERT-V2 newState needs a stringer that does json
|
||||
if initial {
|
||||
p.l.WithField("cert", cs.Certificate).Debug("Client nebula certificate")
|
||||
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
|
||||
} else {
|
||||
p.l.WithField("cert", cs.Certificate).Info("Client cert refreshed from disk")
|
||||
p.l.WithField("cert", newState).Info("Client certificate(s) refreshed from disk")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -116,55 +193,65 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newCertState(certificate cert.Certificate, pkcs11backed bool, privateKey []byte) (*CertState, error) {
|
||||
// Marshal the certificate to ensure it is valid
|
||||
rawCertificate, err := certificate.Marshal()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err)
|
||||
func (cs *CertState) GetDefaultCertificate() cert.Certificate {
|
||||
c := cs.getCertificate(cs.defaultVersion)
|
||||
if c == nil {
|
||||
panic("No default certificate found")
|
||||
}
|
||||
|
||||
publicKey := certificate.PublicKey()
|
||||
cs := &CertState{
|
||||
RawCertificate: rawCertificate,
|
||||
Certificate: certificate,
|
||||
PrivateKey: privateKey,
|
||||
PublicKey: publicKey,
|
||||
pkcs11Backed: pkcs11backed,
|
||||
}
|
||||
|
||||
rawCertNoKey, err := cs.Certificate.MarshalForHandshakes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshalling certificate no key: %s", err)
|
||||
}
|
||||
cs.RawCertificateNoKey = rawCertNoKey
|
||||
|
||||
return cs, nil
|
||||
return c
|
||||
}
|
||||
|
||||
func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) {
|
||||
var pemPrivateKey []byte
|
||||
if strings.Contains(privPathOrPEM, "-----BEGIN") {
|
||||
pemPrivateKey = []byte(privPathOrPEM)
|
||||
privPathOrPEM = "<inline>"
|
||||
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
|
||||
if err != nil {
|
||||
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
|
||||
}
|
||||
} else if strings.HasPrefix(privPathOrPEM, "pkcs11:") {
|
||||
rawKey = []byte(privPathOrPEM)
|
||||
return rawKey, cert.Curve_P256, true, nil
|
||||
} else {
|
||||
pemPrivateKey, err = os.ReadFile(privPathOrPEM)
|
||||
if err != nil {
|
||||
return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
|
||||
}
|
||||
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
|
||||
if err != nil {
|
||||
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
|
||||
}
|
||||
func (cs *CertState) getCertificate(v cert.Version) cert.Certificate {
|
||||
switch v {
|
||||
case cert.Version1:
|
||||
return cs.v1Cert
|
||||
case cert.Version2:
|
||||
return cs.v2Cert
|
||||
}
|
||||
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
// getHandshakeBytes returns the cached bytes to be used in a handshake message for the requested version.
|
||||
// Callers must check if the return []byte is nil.
|
||||
func (cs *CertState) getHandshakeBytes(v cert.Version) []byte {
|
||||
switch v {
|
||||
case cert.Version1:
|
||||
return cs.v1HandshakeBytes
|
||||
case cert.Version2:
|
||||
return cs.v2HandshakeBytes
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (cs *CertState) String() string {
|
||||
b, err := cs.MarshalJSON()
|
||||
if err != nil {
|
||||
return fmt.Sprintf("error marshaling certificate state: %v", err)
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func (cs *CertState) MarshalJSON() ([]byte, error) {
|
||||
msg := []json.RawMessage{}
|
||||
if cs.v1Cert != nil {
|
||||
b, err := cs.v1Cert.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg = append(msg, b)
|
||||
}
|
||||
|
||||
if cs.v2Cert != nil {
|
||||
b, err := cs.v2Cert.MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg = append(msg, b)
|
||||
}
|
||||
|
||||
return json.Marshal(msg)
|
||||
}
|
||||
|
||||
func newCertStateFromConfig(c *config.C) (*CertState, error) {
|
||||
@@ -198,24 +285,197 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
|
||||
}
|
||||
}
|
||||
|
||||
nebulaCert, _, err := cert.UnmarshalCertificateFromPEM(rawCert)
|
||||
var crt, v1, v2 cert.Certificate
|
||||
for {
|
||||
// Load the certificate
|
||||
crt, rawCert, err = loadCertificate(rawCert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch crt.Version() {
|
||||
case cert.Version1:
|
||||
if v1 != nil {
|
||||
return nil, fmt.Errorf("v1 certificate already found in pki.cert")
|
||||
}
|
||||
v1 = crt
|
||||
case cert.Version2:
|
||||
if v2 != nil {
|
||||
return nil, fmt.Errorf("v2 certificate already found in pki.cert")
|
||||
}
|
||||
v2 = crt
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown certificate version %v", crt.Version())
|
||||
}
|
||||
|
||||
if len(rawCert) == 0 || strings.TrimSpace(string(rawCert)) == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if v1 == nil && v2 == nil {
|
||||
return nil, errors.New("no certificates found in pki.cert")
|
||||
}
|
||||
|
||||
useDefaultVersion := uint32(1)
|
||||
if v1 == nil {
|
||||
// The only condition that requires v2 as the default is if only a v2 certificate is present
|
||||
// We do this to avoid having to configure it specifically in the config file
|
||||
useDefaultVersion = 2
|
||||
}
|
||||
|
||||
rawDefaultVersion := c.GetUint32("pki.default_version", useDefaultVersion)
|
||||
var defaultVersion cert.Version
|
||||
switch rawDefaultVersion {
|
||||
case 1:
|
||||
if v1 == nil {
|
||||
return nil, fmt.Errorf("can not use pki.default_version 1 without a v1 certificate in pki.cert")
|
||||
}
|
||||
defaultVersion = cert.Version1
|
||||
case 2:
|
||||
defaultVersion = cert.Version2
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown pki.default_version: %v", rawDefaultVersion)
|
||||
}
|
||||
|
||||
return newCertState(defaultVersion, v1, v2, isPkcs11, curve, rawKey)
|
||||
}
|
||||
|
||||
func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) {
|
||||
cs := CertState{
|
||||
privateKey: privateKey,
|
||||
pkcs11Backed: pkcs11backed,
|
||||
myVpnNetworksTable: new(bart.Table[struct{}]),
|
||||
myVpnAddrsTable: new(bart.Table[struct{}]),
|
||||
myVpnBroadcastAddrsTable: new(bart.Table[struct{}]),
|
||||
}
|
||||
|
||||
if v1 != nil && v2 != nil {
|
||||
if !slices.Equal(v1.PublicKey(), v2.PublicKey()) {
|
||||
return nil, util.NewContextualError("v1 and v2 public keys are not the same, ignoring", nil, nil)
|
||||
}
|
||||
|
||||
if v1.Curve() != v2.Curve() {
|
||||
return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
|
||||
}
|
||||
|
||||
//TODO: CERT-V2 make sure v2 has v1s address
|
||||
|
||||
cs.defaultVersion = dv
|
||||
}
|
||||
|
||||
if v1 != nil {
|
||||
if pkcs11backed {
|
||||
//NOTE: We do not currently have a method to verify a public private key pair when the private key is in an hsm
|
||||
} else {
|
||||
if err := v1.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil {
|
||||
return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
|
||||
}
|
||||
}
|
||||
|
||||
v1hs, err := v1.MarshalForHandshakes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err)
|
||||
}
|
||||
cs.v1Cert = v1
|
||||
cs.v1HandshakeBytes = v1hs
|
||||
|
||||
if cs.defaultVersion == 0 {
|
||||
cs.defaultVersion = cert.Version1
|
||||
}
|
||||
}
|
||||
|
||||
if v2 != nil {
|
||||
if pkcs11backed {
|
||||
//NOTE: We do not currently have a method to verify a public private key pair when the private key is in an hsm
|
||||
} else {
|
||||
if err := v2.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil {
|
||||
return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
|
||||
}
|
||||
}
|
||||
|
||||
v2hs, err := v2.MarshalForHandshakes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err)
|
||||
}
|
||||
cs.v2Cert = v2
|
||||
cs.v2HandshakeBytes = v2hs
|
||||
|
||||
if cs.defaultVersion == 0 {
|
||||
cs.defaultVersion = cert.Version2
|
||||
}
|
||||
}
|
||||
|
||||
var crt cert.Certificate
|
||||
crt = cs.getCertificate(cert.Version2)
|
||||
if crt == nil {
|
||||
// v2 certificates are a superset, only look at v1 if its all we have
|
||||
crt = cs.getCertificate(cert.Version1)
|
||||
}
|
||||
|
||||
for _, network := range crt.Networks() {
|
||||
cs.myVpnNetworks = append(cs.myVpnNetworks, network)
|
||||
cs.myVpnNetworksTable.Insert(network, struct{}{})
|
||||
|
||||
cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr())
|
||||
cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{})
|
||||
|
||||
if network.Addr().Is4() {
|
||||
addr := network.Masked().Addr().As4()
|
||||
mask := net.CIDRMask(network.Bits(), network.Addr().BitLen())
|
||||
binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask))
|
||||
cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()), struct{}{})
|
||||
}
|
||||
}
|
||||
|
||||
return &cs, nil
|
||||
}
|
||||
|
||||
func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) {
|
||||
var pemPrivateKey []byte
|
||||
if strings.Contains(privPathOrPEM, "-----BEGIN") {
|
||||
pemPrivateKey = []byte(privPathOrPEM)
|
||||
privPathOrPEM = "<inline>"
|
||||
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
|
||||
if err != nil {
|
||||
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
|
||||
}
|
||||
} else if strings.HasPrefix(privPathOrPEM, "pkcs11:") {
|
||||
rawKey = []byte(privPathOrPEM)
|
||||
return rawKey, cert.Curve_P256, true, nil
|
||||
} else {
|
||||
pemPrivateKey, err = os.ReadFile(privPathOrPEM)
|
||||
if err != nil {
|
||||
return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
|
||||
}
|
||||
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
|
||||
if err != nil {
|
||||
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func loadCertificate(b []byte) (cert.Certificate, []byte, error) {
|
||||
c, b, err := cert.UnmarshalCertificateFromPEM(b)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err)
|
||||
return nil, b, fmt.Errorf("error while unmarshaling pki.cert: %w", err)
|
||||
}
|
||||
|
||||
if nebulaCert.Expired(time.Now()) {
|
||||
return nil, fmt.Errorf("nebula certificate for this host is expired")
|
||||
if c.Expired(time.Now()) {
|
||||
return nil, b, fmt.Errorf("nebula certificate for this host is expired")
|
||||
}
|
||||
|
||||
if len(nebulaCert.Networks()) == 0 {
|
||||
return nil, fmt.Errorf("no networks encoded in certificate")
|
||||
if len(c.Networks()) == 0 {
|
||||
return nil, b, fmt.Errorf("no networks encoded in certificate")
|
||||
}
|
||||
|
||||
if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil {
|
||||
return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
|
||||
if c.IsCA() {
|
||||
return nil, b, fmt.Errorf("host certificate is a CA certificate")
|
||||
}
|
||||
|
||||
return newCertState(nebulaCert, isPkcs11, rawKey)
|
||||
return c, b, nil
|
||||
}
|
||||
|
||||
func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
|
||||
|
||||
Reference in New Issue
Block a user