mirror of
https://github.com/slackhq/nebula.git
synced 2026-02-14 08:44:24 +01:00
Refactor CA pool handling to use streaming
This commit is contained in:
@@ -1,11 +1,13 @@
|
|||||||
package cert
|
package cert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -29,21 +31,55 @@ func NewCAPool() *CAPool {
|
|||||||
// If the pool contains any expired certificates, an ErrExpired will be
|
// If the pool contains any expired certificates, an ErrExpired will be
|
||||||
// returned along with the pool. The caller must handle any such errors.
|
// returned along with the pool. The caller must handle any such errors.
|
||||||
func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) {
|
func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) {
|
||||||
|
return NewCAPoolFromPEMReader(bytes.NewReader(caPEMs))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCAPoolFromPEMReader will create a new CA pool from the provided reader.
|
||||||
|
// The reader must contain a PEM-encoded set of nebula certificates.
|
||||||
|
func NewCAPoolFromPEMReader(r io.Reader) (*CAPool, error) {
|
||||||
pool := NewCAPool()
|
pool := NewCAPool()
|
||||||
var err error
|
buf := make([]byte, 0, 64*1024)
|
||||||
|
tmp := make([]byte, 32*1024)
|
||||||
var expired bool
|
var expired bool
|
||||||
|
|
||||||
for {
|
for {
|
||||||
caPEMs, err = pool.AddCAFromPEM(caPEMs)
|
n, err := r.Read(tmp)
|
||||||
if errors.Is(err, ErrExpired) {
|
if n > 0 {
|
||||||
expired = true
|
buf = append(buf, tmp[:n]...)
|
||||||
err = nil
|
|
||||||
|
for {
|
||||||
|
var block *pem.Block
|
||||||
|
block, buf = pem.Decode(buf)
|
||||||
|
if block == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := unmarshalCertificateBlock(block)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pool.AddCA(c)
|
||||||
|
if errors.Is(err, ErrExpired) {
|
||||||
|
expired = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
break
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
|
}
|
||||||
break
|
|
||||||
}
|
if len(bytes.TrimSpace(buf)) > 0 {
|
||||||
|
return nil, ErrInvalidPEMBlock
|
||||||
}
|
}
|
||||||
|
|
||||||
if expired {
|
if expired {
|
||||||
|
|||||||
28
cert/pem.go
28
cert/pem.go
@@ -37,19 +37,7 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
|
|||||||
return nil, r, ErrInvalidPEMBlock
|
return nil, r, ErrInvalidPEMBlock
|
||||||
}
|
}
|
||||||
|
|
||||||
var c Certificate
|
c, err := unmarshalCertificateBlock(p)
|
||||||
var err error
|
|
||||||
|
|
||||||
switch p.Type {
|
|
||||||
// Implementations must validate the resulting certificate contains valid information
|
|
||||||
case CertificateBanner:
|
|
||||||
c, err = unmarshalCertificateV1(p.Bytes, nil)
|
|
||||||
case CertificateV2Banner:
|
|
||||||
c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519)
|
|
||||||
default:
|
|
||||||
return nil, r, ErrInvalidPEMCertificateBanner
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, r, err
|
return nil, r, err
|
||||||
}
|
}
|
||||||
@@ -58,6 +46,20 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// unmarshalCertificateBlock decodes a single PEM block into a certificate.
|
||||||
|
// It expects a Nebula certificate banner and returns ErrInvalidPEMCertificateBanner otherwise.
|
||||||
|
func unmarshalCertificateBlock(block *pem.Block) (Certificate, error) {
|
||||||
|
switch block.Type {
|
||||||
|
// Implementations must validate the resulting certificate contains valid information
|
||||||
|
case CertificateBanner:
|
||||||
|
return unmarshalCertificateV1(block.Bytes, nil)
|
||||||
|
case CertificateV2Banner:
|
||||||
|
return unmarshalCertificateV2(block.Bytes, nil, Curve_CURVE25519)
|
||||||
|
default:
|
||||||
|
return nil, ErrInvalidPEMCertificateBanner
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func marshalCertPublicKeyToPEM(c Certificate) []byte {
|
func marshalCertPublicKeyToPEM(c Certificate) []byte {
|
||||||
if c.IsCA() {
|
if c.IsCA() {
|
||||||
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())
|
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())
|
||||||
|
|||||||
15
pki.go
15
pki.go
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
@@ -487,25 +488,25 @@ func loadCertificate(b []byte) (cert.Certificate, []byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
|
func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
|
||||||
var rawCA []byte
|
|
||||||
var err error
|
|
||||||
|
|
||||||
caPathOrPEM := c.GetString("pki.ca", "")
|
caPathOrPEM := c.GetString("pki.ca", "")
|
||||||
if caPathOrPEM == "" {
|
if caPathOrPEM == "" {
|
||||||
return nil, errors.New("no pki.ca path or PEM data provided")
|
return nil, errors.New("no pki.ca path or PEM data provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.Contains(caPathOrPEM, "-----BEGIN") {
|
var caReader io.ReadCloser
|
||||||
rawCA = []byte(caPathOrPEM)
|
var err error
|
||||||
|
|
||||||
|
if strings.Contains(caPathOrPEM, "-----BEGIN") {
|
||||||
|
caReader = io.NopCloser(strings.NewReader(caPathOrPEM))
|
||||||
} else {
|
} else {
|
||||||
rawCA, err = os.ReadFile(caPathOrPEM)
|
caReader, err = os.Open(caPathOrPEM)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err)
|
return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
defer caReader.Close()
|
||||||
|
|
||||||
caPool, err := cert.NewCAPoolFromPEM(rawCA)
|
caPool, err := cert.NewCAPoolFromPEMReader(caReader)
|
||||||
if errors.Is(err, cert.ErrExpired) {
|
if errors.Is(err, cert.ErrExpired) {
|
||||||
var expired int
|
var expired int
|
||||||
for _, crt := range caPool.CAs {
|
for _, crt := range caPool.CAs {
|
||||||
|
|||||||
120
pki_hup_benchmark_test.go
Normal file
120
pki_hup_benchmark_test.go
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
package nebula
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/slackhq/nebula/cert"
|
||||||
|
cert_test "github.com/slackhq/nebula/cert_test"
|
||||||
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/test"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkReloadConfigWithCAs(b *testing.B) {
|
||||||
|
prevProcs := runtime.GOMAXPROCS(1)
|
||||||
|
b.Cleanup(func() { runtime.GOMAXPROCS(prevProcs) })
|
||||||
|
|
||||||
|
for _, size := range []int{100, 250, 500, 1000, 5000} {
|
||||||
|
b.Run(fmt.Sprintf("%dCAs", size), func(b *testing.B) {
|
||||||
|
l := test.NewLogger()
|
||||||
|
dir := b.TempDir()
|
||||||
|
|
||||||
|
ca, caKey, caBundle := buildCABundle(b, size)
|
||||||
|
caPath, certPath, keyPath := writePKIFiles(b, dir, ca, caKey, caBundle)
|
||||||
|
|
||||||
|
configBody := fmt.Sprintf(`pki:
|
||||||
|
ca: %s
|
||||||
|
cert: %s
|
||||||
|
key: %s
|
||||||
|
`, caPath, certPath, keyPath)
|
||||||
|
|
||||||
|
configPath := filepath.Join(dir, "config.yml")
|
||||||
|
require.NoError(b, os.WriteFile(configPath, []byte(configBody), 0o600))
|
||||||
|
|
||||||
|
c := config.NewC(l)
|
||||||
|
require.NoError(b, c.Load(dir))
|
||||||
|
|
||||||
|
_, err := NewPKIFromConfig(l, c)
|
||||||
|
require.NoError(b, err)
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for b.Loop() {
|
||||||
|
c.ReloadConfig()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildCABundle(b *testing.B, count int) (cert.Certificate, []byte, []byte) {
|
||||||
|
b.Helper()
|
||||||
|
require.GreaterOrEqual(b, count, 1)
|
||||||
|
|
||||||
|
before := time.Now().Add(-24 * time.Hour)
|
||||||
|
after := time.Now().Add(24 * time.Hour)
|
||||||
|
|
||||||
|
ca, _, caKey, pem := cert_test.NewTestCaCert(
|
||||||
|
cert.Version2,
|
||||||
|
cert.Curve_CURVE25519,
|
||||||
|
before,
|
||||||
|
after,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
buf := bytes.NewBuffer(pem)
|
||||||
|
|
||||||
|
for i := 1; i < count; i++ {
|
||||||
|
_, _, _, extraPEM := cert_test.NewTestCaCert(
|
||||||
|
cert.Version2,
|
||||||
|
cert.Curve_CURVE25519,
|
||||||
|
time.Now(),
|
||||||
|
time.Now().Add(time.Hour),
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
buf.Write(extraPEM)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ca, caKey, buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func writePKIFiles(b *testing.B, dir string, ca cert.Certificate, caKey []byte, caBundle []byte) (string, string, string) {
|
||||||
|
b.Helper()
|
||||||
|
|
||||||
|
networks := []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}
|
||||||
|
|
||||||
|
_, _, keyPEM, certPEM := cert_test.NewTestCert(
|
||||||
|
cert.Version2,
|
||||||
|
cert.Curve_CURVE25519,
|
||||||
|
ca,
|
||||||
|
caKey,
|
||||||
|
"reload-benchmark",
|
||||||
|
time.Now(),
|
||||||
|
time.Now().Add(time.Hour),
|
||||||
|
networks,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
caPath := filepath.Join(dir, "ca.pem")
|
||||||
|
certPath := filepath.Join(dir, "cert.pem")
|
||||||
|
keyPath := filepath.Join(dir, "key.pem")
|
||||||
|
|
||||||
|
require.NoError(b, os.WriteFile(caPath, caBundle, 0o600))
|
||||||
|
require.NoError(b, os.WriteFile(certPath, certPEM, 0o600))
|
||||||
|
require.NoError(b, os.WriteFile(keyPath, keyPEM, 0o600))
|
||||||
|
|
||||||
|
return caPath, certPath, keyPath
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user