mirror of
https://github.com/slackhq/nebula.git
synced 2025-11-22 08:24:25 +01:00
even spicier change to rehandshake if we detect our cert is lower-version than our peer, and we have a newer-version cert available
This commit is contained in:
@@ -354,7 +354,6 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
||||
|
||||
if mainHostInfo {
|
||||
decision = tryRehandshake
|
||||
|
||||
} else {
|
||||
if cm.shouldSwapPrimary(hostinfo) {
|
||||
decision = swapPrimary
|
||||
@@ -554,8 +553,33 @@ func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
||||
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||
cs := cm.intf.pki.getCertState()
|
||||
curCrt := hostinfo.ConnectionState.myCert
|
||||
myCrt := cs.getCertificate(curCrt.Version())
|
||||
if myCrt != nil && curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
|
||||
curCrtVersion := curCrt.Version()
|
||||
myCrt := cs.getCertificate(curCrtVersion)
|
||||
if myCrt == nil {
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("version", curCrtVersion).
|
||||
WithField("reason", "local certificate removed").
|
||||
Info("Re-handshaking with remote")
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
|
||||
return
|
||||
}
|
||||
peerCrt := hostinfo.ConnectionState.peerCert
|
||||
if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
|
||||
// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
|
||||
if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
|
||||
//todo trigger rehandshake with specific cert?
|
||||
cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
|
||||
WithField("version", curCrtVersion).
|
||||
WithField("peerVersion", peerCrt.Certificate.Version()).
|
||||
WithField("reason", "local certificate version mismatch with peer, correcting").
|
||||
Info("Re-handshaking with remote")
|
||||
cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
|
||||
hh.initiatingVersionOverride = peerCrt.Certificate.Version()
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if curCrtVersion >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
|
||||
// The current tunnel is using the latest certificate and version, no need to rehandshake.
|
||||
return
|
||||
}
|
||||
|
||||
@@ -224,11 +224,13 @@ func TestCertDowngrade(t *testing.T) {
|
||||
for {
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||
if c == nil {
|
||||
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||
if c == nil || c2 == nil {
|
||||
r.Log("nil")
|
||||
} else {
|
||||
version := c.Cert.Version()
|
||||
r.Logf("version %d", version)
|
||||
theirVersion := c2.Cert.Version()
|
||||
r.Logf("version %d,%d", version, theirVersion)
|
||||
if version == cert.Version1 {
|
||||
break
|
||||
}
|
||||
@@ -249,3 +251,70 @@ func TestCertDowngrade(t *testing.T) {
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
func TestCertMismatchCorrection(t *testing.T) {
|
||||
// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
|
||||
// under ideal conditions
|
||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||
|
||||
myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
|
||||
myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
|
||||
theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
|
||||
|
||||
myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
|
||||
theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
|
||||
|
||||
// Share our underlay information
|
||||
myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
|
||||
theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
r := router.NewR(t, myControl, theirControl)
|
||||
defer r.RenderFlow()
|
||||
|
||||
r.Log("Assert the tunnel between me and them works")
|
||||
//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
|
||||
//r.Log("yay")
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
r.Log("yay")
|
||||
//todo ???
|
||||
time.Sleep(1 * time.Second)
|
||||
r.FlushAll()
|
||||
|
||||
waitStart := time.Now()
|
||||
for {
|
||||
assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
|
||||
c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
|
||||
c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
|
||||
if c == nil || c2 == nil {
|
||||
r.Log("nil")
|
||||
} else {
|
||||
version := c.Cert.Version()
|
||||
theirVersion := c2.Cert.Version()
|
||||
r.Logf("version %d,%d", version, theirVersion)
|
||||
if version == theirVersion {
|
||||
break
|
||||
}
|
||||
}
|
||||
since := time.Since(waitStart)
|
||||
if since > time.Second*5 {
|
||||
r.Log("wtf")
|
||||
}
|
||||
if since > time.Second*10 {
|
||||
r.Log("wtf")
|
||||
t.Fatal("Cert should be new by now")
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
r.RenderHostmaps("Final hostmaps", myControl, theirControl)
|
||||
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
}
|
||||
|
||||
@@ -23,9 +23,12 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// If we're connecting to a v6 address we must use a v2 cert
|
||||
cs := f.pki.getCertState()
|
||||
v := cs.initiatingVersion
|
||||
if hh.initiatingVersionOverride != cert.VersionPre1 {
|
||||
v = hh.initiatingVersionOverride
|
||||
}
|
||||
// If we're connecting to a v6 address we must use a v2 cert
|
||||
for _, a := range hh.hostinfo.vpnAddrs {
|
||||
if a.Is6() {
|
||||
v = cert.Version2
|
||||
|
||||
@@ -74,7 +74,8 @@ type HandshakeHostInfo struct {
|
||||
lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
|
||||
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
|
||||
|
||||
hostinfo *HostInfo
|
||||
hostinfo *HostInfo
|
||||
initiatingVersionOverride cert.Version
|
||||
}
|
||||
|
||||
func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
|
||||
|
||||
Reference in New Issue
Block a user