diff --git a/cert/ca_pool.go b/cert/ca_pool.go index d525830..2bf480f 100644 --- a/cert/ca_pool.go +++ b/cert/ca_pool.go @@ -213,7 +213,7 @@ func (ncp *CAPool) GetCAForCert(c Certificate) (*CachedCertificate, error) { return signer, nil } - return nil, fmt.Errorf("could not find ca for the certificate") + return nil, ErrCaNotFound } // GetFingerprints returns an array of trusted CA fingerprints diff --git a/cert/errors.go b/cert/errors.go index ab18cf2..60273a9 100644 --- a/cert/errors.go +++ b/cert/errors.go @@ -17,6 +17,7 @@ var ( ErrInvalidPrivateKey = errors.New("invalid private key") ErrPublicPrivateCurveMismatch = errors.New("public key does not match private key curve") ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair") + ErrCaNotFound = errors.New("could not find ca for the certificate") ErrPrivateKeyEncrypted = errors.New("private key must be decrypted") diff --git a/cmd/nebula-cert/verify.go b/cmd/nebula-cert/verify.go index 80cfef3..bea4d1d 100644 --- a/cmd/nebula-cert/verify.go +++ b/cmd/nebula-cert/verify.go @@ -1,6 +1,7 @@ package main import ( + "errors" "flag" "fmt" "io" @@ -41,14 +42,14 @@ func verify(args []string, out io.Writer, errOut io.Writer) error { rawCACert, err := os.ReadFile(*vf.caPath) if err != nil { - return fmt.Errorf("error while reading ca: %s", err) + return fmt.Errorf("error while reading ca: %w", err) } caPool := cert.NewCAPool() for { rawCACert, err = caPool.AddCAFromPEM(rawCACert) if err != nil { - return fmt.Errorf("error while adding ca cert to pool: %s", err) + return fmt.Errorf("error while adding ca cert to pool: %w", err) } if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" { @@ -58,20 +59,30 @@ func verify(args []string, out io.Writer, errOut io.Writer) error { rawCert, err := os.ReadFile(*vf.certPath) if err != nil { - return fmt.Errorf("unable to read crt; %s", err) + return fmt.Errorf("unable to read crt: %w", err) + } + var errs []error + for { + if len(rawCert) == 0 { + break + } + c, extra, err := cert.UnmarshalCertificateFromPEM(rawCert) + if err != nil { + return fmt.Errorf("error while parsing crt: %w", err) + } + rawCert = extra + _, err = caPool.VerifyCertificate(time.Now(), c) + if err != nil { + switch { + case errors.Is(err, cert.ErrCaNotFound): + errs = append(errs, fmt.Errorf("error while verifying certificate v%d %s with issuer %s: %w", c.Version(), c.Name(), c.Issuer(), err)) + default: + errs = append(errs, fmt.Errorf("error while verifying certificate %+v: %w", c, err)) + } + } } - c, _, err := cert.UnmarshalCertificateFromPEM(rawCert) - if err != nil { - return fmt.Errorf("error while parsing crt: %s", err) - } - - _, err = caPool.VerifyCertificate(time.Now(), c) - if err != nil { - return err - } - - return nil + return errors.Join(errs...) } func verifySummary() string { @@ -80,7 +91,7 @@ func verifySummary() string { func verifyHelp(out io.Writer) { vf := newVerifyFlags() - out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n")) + _, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n")) vf.set.SetOutput(out) vf.set.PrintDefaults() } diff --git a/cmd/nebula-cert/verify_test.go b/cmd/nebula-cert/verify_test.go index 204ff09..d94bd1f 100644 --- a/cmd/nebula-cert/verify_test.go +++ b/cmd/nebula-cert/verify_test.go @@ -3,10 +3,12 @@ package main import ( "bytes" "crypto/rand" + "errors" "os" "testing" "time" + "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" "golang.org/x/crypto/ed25519" ) @@ -76,7 +78,7 @@ func Test_verify(t *testing.T) { err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "unable to read crt; open does_not_exist: "+NoSuchFileError) + assert.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError) // invalid crt at path ob.Reset() @@ -106,7 +108,7 @@ func Test_verify(t *testing.T) { err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "certificate signature did not match") + assert.True(t, errors.Is(err, cert.ErrSignatureMismatch)) // verified cert at path crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)