diff --git a/connection_manager.go b/connection_manager.go index 5c9b3a5..0e5287c 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -11,6 +11,7 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" ) @@ -26,6 +27,12 @@ const ( sendTestPacket trafficDecision = 6 ) +// LastCommunication tracks when we last communicated with a host +type LastCommunication struct { + timestamp time.Time + vpnIp netip.Addr // To help with logging +} + type connectionManager struct { in map[uint32]struct{} inLock *sync.RWMutex @@ -37,6 +44,12 @@ type connectionManager struct { relayUsed map[uint32]struct{} relayUsedLock *sync.RWMutex + // Track last communication with hosts + lastCommMap map[uint32]*LastCommunication + lastCommLock *sync.RWMutex + inactivityTimer *LockingTimerWheel[uint32] + inactivityTimeout time.Duration + hostMap *HostMap trafficTimer *LockingTimerWheel[uint32] intf *Interface @@ -65,6 +78,9 @@ func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface outLock: &sync.RWMutex{}, relayUsed: make(map[uint32]struct{}), relayUsedLock: &sync.RWMutex{}, + lastCommMap: make(map[uint32]*LastCommunication), + lastCommLock: &sync.RWMutex{}, + inactivityTimeout: 1 * time.Minute, // Default inactivity timeout: 10 minutes trafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, max), intf: intf, pendingDeletion: make(map[uint32]struct{}), @@ -75,10 +91,42 @@ func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface l: l, } + // Initialize the inactivity timer wheel - make wheel duration slightly longer than the timeout + nc.inactivityTimer = NewLockingTimerWheel[uint32](time.Minute, nc.inactivityTimeout+time.Minute) + nc.Start(ctx) return nc } +func (n *connectionManager) updateLastCommunication(localIndex uint32) { + // Get host info to record VPN IP for better logging + hostInfo := n.hostMap.QueryIndex(localIndex) + if hostInfo == nil { + return + } + + now := time.Now() + n.lastCommLock.Lock() + lastComm, exists := n.lastCommMap[localIndex] + if !exists { + // First time we've seen this host + lastComm = &LastCommunication{ + timestamp: now, + vpnIp: hostInfo.vpnIp, + } + n.lastCommMap[localIndex] = lastComm + } else { + // Update existing record + lastComm.timestamp = now + } + n.lastCommLock.Unlock() + + // Reset the inactivity timer for this host + n.inactivityTimer.m.Lock() + n.inactivityTimer.t.Add(localIndex, n.inactivityTimeout) + n.inactivityTimer.m.Unlock() +} + func (n *connectionManager) In(localIndex uint32) { n.inLock.RLock() // If this already exists, return @@ -90,6 +138,9 @@ func (n *connectionManager) In(localIndex uint32) { n.inLock.Lock() n.in[localIndex] = struct{}{} n.inLock.Unlock() + + // Update last communication time + n.updateLastCommunication(localIndex) } func (n *connectionManager) Out(localIndex uint32) { @@ -103,6 +154,9 @@ func (n *connectionManager) Out(localIndex uint32) { n.outLock.Lock() n.out[localIndex] = struct{}{} n.outLock.Unlock() + + // Update last communication time + n.updateLastCommunication(localIndex) } func (n *connectionManager) RelayUsed(localIndex uint32) { @@ -144,6 +198,134 @@ func (n *connectionManager) AddTrafficWatch(localIndex uint32) { n.outLock.Unlock() } +// checkInactiveTunnels checks for tunnels that have been inactive for too long and drops them +func (n *connectionManager) checkInactiveTunnels() { + now := time.Now() + + // First, advance the timer wheel to the current time + n.inactivityTimer.m.Lock() + n.inactivityTimer.t.Advance(now) + n.inactivityTimer.m.Unlock() + + // Check for expired timers (inactive connections) + for { + // Get the next expired tunnel + n.inactivityTimer.m.Lock() + localIndex, ok := n.inactivityTimer.t.Purge() + n.inactivityTimer.m.Unlock() + + if !ok { + // No more expired timers + break + } + + n.lastCommLock.RLock() + lastComm, exists := n.lastCommMap[localIndex] + n.lastCommLock.RUnlock() + + if !exists { + // No last communication record, odd but skip + continue + } + + // Calculate inactivity duration + inactiveDuration := now.Sub(lastComm.timestamp) + + // Check if we've exceeded the inactivity timeout + if inactiveDuration >= n.inactivityTimeout { + // Get the host info (if it still exists) + hostInfo := n.hostMap.QueryIndex(localIndex) + if hostInfo == nil { + // Host info is gone, remove from our tracking map + n.lastCommLock.Lock() + delete(n.lastCommMap, localIndex) + n.lastCommLock.Unlock() + continue + } + + // Log the inactivity and drop the tunnel + n.l.WithField("vpnIp", lastComm.vpnIp). + WithField("localIndex", localIndex). + WithField("inactiveDuration", inactiveDuration). + WithField("timeout", n.inactivityTimeout). + Info("Dropping tunnel due to inactivity") + + // Close the tunnel using the existing mechanism + n.intf.closeTunnel(hostInfo) + + // Clean up our tracking map + n.lastCommLock.Lock() + delete(n.lastCommMap, localIndex) + n.lastCommLock.Unlock() + } else { + // Re-add to the timer wheel with the remaining time + remainingTime := n.inactivityTimeout - inactiveDuration + n.inactivityTimer.m.Lock() + n.inactivityTimer.t.Add(localIndex, remainingTime) + n.inactivityTimer.m.Unlock() + } + } +} + +// CleanupDeletedHostInfos removes entries from our lastCommMap for hosts that no longer exist +func (n *connectionManager) CleanupDeletedHostInfos() { + n.lastCommLock.Lock() + defer n.lastCommLock.Unlock() + + // Find indexes to delete + var toDelete []uint32 + for localIndex := range n.lastCommMap { + if n.hostMap.QueryIndex(localIndex) == nil { + toDelete = append(toDelete, localIndex) + } + } + + // Delete them + for _, localIndex := range toDelete { + delete(n.lastCommMap, localIndex) + } + + if len(toDelete) > 0 && n.l.Level >= logrus.DebugLevel { + n.l.WithField("count", len(toDelete)).Debug("Cleaned up deleted host entries from lastCommMap") + } +} + +// ReloadConfig updates the connection manager configuration +func (n *connectionManager) ReloadConfig(c *config.C) { + // Get the inactivity timeout from config + inactivityTimeout := c.GetDuration("timers.inactivity_timeout", 10*time.Minute) + + // Only update if different + if inactivityTimeout != n.inactivityTimeout { + n.l.WithField("old", n.inactivityTimeout). + WithField("new", inactivityTimeout). + Info("Updating inactivity timeout") + + n.inactivityTimeout = inactivityTimeout + + // Recreate the inactivity timer wheel with the new timeout + n.inactivityTimer = NewLockingTimerWheel[uint32](time.Minute, n.inactivityTimeout+time.Minute) + + // Re-add all existing hosts to the new timer wheel + n.lastCommLock.RLock() + for localIndex, lastComm := range n.lastCommMap { + // Calculate remaining time based on last communication + now := time.Now() + elapsed := now.Sub(lastComm.timestamp) + + // If the elapsed time exceeds the new timeout, this will be caught + // in the next inactivity check. Otherwise, add with remaining time. + if elapsed < n.inactivityTimeout { + remainingTime := n.inactivityTimeout - elapsed + n.inactivityTimer.m.Lock() + n.inactivityTimer.t.Add(localIndex, remainingTime) + n.inactivityTimer.m.Unlock() + } + } + n.lastCommLock.RUnlock() + } +} + func (n *connectionManager) Start(ctx context.Context) { go n.Run(ctx) } @@ -153,6 +335,14 @@ func (n *connectionManager) Run(ctx context.Context) { clockSource := time.NewTicker(500 * time.Millisecond) defer clockSource.Stop() + // Create ticker for inactivity checks (every minute) + inactivityTicker := time.NewTicker(time.Minute) + defer inactivityTicker.Stop() + + // Create ticker for cleanup (every 5 minutes) + cleanupTicker := time.NewTicker(5 * time.Minute) + defer cleanupTicker.Stop() + p := []byte("") nb := make([]byte, 12, 12) out := make([]byte, mtu) @@ -172,6 +362,14 @@ func (n *connectionManager) Run(ctx context.Context) { n.doTrafficCheck(localIndex, p, nb, out, now) } + + case <-inactivityTicker.C: + // Check for inactive tunnels + n.checkInactiveTunnels() + + case <-cleanupTicker.C: + // Periodically clean up deleted hosts + n.CleanupDeletedHostInfos() } } } @@ -367,7 +565,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time if !outTraffic { // Send a punch packet to keep the NAT state alive - n.sendPunch(hostinfo) + //n.sendPunch(hostinfo) } return decision, hostinfo, primary @@ -388,7 +586,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time if !outTraffic { // If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel. // Just maintain NAT state if configured to do so. - n.sendPunch(hostinfo) + //n.sendPunch(hostinfo) n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval) return doNothing, nil, nil @@ -398,7 +596,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time // This is similar to the old punchy behavior with a slight optimization. // We aren't receiving traffic but we are sending it, punch on all known // ips in case we need to re-prime NAT state - n.sendPunch(hostinfo) + //n.sendPunch(hostinfo) } if n.l.Level >= logrus.DebugLevel {