Use connection manager to drive NAT maintenance (#835)

Co-authored-by: brad-defined <77982333+brad-defined@users.noreply.github.com>
This commit is contained in:
Nate Brown
2023-03-31 15:45:05 -05:00
committed by GitHub
parent 1a6c657451
commit ee8e1348e9
9 changed files with 233 additions and 333 deletions

View File

@@ -5,49 +5,55 @@ import (
"sync"
"time"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp"
)
// TODO: incount and outcount are intended as a shortcut to locking the mutexes for every single packet
// and something like every 10 packets we could lock, send 10, then unlock for a moment
type connectionManager struct {
hostMap *HostMap
in map[uint32]struct{}
inLock *sync.RWMutex
out map[uint32]struct{}
outLock *sync.RWMutex
TrafficTimer *LockingTimerWheel[uint32]
intf *Interface
in map[uint32]struct{}
inLock *sync.RWMutex
pendingDeletion map[uint32]int
pendingDeletionLock *sync.RWMutex
pendingDeletionTimer *LockingTimerWheel[uint32]
out map[uint32]struct{}
outLock *sync.RWMutex
checkInterval int
pendingDeletionInterval int
hostMap *HostMap
trafficTimer *LockingTimerWheel[uint32]
intf *Interface
pendingDeletion map[uint32]struct{}
punchy *Punchy
checkInterval time.Duration
pendingDeletionInterval time.Duration
metricsTxPunchy metrics.Counter
l *logrus.Logger
// I wanted to call one matLock
}
func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval time.Duration, punchy *Punchy) *connectionManager {
var max time.Duration
if checkInterval < pendingDeletionInterval {
max = pendingDeletionInterval
} else {
max = checkInterval
}
nc := &connectionManager{
hostMap: intf.hostMap,
in: make(map[uint32]struct{}),
inLock: &sync.RWMutex{},
out: make(map[uint32]struct{}),
outLock: &sync.RWMutex{},
TrafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, time.Second*60),
trafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, max),
intf: intf,
pendingDeletion: make(map[uint32]int),
pendingDeletionLock: &sync.RWMutex{},
pendingDeletionTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, time.Second*60),
pendingDeletion: make(map[uint32]struct{}),
checkInterval: checkInterval,
pendingDeletionInterval: pendingDeletionInterval,
punchy: punchy,
metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
l: l,
}
nc.Start(ctx)
return nc
}
@@ -74,65 +80,27 @@ func (n *connectionManager) Out(localIndex uint32) {
}
n.outLock.RUnlock()
n.outLock.Lock()
// double check since we dropped the lock temporarily
if _, ok := n.out[localIndex]; ok {
n.outLock.Unlock()
return
}
n.out[localIndex] = struct{}{}
n.AddTrafficWatch(localIndex, n.checkInterval)
n.outLock.Unlock()
}
func (n *connectionManager) CheckIn(localIndex uint32) bool {
n.inLock.RLock()
if _, ok := n.in[localIndex]; ok {
n.inLock.RUnlock()
return true
}
n.inLock.RUnlock()
return false
}
func (n *connectionManager) ClearLocalIndex(localIndex uint32) {
// getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
// resets the state for this local index
func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) {
n.inLock.Lock()
n.outLock.Lock()
_, in := n.in[localIndex]
_, out := n.out[localIndex]
delete(n.in, localIndex)
delete(n.out, localIndex)
n.inLock.Unlock()
n.outLock.Unlock()
return in, out
}
func (n *connectionManager) ClearPendingDeletion(localIndex uint32) {
n.pendingDeletionLock.Lock()
delete(n.pendingDeletion, localIndex)
n.pendingDeletionLock.Unlock()
}
func (n *connectionManager) AddPendingDeletion(localIndex uint32) {
n.pendingDeletionLock.Lock()
if _, ok := n.pendingDeletion[localIndex]; ok {
n.pendingDeletion[localIndex] += 1
} else {
n.pendingDeletion[localIndex] = 0
}
n.pendingDeletionTimer.Add(localIndex, time.Second*time.Duration(n.pendingDeletionInterval))
n.pendingDeletionLock.Unlock()
}
func (n *connectionManager) checkPendingDeletion(localIndex uint32) bool {
n.pendingDeletionLock.RLock()
if _, ok := n.pendingDeletion[localIndex]; ok {
n.pendingDeletionLock.RUnlock()
return true
}
n.pendingDeletionLock.RUnlock()
return false
}
func (n *connectionManager) AddTrafficWatch(localIndex uint32, seconds int) {
n.TrafficTimer.Add(localIndex, time.Second*time.Duration(seconds))
func (n *connectionManager) AddTrafficWatch(localIndex uint32) {
n.Out(localIndex)
n.trafficTimer.Add(localIndex, n.checkInterval)
}
func (n *connectionManager) Start(ctx context.Context) {
@@ -140,6 +108,7 @@ func (n *connectionManager) Start(ctx context.Context) {
}
func (n *connectionManager) Run(ctx context.Context) {
//TODO: this tick should be based on the min wheel tick? Check firewall
clockSource := time.NewTicker(500 * time.Millisecond)
defer clockSource.Stop()
@@ -151,151 +120,106 @@ func (n *connectionManager) Run(ctx context.Context) {
select {
case <-ctx.Done():
return
case now := <-clockSource.C:
n.HandleMonitorTick(now, p, nb, out)
n.HandleDeletionTick(now)
}
}
}
func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) {
n.TrafficTimer.Advance(now)
for {
localIndex, has := n.TrafficTimer.Purge()
if !has {
break
}
// Check for traffic coming back in from this host.
traf := n.CheckIn(localIndex)
hostinfo, err := n.hostMap.QueryIndex(localIndex)
if err != nil {
n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
n.ClearLocalIndex(localIndex)
n.ClearPendingDeletion(localIndex)
continue
}
if n.handleInvalidCertificate(now, hostinfo) {
continue
}
// Does the vpnIp point to this hostinfo or is it ancillary? If we have ancillary hostinfos then we need to
// decide if this should be the main hostinfo if we are seeing traffic on it
primary, _ := n.hostMap.QueryVpnIp(hostinfo.vpnIp)
mainHostInfo := true
if primary != nil && primary != hostinfo {
mainHostInfo = false
}
// If we saw an incoming packets from this ip and peer's certificate is not
// expired, just ignore.
if traf {
if n.l.Level >= logrus.DebugLevel {
hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
Debug("Tunnel status")
}
n.ClearLocalIndex(localIndex)
n.ClearPendingDeletion(localIndex)
if !mainHostInfo {
if hostinfo.vpnIp > n.intf.myVpnIp {
// We are receiving traffic on the non primary hostinfo and we really just want 1 tunnel. Make
// This the primary and prime the old primary hostinfo for testing
n.hostMap.MakePrimary(hostinfo)
n.Out(primary.localIndexId)
} else {
// This hostinfo is still being used despite not being the primary hostinfo for this vpn ip
// Keep tracking so that we can tear it down when it goes away
n.Out(hostinfo.localIndexId)
n.trafficTimer.Advance(now)
for {
localIndex, has := n.trafficTimer.Purge()
if !has {
break
}
n.doTrafficCheck(localIndex, p, nb, out, now)
}
continue
}
if n.intf.lightHouse.IsLighthouseIP(hostinfo.vpnIp) {
// Don't probe lighthouses since recv_error should naturally catch this.
n.ClearLocalIndex(localIndex)
n.ClearPendingDeletion(localIndex)
continue
}
hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
Debug("Tunnel status")
if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
n.intf.sendMessageToVpnIp(header.Test, header.TestRequest, hostinfo, p, nb, out)
} else {
hostinfo.logger(n.l).Debugf("Hostinfo sadness")
}
n.AddPendingDeletion(localIndex)
}
}
func (n *connectionManager) HandleDeletionTick(now time.Time) {
n.pendingDeletionTimer.Advance(now)
for {
localIndex, has := n.pendingDeletionTimer.Purge()
if !has {
break
}
hostinfo, mainHostInfo, err := n.hostMap.QueryIndexIsPrimary(localIndex)
if err != nil {
n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
n.ClearLocalIndex(localIndex)
n.ClearPendingDeletion(localIndex)
continue
}
if n.handleInvalidCertificate(now, hostinfo) {
continue
}
// If we saw an incoming packets from this ip and peer's certificate is not
// expired, just ignore.
traf := n.CheckIn(localIndex)
if traf {
hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "alive", "method": "active"}).
Debug("Tunnel status")
n.ClearLocalIndex(localIndex)
n.ClearPendingDeletion(localIndex)
if !mainHostInfo {
// This hostinfo is still being used despite not being the primary hostinfo for this vpn ip
// Keep tracking so that we can tear it down when it goes away
n.Out(localIndex)
}
continue
}
// If it comes around on deletion wheel and hasn't resolved itself, delete
if n.checkPendingDeletion(localIndex) {
cn := ""
if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil {
cn = hostinfo.ConnectionState.peerCert.Details.Name
}
hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
WithField("certName", cn).
Info("Tunnel status")
n.hostMap.DeleteHostInfo(hostinfo)
}
n.ClearLocalIndex(localIndex)
n.ClearPendingDeletion(localIndex)
func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
hostinfo, err := n.hostMap.QueryIndex(localIndex)
if err != nil {
n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
delete(n.pendingDeletion, localIndex)
return
}
if n.handleInvalidCertificate(now, hostinfo) {
return
}
primary, _ := n.hostMap.QueryVpnIp(hostinfo.vpnIp)
mainHostInfo := true
if primary != nil && primary != hostinfo {
mainHostInfo = false
}
// Check for traffic on this hostinfo
inTraffic, outTraffic := n.getAndResetTrafficCheck(localIndex)
// A hostinfo is determined alive if there is incoming traffic
if inTraffic {
if n.l.Level >= logrus.DebugLevel {
hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
Debug("Tunnel status")
}
delete(n.pendingDeletion, hostinfo.localIndexId)
if !mainHostInfo {
if hostinfo.vpnIp > n.intf.myVpnIp {
// We are receiving traffic on the non primary hostinfo and we really just want 1 tunnel. Make
// This the primary and prime the old primary hostinfo for testing
n.hostMap.MakePrimary(hostinfo)
}
}
n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
if !outTraffic {
// Send a punch packet to keep the NAT state alive
n.sendPunch(hostinfo)
}
return
}
if n.intf.lightHouse.IsLighthouseIP(hostinfo.vpnIp) {
// We are sending traffic to the lighthouse, let recv_error sort out any issues instead of testing the tunnel
n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
return
}
if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok {
// We have already sent a test packet and nothing was returned, this hostinfo is dead
hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
Info("Tunnel status")
n.hostMap.DeleteHostInfo(hostinfo)
delete(n.pendingDeletion, hostinfo.localIndexId)
return
}
hostinfo.logger(n.l).
WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
Debug("Tunnel status")
if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
if n.punchy.GetTargetEverything() {
// Maybe the remote is sending us packets but our NAT is blocking it and since we are configured to punch to all
// known remotes, go ahead and do that AND send a test packet
n.sendPunch(hostinfo)
}
// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
n.intf.sendMessageToVpnIp(header.Test, header.TestRequest, hostinfo, p, nb, out)
} else {
hostinfo.logger(n.l).Debugf("Hostinfo sadness")
}
n.pendingDeletion[hostinfo.localIndexId] = struct{}{}
n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval)
}
// handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid
@@ -322,8 +246,24 @@ func (n *connectionManager) handleInvalidCertificate(now time.Time, hostinfo *Ho
// Inform the remote and close the tunnel locally
n.intf.sendCloseTunnel(hostinfo)
n.intf.closeTunnel(hostinfo)
n.ClearLocalIndex(hostinfo.localIndexId)
n.ClearPendingDeletion(hostinfo.localIndexId)
delete(n.pendingDeletion, hostinfo.localIndexId)
return true
}
func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
if !n.punchy.GetPunch() {
// Punching is disabled
return
}
if n.punchy.GetTargetEverything() {
hostinfo.remotes.ForEach(n.hostMap.preferredRanges, func(addr *udp.Addr, preferred bool) {
n.metricsTxPunchy.Inc(1)
n.intf.outside.WriteTo([]byte{1}, addr)
})
} else if hostinfo.remote != nil {
n.metricsTxPunchy.Inc(1)
n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
}
}