diff --git a/connection_manager.go b/connection_manager.go index 08937e6..5a5a87e 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -461,6 +461,10 @@ func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool { } crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version()) + if crt == nil { + //my cert was reloaded away. We should definitely swap from this tunnel + return true + } // If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things // settle down. return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature()) @@ -551,7 +555,7 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) { cs := cm.intf.pki.getCertState() curCrt := hostinfo.ConnectionState.myCert myCrt := cs.getCertificate(curCrt.Version()) - if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true { + if myCrt != nil && curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true { // The current tunnel is using the latest certificate and version, no need to rehandshake. return } diff --git a/e2e/tunnels_test.go b/e2e/tunnels_test.go index af2c561..76d22f8 100644 --- a/e2e/tunnels_test.go +++ b/e2e/tunnels_test.go @@ -153,3 +153,99 @@ func TestCertUpgrade(t *testing.T) { myControl.Stop() theirControl.Stop() } + +func TestCertDowngrade(t *testing.T) { + // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides + // under ideal conditions + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + caB, err := ca.MarshalPEM() + if err != nil { + panic(err) + } + ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + ca2B, err := ca2.MarshalPEM() + if err != nil { + panic(err) + } + caStr := fmt.Sprintf("%s\n%s", caB, ca2B) + + myCert, _, myPrivKey, myCertPem := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{}) + myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2) + + theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{}) + theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2) + + myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{}) + + // Share our underlay information + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + // Start the servers + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + r.Log("Assert the tunnel between me and them works") + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) + r.Log("yay") + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + r.Log("yay") + //todo ??? + time.Sleep(1 * time.Second) + r.FlushAll() + + mc := m{ + "pki": m{ + "ca": caStr, + "cert": string(myCertPem), + "key": string(myPrivKey), + }, + "firewall": myC.Settings["firewall"], + "listen": myC.Settings["listen"], + "logging": myC.Settings["logging"], + "timers": myC.Settings["timers"], + } + + cb, err := yaml.Marshal(mc) + if err != nil { + panic(err) + } + + r.Logf("reload new v1-only config") + err = myC.ReloadConfigString(string(cb)) + assert.NoError(t, err) + r.Log("yay, spin until their sees it") + waitStart := time.Now() + for { + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + if c == nil { + r.Log("nil") + } else { + version := c.Cert.Version() + r.Logf("version %d", version) + if version == cert.Version1 { + break + } + } + since := time.Since(waitStart) + if since > time.Second*5 { + r.Log("wtf") + } + if since > time.Second*10 { + r.Log("wtf") + t.Fatal("Cert should be new by now") + } + time.Sleep(time.Second) + } + + r.RenderHostmaps("Final hostmaps", myControl, theirControl) + + myControl.Stop() + theirControl.Stop() +}