From 52623820c2be9571bb46acd16f5afd7811fe6542 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Thu, 3 Jul 2025 09:58:37 -0500 Subject: [PATCH] Drop inactive tunnels (#1427) --- connection_manager.go | 383 ++++++++++++++++++++----------------- connection_manager_test.go | 176 +++++++++++++---- control.go | 20 +- e2e/handshakes_test.go | 5 +- e2e/router/router.go | 1 + e2e/tunnels_test.go | 57 ++++++ examples/config.yml | 12 ++ handshake_ix.go | 4 +- hostmap.go | 8 + inside.go | 4 +- interface.go | 41 ++-- main.go | 45 ++--- outside.go | 6 +- udp/udp_darwin.go | 2 - 14 files changed, 485 insertions(+), 279 deletions(-) create mode 100644 e2e/tunnels_test.go diff --git a/connection_manager.go b/connection_manager.go index f3acc92..1f9b18b 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -7,11 +7,13 @@ import ( "fmt" "net/netip" "sync" + "sync/atomic" "time" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" ) @@ -28,130 +30,124 @@ const ( ) type connectionManager struct { - in map[uint32]struct{} - inLock *sync.RWMutex - - out map[uint32]struct{} - outLock *sync.RWMutex - // relayUsed holds which relay localIndexs are in use relayUsed map[uint32]struct{} relayUsedLock *sync.RWMutex - hostMap *HostMap - trafficTimer *LockingTimerWheel[uint32] - intf *Interface - pendingDeletion map[uint32]struct{} - punchy *Punchy + hostMap *HostMap + trafficTimer *LockingTimerWheel[uint32] + intf *Interface + punchy *Punchy + + // Configuration settings checkInterval time.Duration pendingDeletionInterval time.Duration - metricsTxPunchy metrics.Counter + inactivityTimeout atomic.Int64 + dropInactive atomic.Bool + + metricsTxPunchy metrics.Counter l *logrus.Logger } -func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval time.Duration, punchy *Punchy) *connectionManager { - var max time.Duration - if checkInterval < pendingDeletionInterval { - max = pendingDeletionInterval - } else { - max = checkInterval +func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager { + cm := &connectionManager{ + hostMap: hm, + l: l, + punchy: p, + relayUsed: make(map[uint32]struct{}), + relayUsedLock: &sync.RWMutex{}, + metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil), } - nc := &connectionManager{ - hostMap: intf.hostMap, - in: make(map[uint32]struct{}), - inLock: &sync.RWMutex{}, - out: make(map[uint32]struct{}), - outLock: &sync.RWMutex{}, - relayUsed: make(map[uint32]struct{}), - relayUsedLock: &sync.RWMutex{}, - trafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, max), - intf: intf, - pendingDeletion: make(map[uint32]struct{}), - checkInterval: checkInterval, - pendingDeletionInterval: pendingDeletionInterval, - punchy: punchy, - metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil), - l: l, - } + cm.reload(c, true) + c.RegisterReloadCallback(func(c *config.C) { + cm.reload(c, false) + }) - nc.Start(ctx) - return nc + return cm } -func (n *connectionManager) In(localIndex uint32) { - n.inLock.RLock() - // If this already exists, return - if _, ok := n.in[localIndex]; ok { - n.inLock.RUnlock() - return +func (cm *connectionManager) reload(c *config.C, initial bool) { + if initial { + cm.checkInterval = time.Duration(c.GetInt("timers.connection_alive_interval", 5)) * time.Second + cm.pendingDeletionInterval = time.Duration(c.GetInt("timers.pending_deletion_interval", 10)) * time.Second + + // We want at least a minimum resolution of 500ms per tick so that we can hit these intervals + // pretty close to their configured duration. + // The inactivity duration is checked each time a hostinfo ticks through so we don't need the wheel to contain it. + minDuration := min(time.Millisecond*500, cm.checkInterval, cm.pendingDeletionInterval) + maxDuration := max(cm.checkInterval, cm.pendingDeletionInterval) + cm.trafficTimer = NewLockingTimerWheel[uint32](minDuration, maxDuration) + } + + if initial || c.HasChanged("tunnels.inactivity_timeout") { + old := cm.getInactivityTimeout() + cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute))) + if !initial { + cm.l.WithField("oldDuration", old). + WithField("newDuration", cm.getInactivityTimeout()). + Info("Inactivity timeout has changed") + } + } + + if initial || c.HasChanged("tunnels.drop_inactive") { + old := cm.dropInactive.Load() + cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false)) + if !initial { + cm.l.WithField("oldBool", old). + WithField("newBool", cm.dropInactive.Load()). + Info("Drop inactive setting has changed") + } } - n.inLock.RUnlock() - n.inLock.Lock() - n.in[localIndex] = struct{}{} - n.inLock.Unlock() } -func (n *connectionManager) Out(localIndex uint32) { - n.outLock.RLock() - // If this already exists, return - if _, ok := n.out[localIndex]; ok { - n.outLock.RUnlock() - return - } - n.outLock.RUnlock() - n.outLock.Lock() - n.out[localIndex] = struct{}{} - n.outLock.Unlock() +func (cm *connectionManager) getInactivityTimeout() time.Duration { + return (time.Duration)(cm.inactivityTimeout.Load()) } -func (n *connectionManager) RelayUsed(localIndex uint32) { - n.relayUsedLock.RLock() +func (cm *connectionManager) In(h *HostInfo) { + h.in.Store(true) +} + +func (cm *connectionManager) Out(h *HostInfo) { + h.out.Store(true) +} + +func (cm *connectionManager) RelayUsed(localIndex uint32) { + cm.relayUsedLock.RLock() // If this already exists, return - if _, ok := n.relayUsed[localIndex]; ok { - n.relayUsedLock.RUnlock() + if _, ok := cm.relayUsed[localIndex]; ok { + cm.relayUsedLock.RUnlock() return } - n.relayUsedLock.RUnlock() - n.relayUsedLock.Lock() - n.relayUsed[localIndex] = struct{}{} - n.relayUsedLock.Unlock() + cm.relayUsedLock.RUnlock() + cm.relayUsedLock.Lock() + cm.relayUsed[localIndex] = struct{}{} + cm.relayUsedLock.Unlock() } // getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and // resets the state for this local index -func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) { - n.inLock.Lock() - n.outLock.Lock() - _, in := n.in[localIndex] - _, out := n.out[localIndex] - delete(n.in, localIndex) - delete(n.out, localIndex) - n.inLock.Unlock() - n.outLock.Unlock() +func (cm *connectionManager) getAndResetTrafficCheck(h *HostInfo, now time.Time) (bool, bool) { + in := h.in.Swap(false) + out := h.out.Swap(false) + if in || out { + h.lastUsed = now + } return in, out } -func (n *connectionManager) AddTrafficWatch(localIndex uint32) { - // Use a write lock directly because it should be incredibly rare that we are ever already tracking this index - n.outLock.Lock() - if _, ok := n.out[localIndex]; ok { - n.outLock.Unlock() - return +// AddTrafficWatch must be called for every new HostInfo. +// We will continue to monitor the HostInfo until the tunnel is dropped. +func (cm *connectionManager) AddTrafficWatch(h *HostInfo) { + if h.out.Swap(true) == false { + cm.trafficTimer.Add(h.localIndexId, cm.checkInterval) } - n.out[localIndex] = struct{}{} - n.trafficTimer.Add(localIndex, n.checkInterval) - n.outLock.Unlock() } -func (n *connectionManager) Start(ctx context.Context) { - go n.Run(ctx) -} - -func (n *connectionManager) Run(ctx context.Context) { - //TODO: this tick should be based on the min wheel tick? Check firewall - clockSource := time.NewTicker(500 * time.Millisecond) +func (cm *connectionManager) Start(ctx context.Context) { + clockSource := time.NewTicker(cm.trafficTimer.t.tickDuration) defer clockSource.Stop() p := []byte("") @@ -164,61 +160,61 @@ func (n *connectionManager) Run(ctx context.Context) { return case now := <-clockSource.C: - n.trafficTimer.Advance(now) + cm.trafficTimer.Advance(now) for { - localIndex, has := n.trafficTimer.Purge() + localIndex, has := cm.trafficTimer.Purge() if !has { break } - n.doTrafficCheck(localIndex, p, nb, out, now) + cm.doTrafficCheck(localIndex, p, nb, out, now) } } } } -func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) { - decision, hostinfo, primary := n.makeTrafficDecision(localIndex, now) +func (cm *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) { + decision, hostinfo, primary := cm.makeTrafficDecision(localIndex, now) switch decision { case deleteTunnel: - if n.hostMap.DeleteHostInfo(hostinfo) { + if cm.hostMap.DeleteHostInfo(hostinfo) { // Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap - n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs) + cm.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs) } case closeTunnel: - n.intf.sendCloseTunnel(hostinfo) - n.intf.closeTunnel(hostinfo) + cm.intf.sendCloseTunnel(hostinfo) + cm.intf.closeTunnel(hostinfo) case swapPrimary: - n.swapPrimary(hostinfo, primary) + cm.swapPrimary(hostinfo, primary) case migrateRelays: - n.migrateRelayUsed(hostinfo, primary) + cm.migrateRelayUsed(hostinfo, primary) case tryRehandshake: - n.tryRehandshake(hostinfo) + cm.tryRehandshake(hostinfo) case sendTestPacket: - n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out) + cm.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out) } - n.resetRelayTrafficCheck(hostinfo) + cm.resetRelayTrafficCheck(hostinfo) } -func (n *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) { +func (cm *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) { if hostinfo != nil { - n.relayUsedLock.Lock() - defer n.relayUsedLock.Unlock() + cm.relayUsedLock.Lock() + defer cm.relayUsedLock.Unlock() // No need to migrate any relays, delete usage info now. for _, idx := range hostinfo.relayState.CopyRelayForIdxs() { - delete(n.relayUsed, idx) + delete(cm.relayUsed, idx) } } } -func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) { +func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) { relayFor := oldhostinfo.relayState.CopyAllRelayFor() for _, r := range relayFor { @@ -238,7 +234,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) index = existing.LocalIndex switch r.Type { case TerminalType: - relayFrom = n.intf.myVpnAddrs[0] + relayFrom = cm.intf.myVpnAddrs[0] relayTo = existing.PeerAddr case ForwardingType: relayFrom = existing.PeerAddr @@ -249,23 +245,23 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) } } case !ok: - n.relayUsedLock.RLock() - if _, relayUsed := n.relayUsed[r.LocalIndex]; !relayUsed { + cm.relayUsedLock.RLock() + if _, relayUsed := cm.relayUsed[r.LocalIndex]; !relayUsed { // The relay hasn't been used; don't migrate it. - n.relayUsedLock.RUnlock() + cm.relayUsedLock.RUnlock() continue } - n.relayUsedLock.RUnlock() + cm.relayUsedLock.RUnlock() // The relay doesn't exist at all; create some relay state and send the request. var err error - index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested) + index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested) if err != nil { - n.l.WithError(err).Error("failed to migrate relay to new hostinfo") + cm.l.WithError(err).Error("failed to migrate relay to new hostinfo") continue } switch r.Type { case TerminalType: - relayFrom = n.intf.myVpnAddrs[0] + relayFrom = cm.intf.myVpnAddrs[0] relayTo = r.PeerAddr case ForwardingType: relayFrom = r.PeerAddr @@ -285,12 +281,12 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) switch newhostinfo.GetCert().Certificate.Version() { case cert.Version1: if !relayFrom.Is4() { - n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version") + cm.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version") continue } if !relayTo.Is4() { - n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version") + cm.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version") continue } @@ -302,16 +298,16 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) req.RelayFromAddr = netAddrToProtoAddr(relayFrom) req.RelayToAddr = netAddrToProtoAddr(relayTo) default: - newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay") + newhostinfo.logger(cm.l).Error("Unknown certificate version found while attempting to migrate relay") continue } msg, err := req.Marshal() if err != nil { - n.l.WithError(err).Error("failed to marshal Control message to migrate relay") + cm.l.WithError(err).Error("failed to marshal Control message to migrate relay") } else { - n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) - n.l.WithFields(logrus.Fields{ + cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) + cm.l.WithFields(logrus.Fields{ "relayFrom": req.RelayFromAddr, "relayTo": req.RelayToAddr, "initiatorRelayIndex": req.InitiatorRelayIndex, @@ -322,46 +318,45 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) } } -func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) { - n.hostMap.RLock() - defer n.hostMap.RUnlock() +func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) { + // Read lock the main hostmap to order decisions based on tunnels being the primary tunnel + cm.hostMap.RLock() + defer cm.hostMap.RUnlock() - hostinfo := n.hostMap.Indexes[localIndex] + hostinfo := cm.hostMap.Indexes[localIndex] if hostinfo == nil { - n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap") - delete(n.pendingDeletion, localIndex) + cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap") return doNothing, nil, nil } - if n.isInvalidCertificate(now, hostinfo) { - delete(n.pendingDeletion, hostinfo.localIndexId) + if cm.isInvalidCertificate(now, hostinfo) { return closeTunnel, hostinfo, nil } - primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]] + primary := cm.hostMap.Hosts[hostinfo.vpnAddrs[0]] mainHostInfo := true if primary != nil && primary != hostinfo { mainHostInfo = false } // Check for traffic on this hostinfo - inTraffic, outTraffic := n.getAndResetTrafficCheck(localIndex) + inTraffic, outTraffic := cm.getAndResetTrafficCheck(hostinfo, now) // A hostinfo is determined alive if there is incoming traffic if inTraffic { decision := doNothing - if n.l.Level >= logrus.DebugLevel { - hostinfo.logger(n.l). + if cm.l.Level >= logrus.DebugLevel { + hostinfo.logger(cm.l). WithField("tunnelCheck", m{"state": "alive", "method": "passive"}). Debug("Tunnel status") } - delete(n.pendingDeletion, hostinfo.localIndexId) + hostinfo.pendingDeletion.Store(false) if mainHostInfo { decision = tryRehandshake } else { - if n.shouldSwapPrimary(hostinfo, primary) { + if cm.shouldSwapPrimary(hostinfo, primary) { decision = swapPrimary } else { // migrate the relays to the primary, if in use. @@ -369,46 +364,55 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time } } - n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval) + cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval) if !outTraffic { // Send a punch packet to keep the NAT state alive - n.sendPunch(hostinfo) + cm.sendPunch(hostinfo) } return decision, hostinfo, primary } - if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok { + if hostinfo.pendingDeletion.Load() { // We have already sent a test packet and nothing was returned, this hostinfo is dead - hostinfo.logger(n.l). + hostinfo.logger(cm.l). WithField("tunnelCheck", m{"state": "dead", "method": "active"}). Info("Tunnel status") - delete(n.pendingDeletion, hostinfo.localIndexId) return deleteTunnel, hostinfo, nil } decision := doNothing if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo { if !outTraffic { + inactiveFor, isInactive := cm.isInactive(hostinfo, now) + if isInactive { + // Tunnel is inactive, tear it down + hostinfo.logger(cm.l). + WithField("inactiveDuration", inactiveFor). + WithField("primary", mainHostInfo). + Info("Dropping tunnel due to inactivity") + + return closeTunnel, hostinfo, primary + } + // If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel. // Just maintain NAT state if configured to do so. - n.sendPunch(hostinfo) - n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval) + cm.sendPunch(hostinfo) + cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval) return doNothing, nil, nil - } - if n.punchy.GetTargetEverything() { + if cm.punchy.GetTargetEverything() { // This is similar to the old punchy behavior with a slight optimization. // We aren't receiving traffic but we are sending it, punch on all known // ips in case we need to re-prime NAT state - n.sendPunch(hostinfo) + cm.sendPunch(hostinfo) } - if n.l.Level >= logrus.DebugLevel { - hostinfo.logger(n.l). + if cm.l.Level >= logrus.DebugLevel { + hostinfo.logger(cm.l). WithField("tunnelCheck", m{"state": "testing", "method": "active"}). Debug("Tunnel status") } @@ -417,17 +421,33 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time decision = sendTestPacket } else { - if n.l.Level >= logrus.DebugLevel { - hostinfo.logger(n.l).Debugf("Hostinfo sadness") + if cm.l.Level >= logrus.DebugLevel { + hostinfo.logger(cm.l).Debugf("Hostinfo sadness") } } - n.pendingDeletion[hostinfo.localIndexId] = struct{}{} - n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval) + hostinfo.pendingDeletion.Store(true) + cm.trafficTimer.Add(hostinfo.localIndexId, cm.pendingDeletionInterval) return decision, hostinfo, nil } -func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { +func (cm *connectionManager) isInactive(hostinfo *HostInfo, now time.Time) (time.Duration, bool) { + if cm.dropInactive.Load() == false { + // We aren't configured to drop inactive tunnels + return 0, false + } + + inactiveDuration := now.Sub(hostinfo.lastUsed) + if inactiveDuration < cm.getInactivityTimeout() { + // It's not considered inactive + return inactiveDuration, false + } + + // The tunnel is inactive + return inactiveDuration, true +} + +func (cm *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { // The primary tunnel is the most recent handshake to complete locally and should work entirely fine. // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary. // Let's sort this out. @@ -435,73 +455,80 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { // Only one side should swap because if both swap then we may never resolve to a single tunnel. // vpn addr is static across all tunnels for this host pair so lets // use that to determine if we should consider swapping. - if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 { + if current.vpnAddrs[0].Compare(cm.intf.myVpnAddrs[0]) < 0 { // Their primary vpn addr is less than mine. Do not swap. return false } - crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version()) + crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version()) // If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things // settle down. return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature()) } -func (n *connectionManager) swapPrimary(current, primary *HostInfo) { - n.hostMap.Lock() +func (cm *connectionManager) swapPrimary(current, primary *HostInfo) { + cm.hostMap.Lock() // Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake. - if n.hostMap.Hosts[current.vpnAddrs[0]] == primary { - n.hostMap.unlockedMakePrimary(current) + if cm.hostMap.Hosts[current.vpnAddrs[0]] == primary { + cm.hostMap.unlockedMakePrimary(current) } - n.hostMap.Unlock() + cm.hostMap.Unlock() } // isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and // the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid // check and return true. -func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool { +func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool { remoteCert := hostinfo.GetCert() if remoteCert == nil { return false } - caPool := n.intf.pki.GetCAPool() + caPool := cm.intf.pki.GetCAPool() err := caPool.VerifyCachedCertificate(now, remoteCert) if err == nil { return false } - if !n.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed { + if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed { // Block listed certificates should always be disconnected return false } - hostinfo.logger(n.l).WithError(err). + hostinfo.logger(cm.l).WithError(err). WithField("fingerprint", remoteCert.Fingerprint). Info("Remote certificate is no longer valid, tearing down the tunnel") return true } -func (n *connectionManager) sendPunch(hostinfo *HostInfo) { - if !n.punchy.GetPunch() { +func (cm *connectionManager) sendPunch(hostinfo *HostInfo) { + if !cm.punchy.GetPunch() { // Punching is disabled return } - if n.punchy.GetTargetEverything() { - hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) { - n.metricsTxPunchy.Inc(1) - n.intf.outside.WriteTo([]byte{1}, addr) + if cm.intf.lightHouse.IsAnyLighthouseAddr(hostinfo.vpnAddrs) { + // Do not punch to lighthouses, we assume our lighthouse update interval is good enough. + // In the event the update interval is not sufficient to maintain NAT state then a publicly available lighthouse + // would lose the ability to notify us and punchy.respond would become unreliable. + return + } + + if cm.punchy.GetTargetEverything() { + hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) { + cm.metricsTxPunchy.Inc(1) + cm.intf.outside.WriteTo([]byte{1}, addr) }) } else if hostinfo.remote.IsValid() { - n.metricsTxPunchy.Inc(1) - n.intf.outside.WriteTo([]byte{1}, hostinfo.remote) + cm.metricsTxPunchy.Inc(1) + cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote) } } -func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { - cs := n.intf.pki.getCertState() +func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) { + cs := cm.intf.pki.getCertState() curCrt := hostinfo.ConnectionState.myCert myCrt := cs.getCertificate(curCrt.Version()) if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true { @@ -509,9 +536,9 @@ func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { return } - n.l.WithField("vpnAddrs", hostinfo.vpnAddrs). + cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("reason", "local certificate is not current"). Info("Re-handshaking with remote") - n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) + cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) } diff --git a/connection_manager_test.go b/connection_manager_test.go index d1c5ba3..ecd2880 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -1,7 +1,6 @@ package nebula import ( - "context" "crypto/ed25519" "crypto/rand" "net/netip" @@ -64,10 +63,10 @@ func Test_NewConnectionManagerTest(t *testing.T) { ifce.pki.cs.Store(cs) // Create manager - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - punchy := NewPunchyFromConfig(l, config.NewC(l)) - nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) + conf := config.NewC(l) + punchy := NewPunchyFromConfig(l, conf) + nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) + nc.intf = ifce p := []byte("") nb := make([]byte, 12, 12) out := make([]byte, mtu) @@ -85,32 +84,33 @@ func Test_NewConnectionManagerTest(t *testing.T) { nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) // We saw traffic out to vpnIp - nc.Out(hostinfo.localIndexId) - nc.In(hostinfo.localIndexId) - assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) + nc.Out(hostinfo) + nc.In(hostinfo) + assert.False(t, hostinfo.pendingDeletion.Load()) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) - assert.Contains(t, nc.out, hostinfo.localIndexId) + assert.True(t, hostinfo.out.Load()) + assert.True(t, hostinfo.in.Load()) // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) - assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.NotContains(t, nc.out, hostinfo.localIndexId) - assert.NotContains(t, nc.in, hostinfo.localIndexId) + assert.False(t, hostinfo.pendingDeletion.Load()) + assert.False(t, hostinfo.out.Load()) + assert.False(t, hostinfo.in.Load()) // Do another traffic check tick, this host should be pending deletion now - nc.Out(hostinfo.localIndexId) + nc.Out(hostinfo) + assert.True(t, hostinfo.out.Load()) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) - assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.NotContains(t, nc.out, hostinfo.localIndexId) - assert.NotContains(t, nc.in, hostinfo.localIndexId) + assert.True(t, hostinfo.pendingDeletion.Load()) + assert.False(t, hostinfo.out.Load()) + assert.False(t, hostinfo.in.Load()) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) // Do a final traffic check tick, the host should now be removed nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) - assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) + assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs) assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId) } @@ -146,10 +146,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) { ifce.pki.cs.Store(cs) // Create manager - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - punchy := NewPunchyFromConfig(l, config.NewC(l)) - nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) + conf := config.NewC(l) + punchy := NewPunchyFromConfig(l, conf) + nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) + nc.intf = ifce p := []byte("") nb := make([]byte, 12, 12) out := make([]byte, mtu) @@ -167,33 +167,129 @@ func Test_NewConnectionManagerTest2(t *testing.T) { nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) // We saw traffic out to vpnIp - nc.Out(hostinfo.localIndexId) - nc.In(hostinfo.localIndexId) - assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0]) + nc.Out(hostinfo) + nc.In(hostinfo) + assert.True(t, hostinfo.in.Load()) + assert.True(t, hostinfo.out.Load()) + assert.False(t, hostinfo.pendingDeletion.Load()) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) - assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.NotContains(t, nc.out, hostinfo.localIndexId) - assert.NotContains(t, nc.in, hostinfo.localIndexId) + assert.False(t, hostinfo.pendingDeletion.Load()) + assert.False(t, hostinfo.out.Load()) + assert.False(t, hostinfo.in.Load()) // Do another traffic check tick, this host should be pending deletion now - nc.Out(hostinfo.localIndexId) + nc.Out(hostinfo) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) - assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.NotContains(t, nc.out, hostinfo.localIndexId) - assert.NotContains(t, nc.in, hostinfo.localIndexId) + assert.True(t, hostinfo.pendingDeletion.Load()) + assert.False(t, hostinfo.out.Load()) + assert.False(t, hostinfo.in.Load()) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) // We saw traffic, should no longer be pending deletion - nc.In(hostinfo.localIndexId) + nc.In(hostinfo) nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) - assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.NotContains(t, nc.out, hostinfo.localIndexId) - assert.NotContains(t, nc.in, hostinfo.localIndexId) + assert.False(t, hostinfo.pendingDeletion.Load()) + assert.False(t, hostinfo.out.Load()) + assert.False(t, hostinfo.in.Load()) + assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) +} + +func Test_NewConnectionManager_DisconnectInactive(t *testing.T) { + l := test.NewLogger() + localrange := netip.MustParsePrefix("10.1.1.1/24") + vpnAddrs := []netip.Addr{netip.MustParseAddr("172.1.1.2")} + preferredRanges := []netip.Prefix{localrange} + + // Very incomplete mock objects + hostMap := newHostMap(l) + hostMap.preferredRanges.Store(&preferredRanges) + + cs := &CertState{ + initiatingVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1}, + v1HandshakeBytes: []byte{}, + } + + lh := newTestLighthouse() + ifce := &Interface{ + hostMap: hostMap, + inside: &test.NoopTun{}, + outside: &udp.NoopConn{}, + firewall: &Firewall{}, + lightHouse: lh, + pki: &PKI{}, + handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), + l: l, + } + ifce.pki.cs.Store(cs) + + // Create manager + conf := config.NewC(l) + conf.Settings["tunnels"] = map[string]any{ + "drop_inactive": true, + } + punchy := NewPunchyFromConfig(l, conf) + nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) + assert.True(t, nc.dropInactive.Load()) + nc.intf = ifce + + // Add an ip we have established a connection w/ to hostmap + hostinfo := &HostInfo{ + vpnAddrs: vpnAddrs, + localIndexId: 1099, + remoteIndexId: 9901, + } + hostinfo.ConnectionState = &ConnectionState{ + myCert: &dummyCert{version: cert.Version1}, + H: &noise.HandshakeState{}, + } + nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) + + // Do a traffic check tick, in and out should be cleared but should not be pending deletion + nc.Out(hostinfo) + nc.In(hostinfo) + assert.True(t, hostinfo.out.Load()) + assert.True(t, hostinfo.in.Load()) + + now := time.Now() + decision, _, _ := nc.makeTrafficDecision(hostinfo.localIndexId, now) + assert.Equal(t, tryRehandshake, decision) + assert.Equal(t, now, hostinfo.lastUsed) + assert.False(t, hostinfo.pendingDeletion.Load()) + assert.False(t, hostinfo.out.Load()) + assert.False(t, hostinfo.in.Load()) + + decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*5)) + assert.Equal(t, doNothing, decision) + assert.Equal(t, now, hostinfo.lastUsed) + assert.False(t, hostinfo.pendingDeletion.Load()) + assert.False(t, hostinfo.out.Load()) + assert.False(t, hostinfo.in.Load()) + + // Do another traffic check tick, should still not be pending deletion + decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*10)) + assert.Equal(t, doNothing, decision) + assert.Equal(t, now, hostinfo.lastUsed) + assert.False(t, hostinfo.pendingDeletion.Load()) + assert.False(t, hostinfo.out.Load()) + assert.False(t, hostinfo.in.Load()) + assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) + + // Finally advance beyond the inactivity timeout + decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Minute*10)) + assert.Equal(t, closeTunnel, decision) + assert.Equal(t, now, hostinfo.lastUsed) + assert.False(t, hostinfo.pendingDeletion.Load()) + assert.False(t, hostinfo.out.Load()) + assert.False(t, hostinfo.in.Load()) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) } @@ -264,10 +360,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { ifce.disconnectInvalid.Store(true) // Create manager - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - punchy := NewPunchyFromConfig(l, config.NewC(l)) - nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy) + conf := config.NewC(l) + punchy := NewPunchyFromConfig(l, conf) + nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy) + nc.intf = ifce ifce.connectionManager = nc hostinfo := &HostInfo{ diff --git a/control.go b/control.go index 016a79b..f8567b5 100644 --- a/control.go +++ b/control.go @@ -26,14 +26,15 @@ type controlHostLister interface { } type Control struct { - f *Interface - l *logrus.Logger - ctx context.Context - cancel context.CancelFunc - sshStart func() - statsStart func() - dnsStart func() - lighthouseStart func() + f *Interface + l *logrus.Logger + ctx context.Context + cancel context.CancelFunc + sshStart func() + statsStart func() + dnsStart func() + lighthouseStart func() + connectionManagerStart func(context.Context) } type ControlHostInfo struct { @@ -63,6 +64,9 @@ func (c *Control) Start() { if c.dnsStart != nil { go c.dnsStart() } + if c.connectionManagerStart != nil { + go c.connectionManagerStart(c.ctx) + } if c.lighthouseStart != nil { c.lighthouseStart() } diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index bc080ce..53d3738 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -506,7 +506,7 @@ func TestReestablishRelays(t *testing.T) { curIndexes := len(myControl.GetHostmap().Indexes) for curIndexes >= start { curIndexes = len(myControl.GetHostmap().Indexes) - r.Logf("Wait for the dead index to go away:start=%v indexes, currnet=%v indexes", start, curIndexes) + r.Logf("Wait for the dead index to go away:start=%v indexes, current=%v indexes", start, curIndexes) myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail")) r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { @@ -1052,6 +1052,9 @@ func TestRehandshakingLoser(t *testing.T) { t.Log("Stand up a tunnel between me and them") assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) + theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) + r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew their certificate and spin until mine sees it") diff --git a/e2e/router/router.go b/e2e/router/router.go index 5e52ed7..c8264ab 100644 --- a/e2e/router/router.go +++ b/e2e/router/router.go @@ -700,6 +700,7 @@ func (r *R) FlushAll() { r.Unlock() panic("Can't FlushAll for host: " + p.To.String()) } + receiver.InjectUDPPacket(p) r.Unlock() } } diff --git a/e2e/tunnels_test.go b/e2e/tunnels_test.go new file mode 100644 index 0000000..55974f0 --- /dev/null +++ b/e2e/tunnels_test.go @@ -0,0 +1,57 @@ +//go:build e2e_testing +// +build e2e_testing + +package e2e + +import ( + "testing" + "time" + + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/e2e/router" +) + +func TestDropInactiveTunnels(t *testing.T) { + // The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides + // under ideal conditions + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "5s"}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "10m"}}) + + // Share our underlay information + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) + + // Start the servers + myControl.Start() + theirControl.Start() + + r := router.NewR(t, myControl, theirControl) + + r.Log("Assert the tunnel between me and them works") + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + r.Log("Go inactive and wait for the tunnels to get dropped") + waitStart := time.Now() + for { + myIndexes := len(myControl.GetHostmap().Indexes) + theirIndexes := len(theirControl.GetHostmap().Indexes) + if myIndexes == 0 && theirIndexes == 0 { + break + } + + since := time.Since(waitStart) + r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since) + if since > time.Second*30 { + t.Fatal("Tunnel should have been declared inactive after 5 seconds and before 30 seconds") + } + + time.Sleep(1 * time.Second) + r.FlushAll() + } + + r.Logf("Inactive tunnels were dropped within %v", time.Since(waitStart)) + myControl.Stop() + theirControl.Stop() +} diff --git a/examples/config.yml b/examples/config.yml index eec4f1c..42c32c8 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -338,6 +338,18 @@ logging: # after receiving the response for lighthouse queries #trigger_buffer: 64 +# Tunnel manager settings +#tunnels: + # drop_inactive controls whether inactive tunnels are maintained or dropped after the inactive_timeout period has + # elapsed. + # In general, it is a good idea to enable this setting. It will be enabled by default in a future release. + # This setting is reloadable + #drop_inactive: false + + # inactivity_timeout controls how long a tunnel MUST NOT see any inbound or outbound traffic before being considered + # inactive and eligible to be dropped. + # This setting is reloadable + #inactivity_timeout: 10m # Nebula security group configuration firewall: diff --git a/handshake_ix.go b/handshake_ix.go index 0548a23..d53e5a7 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -457,7 +457,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet Info("Handshake message sent") } - f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) + f.connectionManager.AddTrafficWatch(hostinfo) hostinfo.remotes.ResetBlockedRemotes() @@ -652,7 +652,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha // Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here f.handshakeManager.Complete(hostinfo, f) - f.connectionManager.AddTrafficWatch(hostinfo.localIndexId) + f.connectionManager.AddTrafficWatch(hostinfo) if f.l.Level >= logrus.DebugLevel { hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore)) diff --git a/hostmap.go b/hostmap.go index 7b9b8b9..7e3b1bd 100644 --- a/hostmap.go +++ b/hostmap.go @@ -256,6 +256,14 @@ type HostInfo struct { // Used to track other hostinfos for this vpn ip since only 1 can be primary // Synchronised via hostmap lock and not the hostinfo lock. next, prev *HostInfo + + //TODO: in, out, and others might benefit from being an atomic.Int32. We could collapse connectionManager pendingDeletion, relayUsed, and in/out into this 1 thing + in, out, pendingDeletion atomic.Bool + + // lastUsed tracks the last time ConnectionManager checked the tunnel and it was in use. + // This value will be behind against actual tunnel utilization in the hot path. + // This should only be used by the ConnectionManagers ticker routine. + lastUsed time.Time } type ViaSender struct { diff --git a/inside.go b/inside.go index 239ea6a..d24ed31 100644 --- a/inside.go +++ b/inside.go @@ -288,7 +288,7 @@ func (f *Interface) SendVia(via *HostInfo, c := via.ConnectionState.messageCounter.Add(1) out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c) - f.connectionManager.Out(via.localIndexId) + f.connectionManager.Out(via) // Authenticate the header and payload, but do not encrypt for this message type. // The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload. @@ -356,7 +356,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType //l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p) out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c) - f.connectionManager.Out(hostinfo.localIndexId) + f.connectionManager.Out(hostinfo) // Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against // all our addrs and enable a faster roaming. diff --git a/interface.go b/interface.go index ddd0681..082906d 100644 --- a/interface.go +++ b/interface.go @@ -24,23 +24,23 @@ import ( const mtu = 9001 type InterfaceConfig struct { - HostMap *HostMap - Outside udp.Conn - Inside overlay.Device - pki *PKI - Firewall *Firewall - ServeDns bool - HandshakeManager *HandshakeManager - lightHouse *LightHouse - checkInterval time.Duration - pendingDeletionInterval time.Duration - DropLocalBroadcast bool - DropMulticast bool - routines int - MessageMetrics *MessageMetrics - version string - relayManager *relayManager - punchy *Punchy + HostMap *HostMap + Outside udp.Conn + Inside overlay.Device + pki *PKI + Cipher string + Firewall *Firewall + ServeDns bool + HandshakeManager *HandshakeManager + lightHouse *LightHouse + connectionManager *connectionManager + DropLocalBroadcast bool + DropMulticast bool + routines int + MessageMetrics *MessageMetrics + version string + relayManager *relayManager + punchy *Punchy tryPromoteEvery uint32 reQueryEvery uint32 @@ -157,6 +157,9 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { if c.Firewall == nil { return nil, errors.New("no firewall rules") } + if c.connectionManager == nil { + return nil, errors.New("no connection manager") + } cs := c.pki.getCertState() ifce := &Interface{ @@ -181,7 +184,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { myVpnAddrsTable: cs.myVpnAddrsTable, myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable, relayManager: c.relayManager, - + connectionManager: c.connectionManager, conntrackCacheTimeout: c.ConntrackCacheTimeout, metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)), @@ -198,7 +201,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryWait.Store(int64(c.reQueryWait)) - ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy) + ifce.connectionManager.intf = ifce return ifce, nil } diff --git a/main.go b/main.go index b278fa6..eb296fb 100644 --- a/main.go +++ b/main.go @@ -185,6 +185,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg hostMap := NewHostMapFromConfig(l, c) punchy := NewPunchyFromConfig(l, c) + connManager := newConnectionManagerFromConfig(l, c, hostMap, punchy) lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err) @@ -220,31 +221,26 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } } - checkInterval := c.GetInt("timers.connection_alive_interval", 5) - pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10) - ifConfig := &InterfaceConfig{ - HostMap: hostMap, - Inside: tun, - Outside: udpConns[0], - pki: pki, - Firewall: fw, - ServeDns: serveDns, - HandshakeManager: handshakeManager, - lightHouse: lightHouse, - checkInterval: time.Second * time.Duration(checkInterval), - pendingDeletionInterval: time.Second * time.Duration(pendingDeletionInterval), - tryPromoteEvery: c.GetUint32("counters.try_promote", defaultPromoteEvery), - reQueryEvery: c.GetUint32("counters.requery_every_packets", defaultReQueryEvery), - reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait), - DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false), - DropMulticast: c.GetBool("tun.drop_multicast", false), - routines: routines, - MessageMetrics: messageMetrics, - version: buildVersion, - relayManager: NewRelayManager(ctx, l, hostMap, c), - punchy: punchy, - + HostMap: hostMap, + Inside: tun, + Outside: udpConns[0], + pki: pki, + Firewall: fw, + ServeDns: serveDns, + HandshakeManager: handshakeManager, + connectionManager: connManager, + lightHouse: lightHouse, + tryPromoteEvery: c.GetUint32("counters.try_promote", defaultPromoteEvery), + reQueryEvery: c.GetUint32("counters.requery_every_packets", defaultReQueryEvery), + reQueryWait: c.GetDuration("timers.requery_wait_duration", defaultReQueryWait), + DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false), + DropMulticast: c.GetBool("tun.drop_multicast", false), + routines: routines, + MessageMetrics: messageMetrics, + version: buildVersion, + relayManager: NewRelayManager(ctx, l, hostMap, c), + punchy: punchy, ConntrackCacheTimeout: conntrackCacheTimeout, l: l, } @@ -296,5 +292,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg statsStart, dnsStart, lightHouse.StartUpdateWorker, + connManager.Start, }, nil } diff --git a/outside.go b/outside.go index 6d4127d..8720eef 100644 --- a/outside.go +++ b/outside.go @@ -81,7 +81,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] // Pull the Roaming parts up here, and return in all call paths. f.handleHostRoaming(hostinfo, ip) // Track usage of both the HostInfo and the Relay for the received & authenticated packet - f.connectionManager.In(hostinfo.localIndexId) + f.connectionManager.In(hostinfo) f.connectionManager.RelayUsed(h.RemoteIndex) relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex) @@ -213,7 +213,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] f.handleHostRoaming(hostinfo, ip) - f.connectionManager.In(hostinfo.localIndexId) + f.connectionManager.In(hostinfo) } // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote @@ -498,7 +498,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out return false } - f.connectionManager.In(hostinfo.localIndexId) + f.connectionManager.In(hostinfo) _, err = f.readers[q].Write(out) if err != nil { f.l.WithError(err).Error("Failed to write to tun") diff --git a/udp/udp_darwin.go b/udp/udp_darwin.go index 5e50d8b..c0c6233 100644 --- a/udp/udp_darwin.go +++ b/udp/udp_darwin.go @@ -3,8 +3,6 @@ package udp -// Darwin support is primarily implemented in udp_generic, besides NewListenConfig - import ( "context" "encoding/binary"