diff --git a/cmd/nebula-cert/sign.go b/cmd/nebula-cert/sign.go index 3631c50..561138c 100644 --- a/cmd/nebula-cert/sign.go +++ b/cmd/nebula-cert/sign.go @@ -43,7 +43,7 @@ type signFlags struct { func newSignFlags() *signFlags { sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)} sf.set.Usage = func() {} - sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.") + sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use. The default is to match the version of the signing CA") sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key") sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert") sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname") @@ -167,6 +167,10 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return fmt.Errorf("ca certificate is expired") } + if version == 0 { + version = caCert.Version() + } + // if no duration is given, expire one second before the root expires if *sf.duration <= 0 { *sf.duration = time.Until(caCert.NotAfter()) - time.Second*1 @@ -279,21 +283,19 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) notBefore := time.Now() notAfter := notBefore.Add(*sf.duration) - if version == 0 || version == cert.Version1 { - // Make sure we at least have an ip + switch version { + case cert.Version1: + // Make sure we have only one ipv4 address if len(v4Networks) != 1 { return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address") } - if version == cert.Version1 { - // If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses - if len(v6Networks) > 0 { - return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4") - } + if len(v6Networks) > 0 { + return newHelpErrorf("invalid -networks definition: v1 certificates can only contain ipv4 addresses") + } - if len(v6UnsafeNetworks) > 0 { - return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4") - } + if len(v6UnsafeNetworks) > 0 { + return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses") } t := &cert.TBSCertificate{ @@ -323,9 +325,8 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) } crts = append(crts, nc) - } - if version == 0 || version == cert.Version2 { + case cert.Version2: t := &cert.TBSCertificate{ Version: cert.Version2, Name: *sf.name, @@ -353,6 +354,9 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) } crts = append(crts, nc) + default: + // this should be unreachable + return fmt.Errorf("invalid version: %d", version) } if !isP11 && *sf.inPubPath == "" { diff --git a/cmd/nebula-cert/sign_test.go b/cmd/nebula-cert/sign_test.go index 12eddf6..f5f8cbb 100644 --- a/cmd/nebula-cert/sign_test.go +++ b/cmd/nebula-cert/sign_test.go @@ -55,7 +55,7 @@ func Test_signHelp(t *testing.T) { " -unsafe-networks string\n"+ " \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+ " -version uint\n"+ - " \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n", + " \tOptional: version of the certificate format to use. The default is to match the version of the signing CA\n", ob.String(), ) } @@ -204,7 +204,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4") + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses") assert.Empty(t, ob.String()) assert.Empty(t, eb.String())