Refactor CA pool handling to use streaming

This commit is contained in:
maggie44
2025-12-11 19:51:40 +00:00
parent 59e24b98bd
commit da05932c5d
4 changed files with 188 additions and 29 deletions

View File

@@ -1,11 +1,13 @@
package cert
import (
"bytes"
"encoding/pem"
"errors"
"fmt"
"io"
"net/netip"
"slices"
"strings"
"time"
)
@@ -29,21 +31,55 @@ func NewCAPool() *CAPool {
// If the pool contains any expired certificates, an ErrExpired will be
// returned along with the pool. The caller must handle any such errors.
func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) {
return NewCAPoolFromPEMReader(bytes.NewReader(caPEMs))
}
// NewCAPoolFromPEMReader will create a new CA pool from the provided reader.
// The reader must contain a PEM-encoded set of nebula certificates.
func NewCAPoolFromPEMReader(r io.Reader) (*CAPool, error) {
pool := NewCAPool()
var err error
buf := make([]byte, 0, 64*1024)
tmp := make([]byte, 32*1024)
var expired bool
for {
caPEMs, err = pool.AddCAFromPEM(caPEMs)
if errors.Is(err, ErrExpired) {
expired = true
err = nil
n, err := r.Read(tmp)
if n > 0 {
buf = append(buf, tmp[:n]...)
for {
var block *pem.Block
block, buf = pem.Decode(buf)
if block == nil {
break
}
c, err := unmarshalCertificateBlock(block)
if err != nil {
return nil, err
}
err = pool.AddCA(c)
if errors.Is(err, ErrExpired) {
expired = true
continue
}
if err != nil {
return nil, err
}
}
}
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, err
}
if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
break
}
}
if len(bytes.TrimSpace(buf)) > 0 {
return nil, ErrInvalidPEMBlock
}
if expired {

View File

@@ -37,19 +37,7 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
return nil, r, ErrInvalidPEMBlock
}
var c Certificate
var err error
switch p.Type {
// Implementations must validate the resulting certificate contains valid information
case CertificateBanner:
c, err = unmarshalCertificateV1(p.Bytes, nil)
case CertificateV2Banner:
c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519)
default:
return nil, r, ErrInvalidPEMCertificateBanner
}
c, err := unmarshalCertificateBlock(p)
if err != nil {
return nil, r, err
}
@@ -58,6 +46,20 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
}
// unmarshalCertificateBlock decodes a single PEM block into a certificate.
// It expects a Nebula certificate banner and returns ErrInvalidPEMCertificateBanner otherwise.
func unmarshalCertificateBlock(block *pem.Block) (Certificate, error) {
switch block.Type {
// Implementations must validate the resulting certificate contains valid information
case CertificateBanner:
return unmarshalCertificateV1(block.Bytes, nil)
case CertificateV2Banner:
return unmarshalCertificateV2(block.Bytes, nil, Curve_CURVE25519)
default:
return nil, ErrInvalidPEMCertificateBanner
}
}
func marshalCertPublicKeyToPEM(c Certificate) []byte {
if c.IsCA() {
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())