[cert-v2] nebula-cert should verify all certs (#1291)

This commit is contained in:
Jack Doan 2025-01-06 16:07:55 -05:00 committed by GitHub
parent 21a117a156
commit 3f31517018
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 32 additions and 18 deletions

View File

@ -213,7 +213,7 @@ func (ncp *CAPool) GetCAForCert(c Certificate) (*CachedCertificate, error) {
return signer, nil return signer, nil
} }
return nil, fmt.Errorf("could not find ca for the certificate") return nil, ErrCaNotFound
} }
// GetFingerprints returns an array of trusted CA fingerprints // GetFingerprints returns an array of trusted CA fingerprints

View File

@ -17,6 +17,7 @@ var (
ErrInvalidPrivateKey = errors.New("invalid private key") ErrInvalidPrivateKey = errors.New("invalid private key")
ErrPublicPrivateCurveMismatch = errors.New("public key does not match private key curve") ErrPublicPrivateCurveMismatch = errors.New("public key does not match private key curve")
ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair") ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair")
ErrCaNotFound = errors.New("could not find ca for the certificate")
ErrPrivateKeyEncrypted = errors.New("private key must be decrypted") ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"errors"
"flag" "flag"
"fmt" "fmt"
"io" "io"
@ -41,14 +42,14 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
rawCACert, err := os.ReadFile(*vf.caPath) rawCACert, err := os.ReadFile(*vf.caPath)
if err != nil { if err != nil {
return fmt.Errorf("error while reading ca: %s", err) return fmt.Errorf("error while reading ca: %w", err)
} }
caPool := cert.NewCAPool() caPool := cert.NewCAPool()
for { for {
rawCACert, err = caPool.AddCAFromPEM(rawCACert) rawCACert, err = caPool.AddCAFromPEM(rawCACert)
if err != nil { if err != nil {
return fmt.Errorf("error while adding ca cert to pool: %s", err) return fmt.Errorf("error while adding ca cert to pool: %w", err)
} }
if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" { if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" {
@ -58,20 +59,30 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
rawCert, err := os.ReadFile(*vf.certPath) rawCert, err := os.ReadFile(*vf.certPath)
if err != nil { if err != nil {
return fmt.Errorf("unable to read crt; %s", err) return fmt.Errorf("unable to read crt: %w", err)
}
var errs []error
for {
if len(rawCert) == 0 {
break
}
c, extra, err := cert.UnmarshalCertificateFromPEM(rawCert)
if err != nil {
return fmt.Errorf("error while parsing crt: %w", err)
}
rawCert = extra
_, err = caPool.VerifyCertificate(time.Now(), c)
if err != nil {
switch {
case errors.Is(err, cert.ErrCaNotFound):
errs = append(errs, fmt.Errorf("error while verifying certificate v%d %s with issuer %s: %w", c.Version(), c.Name(), c.Issuer(), err))
default:
errs = append(errs, fmt.Errorf("error while verifying certificate %+v: %w", c, err))
}
}
} }
c, _, err := cert.UnmarshalCertificateFromPEM(rawCert) return errors.Join(errs...)
if err != nil {
return fmt.Errorf("error while parsing crt: %s", err)
}
_, err = caPool.VerifyCertificate(time.Now(), c)
if err != nil {
return err
}
return nil
} }
func verifySummary() string { func verifySummary() string {
@ -80,7 +91,7 @@ func verifySummary() string {
func verifyHelp(out io.Writer) { func verifyHelp(out io.Writer) {
vf := newVerifyFlags() vf := newVerifyFlags()
out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n")) _, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n"))
vf.set.SetOutput(out) vf.set.SetOutput(out)
vf.set.PrintDefaults() vf.set.PrintDefaults()
} }

View File

@ -3,10 +3,12 @@ package main
import ( import (
"bytes" "bytes"
"crypto/rand" "crypto/rand"
"errors"
"os" "os"
"testing" "testing"
"time" "time"
"github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
) )
@ -76,7 +78,7 @@ func Test_verify(t *testing.T) {
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
assert.EqualError(t, err, "unable to read crt; open does_not_exist: "+NoSuchFileError) assert.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError)
// invalid crt at path // invalid crt at path
ob.Reset() ob.Reset()
@ -106,7 +108,7 @@ func Test_verify(t *testing.T) {
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
assert.EqualError(t, err, "certificate signature did not match") assert.True(t, errors.Is(err, cert.ErrSignatureMismatch))
// verified cert at path // verified cert at path
crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)