Rehandshaking (#838)

Co-authored-by: Brad Higgins <brad@defined.net>
Co-authored-by: Wade Simmons <wadey@slack-corp.com>
This commit is contained in:
Nate Brown
2023-05-04 15:16:37 -05:00
committed by GitHub
parent 0b67b19771
commit 03e4a7f988
15 changed files with 761 additions and 172 deletions

View File

@@ -1,6 +1,7 @@
package nebula
import (
"bytes"
"context"
"sync"
"time"
@@ -8,9 +9,20 @@ import (
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
)
type trafficDecision int
const (
doNothing trafficDecision = 0
deleteTunnel trafficDecision = 1 // delete the hostinfo on our side, do not notify the remote
closeTunnel trafficDecision = 2 // delete the hostinfo and notify the remote
swapPrimary trafficDecision = 3
migrateRelays trafficDecision = 4
)
type connectionManager struct {
in map[uint32]struct{}
inLock *sync.RWMutex
@@ -18,6 +30,10 @@ type connectionManager struct {
out map[uint32]struct{}
outLock *sync.RWMutex
// relayUsed holds which relay localIndexs are in use
relayUsed map[uint32]struct{}
relayUsedLock *sync.RWMutex
hostMap *HostMap
trafficTimer *LockingTimerWheel[uint32]
intf *Interface
@@ -44,6 +60,8 @@ func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface
inLock: &sync.RWMutex{},
out: make(map[uint32]struct{}),
outLock: &sync.RWMutex{},
relayUsed: make(map[uint32]struct{}),
relayUsedLock: &sync.RWMutex{},
trafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, max),
intf: intf,
pendingDeletion: make(map[uint32]struct{}),
@@ -84,6 +102,19 @@ func (n *connectionManager) Out(localIndex uint32) {
n.outLock.Unlock()
}
func (n *connectionManager) RelayUsed(localIndex uint32) {
n.relayUsedLock.RLock()
// If this already exists, return
if _, ok := n.relayUsed[localIndex]; ok {
n.relayUsedLock.RUnlock()
return
}
n.relayUsedLock.RUnlock()
n.relayUsedLock.Lock()
n.relayUsed[localIndex] = struct{}{}
n.relayUsedLock.Unlock()
}
// getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
// resets the state for this local index
func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) {
@@ -99,8 +130,15 @@ func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bo
}
func (n *connectionManager) AddTrafficWatch(localIndex uint32) {
n.Out(localIndex)
// Use a write lock directly because it should be incredibly rare that we are ever already tracking this index
n.outLock.Lock()
if _, ok := n.out[localIndex]; ok {
n.outLock.Unlock()
return
}
n.out[localIndex] = struct{}{}
n.trafficTimer.Add(localIndex, n.checkInterval)
n.outLock.Unlock()
}
func (n *connectionManager) Start(ctx context.Context) {
@@ -136,18 +174,130 @@ func (n *connectionManager) Run(ctx context.Context) {
}
func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
hostinfo, err := n.hostMap.QueryIndex(localIndex)
if err != nil {
decision, hostinfo, primary := n.makeTrafficDecision(localIndex, p, nb, out, now)
switch decision {
case deleteTunnel:
n.hostMap.DeleteHostInfo(hostinfo)
case closeTunnel:
n.intf.sendCloseTunnel(hostinfo)
n.intf.closeTunnel(hostinfo)
case swapPrimary:
n.swapPrimary(hostinfo, primary)
case migrateRelays:
n.migrateRelayUsed(hostinfo, primary)
}
n.resetRelayTrafficCheck(hostinfo)
}
func (n *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
if hostinfo != nil {
n.relayUsedLock.Lock()
defer n.relayUsedLock.Unlock()
// No need to migrate any relays, delete usage info now.
for _, idx := range hostinfo.relayState.CopyRelayForIdxs() {
delete(n.relayUsed, idx)
}
}
}
func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
relayFor := oldhostinfo.relayState.CopyAllRelayFor()
for _, r := range relayFor {
existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp)
var index uint32
var relayFrom iputil.VpnIp
var relayTo iputil.VpnIp
switch {
case ok && existing.State == Established:
// This relay already exists in newhostinfo, then do nothing.
continue
case ok && existing.State == Requested:
// The relay exists in a Requested state; re-send the request
index = existing.LocalIndex
switch r.Type {
case TerminalType:
relayFrom = newhostinfo.vpnIp
relayTo = existing.PeerIp
case ForwardingType:
relayFrom = existing.PeerIp
relayTo = newhostinfo.vpnIp
default:
// should never happen
}
case !ok:
n.relayUsedLock.RLock()
if _, relayUsed := n.relayUsed[r.LocalIndex]; !relayUsed {
// The relay hasn't been used; don't migrate it.
n.relayUsedLock.RUnlock()
continue
}
n.relayUsedLock.RUnlock()
// The relay doesn't exist at all; create some relay state and send the request.
var err error
index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested)
if err != nil {
n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
continue
}
switch r.Type {
case TerminalType:
relayFrom = newhostinfo.vpnIp
relayTo = r.PeerIp
case ForwardingType:
relayFrom = r.PeerIp
relayTo = newhostinfo.vpnIp
default:
// should never happen
}
}
// Send a CreateRelayRequest to the peer.
req := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: index,
RelayFromIp: uint32(relayFrom),
RelayToIp: uint32(relayTo),
}
msg, err := req.Marshal()
if err != nil {
n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
} else {
n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
n.l.WithFields(logrus.Fields{
"relayFrom": iputil.VpnIp(req.RelayFromIp),
"relayTo": iputil.VpnIp(req.RelayToIp),
"initiatorRelayIndex": req.InitiatorRelayIndex,
"responderRelayIndex": req.ResponderRelayIndex,
"vpnIp": newhostinfo.vpnIp}).
Info("send CreateRelayRequest")
}
}
}
func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out []byte, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
n.hostMap.RLock()
defer n.hostMap.RUnlock()
hostinfo := n.hostMap.Indexes[localIndex]
if hostinfo == nil {
n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
delete(n.pendingDeletion, localIndex)
return
return doNothing, nil, nil
}
if n.handleInvalidCertificate(now, hostinfo) {
return
if n.isInvalidCertificate(now, hostinfo) {
delete(n.pendingDeletion, hostinfo.localIndexId)
return closeTunnel, hostinfo, nil
}
primary, _ := n.hostMap.QueryVpnIp(hostinfo.vpnIp)
primary := n.hostMap.Hosts[hostinfo.vpnIp]
mainHostInfo := true
if primary != nil && primary != hostinfo {
mainHostInfo = false
@@ -158,6 +308,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
// A hostinfo is determined alive if there is incoming traffic
if inTraffic {
decision := doNothing
if n.l.Level >= logrus.DebugLevel {
hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
@@ -165,11 +316,14 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
}
delete(n.pendingDeletion, hostinfo.localIndexId)
if !mainHostInfo {
if hostinfo.vpnIp > n.intf.myVpnIp {
// We are receiving traffic on the non primary hostinfo and we really just want 1 tunnel. Make
// This the primary and prime the old primary hostinfo for testing
n.hostMap.MakePrimary(hostinfo)
if mainHostInfo {
n.tryRehandshake(hostinfo)
} else {
if n.shouldSwapPrimary(hostinfo, primary) {
decision = swapPrimary
} else {
// migrate the relays to the primary, if in use.
decision = migrateRelays
}
}
@@ -180,7 +334,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
n.sendPunch(hostinfo)
}
return
return decision, hostinfo, primary
}
if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok {
@@ -189,22 +343,17 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
Info("Tunnel status")
n.hostMap.DeleteHostInfo(hostinfo)
delete(n.pendingDeletion, hostinfo.localIndexId)
return
return deleteTunnel, hostinfo, nil
}
hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
Debug("Tunnel status")
if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
if !outTraffic {
// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
// Just maintain NAT state if configured to do so.
n.sendPunch(hostinfo)
n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
return
return doNothing, nil, nil
}
@@ -218,22 +367,58 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
if n.intf.lightHouse.IsLighthouseIP(hostinfo.vpnIp) {
// We are sending traffic to the lighthouse, let recv_error sort out any issues instead of testing the tunnel
n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
return
return doNothing, nil, nil
}
if n.l.Level >= logrus.DebugLevel {
hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
Debug("Tunnel status")
}
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
n.intf.sendMessageToVpnIp(header.Test, header.TestRequest, hostinfo, p, nb, out)
n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
} else {
hostinfo.logger(n.l).Debugf("Hostinfo sadness")
if n.l.Level >= logrus.DebugLevel {
hostinfo.logger(n.l).Debugf("Hostinfo sadness")
}
}
n.pendingDeletion[hostinfo.localIndexId] = struct{}{}
n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval)
return doNothing, nil, nil
}
// handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid
func (n *connectionManager) handleInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
// The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
// Let's sort this out.
if current.vpnIp < n.intf.myVpnIp {
// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
// The remotes vpn ip is lower than mine. I will not flip.
return false
}
certState := n.intf.certState.Load()
return bytes.Equal(current.ConnectionState.certState.certificate.Signature, certState.certificate.Signature)
}
func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
n.hostMap.Lock()
// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
if n.hostMap.Hosts[current.vpnIp] == primary {
n.hostMap.unlockedMakePrimary(current)
}
n.hostMap.Unlock()
}
// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
// the certificate is no longer valid
func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
if !n.intf.disconnectInvalid {
return false
}
@@ -253,10 +438,6 @@ func (n *connectionManager) handleInvalidCertificate(now time.Time, hostinfo *Ho
WithField("fingerprint", fingerprint).
Info("Remote certificate is no longer valid, tearing down the tunnel")
// Inform the remote and close the tunnel locally
n.intf.sendCloseTunnel(hostinfo)
n.intf.closeTunnel(hostinfo)
delete(n.pendingDeletion, hostinfo.localIndexId)
return true
}
@@ -277,3 +458,29 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
}
}
func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
certState := n.intf.certState.Load()
if bytes.Equal(hostinfo.ConnectionState.certState.certificate.Signature, certState.certificate.Signature) {
return
}
n.l.WithField("vpnIp", hostinfo.vpnIp).
WithField("reason", "local certificate is not current").
Info("Re-handshaking with remote")
//TODO: this is copied from getOrHandshake to keep the extra checks out of the hot path, figure it out
newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp, n.intf.initHostInfo)
if !newHostinfo.HandshakeReady {
ixHandshakeStage0(n.intf, newHostinfo.vpnIp, newHostinfo)
}
//If this is a static host, we don't need to wait for the HostQueryReply
//We can trigger the handshake right now
if _, ok := n.intf.lightHouse.GetStaticHostList()[hostinfo.vpnIp]; ok {
select {
case n.intf.handshakeManager.trigger <- hostinfo.vpnIp:
default:
}
}
}