diff --git a/cert_test/cert.go b/cert_test/cert.go index ebc6f52..7513431 100644 --- a/cert_test/cert.go +++ b/cert_test/cert.go @@ -114,6 +114,33 @@ func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []by return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem } +func NewTestCertDifferentVersion(c cert.Certificate, v cert.Version, ca cert.Certificate, key []byte) (cert.Certificate, []byte) { + nc := &cert.TBSCertificate{ + Version: v, + Curve: c.Curve(), + Name: c.Name(), + Networks: c.Networks(), + UnsafeNetworks: c.UnsafeNetworks(), + Groups: c.Groups(), + NotBefore: time.Unix(c.NotBefore().Unix(), 0), + NotAfter: time.Unix(c.NotAfter().Unix(), 0), + PublicKey: c.PublicKey(), + IsCA: false, + } + + c, err := nc.Sign(ca, ca.Curve(), key) + if err != nil { + panic(err) + } + + pem, err := c.MarshalPEM() + if err != nil { + panic(err) + } + + return c, pem +} + func X25519Keypair() ([]byte, []byte) { privkey := make([]byte, 32) if _, err := io.ReadFull(rand.Reader, privkey); err != nil { diff --git a/connection_manager.go b/connection_manager.go index 7242c72..4c2f26e 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -354,7 +354,6 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim if mainHostInfo { decision = tryRehandshake - } else { if cm.shouldSwapPrimary(hostinfo) { decision = swapPrimary @@ -461,6 +460,10 @@ func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool { } crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version()) + if crt == nil { + //my cert was reloaded away. We should definitely swap from this tunnel + return true + } // If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things // settle down. return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature()) @@ -475,31 +478,34 @@ func (cm *connectionManager) swapPrimary(current, primary *HostInfo) { cm.hostMap.Unlock() } -// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and -// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid -// check and return true. +// isInvalidCertificate decides if we should destroy a tunnel. +// returns true if pki.disconnect_invalid is true and the certificate is no longer valid. +// Blocklisted certificates will skip the pki.disconnect_invalid check and return true. func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool { remoteCert := hostinfo.GetCert() if remoteCert == nil { - return false + return false //don't tear down tunnels for handshakes in progress } caPool := cm.intf.pki.GetCAPool() err := caPool.VerifyCachedCertificate(now, remoteCert) if err == nil { - return false - } - - if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed { + return false //cert is still valid! yay! + } else if err == cert.ErrBlockListed { //avoiding errors.Is for speed // Block listed certificates should always be disconnected + hostinfo.logger(cm.l).WithError(err). + WithField("fingerprint", remoteCert.Fingerprint). + Info("Remote certificate is blocked, tearing down the tunnel") + return true + } else if cm.intf.disconnectInvalid.Load() { + hostinfo.logger(cm.l).WithError(err). + WithField("fingerprint", remoteCert.Fingerprint). + Info("Remote certificate is no longer valid, tearing down the tunnel") + return true + } else { + //if we reach here, the cert is no longer valid, but we're configured to keep tunnels from now-invalid certs open return false } - - hostinfo.logger(cm.l).WithError(err). - WithField("fingerprint", remoteCert.Fingerprint). - Info("Remote certificate is no longer valid, tearing down the tunnel") - - return true } func (cm *connectionManager) sendPunch(hostinfo *HostInfo) { @@ -530,15 +536,45 @@ func (cm *connectionManager) sendPunch(hostinfo *HostInfo) { func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) { cs := cm.intf.pki.getCertState() curCrt := hostinfo.ConnectionState.myCert - myCrt := cs.getCertificate(curCrt.Version()) - if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true { - // The current tunnel is using the latest certificate and version, no need to rehandshake. + curCrtVersion := curCrt.Version() + myCrt := cs.getCertificate(curCrtVersion) + if myCrt == nil { + cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). + WithField("version", curCrtVersion). + WithField("reason", "local certificate removed"). + Info("Re-handshaking with remote") + cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) return } + peerCrt := hostinfo.ConnectionState.peerCert + if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() { + // if our certificate version is less than theirs, and we have a matching version available, rehandshake? + if cs.getCertificate(peerCrt.Certificate.Version()) != nil { + cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). + WithField("version", curCrtVersion). + WithField("peerVersion", peerCrt.Certificate.Version()). + WithField("reason", "local certificate version lower than peer, attempting to correct"). + Info("Re-handshaking with remote") + cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) { + hh.initiatingVersionOverride = peerCrt.Certificate.Version() + }) + return + } + } + if !bytes.Equal(curCrt.Signature(), myCrt.Signature()) { + cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). + WithField("reason", "local certificate is not current"). + Info("Re-handshaking with remote") - cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). - WithField("reason", "local certificate is not current"). - Info("Re-handshaking with remote") + cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) + return + } + if curCrtVersion < cs.initiatingVersion { + cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). + WithField("reason", "current cert version < pki.initiatingVersion"). + Info("Re-handshaking with remote") - cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) + cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) + return + } } diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index a63b3d0..3014096 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -129,6 +129,109 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name return control, vpnNetworks, udpAddr, c } +// newServer creates a nebula instance with fewer assumptions +func newServer(caCrt []cert.Certificate, certs []cert.Certificate, key []byte, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { + l := NewTestLogger() + + vpnNetworks := certs[len(certs)-1].Networks() + + var udpAddr netip.AddrPort + if vpnNetworks[0].Addr().Is4() { + budpIp := vpnNetworks[0].Addr().As4() + budpIp[1] -= 128 + udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242) + } else { + budpIp := vpnNetworks[0].Addr().As16() + // beef for funsies + budpIp[2] = 190 + budpIp[3] = 239 + udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) + } + + caStr := "" + for _, ca := range caCrt { + x, err := ca.MarshalPEM() + if err != nil { + panic(err) + } + caStr += string(x) + } + certStr := "" + for _, c := range certs { + x, err := c.MarshalPEM() + if err != nil { + panic(err) + } + certStr += string(x) + } + + mc := m{ + "pki": m{ + "ca": caStr, + "cert": certStr, + "key": string(key), + }, + //"tun": m{"disabled": true}, + "firewall": m{ + "outbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + }}, + "inbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + }}, + }, + //"handshakes": m{ + // "try_interval": "1s", + //}, + "listen": m{ + "host": udpAddr.Addr().String(), + "port": udpAddr.Port(), + }, + "logging": m{ + "timestamp_format": fmt.Sprintf("%v 15:04:05.000000", certs[0].Name()), + "level": l.Level.String(), + }, + "timers": m{ + "pending_deletion_interval": 2, + "connection_alive_interval": 2, + }, + } + + if overrides != nil { + final := m{} + err := mergo.Merge(&final, overrides, mergo.WithAppendSlice) + if err != nil { + panic(err) + } + err = mergo.Merge(&final, mc, mergo.WithAppendSlice) + if err != nil { + panic(err) + } + mc = final + } + + cb, err := yaml.Marshal(mc) + if err != nil { + panic(err) + } + + c := config.NewC(l) + cStr := string(cb) + c.LoadString(cStr) + + control, err := nebula.Main(c, false, "e2e-test", l, nil) + + if err != nil { + panic(err) + } + + return control, vpnNetworks, udpAddr, c +} + type doneCb func() func deadline(t *testing.T, seconds time.Duration) doneCb { diff --git a/e2e/tunnels_test.go b/e2e/tunnels_test.go index 55974f0..f1e9ca7 100644 --- a/e2e/tunnels_test.go +++ b/e2e/tunnels_test.go @@ -4,12 +4,16 @@ package e2e import ( + "fmt" + "net/netip" "testing" "time" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/e2e/router" + "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v3" ) func TestDropInactiveTunnels(t *testing.T) { @@ -55,3 +59,262 @@ func TestDropInactiveTunnels(t *testing.T) { myControl.Stop() theirControl.Stop() } + +func TestCertUpgrade(t *testing.T) { + // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides + // under ideal conditions + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + caB, err := ca.MarshalPEM() + if err != nil { + panic(err) + } + ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + ca2B, err := ca2.MarshalPEM() + if err != nil { + panic(err) + } + caStr := fmt.Sprintf("%s\n%s", caB, ca2B) + + myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{}) + _, myCert2Pem := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2) + + theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{}) + theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2) + + myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert}, myPrivKey, m{}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{}) + + // Share our underlay information + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + // Start the servers + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + r.Log("Assert the tunnel between me and them works") + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + r.Log("yay") + //todo ??? + time.Sleep(1 * time.Second) + r.FlushAll() + + mc := m{ + "pki": m{ + "ca": caStr, + "cert": string(myCert2Pem), + "key": string(myPrivKey), + }, + //"tun": m{"disabled": true}, + "firewall": myC.Settings["firewall"], + //"handshakes": m{ + // "try_interval": "1s", + //}, + "listen": myC.Settings["listen"], + "logging": myC.Settings["logging"], + "timers": myC.Settings["timers"], + } + + cb, err := yaml.Marshal(mc) + if err != nil { + panic(err) + } + + r.Logf("reload new v2-only config") + err = myC.ReloadConfigString(string(cb)) + assert.NoError(t, err) + r.Log("yay, spin until their sees it") + waitStart := time.Now() + for { + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + if c == nil { + r.Log("nil") + } else { + version := c.Cert.Version() + r.Logf("version %d", version) + if version == cert.Version2 { + break + } + } + since := time.Since(waitStart) + if since > time.Second*10 { + t.Fatal("Cert should be new by now") + } + time.Sleep(time.Second) + } + + r.RenderHostmaps("Final hostmaps", myControl, theirControl) + + myControl.Stop() + theirControl.Stop() +} + +func TestCertDowngrade(t *testing.T) { + // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides + // under ideal conditions + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + caB, err := ca.MarshalPEM() + if err != nil { + panic(err) + } + ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + ca2B, err := ca2.MarshalPEM() + if err != nil { + panic(err) + } + caStr := fmt.Sprintf("%s\n%s", caB, ca2B) + + myCert, _, myPrivKey, myCertPem := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{}) + myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2) + + theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{}) + theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2) + + myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{}) + + // Share our underlay information + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + // Start the servers + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + r.Log("Assert the tunnel between me and them works") + //assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) + //r.Log("yay") + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + r.Log("yay") + //todo ??? + time.Sleep(1 * time.Second) + r.FlushAll() + + mc := m{ + "pki": m{ + "ca": caStr, + "cert": string(myCertPem), + "key": string(myPrivKey), + }, + "firewall": myC.Settings["firewall"], + "listen": myC.Settings["listen"], + "logging": myC.Settings["logging"], + "timers": myC.Settings["timers"], + } + + cb, err := yaml.Marshal(mc) + if err != nil { + panic(err) + } + + r.Logf("reload new v1-only config") + err = myC.ReloadConfigString(string(cb)) + assert.NoError(t, err) + r.Log("yay, spin until their sees it") + waitStart := time.Now() + for { + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) + if c == nil || c2 == nil { + r.Log("nil") + } else { + version := c.Cert.Version() + theirVersion := c2.Cert.Version() + r.Logf("version %d,%d", version, theirVersion) + if version == cert.Version1 { + break + } + } + since := time.Since(waitStart) + if since > time.Second*5 { + r.Log("it is unusual that the cert is not new yet, but not a failure yet") + } + if since > time.Second*10 { + r.Log("wtf") + t.Fatal("Cert should be new by now") + } + time.Sleep(time.Second) + } + + r.RenderHostmaps("Final hostmaps", myControl, theirControl) + + myControl.Stop() + theirControl.Stop() +} + +func TestCertMismatchCorrection(t *testing.T) { + // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides + // under ideal conditions + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{}) + myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2) + + theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{}) + theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2) + + myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{}) + + // Share our underlay information + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + // Start the servers + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + r.Log("Assert the tunnel between me and them works") + //assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) + //r.Log("yay") + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + r.Log("yay") + //todo ??? + time.Sleep(1 * time.Second) + r.FlushAll() + + waitStart := time.Now() + for { + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) + if c == nil || c2 == nil { + r.Log("nil") + } else { + version := c.Cert.Version() + theirVersion := c2.Cert.Version() + r.Logf("version %d,%d", version, theirVersion) + if version == theirVersion { + break + } + } + since := time.Since(waitStart) + if since > time.Second*5 { + r.Log("wtf") + } + if since > time.Second*10 { + r.Log("wtf") + t.Fatal("Cert should be new by now") + } + time.Sleep(time.Second) + } + + r.RenderHostmaps("Final hostmaps", myControl, theirControl) + + myControl.Stop() + theirControl.Stop() +} diff --git a/handshake_ix.go b/handshake_ix.go index 026bfbd..00b1d40 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -23,13 +23,17 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { return false } - // If we're connecting to a v6 address we must use a v2 cert cs := f.pki.getCertState() v := cs.initiatingVersion - for _, a := range hh.hostinfo.vpnAddrs { - if a.Is6() { - v = cert.Version2 - break + if hh.initiatingVersionOverride != cert.VersionPre1 { + v = hh.initiatingVersionOverride + } else if v < cert.Version2 { + // If we're connecting to a v6 address we should encourage use of a V2 cert + for _, a := range hh.hostinfo.vpnAddrs { + if a.Is6() { + v = cert.Version2 + break + } } } @@ -48,6 +52,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). WithField("certVersion", v). Error("Unable to handshake with host because no certificate handshake bytes is available") + return false } ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX) @@ -103,6 +108,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). WithField("certVersion", cs.initiatingVersion). Error("Unable to handshake with host because no certificate is available") + return } ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX) @@ -143,8 +149,8 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc) if err != nil { - fp, err := rc.Fingerprint() - if err != nil { + fp, fperr := rc.Fingerprint() + if fperr != nil { fp = "" } @@ -163,16 +169,19 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if remoteCert.Certificate.Version() != ci.myCert.Version() { // We started off using the wrong certificate version, lets see if we can match the version that was sent to us - rc := cs.getCertificate(remoteCert.Certificate.Version()) - if rc == nil { - f.l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert). - Info("Unable to handshake with host due to missing certificate version") - return + myCertOtherVersion := cs.getCertificate(remoteCert.Certificate.Version()) + if myCertOtherVersion == nil { + if f.l.Level >= logrus.DebugLevel { + f.l.WithError(err).WithFields(m{ + "udpAddr": addr, + "handshake": m{"stage": 1, "style": "ix_psk0"}, + "cert": remoteCert, + }).Debug("Might be unable to handshake with host due to missing certificate version") + } + } else { + // Record the certificate we are actually using + ci.myCert = myCertOtherVersion } - - // Record the certificate we are actually using - ci.myCert = rc } if len(remoteCert.Certificate.Networks()) == 0 { diff --git a/handshake_manager.go b/handshake_manager.go index f92e72d..ee72d71 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -68,11 +68,12 @@ type HandshakeManager struct { type HandshakeHostInfo struct { sync.Mutex - startTime time.Time // Time that we first started trying with this handshake - ready bool // Is the handshake ready - counter int64 // How many attempts have we made so far - lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt - packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes + startTime time.Time // Time that we first started trying with this handshake + ready bool // Is the handshake ready + initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake? + counter int64 // How many attempts have we made so far + lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt + packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes hostinfo *HostInfo } diff --git a/pki.go b/pki.go index e71d326..e6e2839 100644 --- a/pki.go +++ b/pki.go @@ -100,55 +100,62 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { currentState := p.cs.Load() if newState.v1Cert != nil { if currentState.v1Cert == nil { - return util.NewContextualError("v1 certificate was added, restart required", nil, err) - } + //adding certs is fine, actually. Networks-in-common confirmed in newCertState(). + } else { + // did IP in cert change? if so, don't set + if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) { + return util.NewContextualError( + "Networks in new cert was different from old", + m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks(), "cert_version": cert.Version1}, + nil, + ) + } - // did IP in cert change? if so, don't set - if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) { - return util.NewContextualError( - "Networks in new cert was different from old", - m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()}, - nil, - ) + if currentState.v1Cert.Curve() != newState.v1Cert.Curve() { + return util.NewContextualError( + "Curve in new v1 cert was different from old", + m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve(), "cert_version": cert.Version1}, + nil, + ) + } } - - if currentState.v1Cert.Curve() != newState.v1Cert.Curve() { - return util.NewContextualError( - "Curve in new cert was different from old", - m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()}, - nil, - ) - } - - } else if currentState.v1Cert != nil { - //TODO: CERT-V2 we should be able to tear this down - return util.NewContextualError("v1 certificate was removed, restart required", nil, err) } if newState.v2Cert != nil { if currentState.v2Cert == nil { - return util.NewContextualError("v2 certificate was added, restart required", nil, err) - } + //adding certs is fine, actually + } else { + // did IP in cert change? if so, don't set + if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) { + return util.NewContextualError( + "Networks in new cert was different from old", + m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks(), "cert_version": cert.Version2}, + nil, + ) + } - // did IP in cert change? if so, don't set - if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) { - return util.NewContextualError( - "Networks in new cert was different from old", - m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()}, - nil, - ) - } - - if currentState.v2Cert.Curve() != newState.v2Cert.Curve() { - return util.NewContextualError( - "Curve in new cert was different from old", - m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()}, - nil, - ) + if currentState.v2Cert.Curve() != newState.v2Cert.Curve() { + return util.NewContextualError( + "Curve in new cert was different from old", + m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve(), "cert_version": cert.Version2}, + nil, + ) + } } } else if currentState.v2Cert != nil { - return util.NewContextualError("v2 certificate was removed, restart required", nil, err) + //newState.v1Cert is non-nil bc empty certstates aren't permitted + if newState.v1Cert == nil { + return util.NewContextualError("v1 and v2 certs are nil, this should be impossible", nil, err) + } + //if we're going to v1-only, we need to make sure we didn't orphan any v2-cert vpnaddrs + if !slices.Equal(currentState.v2Cert.Networks(), newState.v1Cert.Networks()) { + return util.NewContextualError( + "Removing a V2 cert is not permitted unless it has identical networks to the new V1 cert", + m{"new_v1_networks": newState.v1Cert.Networks(), "old_v2_networks": currentState.v2Cert.Networks()}, + nil, + ) + } } // Cipher cant be hot swapped so just leave it at what it was before