mirror of
https://github.com/slackhq/nebula.git
synced 2026-02-15 09:14:23 +01:00
Refactor CA pool handling to use streaming
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
package cert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -29,21 +31,55 @@ func NewCAPool() *CAPool {
|
||||
// If the pool contains any expired certificates, an ErrExpired will be
|
||||
// returned along with the pool. The caller must handle any such errors.
|
||||
func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) {
|
||||
return NewCAPoolFromPEMReader(bytes.NewReader(caPEMs))
|
||||
}
|
||||
|
||||
// NewCAPoolFromPEMReader will create a new CA pool from the provided reader.
|
||||
// The reader must contain a PEM-encoded set of nebula certificates.
|
||||
func NewCAPoolFromPEMReader(r io.Reader) (*CAPool, error) {
|
||||
pool := NewCAPool()
|
||||
var err error
|
||||
buf := make([]byte, 0, 64*1024)
|
||||
tmp := make([]byte, 32*1024)
|
||||
var expired bool
|
||||
|
||||
for {
|
||||
caPEMs, err = pool.AddCAFromPEM(caPEMs)
|
||||
if errors.Is(err, ErrExpired) {
|
||||
expired = true
|
||||
err = nil
|
||||
n, err := r.Read(tmp)
|
||||
if n > 0 {
|
||||
buf = append(buf, tmp[:n]...)
|
||||
|
||||
for {
|
||||
var block *pem.Block
|
||||
block, buf = pem.Decode(buf)
|
||||
if block == nil {
|
||||
break
|
||||
}
|
||||
|
||||
c, err := unmarshalCertificateBlock(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = pool.AddCA(c)
|
||||
if errors.Is(err, ErrExpired) {
|
||||
expired = true
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(bytes.TrimSpace(buf)) > 0 {
|
||||
return nil, ErrInvalidPEMBlock
|
||||
}
|
||||
|
||||
if expired {
|
||||
|
||||
28
cert/pem.go
28
cert/pem.go
@@ -37,19 +37,7 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
|
||||
return nil, r, ErrInvalidPEMBlock
|
||||
}
|
||||
|
||||
var c Certificate
|
||||
var err error
|
||||
|
||||
switch p.Type {
|
||||
// Implementations must validate the resulting certificate contains valid information
|
||||
case CertificateBanner:
|
||||
c, err = unmarshalCertificateV1(p.Bytes, nil)
|
||||
case CertificateV2Banner:
|
||||
c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519)
|
||||
default:
|
||||
return nil, r, ErrInvalidPEMCertificateBanner
|
||||
}
|
||||
|
||||
c, err := unmarshalCertificateBlock(p)
|
||||
if err != nil {
|
||||
return nil, r, err
|
||||
}
|
||||
@@ -58,6 +46,20 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
|
||||
|
||||
}
|
||||
|
||||
// unmarshalCertificateBlock decodes a single PEM block into a certificate.
|
||||
// It expects a Nebula certificate banner and returns ErrInvalidPEMCertificateBanner otherwise.
|
||||
func unmarshalCertificateBlock(block *pem.Block) (Certificate, error) {
|
||||
switch block.Type {
|
||||
// Implementations must validate the resulting certificate contains valid information
|
||||
case CertificateBanner:
|
||||
return unmarshalCertificateV1(block.Bytes, nil)
|
||||
case CertificateV2Banner:
|
||||
return unmarshalCertificateV2(block.Bytes, nil, Curve_CURVE25519)
|
||||
default:
|
||||
return nil, ErrInvalidPEMCertificateBanner
|
||||
}
|
||||
}
|
||||
|
||||
func marshalCertPublicKeyToPEM(c Certificate) []byte {
|
||||
if c.IsCA() {
|
||||
return MarshalSigningPublicKeyToPEM(c.Curve(), c.PublicKey())
|
||||
|
||||
Reference in New Issue
Block a user