Refactor CA pool handling to use streaming (#1644)
Some checks failed
gofmt / Run gofmt (push) Failing after 3s
smoke-extra / Run extra smoke tests (push) Failing after 3s
smoke / Run multi node smoke test (push) Failing after 3s
Build and test / Build all and test on ubuntu-linux (push) Failing after 2s
Build and test / Build and test on linux with boringcrypto (push) Failing after 3s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled

Co-authored-by: maggie44 <64841595+maggie44@users.noreply.github.com>
Co-authored-by: JackDoan <me@jackdoan.com>
This commit is contained in:
John Maguire
2026-04-13 13:19:55 -04:00
committed by GitHub
parent 6727113b2b
commit 0ad5c771e9
8 changed files with 373 additions and 42 deletions

View File

@@ -1,11 +1,14 @@
package cert package cert
import ( import (
"bufio"
"bytes"
"encoding/pem"
"errors" "errors"
"fmt" "fmt"
"io"
"net/netip" "net/netip"
"slices" "slices"
"strings"
"time" "time"
) )
@@ -29,22 +32,46 @@ func NewCAPool() *CAPool {
// If the pool contains any expired certificates, an ErrExpired will be // If the pool contains any expired certificates, an ErrExpired will be
// returned along with the pool. The caller must handle any such errors. // returned along with the pool. The caller must handle any such errors.
func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) { func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) {
return NewCAPoolFromPEMReader(bytes.NewReader(caPEMs))
}
// NewCAPoolFromPEMReader will create a new CA pool from the provided reader.
// The reader must contain a PEM-encoded set of nebula certificates.
func NewCAPoolFromPEMReader(r io.Reader) (*CAPool, error) {
pool := NewCAPool() pool := NewCAPool()
var err error
var expired bool var expired bool
for {
caPEMs, err = pool.AddCAFromPEM(caPEMs) scanner := bufio.NewScanner(r)
if errors.Is(err, ErrExpired) { scanner.Split(SplitPEM)
expired = true
err = nil for scanner.Scan() {
pemBytes := scanner.Bytes()
block, rest := pem.Decode(pemBytes)
if len(bytes.TrimSpace(rest)) > 0 {
return nil, ErrInvalidPEMBlock
} }
if block == nil {
return nil, ErrInvalidPEMBlock
}
c, err := unmarshalCertificateBlock(block)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
break err = pool.AddCA(c)
if errors.Is(err, ErrExpired) {
expired = true
continue
} else if err != nil {
return nil, err
} }
} }
if err := scanner.Err(); err != nil {
return nil, ErrInvalidPEMBlock
}
if expired { if expired {
return pool, ErrExpired return pool, ErrExpired

View File

@@ -1,7 +1,10 @@
package cert package cert
import ( import (
"bytes"
"io"
"net/netip" "net/netip"
"strings"
"testing" "testing"
"time" "time"
@@ -112,6 +115,60 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe
assert.Len(t, ppppp.CAs, 1) assert.Len(t, ppppp.CAs, 1)
} }
// oneByteReader wraps a reader to return at most 1 byte per Read call,
// exercising the streaming accumulation logic in NewCAPoolFromPEMReader.
type oneByteReader struct {
r io.Reader
}
func (o *oneByteReader) Read(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
return o.r.Read(p[:1])
}
func TestNewCAPoolFromPEMReader_EmptyReader(t *testing.T) {
pool, err := NewCAPoolFromPEMReader(bytes.NewReader(nil))
require.NoError(t, err)
assert.Empty(t, pool.CAs)
pool, err = NewCAPoolFromPEMReader(strings.NewReader(" \n\t\n "))
require.NoError(t, err)
assert.Empty(t, pool.CAs)
}
func TestNewCAPoolFromPEMReader_OneByteReads(t *testing.T) {
ca1, _, _, pem1 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil)
ca2, _, _, pem2 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil)
bundle := append(pem1, pem2...)
pool, err := NewCAPoolFromPEMReader(&oneByteReader{r: bytes.NewReader(bundle)})
require.NoError(t, err)
assert.Len(t, pool.CAs, 2)
fp1, err := ca1.Fingerprint()
require.NoError(t, err)
fp2, err := ca2.Fingerprint()
require.NoError(t, err)
assert.Contains(t, pool.CAs, fp1)
assert.Contains(t, pool.CAs, fp2)
}
func TestNewCAPoolFromPEMReader_TruncatedPEM(t *testing.T) {
_, err := NewCAPoolFromPEMReader(strings.NewReader("-----BEGIN NEBULA CERTIFICATE-----\npartialdata"))
assert.ErrorIs(t, err, ErrInvalidPEMBlock)
}
func TestNewCAPoolFromPEMReader_TrailingGarbage(t *testing.T) {
_, _, _, pem1 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil)
bundle := append(pem1, []byte("some trailing garbage")...)
_, err := NewCAPoolFromPEMReader(bytes.NewReader(bundle))
assert.ErrorIs(t, err, ErrInvalidPEMBlock)
}
func TestCertificateV1_Verify(t *testing.T) { func TestCertificateV1_Verify(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)

View File

@@ -1,12 +1,66 @@
package cert package cert
import ( import (
"bytes"
"encoding/pem" "encoding/pem"
"errors"
"fmt" "fmt"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
) )
var ErrTruncatedPEMBlock = errors.New("truncated PEM block")
// SplitPEM is a split function for bufio.Scanner that returns each PEM block.
func SplitPEM(data []byte, atEOF bool) (advance int, token []byte, err error) {
// Look for the start of a PEM block
start := bytes.Index(data, []byte("-----BEGIN "))
if start == -1 {
if atEOF && len(bytes.TrimSpace(data)) > 0 {
// Non-whitespace content with no PEM block
return 0, nil, ErrTruncatedPEMBlock
}
if atEOF {
return len(data), nil, nil
}
// Request more data
return 0, nil, nil
}
// Look for the end marker
endMarkerStart := bytes.Index(data[start:], []byte("-----END "))
if endMarkerStart == -1 {
if atEOF {
// Incomplete PEM block at EOF
return 0, nil, ErrTruncatedPEMBlock
}
// Need more data to find the end
return 0, nil, nil
}
// Find the actual end of the END line (after the newline)
endMarkerStart += start
endLineEnd := bytes.IndexByte(data[endMarkerStart:], '\n')
var end int
if endLineEnd == -1 {
if atEOF {
// END marker without newline at EOF - take it anyway
end = len(data)
} else {
// Need more data
return 0, nil, nil
}
} else {
end = endMarkerStart + endLineEnd + 1
}
// Extract the PEM block
pemBlock := data[start:end]
// Return the valid PEM block
return end, pemBlock, nil
}
const ( //cert banners const ( //cert banners
CertificateBanner = "NEBULA CERTIFICATE" CertificateBanner = "NEBULA CERTIFICATE"
CertificateV2Banner = "NEBULA CERTIFICATE V2" CertificateV2Banner = "NEBULA CERTIFICATE V2"
@@ -37,19 +91,7 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
return nil, r, ErrInvalidPEMBlock return nil, r, ErrInvalidPEMBlock
} }
var c Certificate c, err := unmarshalCertificateBlock(p)
var err error
switch p.Type {
// Implementations must validate the resulting certificate contains valid information
case CertificateBanner:
c, err = unmarshalCertificateV1(p.Bytes, nil)
case CertificateV2Banner:
c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519)
default:
return nil, r, ErrInvalidPEMCertificateBanner
}
if err != nil { if err != nil {
return nil, r, err return nil, r, err
} }
@@ -58,6 +100,20 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
} }
// unmarshalCertificateBlock decodes a single PEM block into a certificate.
// It expects a Nebula certificate banner and returns ErrInvalidPEMCertificateBanner otherwise.
func unmarshalCertificateBlock(block *pem.Block) (Certificate, error) {
switch block.Type {
// Implementations must validate the resulting certificate contains valid information
case CertificateBanner:
return unmarshalCertificateV1(block.Bytes, nil)
case CertificateV2Banner:
return unmarshalCertificateV2(block.Bytes, nil, Curve_CURVE25519)
default:
return nil, ErrInvalidPEMCertificateBanner
}
}
func marshalCertPublicKeyToPEM(c Certificate) []byte { func marshalCertPublicKeyToPEM(c Certificate) []byte {
if c.IsCA() { if c.IsCA() {
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey()) return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())

View File

@@ -1,12 +1,88 @@
package cert package cert
import ( import (
"bufio"
"strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func scanAll(t *testing.T, input string) ([]string, error) {
t.Helper()
scanner := bufio.NewScanner(strings.NewReader(input))
scanner.Split(SplitPEM)
var blocks []string
for scanner.Scan() {
blocks = append(blocks, scanner.Text())
}
return blocks, scanner.Err()
}
func TestSplitPEM_Single(t *testing.T) {
input := "-----BEGIN TEST-----\ndata\n-----END TEST-----\n"
blocks, err := scanAll(t, input)
require.NoError(t, err)
require.Len(t, blocks, 1)
require.Equal(t, input, blocks[0])
}
func TestSplitPEM_Multiple(t *testing.T) {
block1 := "-----BEGIN TEST-----\naaa\n-----END TEST-----\n"
block2 := "-----BEGIN TEST-----\nbbb\n-----END TEST-----\n"
blocks, err := scanAll(t, block1+block2)
require.NoError(t, err)
require.Len(t, blocks, 2)
require.Equal(t, block1, blocks[0])
require.Equal(t, block2, blocks[1])
}
func TestSplitPEM_CommentsAndWhitespaceBetweenBlocks(t *testing.T) {
input := "# comment\n\n-----BEGIN TEST-----\naaa\n-----END TEST-----\n\n# another comment\n\n-----BEGIN TEST-----\nbbb\n-----END TEST-----\n"
blocks, err := scanAll(t, input)
require.NoError(t, err)
require.Len(t, blocks, 2)
}
func TestSplitPEM_Empty(t *testing.T) {
blocks, err := scanAll(t, "")
require.NoError(t, err)
require.Empty(t, blocks)
}
func TestSplitPEM_WhitespaceOnly(t *testing.T) {
blocks, err := scanAll(t, " \n\t\n ")
require.NoError(t, err)
require.Empty(t, blocks)
}
func TestSplitPEM_TrailingGarbage(t *testing.T) {
input := "-----BEGIN TEST-----\ndata\n-----END TEST-----\ngarbage"
blocks, err := scanAll(t, input)
require.ErrorIs(t, err, ErrTruncatedPEMBlock)
require.Len(t, blocks, 1)
}
func TestSplitPEM_TruncatedBlock(t *testing.T) {
input := "-----BEGIN TEST-----\npartial data with no end"
_, err := scanAll(t, input)
require.ErrorIs(t, err, ErrTruncatedPEMBlock)
}
func TestSplitPEM_NoEndNewline(t *testing.T) {
input := "-----BEGIN TEST-----\ndata\n-----END TEST-----"
blocks, err := scanAll(t, input)
require.NoError(t, err)
require.Len(t, blocks, 1)
require.Equal(t, input, blocks[0])
}
func TestSplitPEM_GarbageOnly(t *testing.T) {
_, err := scanAll(t, "this is not PEM data")
require.ErrorIs(t, err, ErrTruncatedPEMBlock)
}
func TestUnmarshalCertificateFromPEM(t *testing.T) { func TestUnmarshalCertificateFromPEM(t *testing.T) {
goodCert := []byte(` goodCert := []byte(`
# A good cert # A good cert

View File

@@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"strings"
"time" "time"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
@@ -40,21 +39,15 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
return err return err
} }
rawCACert, err := os.ReadFile(*vf.caPath) caFile, err := os.Open(*vf.caPath)
if err != nil { if err != nil {
return fmt.Errorf("error while reading ca: %w", err) return fmt.Errorf("error while reading ca: %w", err)
} }
defer caFile.Close()
caPool := cert.NewCAPool() caPool, err := cert.NewCAPoolFromPEMReader(caFile)
for { if err != nil && !errors.Is(err, cert.ErrExpired) {
rawCACert, err = caPool.AddCAFromPEM(rawCACert) return fmt.Errorf("error while adding ca cert to pool: %w", err)
if err != nil {
return fmt.Errorf("error while adding ca cert to pool: %w", err)
}
if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" {
break
}
} }
rawCert, err := os.ReadFile(*vf.certPath) rawCert, err := os.ReadFile(*vf.certPath)

View File

@@ -64,7 +64,7 @@ func Test_verify(t *testing.T) {
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block") require.ErrorIs(t, err, cert.ErrInvalidPEMBlock)
// make a ca for later // make a ca for later
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader) caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)

15
pki.go
View File

@@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"net/netip" "net/netip"
"os" "os"
@@ -487,25 +488,25 @@ func loadCertificate(b []byte) (cert.Certificate, []byte, error) {
} }
func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) {
var rawCA []byte
var err error
caPathOrPEM := c.GetString("pki.ca", "") caPathOrPEM := c.GetString("pki.ca", "")
if caPathOrPEM == "" { if caPathOrPEM == "" {
return nil, errors.New("no pki.ca path or PEM data provided") return nil, errors.New("no pki.ca path or PEM data provided")
} }
if strings.Contains(caPathOrPEM, "-----BEGIN") { var caReader io.ReadCloser
rawCA = []byte(caPathOrPEM) var err error
if strings.Contains(caPathOrPEM, "-----BEGIN") {
caReader = io.NopCloser(strings.NewReader(caPathOrPEM))
} else { } else {
rawCA, err = os.ReadFile(caPathOrPEM) caReader, err = os.Open(caPathOrPEM)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err) return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err)
} }
} }
defer caReader.Close()
caPool, err := cert.NewCAPoolFromPEM(rawCA) caPool, err := cert.NewCAPoolFromPEMReader(caReader)
if errors.Is(err, cert.ErrExpired) { if errors.Is(err, cert.ErrExpired) {
var expired int var expired int
for _, crt := range caPool.CAs { for _, crt := range caPool.CAs {

121
pki_hup_benchmark_test.go Normal file
View File

@@ -0,0 +1,121 @@
package nebula
import (
"bytes"
"fmt"
"net/netip"
"os"
"path/filepath"
"runtime"
"testing"
"time"
"github.com/slackhq/nebula/cert"
cert_test "github.com/slackhq/nebula/cert_test"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/require"
)
func BenchmarkReloadConfigWithCAs(b *testing.B) {
prevProcs := runtime.GOMAXPROCS(1)
b.Cleanup(func() { runtime.GOMAXPROCS(prevProcs) })
for _, size := range []int{100, 250, 500, 1000, 5000} {
b.Run(fmt.Sprintf("%dCAs", size), func(b *testing.B) {
l := test.NewLogger()
dir := b.TempDir()
ca, caKey, caBundle := buildCABundle(b, size)
caPath, certPath, keyPath := writePKIFiles(b, dir, ca, caKey, caBundle)
configBody := fmt.Sprintf(`pki:
ca: %s
cert: %s
key: %s
`, caPath, certPath, keyPath)
configPath := filepath.Join(dir, "config.yml")
require.NoError(b, os.WriteFile(configPath, []byte(configBody), 0o600))
c := config.NewC(l)
require.NoError(b, c.Load(dir))
_, err := NewPKIFromConfig(l, c)
require.NoError(b, err)
b.ReportAllocs()
b.ResetTimer()
for b.Loop() {
c.ReloadConfig()
}
})
}
}
func buildCABundle(b *testing.B, count int) (cert.Certificate, []byte, []byte) {
b.Helper()
require.GreaterOrEqual(b, count, 1)
before := time.Now().Add(-24 * time.Hour)
after := time.Now().Add(24 * time.Hour)
ca, _, caKey, pem := cert_test.NewTestCaCert(
cert.Version2,
cert.Curve_CURVE25519,
before,
after,
nil,
nil,
nil,
)
buf := bytes.NewBuffer(pem)
buf.Write([]byte("\n# a comment!\n"))
for i := 1; i < count; i++ {
_, _, _, extraPEM := cert_test.NewTestCaCert(
cert.Version2,
cert.Curve_CURVE25519,
time.Now(),
time.Now().Add(time.Hour),
nil,
nil,
nil,
)
buf.Write([]byte("\n# a comment!\n"))
buf.Write(extraPEM)
}
return ca, caKey, buf.Bytes()
}
func writePKIFiles(b *testing.B, dir string, ca cert.Certificate, caKey []byte, caBundle []byte) (string, string, string) {
b.Helper()
networks := []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")}
_, _, keyPEM, certPEM := cert_test.NewTestCert(
cert.Version2,
cert.Curve_CURVE25519,
ca,
caKey,
"reload-benchmark",
time.Now(),
time.Now().Add(time.Hour),
networks,
nil,
nil,
)
caPath := filepath.Join(dir, "ca.pem")
certPath := filepath.Join(dir, "cert.pem")
keyPath := filepath.Join(dir, "key.pem")
require.NoError(b, os.WriteFile(caPath, caBundle, 0o600))
require.NoError(b, os.WriteFile(certPath, certPEM, 0o600))
require.NoError(b, os.WriteFile(keyPath, keyPEM, 0o600))
return caPath, certPath, keyPath
}