mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-16 04:47:38 +02:00
Refactor CA pool handling to use streaming (#1644)
Some checks failed
gofmt / Run gofmt (push) Failing after 3s
smoke-extra / Run extra smoke tests (push) Failing after 3s
smoke / Run multi node smoke test (push) Failing after 3s
Build and test / Build all and test on ubuntu-linux (push) Failing after 2s
Build and test / Build and test on linux with boringcrypto (push) Failing after 3s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
Some checks failed
gofmt / Run gofmt (push) Failing after 3s
smoke-extra / Run extra smoke tests (push) Failing after 3s
smoke / Run multi node smoke test (push) Failing after 3s
Build and test / Build all and test on ubuntu-linux (push) Failing after 2s
Build and test / Build and test on linux with boringcrypto (push) Failing after 3s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled
Co-authored-by: maggie44 <64841595+maggie44@users.noreply.github.com> Co-authored-by: JackDoan <me@jackdoan.com>
This commit is contained in:
@@ -1,11 +1,14 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -29,22 +32,46 @@ func NewCAPool() *CAPool {
|
||||
// If the pool contains any expired certificates, an ErrExpired will be
|
||||
// returned along with the pool. The caller must handle any such errors.
|
||||
func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) {
|
||||
return NewCAPoolFromPEMReader(bytes.NewReader(caPEMs))
|
||||
}
|
||||
|
||||
// NewCAPoolFromPEMReader will create a new CA pool from the provided reader.
|
||||
// The reader must contain a PEM-encoded set of nebula certificates.
|
||||
func NewCAPoolFromPEMReader(r io.Reader) (*CAPool, error) {
|
||||
pool := NewCAPool()
|
||||
var err error
|
||||
|
||||
var expired bool
|
||||
for {
|
||||
caPEMs, err = pool.AddCAFromPEM(caPEMs)
|
||||
if errors.Is(err, ErrExpired) {
|
||||
expired = true
|
||||
err = nil
|
||||
|
||||
scanner := bufio.NewScanner(r)
|
||||
scanner.Split(SplitPEM)
|
||||
|
||||
for scanner.Scan() {
|
||||
pemBytes := scanner.Bytes()
|
||||
|
||||
block, rest := pem.Decode(pemBytes)
|
||||
if len(bytes.TrimSpace(rest)) > 0 {
|
||||
return nil, ErrInvalidPEMBlock
|
||||
}
|
||||
if block == nil {
|
||||
return nil, ErrInvalidPEMBlock
|
||||
}
|
||||
|
||||
c, err := unmarshalCertificateBlock(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
|
||||
break
|
||||
|
||||
err = pool.AddCA(c)
|
||||
if errors.Is(err, ErrExpired) {
|
||||
expired = true
|
||||
continue
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, ErrInvalidPEMBlock
|
||||
}
|
||||
|
||||
if expired {
|
||||
return pool, ErrExpired
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -112,6 +115,60 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe
|
||||
assert.Len(t, ppppp.CAs, 1)
|
||||
}
|
||||
|
||||
// oneByteReader wraps a reader to return at most 1 byte per Read call,
|
||||
// exercising the streaming accumulation logic in NewCAPoolFromPEMReader.
|
||||
type oneByteReader struct {
|
||||
r io.Reader
|
||||
}
|
||||
|
||||
func (o *oneByteReader) Read(p []byte) (int, error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
return o.r.Read(p[:1])
|
||||
}
|
||||
|
||||
func TestNewCAPoolFromPEMReader_EmptyReader(t *testing.T) {
|
||||
pool, err := NewCAPoolFromPEMReader(bytes.NewReader(nil))
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, pool.CAs)
|
||||
|
||||
pool, err = NewCAPoolFromPEMReader(strings.NewReader(" \n\t\n "))
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, pool.CAs)
|
||||
}
|
||||
|
||||
func TestNewCAPoolFromPEMReader_OneByteReads(t *testing.T) {
|
||||
ca1, _, _, pem1 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil)
|
||||
ca2, _, _, pem2 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil)
|
||||
|
||||
bundle := append(pem1, pem2...)
|
||||
pool, err := NewCAPoolFromPEMReader(&oneByteReader{r: bytes.NewReader(bundle)})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pool.CAs, 2)
|
||||
|
||||
fp1, err := ca1.Fingerprint()
|
||||
require.NoError(t, err)
|
||||
fp2, err := ca2.Fingerprint()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, pool.CAs, fp1)
|
||||
assert.Contains(t, pool.CAs, fp2)
|
||||
}
|
||||
|
||||
func TestNewCAPoolFromPEMReader_TruncatedPEM(t *testing.T) {
|
||||
_, err := NewCAPoolFromPEMReader(strings.NewReader("-----BEGIN NEBULA CERTIFICATE-----\npartialdata"))
|
||||
assert.ErrorIs(t, err, ErrInvalidPEMBlock)
|
||||
}
|
||||
|
||||
func TestNewCAPoolFromPEMReader_TrailingGarbage(t *testing.T) {
|
||||
_, _, _, pem1 := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(time.Hour), nil, nil, nil)
|
||||
|
||||
bundle := append(pem1, []byte("some trailing garbage")...)
|
||||
_, err := NewCAPoolFromPEMReader(bytes.NewReader(bundle))
|
||||
assert.ErrorIs(t, err, ErrInvalidPEMBlock)
|
||||
}
|
||||
|
||||
func TestCertificateV1_Verify(t *testing.T) {
|
||||
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil)
|
||||
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
|
||||
|
||||
82
cert/pem.go
82
cert/pem.go
@@ -1,12 +1,66 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/ed25519"
|
||||
)
|
||||
|
||||
var ErrTruncatedPEMBlock = errors.New("truncated PEM block")
|
||||
|
||||
// SplitPEM is a split function for bufio.Scanner that returns each PEM block.
|
||||
func SplitPEM(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
// Look for the start of a PEM block
|
||||
start := bytes.Index(data, []byte("-----BEGIN "))
|
||||
if start == -1 {
|
||||
if atEOF && len(bytes.TrimSpace(data)) > 0 {
|
||||
// Non-whitespace content with no PEM block
|
||||
return 0, nil, ErrTruncatedPEMBlock
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), nil, nil
|
||||
}
|
||||
// Request more data
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
// Look for the end marker
|
||||
endMarkerStart := bytes.Index(data[start:], []byte("-----END "))
|
||||
if endMarkerStart == -1 {
|
||||
if atEOF {
|
||||
// Incomplete PEM block at EOF
|
||||
return 0, nil, ErrTruncatedPEMBlock
|
||||
}
|
||||
// Need more data to find the end
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
// Find the actual end of the END line (after the newline)
|
||||
endMarkerStart += start
|
||||
endLineEnd := bytes.IndexByte(data[endMarkerStart:], '\n')
|
||||
var end int
|
||||
if endLineEnd == -1 {
|
||||
if atEOF {
|
||||
// END marker without newline at EOF - take it anyway
|
||||
end = len(data)
|
||||
} else {
|
||||
// Need more data
|
||||
return 0, nil, nil
|
||||
}
|
||||
} else {
|
||||
end = endMarkerStart + endLineEnd + 1
|
||||
}
|
||||
|
||||
// Extract the PEM block
|
||||
pemBlock := data[start:end]
|
||||
|
||||
// Return the valid PEM block
|
||||
return end, pemBlock, nil
|
||||
}
|
||||
|
||||
const ( //cert banners
|
||||
CertificateBanner = "NEBULA CERTIFICATE"
|
||||
CertificateV2Banner = "NEBULA CERTIFICATE V2"
|
||||
@@ -37,19 +91,7 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
|
||||
return nil, r, ErrInvalidPEMBlock
|
||||
}
|
||||
|
||||
var c Certificate
|
||||
var err error
|
||||
|
||||
switch p.Type {
|
||||
// Implementations must validate the resulting certificate contains valid information
|
||||
case CertificateBanner:
|
||||
c, err = unmarshalCertificateV1(p.Bytes, nil)
|
||||
case CertificateV2Banner:
|
||||
c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519)
|
||||
default:
|
||||
return nil, r, ErrInvalidPEMCertificateBanner
|
||||
}
|
||||
|
||||
c, err := unmarshalCertificateBlock(p)
|
||||
if err != nil {
|
||||
return nil, r, err
|
||||
}
|
||||
@@ -58,6 +100,20 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
|
||||
|
||||
}
|
||||
|
||||
// unmarshalCertificateBlock decodes a single PEM block into a certificate.
|
||||
// It expects a Nebula certificate banner and returns ErrInvalidPEMCertificateBanner otherwise.
|
||||
func unmarshalCertificateBlock(block *pem.Block) (Certificate, error) {
|
||||
switch block.Type {
|
||||
// Implementations must validate the resulting certificate contains valid information
|
||||
case CertificateBanner:
|
||||
return unmarshalCertificateV1(block.Bytes, nil)
|
||||
case CertificateV2Banner:
|
||||
return unmarshalCertificateV2(block.Bytes, nil, Curve_CURVE25519)
|
||||
default:
|
||||
return nil, ErrInvalidPEMCertificateBanner
|
||||
}
|
||||
}
|
||||
|
||||
func marshalCertPublicKeyToPEM(c Certificate) []byte {
|
||||
if c.IsCA() {
|
||||
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())
|
||||
|
||||
@@ -1,12 +1,88 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func scanAll(t *testing.T, input string) ([]string, error) {
|
||||
t.Helper()
|
||||
scanner := bufio.NewScanner(strings.NewReader(input))
|
||||
scanner.Split(SplitPEM)
|
||||
var blocks []string
|
||||
for scanner.Scan() {
|
||||
blocks = append(blocks, scanner.Text())
|
||||
}
|
||||
return blocks, scanner.Err()
|
||||
}
|
||||
|
||||
func TestSplitPEM_Single(t *testing.T) {
|
||||
input := "-----BEGIN TEST-----\ndata\n-----END TEST-----\n"
|
||||
blocks, err := scanAll(t, input)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, blocks, 1)
|
||||
require.Equal(t, input, blocks[0])
|
||||
}
|
||||
|
||||
func TestSplitPEM_Multiple(t *testing.T) {
|
||||
block1 := "-----BEGIN TEST-----\naaa\n-----END TEST-----\n"
|
||||
block2 := "-----BEGIN TEST-----\nbbb\n-----END TEST-----\n"
|
||||
blocks, err := scanAll(t, block1+block2)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, blocks, 2)
|
||||
require.Equal(t, block1, blocks[0])
|
||||
require.Equal(t, block2, blocks[1])
|
||||
}
|
||||
|
||||
func TestSplitPEM_CommentsAndWhitespaceBetweenBlocks(t *testing.T) {
|
||||
input := "# comment\n\n-----BEGIN TEST-----\naaa\n-----END TEST-----\n\n# another comment\n\n-----BEGIN TEST-----\nbbb\n-----END TEST-----\n"
|
||||
blocks, err := scanAll(t, input)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, blocks, 2)
|
||||
}
|
||||
|
||||
func TestSplitPEM_Empty(t *testing.T) {
|
||||
blocks, err := scanAll(t, "")
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, blocks)
|
||||
}
|
||||
|
||||
func TestSplitPEM_WhitespaceOnly(t *testing.T) {
|
||||
blocks, err := scanAll(t, " \n\t\n ")
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, blocks)
|
||||
}
|
||||
|
||||
func TestSplitPEM_TrailingGarbage(t *testing.T) {
|
||||
input := "-----BEGIN TEST-----\ndata\n-----END TEST-----\ngarbage"
|
||||
blocks, err := scanAll(t, input)
|
||||
require.ErrorIs(t, err, ErrTruncatedPEMBlock)
|
||||
require.Len(t, blocks, 1)
|
||||
}
|
||||
|
||||
func TestSplitPEM_TruncatedBlock(t *testing.T) {
|
||||
input := "-----BEGIN TEST-----\npartial data with no end"
|
||||
_, err := scanAll(t, input)
|
||||
require.ErrorIs(t, err, ErrTruncatedPEMBlock)
|
||||
}
|
||||
|
||||
func TestSplitPEM_NoEndNewline(t *testing.T) {
|
||||
input := "-----BEGIN TEST-----\ndata\n-----END TEST-----"
|
||||
blocks, err := scanAll(t, input)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, blocks, 1)
|
||||
require.Equal(t, input, blocks[0])
|
||||
}
|
||||
|
||||
func TestSplitPEM_GarbageOnly(t *testing.T) {
|
||||
_, err := scanAll(t, "this is not PEM data")
|
||||
require.ErrorIs(t, err, ErrTruncatedPEMBlock)
|
||||
}
|
||||
|
||||
func TestUnmarshalCertificateFromPEM(t *testing.T) {
|
||||
goodCert := []byte(`
|
||||
# A good cert
|
||||
|
||||
Reference in New Issue
Block a user