Cert interface (#1212)

This commit is contained in:
Nate Brown
2024-10-10 18:00:22 -05:00
committed by GitHub
parent 16eaae306a
commit 08ac65362e
49 changed files with 2862 additions and 2833 deletions

View File

@@ -6,6 +6,7 @@ package e2e
import (
"fmt"
"net/netip"
"slices"
"testing"
"time"
@@ -538,9 +539,9 @@ func TestRehandshakingRelays(t *testing.T) {
// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
// and the main host infos will not have any relay state to handle the me<->relay<->them tunnel.
r.Log("Renew relay certificate and spin until me and them sees it")
_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"})
_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{relayVpnIpNet}, nil, []string{"new group"})
caB, err := ca.MarshalToPEM()
caB, err := ca.MarshalPEM()
if err != nil {
panic(err)
}
@@ -558,7 +559,7 @@ func TestRehandshakingRelays(t *testing.T) {
r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
if len(c.Cert.Details.Groups) != 0 {
if len(c.Cert.Groups()) != 0 {
// We have a new certificate now
r.Log("Certificate between my and relay is updated!")
break
@@ -571,7 +572,7 @@ func TestRehandshakingRelays(t *testing.T) {
r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
if len(c.Cert.Details.Groups) != 0 {
if len(c.Cert.Groups()) != 0 {
// We have a new certificate now
r.Log("Certificate between their and relay is updated!")
break
@@ -642,9 +643,9 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
// and the main host infos will not have any relay state to handle the me<->relay<->them tunnel.
r.Log("Renew relay certificate and spin until me and them sees it")
_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"})
_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{relayVpnIpNet}, nil, []string{"new group"})
caB, err := ca.MarshalToPEM()
caB, err := ca.MarshalPEM()
if err != nil {
panic(err)
}
@@ -662,7 +663,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r)
c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
if len(c.Cert.Details.Groups) != 0 {
if len(c.Cert.Groups()) != 0 {
// We have a new certificate now
r.Log("Certificate between my and relay is updated!")
break
@@ -675,7 +676,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r)
c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false)
if len(c.Cert.Details.Groups) != 0 {
if len(c.Cert.Groups()) != 0 {
// We have a new certificate now
r.Log("Certificate between their and relay is updated!")
break
@@ -737,9 +738,9 @@ func TestRehandshaking(t *testing.T) {
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
r.Log("Renew my certificate and spin until their sees it")
_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"})
_, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{myVpnIpNet}, nil, []string{"new group"})
caB, err := ca.MarshalToPEM()
caB, err := ca.MarshalPEM()
if err != nil {
panic(err)
}
@@ -756,7 +757,7 @@ func TestRehandshaking(t *testing.T) {
for {
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
if len(c.Cert.Details.Groups) != 0 {
if len(c.Cert.Groups()) != 0 {
// We have a new certificate now
break
}
@@ -764,6 +765,7 @@ func TestRehandshaking(t *testing.T) {
time.Sleep(time.Second)
}
r.Log("Got the new cert")
// Flip their firewall to only allowing the new group to catch the tunnels reverting incorrectly
rc, err = yaml.Marshal(theirConfig.Settings)
assert.NoError(t, err)
@@ -794,7 +796,7 @@ func TestRehandshaking(t *testing.T) {
// Make sure the correct tunnel won
c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
assert.Contains(t, c.Cert.Details.Groups, "new group")
assert.Contains(t, c.Cert.Groups(), "new group")
// We should only have a single tunnel now on both sides
assert.Len(t, myFinalHostmapHosts, 1)
@@ -837,9 +839,9 @@ func TestRehandshakingLoser(t *testing.T) {
r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
r.Log("Renew their certificate and spin until mine sees it")
_, _, theirNextPrivKey, theirNextPEM := NewTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"})
_, _, theirNextPrivKey, theirNextPEM := NewTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{theirVpnIpNet}, nil, []string{"their new group"})
caB, err := ca.MarshalToPEM()
caB, err := ca.MarshalPEM()
if err != nil {
panic(err)
}
@@ -857,8 +859,7 @@ func TestRehandshakingLoser(t *testing.T) {
assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
_, theirNewGroup := theirCertInMe.Cert.Details.InvertedGroups["their new group"]
if theirNewGroup {
if slices.Contains(theirCertInMe.Cert.Groups(), "their new group") {
break
}
@@ -895,7 +896,7 @@ func TestRehandshakingLoser(t *testing.T) {
// Make sure the correct tunnel won
theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
assert.Contains(t, theirCertInMe.Cert.Details.Groups, "their new group")
assert.Contains(t, theirCertInMe.Cert.Groups(), "their new group")
// We should only have a single tunnel now on both sides
assert.Len(t, myFinalHostmapHosts, 1)

View File

@@ -3,7 +3,6 @@ package e2e
import (
"crypto/rand"
"io"
"net"
"net/netip"
"time"
@@ -13,7 +12,7 @@ import (
)
// NewTestCaCert will generate a CA cert
func NewTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
func NewTestCaCert(before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
@@ -22,56 +21,34 @@ func NewTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups
after = time.Now().Add(time.Second * 60).Round(time.Second)
}
nc := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "test ca",
NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub,
IsCA: true,
InvertedGroups: make(map[string]struct{}),
},
t := &cert.TBSCertificate{
Version: cert.Version1,
Name: "test ca",
NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub,
Networks: networks,
UnsafeNetworks: unsafeNetworks,
Groups: groups,
IsCA: true,
}
if len(ips) > 0 {
nc.Details.Ips = make([]*net.IPNet, len(ips))
for i, ip := range ips {
nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}
}
}
if len(subnets) > 0 {
nc.Details.Subnets = make([]*net.IPNet, len(subnets))
for i, ip := range subnets {
nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}
}
}
if len(groups) > 0 {
nc.Details.Groups = groups
}
err = nc.Sign(cert.Curve_CURVE25519, priv)
c, err := t.Sign(nil, cert.Curve_CURVE25519, priv)
if err != nil {
panic(err)
}
pem, err := nc.MarshalToPEM()
pem, err := c.MarshalPEM()
if err != nil {
panic(err)
}
return nc, pub, priv, pem
return c, pub, priv, pem
}
// NewTestCert will generate a signed certificate with the provided details.
// Expiry times are defaulted if you do not pass them in
func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip netip.Prefix, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) {
issuer, err := ca.Sha256Sum()
if err != nil {
panic(err)
}
func NewTestCert(ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) {
if before.IsZero() {
before = time.Now().Add(time.Second * -60).Round(time.Second)
}
@@ -81,33 +58,29 @@ func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, af
}
pub, rawPriv := x25519Keypair()
ipb := ip.Addr().AsSlice()
nc := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: name,
Ips: []*net.IPNet{{IP: ipb[:], Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}},
//Subnets: subnets,
Groups: groups,
NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub,
IsCA: false,
Issuer: issuer,
InvertedGroups: make(map[string]struct{}),
},
nc := &cert.TBSCertificate{
Version: cert.Version1,
Name: name,
Networks: networks,
UnsafeNetworks: unsafeNetworks,
Groups: groups,
NotBefore: time.Unix(before.Unix(), 0),
NotAfter: time.Unix(after.Unix(), 0),
PublicKey: pub,
IsCA: false,
}
err = nc.Sign(ca.Details.Curve, key)
c, err := nc.Sign(ca, ca.Curve(), key)
if err != nil {
panic(err)
}
pem, err := nc.MarshalToPEM()
pem, err := c.MarshalPEM()
if err != nil {
panic(err)
}
return nc, pub, cert.MarshalX25519PrivateKey(rawPriv), pem
return c, pub, cert.MarshalPrivateKeyToPEM(cert.Curve_CURVE25519, rawPriv), pem
}
func x25519Keypair() ([]byte, []byte) {

View File

@@ -26,7 +26,7 @@ import (
type m map[string]interface{}
// newSimpleServer creates a nebula instance with many assumptions
func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) {
func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) {
l := NewTestLogger()
vpnIpNet, err := netip.ParsePrefix(sVpnIpNet)
@@ -44,9 +44,9 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, s
budpIp[13] -= 128
udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
}
_, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
_, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{vpnIpNet}, nil, []string{})
caB, err := caCrt.MarshalToPEM()
caB, err := caCrt.MarshalPEM()
if err != nil {
panic(err)
}

View File

@@ -58,8 +58,8 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
var lines []string
var globalLines []*edge
clusterName := strings.Trim(c.GetCert().Details.Name, " ")
clusterVpnIp := c.GetCert().Details.Ips[0].IP
clusterName := strings.Trim(c.GetCert().Name(), " ")
clusterVpnIp := c.GetCert().Networks()[0].Addr()
r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp)
hm := c.GetHostmap()
@@ -102,7 +102,7 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
hi, ok := hm.Indexes[idx]
if ok {
r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp())
remoteClusterName := strings.Trim(hi.GetCert().Details.Name, " ")
remoteClusterName := strings.Trim(hi.GetCert().Certificate.Name(), " ")
globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())})
_ = hi
}