diff --git a/cert/ca_pool.go b/cert/ca_pool.go index e9903e1f..792f8e66 100644 --- a/cert/ca_pool.go +++ b/cert/ca_pool.go @@ -1,11 +1,14 @@ package cert import ( + "bufio" + "bytes" + "encoding/pem" "errors" "fmt" + "io" "net/netip" "slices" - "strings" "time" ) @@ -29,22 +32,46 @@ func NewCAPool() *CAPool { // If the pool contains any expired certificates, an ErrExpired will be // returned along with the pool. The caller must handle any such errors. func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) { + return NewCAPoolFromPEMReader(bytes.NewReader(caPEMs)) +} + +// NewCAPoolFromPEMReader will create a new CA pool from the provided reader. +// The reader must contain a PEM-encoded set of nebula certificates. +func NewCAPoolFromPEMReader(r io.Reader) (*CAPool, error) { pool := NewCAPool() - var err error + var expired bool - for { - caPEMs, err = pool.AddCAFromPEM(caPEMs) - if errors.Is(err, ErrExpired) { - expired = true - err = nil + + scanner := bufio.NewScanner(r) + scanner.Split(SplitPEM) + + for scanner.Scan() { + pemBytes := scanner.Bytes() + + block, rest := pem.Decode(pemBytes) + if len(bytes.TrimSpace(rest)) > 0 { + return nil, ErrInvalidPEMBlock } + if block == nil { + return nil, ErrInvalidPEMBlock + } + + c, err := unmarshalCertificateBlock(block) if err != nil { return nil, err } - if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" { - break + + err = pool.AddCA(c) + if errors.Is(err, ErrExpired) { + expired = true + continue + } else if err != nil { + return nil, err } } + if err := scanner.Err(); err != nil { + return nil, ErrInvalidPEMBlock + } if expired { return pool, ErrExpired diff --git a/cert/ca_pool_test.go b/cert/ca_pool_test.go index e872c7d4..ab173228 100644 --- a/cert/ca_pool_test.go +++ b/cert/ca_pool_test.go @@ -1,7 +1,10 @@ package cert import ( + "bytes" + "io" "net/netip" + "strings" "testing" "time" @@ -112,6 +115,60 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe assert.Len(t, ppppp.CAs, 1) } +// oneByteReader wraps a reader to return at most 1 byte per Read call, +// exercising the streaming accumulation logic in NewCAPoolFromPEMReader. +type oneByteReader struct { + r io.Reader +} + +func (o *oneByteReader) Read(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + return o.r.Read(p[:1]) +} + +func TestNewCAPoolFromPEMReader_EmptyReader(t *testing.T) { + pool, err := NewCAPoolFromPEMReader(bytes.NewReader(nil)) + require.NoError(t, err) + assert.Empty(t, pool.CAs) + + pool, err = NewCAPoolFromPEMReader(strings.NewReader(" \n\t\n ")) + require.NoError(t, err) + assert.Empty(t, pool.CAs) +} + +func TestNewCAPoolFromPEMReader_OneByteReads(t *testing.T) { + ca1, _, _, pem1 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil) + ca2, _, _, pem2 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil) + + bundle := append(pem1, pem2...) + pool, err := NewCAPoolFromPEMReader(&oneByteReader{r: bytes.NewReader(bundle)}) + require.NoError(t, err) + assert.Len(t, pool.CAs, 2) + + fp1, err := ca1.Fingerprint() + require.NoError(t, err) + fp2, err := ca2.Fingerprint() + require.NoError(t, err) + + assert.Contains(t, pool.CAs, fp1) + assert.Contains(t, pool.CAs, fp2) +} + +func TestNewCAPoolFromPEMReader_TruncatedPEM(t *testing.T) { + _, err := NewCAPoolFromPEMReader(strings.NewReader("-----BEGIN NEBULA CERTIFICATE-----\npartialdata")) + assert.ErrorIs(t, err, ErrInvalidPEMBlock) +} + +func TestNewCAPoolFromPEMReader_TrailingGarbage(t *testing.T) { + _, _, _, pem1 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil) + + bundle := append(pem1, []byte("some trailing garbage")...) + _, err := NewCAPoolFromPEMReader(bytes.NewReader(bundle)) + assert.ErrorIs(t, err, ErrInvalidPEMBlock) +} + func TestCertificateV1_Verify(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) diff --git a/cert/pem.go b/cert/pem.go index 8942c23a..84221b22 100644 --- a/cert/pem.go +++ b/cert/pem.go @@ -1,12 +1,66 @@ package cert import ( + "bytes" "encoding/pem" + "errors" "fmt" "golang.org/x/crypto/ed25519" ) +var ErrTruncatedPEMBlock = errors.New("truncated PEM block") + +// SplitPEM is a split function for bufio.Scanner that returns each PEM block. +func SplitPEM(data []byte, atEOF bool) (advance int, token []byte, err error) { + // Look for the start of a PEM block + start := bytes.Index(data, []byte("-----BEGIN ")) + if start == -1 { + if atEOF && len(bytes.TrimSpace(data)) > 0 { + // Non-whitespace content with no PEM block + return 0, nil, ErrTruncatedPEMBlock + } + if atEOF { + return len(data), nil, nil + } + // Request more data + return 0, nil, nil + } + + // Look for the end marker + endMarkerStart := bytes.Index(data[start:], []byte("-----END ")) + if endMarkerStart == -1 { + if atEOF { + // Incomplete PEM block at EOF + return 0, nil, ErrTruncatedPEMBlock + } + // Need more data to find the end + return 0, nil, nil + } + + // Find the actual end of the END line (after the newline) + endMarkerStart += start + endLineEnd := bytes.IndexByte(data[endMarkerStart:], '\n') + var end int + if endLineEnd == -1 { + if atEOF { + // END marker without newline at EOF - take it anyway + end = len(data) + } else { + // Need more data + return 0, nil, nil + } + } else { + end = endMarkerStart + endLineEnd + 1 + } + + // Extract the PEM block + pemBlock := data[start:end] + + // Return the valid PEM block + return end, pemBlock, nil +} + const ( //cert banners CertificateBanner = "NEBULA CERTIFICATE" CertificateV2Banner = "NEBULA CERTIFICATE V2" @@ -37,19 +91,7 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) { return nil, r, ErrInvalidPEMBlock } - var c Certificate - var err error - - switch p.Type { - // Implementations must validate the resulting certificate contains valid information - case CertificateBanner: - c, err = unmarshalCertificateV1(p.Bytes, nil) - case CertificateV2Banner: - c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519) - default: - return nil, r, ErrInvalidPEMCertificateBanner - } - + c, err := unmarshalCertificateBlock(p) if err != nil { return nil, r, err } @@ -58,6 +100,20 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) { } +// unmarshalCertificateBlock decodes a single PEM block into a certificate. +// It expects a Nebula certificate banner and returns ErrInvalidPEMCertificateBanner otherwise. +func unmarshalCertificateBlock(block *pem.Block) (Certificate, error) { + switch block.Type { + // Implementations must validate the resulting certificate contains valid information + case CertificateBanner: + return unmarshalCertificateV1(block.Bytes, nil) + case CertificateV2Banner: + return unmarshalCertificateV2(block.Bytes, nil, Curve_CURVE25519) + default: + return nil, ErrInvalidPEMCertificateBanner + } +} + func marshalCertPublicKeyToPEM(c Certificate) []byte { if c.IsCA() { return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey()) diff --git a/cert/pem_test.go b/cert/pem_test.go index 310c57a3..ff623541 100644 --- a/cert/pem_test.go +++ b/cert/pem_test.go @@ -1,12 +1,88 @@ package cert import ( + "bufio" + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func scanAll(t *testing.T, input string) ([]string, error) { + t.Helper() + scanner := bufio.NewScanner(strings.NewReader(input)) + scanner.Split(SplitPEM) + var blocks []string + for scanner.Scan() { + blocks = append(blocks, scanner.Text()) + } + return blocks, scanner.Err() +} + +func TestSplitPEM_Single(t *testing.T) { + input := "-----BEGIN TEST-----\ndata\n-----END TEST-----\n" + blocks, err := scanAll(t, input) + require.NoError(t, err) + require.Len(t, blocks, 1) + require.Equal(t, input, blocks[0]) +} + +func TestSplitPEM_Multiple(t *testing.T) { + block1 := "-----BEGIN TEST-----\naaa\n-----END TEST-----\n" + block2 := "-----BEGIN TEST-----\nbbb\n-----END TEST-----\n" + blocks, err := scanAll(t, block1+block2) + require.NoError(t, err) + require.Len(t, blocks, 2) + require.Equal(t, block1, blocks[0]) + require.Equal(t, block2, blocks[1]) +} + +func TestSplitPEM_CommentsAndWhitespaceBetweenBlocks(t *testing.T) { + input := "# comment\n\n-----BEGIN TEST-----\naaa\n-----END TEST-----\n\n# another comment\n\n-----BEGIN TEST-----\nbbb\n-----END TEST-----\n" + blocks, err := scanAll(t, input) + require.NoError(t, err) + require.Len(t, blocks, 2) +} + +func TestSplitPEM_Empty(t *testing.T) { + blocks, err := scanAll(t, "") + require.NoError(t, err) + require.Empty(t, blocks) +} + +func TestSplitPEM_WhitespaceOnly(t *testing.T) { + blocks, err := scanAll(t, " \n\t\n ") + require.NoError(t, err) + require.Empty(t, blocks) +} + +func TestSplitPEM_TrailingGarbage(t *testing.T) { + input := "-----BEGIN TEST-----\ndata\n-----END TEST-----\ngarbage" + blocks, err := scanAll(t, input) + require.ErrorIs(t, err, ErrTruncatedPEMBlock) + require.Len(t, blocks, 1) +} + +func TestSplitPEM_TruncatedBlock(t *testing.T) { + input := "-----BEGIN TEST-----\npartial data with no end" + _, err := scanAll(t, input) + require.ErrorIs(t, err, ErrTruncatedPEMBlock) +} + +func TestSplitPEM_NoEndNewline(t *testing.T) { + input := "-----BEGIN TEST-----\ndata\n-----END TEST-----" + blocks, err := scanAll(t, input) + require.NoError(t, err) + require.Len(t, blocks, 1) + require.Equal(t, input, blocks[0]) +} + +func TestSplitPEM_GarbageOnly(t *testing.T) { + _, err := scanAll(t, "this is not PEM data") + require.ErrorIs(t, err, ErrTruncatedPEMBlock) +} + func TestUnmarshalCertificateFromPEM(t *testing.T) { goodCert := []byte(` # A good cert diff --git a/cmd/nebula-cert/verify.go b/cmd/nebula-cert/verify.go index bea4d1d9..36258dd8 100644 --- a/cmd/nebula-cert/verify.go +++ b/cmd/nebula-cert/verify.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "os" - "strings" "time" "github.com/slackhq/nebula/cert" @@ -40,21 +39,15 @@ func verify(args []string, out io.Writer, errOut io.Writer) error { return err } - rawCACert, err := os.ReadFile(*vf.caPath) + caFile, err := os.Open(*vf.caPath) if err != nil { return fmt.Errorf("error while reading ca: %w", err) } + defer caFile.Close() - caPool := cert.NewCAPool() - for { - rawCACert, err = caPool.AddCAFromPEM(rawCACert) - if err != nil { - return fmt.Errorf("error while adding ca cert to pool: %w", err) - } - - if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" { - break - } + caPool, err := cert.NewCAPoolFromPEMReader(caFile) + if err != nil && !errors.Is(err, cert.ErrExpired) { + return fmt.Errorf("error while adding ca cert to pool: %w", err) } rawCert, err := os.ReadFile(*vf.certPath) diff --git a/cmd/nebula-cert/verify_test.go b/cmd/nebula-cert/verify_test.go index f555e5f5..1aa5e8e6 100644 --- a/cmd/nebula-cert/verify_test.go +++ b/cmd/nebula-cert/verify_test.go @@ -64,7 +64,7 @@ func Test_verify(t *testing.T) { err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) - require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block") + require.ErrorIs(t, err, cert.ErrInvalidPEMBlock) // make a ca for later caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader) diff --git a/pki.go b/pki.go index 19869d58..0639fd3d 100644 --- a/pki.go +++ b/pki.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net" "net/netip" "os" @@ -487,25 +488,25 @@ func loadCertificate(b []byte) (cert.Certificate, []byte, error) { } func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { - var rawCA []byte - var err error - caPathOrPEM := c.GetString("pki.ca", "") if caPathOrPEM == "" { return nil, errors.New("no pki.ca path or PEM data provided") } - if strings.Contains(caPathOrPEM, "-----BEGIN") { - rawCA = []byte(caPathOrPEM) + var caReader io.ReadCloser + var err error + if strings.Contains(caPathOrPEM, "-----BEGIN") { + caReader = io.NopCloser(strings.NewReader(caPathOrPEM)) } else { - rawCA, err = os.ReadFile(caPathOrPEM) + caReader, err = os.Open(caPathOrPEM) if err != nil { return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err) } } + defer caReader.Close() - caPool, err := cert.NewCAPoolFromPEM(rawCA) + caPool, err := cert.NewCAPoolFromPEMReader(caReader) if errors.Is(err, cert.ErrExpired) { var expired int for _, crt := range caPool.CAs { diff --git a/pki_hup_benchmark_test.go b/pki_hup_benchmark_test.go new file mode 100644 index 00000000..39f648ff --- /dev/null +++ b/pki_hup_benchmark_test.go @@ -0,0 +1,121 @@ +package nebula + +import ( + "bytes" + "fmt" + "net/netip" + "os" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/slackhq/nebula/cert" + cert_test "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/test" + "github.com/stretchr/testify/require" +) + +func BenchmarkReloadConfigWithCAs(b *testing.B) { + prevProcs := runtime.GOMAXPROCS(1) + b.Cleanup(func() { runtime.GOMAXPROCS(prevProcs) }) + + for _, size := range []int{100, 250, 500, 1000, 5000} { + b.Run(fmt.Sprintf("%dCAs", size), func(b *testing.B) { + l := test.NewLogger() + dir := b.TempDir() + + ca, caKey, caBundle := buildCABundle(b, size) + caPath, certPath, keyPath := writePKIFiles(b, dir, ca, caKey, caBundle) + + configBody := fmt.Sprintf(`pki: + ca: %s + cert: %s + key: %s +`, caPath, certPath, keyPath) + + configPath := filepath.Join(dir, "config.yml") + require.NoError(b, os.WriteFile(configPath, []byte(configBody), 0o600)) + + c := config.NewC(l) + require.NoError(b, c.Load(dir)) + + _, err := NewPKIFromConfig(l, c) + require.NoError(b, err) + + b.ReportAllocs() + b.ResetTimer() + + for b.Loop() { + c.ReloadConfig() + } + }) + } +} + +func buildCABundle(b *testing.B, count int) (cert.Certificate, []byte, []byte) { + b.Helper() + require.GreaterOrEqual(b, count, 1) + + before := time.Now().Add(-24 * time.Hour) + after := time.Now().Add(24 * time.Hour) + + ca, _, caKey, pem := cert_test.NewTestCaCert( + cert.Version2, + cert.Curve_CURVE25519, + before, + after, + nil, + nil, + nil, + ) + + buf := bytes.NewBuffer(pem) + buf.Write([]byte("\n# a comment!\n")) + + for i := 1; i < count; i++ { + _, _, _, extraPEM := cert_test.NewTestCaCert( + cert.Version2, + cert.Curve_CURVE25519, + time.Now(), + time.Now().Add(time.Hour), + nil, + nil, + nil, + ) + buf.Write([]byte("\n# a comment!\n")) + buf.Write(extraPEM) + } + + return ca, caKey, buf.Bytes() +} + +func writePKIFiles(b *testing.B, dir string, ca cert.Certificate, caKey []byte, caBundle []byte) (string, string, string) { + b.Helper() + + networks := []netip.Prefix{netip.MustParsePrefix("10.0.0.1/24")} + + _, _, keyPEM, certPEM := cert_test.NewTestCert( + cert.Version2, + cert.Curve_CURVE25519, + ca, + caKey, + "reload-benchmark", + time.Now(), + time.Now().Add(time.Hour), + networks, + nil, + nil, + ) + + caPath := filepath.Join(dir, "ca.pem") + certPath := filepath.Join(dir, "cert.pem") + keyPath := filepath.Join(dir, "key.pem") + + require.NoError(b, os.WriteFile(caPath, caBundle, 0o600)) + require.NoError(b, os.WriteFile(certPath, certPEM, 0o600)) + require.NoError(b, os.WriteFile(keyPath, keyPEM, 0o600)) + + return caPath, certPath, keyPath +}