even spicier change to rehandshake if we detect our cert is lower-version than our peer, and we have a newer-version cert available

This commit is contained in:
JackDoan
2025-09-10 13:32:41 -05:00
parent 0946831f88
commit 41273a94bb
4 changed files with 104 additions and 7 deletions

View File

@@ -354,7 +354,6 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
if mainHostInfo {
decision = tryRehandshake
} else {
if cm.shouldSwapPrimary(hostinfo) {
decision = swapPrimary
@@ -554,8 +553,33 @@ func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
cs := cm.intf.pki.getCertState()
curCrt := hostinfo.ConnectionState.myCert
myCrt := cs.getCertificate(curCrt.Version())
if myCrt != nil && curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
curCrtVersion := curCrt.Version()
myCrt := cs.getCertificate(curCrtVersion)
if myCrt == nil {
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("version", curCrtVersion).
WithField("reason", "local certificate removed").
Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
return
}
peerCrt := hostinfo.ConnectionState.peerCert
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
//todo trigger rehandshake with specific cert?
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
WithField("version", curCrtVersion).
WithField("peerVersion", peerCrt.Certificate.Version()).
WithField("reason", "local certificate version mismatch with peer, correcting").
Info("Re-handshaking with remote")
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
})
return
}
}
if curCrtVersion >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
// The current tunnel is using the latest certificate and version, no need to rehandshake.
return
}

View File

@@ -224,11 +224,13 @@ func TestCertDowngrade(t *testing.T) {
for {
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
if c == nil {
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
if c == nil || c2 == nil {
r.Log("nil")
} else {
version := c.Cert.Version()
r.Logf("version %d", version)
theirVersion := c2.Cert.Version()
r.Logf("version %d,%d", version, theirVersion)
if version == cert.Version1 {
break
}
@@ -249,3 +251,70 @@ func TestCertDowngrade(t *testing.T) {
myControl.Stop()
theirControl.Stop()
}
func TestCertMismatchCorrection(t *testing.T) {
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
// under ideal conditions
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
// Share our underlay information
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
// Start the servers
myControl.Start()
theirControl.Start()
r := router.NewR(t, myControl, theirControl)
defer r.RenderFlow()
r.Log("Assert the tunnel between me and them works")
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
//r.Log("yay")
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
r.Log("yay")
//todo ???
time.Sleep(1 * time.Second)
r.FlushAll()
waitStart := time.Now()
for {
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
if c == nil || c2 == nil {
r.Log("nil")
} else {
version := c.Cert.Version()
theirVersion := c2.Cert.Version()
r.Logf("version %d,%d", version, theirVersion)
if version == theirVersion {
break
}
}
since := time.Since(waitStart)
if since > time.Second*5 {
r.Log("wtf")
}
if since > time.Second*10 {
r.Log("wtf")
t.Fatal("Cert should be new by now")
}
time.Sleep(time.Second)
}
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
myControl.Stop()
theirControl.Stop()
}

View File

@@ -23,9 +23,12 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
return false
}
// If we're connecting to a v6 address we must use a v2 cert
cs := f.pki.getCertState()
v := cs.initiatingVersion
if hh.initiatingVersionOverride != cert.VersionPre1 {
v = hh.initiatingVersionOverride
}
// If we're connecting to a v6 address we must use a v2 cert
for _, a := range hh.hostinfo.vpnAddrs {
if a.Is6() {
v = cert.Version2

View File

@@ -74,7 +74,8 @@ type HandshakeHostInfo struct {
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
hostinfo *HostInfo
hostinfo *HostInfo
initiatingVersionOverride cert.Version
}
func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {