V2 certificate format (#1216)

Co-authored-by: Nate Brown <nbrown.us@gmail.com>
Co-authored-by: Jack Doan <jackdoan@rivian.com>
Co-authored-by: brad-defined <77982333+brad-defined@users.noreply.github.com>
Co-authored-by: Jack Doan <me@jackdoan.com>
This commit is contained in:
Nate Brown
2025-03-06 11:28:26 -06:00
committed by GitHub
parent 2b427a7e89
commit d97ed57a19
105 changed files with 8276 additions and 4528 deletions

414
pki.go
View File

@@ -1,13 +1,19 @@
package nebula
import (
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"net"
"net/netip"
"os"
"slices"
"strings"
"sync/atomic"
"time"
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
@@ -21,12 +27,22 @@ type PKI struct {
}
type CertState struct {
Certificate cert.Certificate
RawCertificate []byte
RawCertificateNoKey []byte
PublicKey []byte
PrivateKey []byte
pkcs11Backed bool
v1Cert cert.Certificate
v1HandshakeBytes []byte
v2Cert cert.Certificate
v2HandshakeBytes []byte
defaultVersion cert.Version
privateKey []byte
pkcs11Backed bool
cipher string
myVpnNetworks []netip.Prefix
myVpnNetworksTable *bart.Table[struct{}]
myVpnAddrs []netip.Addr
myVpnAddrsTable *bart.Table[struct{}]
myVpnBroadcastAddrsTable *bart.Table[struct{}]
}
func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
@@ -46,16 +62,16 @@ func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
return pki, nil
}
func (p *PKI) GetCertState() *CertState {
return p.cs.Load()
}
func (p *PKI) GetCAPool() *cert.CAPool {
return p.caPool.Load()
}
func (p *PKI) getCertState() *CertState {
return p.cs.Load()
}
func (p *PKI) reload(c *config.C, initial bool) error {
err := p.reloadCert(c, initial)
err := p.reloadCerts(c, initial)
if err != nil {
if initial {
return err
@@ -74,33 +90,94 @@ func (p *PKI) reload(c *config.C, initial bool) error {
return nil
}
func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError {
cs, err := newCertStateFromConfig(c)
func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
newState, err := newCertStateFromConfig(c)
if err != nil {
return util.NewContextualError("Could not load client cert", nil, err)
}
if !initial {
//TODO: include check for mask equality as well
currentState := p.cs.Load()
if newState.v1Cert != nil {
if currentState.v1Cert == nil {
return util.NewContextualError("v1 certificate was added, restart required", nil, err)
}
// did IP in cert change? if so, don't set
currentCert := p.cs.Load().Certificate
oldIPs := currentCert.Networks()
newIPs := cs.Certificate.Networks()
if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
// did IP in cert change? if so, don't set
if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
return util.NewContextualError(
"Networks in new cert was different from old",
m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
nil,
)
}
if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
return util.NewContextualError(
"Curve in new cert was different from old",
m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
nil,
)
}
} else if currentState.v1Cert != nil {
//TODO: CERT-V2 we should be able to tear this down
return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
}
if newState.v2Cert != nil {
if currentState.v2Cert == nil {
return util.NewContextualError("v2 certificate was added, restart required", nil, err)
}
// did IP in cert change? if so, don't set
if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
return util.NewContextualError(
"Networks in new cert was different from old",
m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
nil,
)
}
if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
return util.NewContextualError(
"Curve in new cert was different from old",
m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
nil,
)
}
} else if currentState.v2Cert != nil {
return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
}
// Cipher cant be hot swapped so just leave it at what it was before
newState.cipher = currentState.cipher
} else {
newState.cipher = c.GetString("cipher", "aes")
//TODO: this sucks and we should make it not a global
switch newState.cipher {
case "aes":
noiseEndianness = binary.BigEndian
case "chachapoly":
noiseEndianness = binary.LittleEndian
default:
return util.NewContextualError(
"Networks in new cert was different from old",
m{"new_network": newIPs[0], "old_network": oldIPs[0]},
"unknown cipher",
m{"cipher": newState.cipher},
nil,
)
}
}
p.cs.Store(cs)
p.cs.Store(newState)
//TODO: CERT-V2 newState needs a stringer that does json
if initial {
p.l.WithField("cert", cs.Certificate).Debug("Client nebula certificate")
p.l.WithField("cert", newState).Debug("Client nebula certificate(s)")
} else {
p.l.WithField("cert", cs.Certificate).Info("Client cert refreshed from disk")
p.l.WithField("cert", newState).Info("Client certificate(s) refreshed from disk")
}
return nil
}
@@ -116,55 +193,65 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
return nil
}
func newCertState(certificate cert.Certificate, pkcs11backed bool, privateKey []byte) (*CertState, error) {
// Marshal the certificate to ensure it is valid
rawCertificate, err := certificate.Marshal()
if err != nil {
return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err)
func (cs *CertState) GetDefaultCertificate() cert.Certificate {
c := cs.getCertificate(cs.defaultVersion)
if c == nil {
panic("No default certificate found")
}
publicKey := certificate.PublicKey()
cs := &CertState{
RawCertificate: rawCertificate,
Certificate: certificate,
PrivateKey: privateKey,
PublicKey: publicKey,
pkcs11Backed: pkcs11backed,
}
rawCertNoKey, err := cs.Certificate.MarshalForHandshakes()
if err != nil {
return nil, fmt.Errorf("error marshalling certificate no key: %s", err)
}
cs.RawCertificateNoKey = rawCertNoKey
return cs, nil
return c
}
func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) {
var pemPrivateKey []byte
if strings.Contains(privPathOrPEM, "-----BEGIN") {
pemPrivateKey = []byte(privPathOrPEM)
privPathOrPEM = "<inline>"
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
if err != nil {
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
}
} else if strings.HasPrefix(privPathOrPEM, "pkcs11:") {
rawKey = []byte(privPathOrPEM)
return rawKey, cert.Curve_P256, true, nil
} else {
pemPrivateKey, err = os.ReadFile(privPathOrPEM)
if err != nil {
return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
}
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
if err != nil {
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
}
func (cs *CertState) getCertificate(v cert.Version) cert.Certificate {
switch v {
case cert.Version1:
return cs.v1Cert
case cert.Version2:
return cs.v2Cert
}
return
return nil
}
// getHandshakeBytes returns the cached bytes to be used in a handshake message for the requested version.
// Callers must check if the return []byte is nil.
func (cs *CertState) getHandshakeBytes(v cert.Version) []byte {
switch v {
case cert.Version1:
return cs.v1HandshakeBytes
case cert.Version2:
return cs.v2HandshakeBytes
default:
return nil
}
}
func (cs *CertState) String() string {
b, err := cs.MarshalJSON()
if err != nil {
return fmt.Sprintf("error marshaling certificate state: %v", err)
}
return string(b)
}
func (cs *CertState) MarshalJSON() ([]byte, error) {
msg := []json.RawMessage{}
if cs.v1Cert != nil {
b, err := cs.v1Cert.MarshalJSON()
if err != nil {
return nil, err
}
msg = append(msg, b)
}
if cs.v2Cert != nil {
b, err := cs.v2Cert.MarshalJSON()
if err != nil {
return nil, err
}
msg = append(msg, b)
}
return json.Marshal(msg)
}
func newCertStateFromConfig(c *config.C) (*CertState, error) {
@@ -198,24 +285,197 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
}
}
nebulaCert, _, err := cert.UnmarshalCertificateFromPEM(rawCert)
var crt, v1, v2 cert.Certificate
for {
// Load the certificate
crt, rawCert, err = loadCertificate(rawCert)
if err != nil {
return nil, err
}
switch crt.Version() {
case cert.Version1:
if v1 != nil {
return nil, fmt.Errorf("v1 certificate already found in pki.cert")
}
v1 = crt
case cert.Version2:
if v2 != nil {
return nil, fmt.Errorf("v2 certificate already found in pki.cert")
}
v2 = crt
default:
return nil, fmt.Errorf("unknown certificate version %v", crt.Version())
}
if len(rawCert) == 0 || strings.TrimSpace(string(rawCert)) == "" {
break
}
}
if v1 == nil && v2 == nil {
return nil, errors.New("no certificates found in pki.cert")
}
useDefaultVersion := uint32(1)
if v1 == nil {
// The only condition that requires v2 as the default is if only a v2 certificate is present
// We do this to avoid having to configure it specifically in the config file
useDefaultVersion = 2
}
rawDefaultVersion := c.GetUint32("pki.default_version", useDefaultVersion)
var defaultVersion cert.Version
switch rawDefaultVersion {
case 1:
if v1 == nil {
return nil, fmt.Errorf("can not use pki.default_version 1 without a v1 certificate in pki.cert")
}
defaultVersion = cert.Version1
case 2:
defaultVersion = cert.Version2
default:
return nil, fmt.Errorf("unknown pki.default_version: %v", rawDefaultVersion)
}
return newCertState(defaultVersion, v1, v2, isPkcs11, curve, rawKey)
}
func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) {
cs := CertState{
privateKey: privateKey,
pkcs11Backed: pkcs11backed,
myVpnNetworksTable: new(bart.Table[struct{}]),
myVpnAddrsTable: new(bart.Table[struct{}]),
myVpnBroadcastAddrsTable: new(bart.Table[struct{}]),
}
if v1 != nil && v2 != nil {
if !slices.Equal(v1.PublicKey(), v2.PublicKey()) {
return nil, util.NewContextualError("v1 and v2 public keys are not the same, ignoring", nil, nil)
}
if v1.Curve() != v2.Curve() {
return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil)
}
//TODO: CERT-V2 make sure v2 has v1s address
cs.defaultVersion = dv
}
if v1 != nil {
if pkcs11backed {
//NOTE: We do not currently have a method to verify a public private key pair when the private key is in an hsm
} else {
if err := v1.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil {
return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
}
}
v1hs, err := v1.MarshalForHandshakes()
if err != nil {
return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err)
}
cs.v1Cert = v1
cs.v1HandshakeBytes = v1hs
if cs.defaultVersion == 0 {
cs.defaultVersion = cert.Version1
}
}
if v2 != nil {
if pkcs11backed {
//NOTE: We do not currently have a method to verify a public private key pair when the private key is in an hsm
} else {
if err := v2.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil {
return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
}
}
v2hs, err := v2.MarshalForHandshakes()
if err != nil {
return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err)
}
cs.v2Cert = v2
cs.v2HandshakeBytes = v2hs
if cs.defaultVersion == 0 {
cs.defaultVersion = cert.Version2
}
}
var crt cert.Certificate
crt = cs.getCertificate(cert.Version2)
if crt == nil {
// v2 certificates are a superset, only look at v1 if its all we have
crt = cs.getCertificate(cert.Version1)
}
for _, network := range crt.Networks() {
cs.myVpnNetworks = append(cs.myVpnNetworks, network)
cs.myVpnNetworksTable.Insert(network, struct{}{})
cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr())
cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{})
if network.Addr().Is4() {
addr := network.Masked().Addr().As4()
mask := net.CIDRMask(network.Bits(), network.Addr().BitLen())
binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask))
cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()), struct{}{})
}
}
return &cs, nil
}
func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) {
var pemPrivateKey []byte
if strings.Contains(privPathOrPEM, "-----BEGIN") {
pemPrivateKey = []byte(privPathOrPEM)
privPathOrPEM = "<inline>"
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
if err != nil {
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
}
} else if strings.HasPrefix(privPathOrPEM, "pkcs11:") {
rawKey = []byte(privPathOrPEM)
return rawKey, cert.Curve_P256, true, nil
} else {
pemPrivateKey, err = os.ReadFile(privPathOrPEM)
if err != nil {
return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
}
rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey)
if err != nil {
return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
}
}
return
}
func loadCertificate(b []byte) (cert.Certificate, []byte, error) {
c, b, err := cert.UnmarshalCertificateFromPEM(b)
if err != nil {
return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err)
return nil, b, fmt.Errorf("error while unmarshaling pki.cert: %w", err)
}
if nebulaCert.Expired(time.Now()) {
return nil, fmt.Errorf("nebula certificate for this host is expired")
if c.Expired(time.Now()) {
return nil, b, fmt.Errorf("nebula certificate for this host is expired")
}
if len(nebulaCert.Networks()) == 0 {
return nil, fmt.Errorf("no networks encoded in certificate")
if len(c.Networks()) == 0 {
return nil, b, fmt.Errorf("no networks encoded in certificate")
}
if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil {
return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
if c.IsCA() {
return nil, b, fmt.Errorf("host certificate is a CA certificate")
}
return newCertState(nebulaCert, isPkcs11, rawKey)
return c, b, nil
}
func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {