use bufio.Scanner

This commit is contained in:
JackDoan
2026-01-16 12:26:01 -06:00
parent da05932c5d
commit 6c9fa3f342
3 changed files with 84 additions and 34 deletions

View File

@@ -1,6 +1,7 @@
package cert package cert
import ( import (
"bufio"
"bytes" "bytes"
"encoding/pem" "encoding/pem"
"errors" "errors"
@@ -9,6 +10,8 @@ import (
"net/netip" "net/netip"
"slices" "slices"
"time" "time"
"github.com/slackhq/nebula/util"
) )
type CAPool struct { type CAPool struct {
@@ -38,18 +41,26 @@ func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) {
// The reader must contain a PEM-encoded set of nebula certificates. // The reader must contain a PEM-encoded set of nebula certificates.
func NewCAPoolFromPEMReader(r io.Reader) (*CAPool, error) { func NewCAPoolFromPEMReader(r io.Reader) (*CAPool, error) {
pool := NewCAPool() pool := NewCAPool()
buf := make([]byte, 0, 64*1024)
tmp := make([]byte, 32*1024)
var expired bool var expired bool
for { scanner := bufio.NewScanner(r)
n, err := r.Read(tmp) scanner.Split(util.SplitPEM)
if n > 0 {
buf = append(buf, tmp[:n]...)
for { for {
var block *pem.Block ready := scanner.Scan()
block, buf = pem.Decode(buf) if !ready {
break
}
pemBytes := scanner.Bytes()
if scanner.Err() != nil {
return nil, scanner.Err()
}
block, rest := pem.Decode(pemBytes)
if len(bytes.TrimSpace(rest)) > 0 {
return nil, ErrInvalidPEMBlock
}
if block == nil { if block == nil {
break break
} }
@@ -63,24 +74,10 @@ func NewCAPoolFromPEMReader(r io.Reader) (*CAPool, error) {
if errors.Is(err, ErrExpired) { if errors.Is(err, ErrExpired) {
expired = true expired = true
continue continue
} } else if err != nil {
if err != nil {
return nil, err return nil, err
} }
} }
}
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, err
}
}
if len(bytes.TrimSpace(buf)) > 0 {
return nil, ErrInvalidPEMBlock
}
if expired { if expired {
return pool, ErrExpired return pool, ErrExpired

View File

@@ -72,6 +72,7 @@ func buildCABundle(b *testing.B, count int) (cert.Certificate, []byte, []byte) {
) )
buf := bytes.NewBuffer(pem) buf := bytes.NewBuffer(pem)
buf.Write([]byte("\n# a comment!\n"))
for i := 1; i < count; i++ { for i := 1; i < count; i++ {
_, _, _, extraPEM := cert_test.NewTestCaCert( _, _, _, extraPEM := cert_test.NewTestCaCert(
@@ -83,7 +84,7 @@ func buildCABundle(b *testing.B, count int) (cert.Certificate, []byte, []byte) {
nil, nil,
nil, nil,
) )
buf.Write([]byte("\n# a comment!\n"))
buf.Write(extraPEM) buf.Write(extraPEM)
} }

52
util/pem.go Normal file
View File

@@ -0,0 +1,52 @@
package util
import (
"bufio"
"bytes"
)
// SplitPEM is a split function for bufio.Scanner that returns each PEM block.
func SplitPEM(data []byte, atEOF bool) (advance int, token []byte, err error) {
// Look for the start of a PEM block
start := bytes.Index(data, []byte("-----BEGIN "))
if start == -1 {
if atEOF && len(data) > 0 {
// No PEM block found, skip remaining data
return len(data), nil, nil
}
// Request more data
return 0, nil, nil
}
// Look for the end marker
endMarkerStart := bytes.Index(data[start:], []byte("-----END "))
if endMarkerStart == -1 {
if atEOF {
// Incomplete PEM block at EOF
return 0, nil, bufio.ErrFinalToken
}
// Need more data to find the end
return 0, nil, nil
}
// Find the actual end of the END line (after the newline)
endMarkerStart += start
endLineEnd := bytes.IndexByte(data[endMarkerStart:], '\n')
if endLineEnd == -1 {
if atEOF {
// END marker without newline at EOF - take it anyway
endLineEnd = len(data) - endMarkerStart
} else {
// Need more data
return 0, nil, nil
}
}
end := endMarkerStart + endLineEnd + 1
// Extract the PEM block
pemBlock := data[start:end]
// Return the valid PEM block
return end, pemBlock, nil
}