diff --git a/connection_manager.go b/connection_manager.go index 5a5a87e..b95750e 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -354,7 +354,6 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim if mainHostInfo { decision = tryRehandshake - } else { if cm.shouldSwapPrimary(hostinfo) { decision = swapPrimary @@ -554,8 +553,33 @@ func (cm *connectionManager) sendPunch(hostinfo *HostInfo) { func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) { cs := cm.intf.pki.getCertState() curCrt := hostinfo.ConnectionState.myCert - myCrt := cs.getCertificate(curCrt.Version()) - if myCrt != nil && curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true { + curCrtVersion := curCrt.Version() + myCrt := cs.getCertificate(curCrtVersion) + if myCrt == nil { + cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). + WithField("version", curCrtVersion). + WithField("reason", "local certificate removed"). + Info("Re-handshaking with remote") + cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) + return + } + peerCrt := hostinfo.ConnectionState.peerCert + if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() { + // if our certificate version is less than theirs, and we have a matching version available, rehandshake? + if cs.getCertificate(peerCrt.Certificate.Version()) != nil { + //todo trigger rehandshake with specific cert? + cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). + WithField("version", curCrtVersion). + WithField("peerVersion", peerCrt.Certificate.Version()). + WithField("reason", "local certificate version mismatch with peer, correcting"). + Info("Re-handshaking with remote") + cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) { + hh.initiatingVersionOverride = peerCrt.Certificate.Version() + }) + return + } + } + if curCrtVersion >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true { // The current tunnel is using the latest certificate and version, no need to rehandshake. return } diff --git a/e2e/tunnels_test.go b/e2e/tunnels_test.go index d8c8d36..27ea3d1 100644 --- a/e2e/tunnels_test.go +++ b/e2e/tunnels_test.go @@ -224,11 +224,13 @@ func TestCertDowngrade(t *testing.T) { for { assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) - if c == nil { + c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) + if c == nil || c2 == nil { r.Log("nil") } else { version := c.Cert.Version() - r.Logf("version %d", version) + theirVersion := c2.Cert.Version() + r.Logf("version %d,%d", version, theirVersion) if version == cert.Version1 { break } @@ -249,3 +251,70 @@ func TestCertDowngrade(t *testing.T) { myControl.Stop() theirControl.Stop() } + +func TestCertMismatchCorrection(t *testing.T) { + // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides + // under ideal conditions + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + + myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{}) + myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2) + + theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{}) + theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2) + + myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{}) + + // Share our underlay information + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + // Start the servers + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + defer r.RenderFlow() + + r.Log("Assert the tunnel between me and them works") + //assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) + //r.Log("yay") + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + r.Log("yay") + //todo ??? + time.Sleep(1 * time.Second) + r.FlushAll() + + waitStart := time.Now() + for { + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) + if c == nil || c2 == nil { + r.Log("nil") + } else { + version := c.Cert.Version() + theirVersion := c2.Cert.Version() + r.Logf("version %d,%d", version, theirVersion) + if version == theirVersion { + break + } + } + since := time.Since(waitStart) + if since > time.Second*5 { + r.Log("wtf") + } + if since > time.Second*10 { + r.Log("wtf") + t.Fatal("Cert should be new by now") + } + time.Sleep(time.Second) + } + + r.RenderHostmaps("Final hostmaps", myControl, theirControl) + + myControl.Stop() + theirControl.Stop() +} diff --git a/handshake_ix.go b/handshake_ix.go index 003bb1e..27f4b0f 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -23,9 +23,12 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { return false } - // If we're connecting to a v6 address we must use a v2 cert cs := f.pki.getCertState() v := cs.initiatingVersion + if hh.initiatingVersionOverride != cert.VersionPre1 { + v = hh.initiatingVersionOverride + } + // If we're connecting to a v6 address we must use a v2 cert for _, a := range hh.hostinfo.vpnAddrs { if a.Is6() { v = cert.Version2 diff --git a/handshake_manager.go b/handshake_manager.go index f92e72d..5f2c254 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -74,7 +74,8 @@ type HandshakeHostInfo struct { lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes - hostinfo *HostInfo + hostinfo *HostInfo + initiatingVersionOverride cert.Version } func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {