mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-15 20:37:36 +02:00
Stop leaking goroutines past Control.Stop, consolidate punching in Punchy (#1708)
This commit is contained in:
@@ -11,7 +11,6 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rcrowley/go-metrics"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
@@ -45,8 +44,6 @@ type connectionManager struct {
|
|||||||
inactivityTimeout atomic.Int64
|
inactivityTimeout atomic.Int64
|
||||||
dropInactive atomic.Bool
|
dropInactive atomic.Bool
|
||||||
|
|
||||||
metricsTxPunchy metrics.Counter
|
|
||||||
|
|
||||||
l *slog.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -57,7 +54,6 @@ func newConnectionManagerFromConfig(l *slog.Logger, c *config.C, hm *HostMap, p
|
|||||||
punchy: p,
|
punchy: p,
|
||||||
relayUsed: make(map[uint32]struct{}),
|
relayUsed: make(map[uint32]struct{}),
|
||||||
relayUsedLock: &sync.RWMutex{},
|
relayUsedLock: &sync.RWMutex{},
|
||||||
metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cm.reload(c, true)
|
cm.reload(c, true)
|
||||||
@@ -369,7 +365,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
|
|
||||||
if !outTraffic {
|
if !outTraffic {
|
||||||
// Send a punch packet to keep the NAT state alive
|
// Send a punch packet to keep the NAT state alive
|
||||||
cm.sendPunch(hostinfo)
|
cm.punchy.SendPunch(hostinfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
return decision, hostinfo, primary
|
return decision, hostinfo, primary
|
||||||
@@ -400,17 +396,16 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
|
|||||||
|
|
||||||
// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
|
// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
|
||||||
// Just maintain NAT state if configured to do so.
|
// Just maintain NAT state if configured to do so.
|
||||||
cm.sendPunch(hostinfo)
|
cm.punchy.SendPunch(hostinfo)
|
||||||
cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
|
cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
|
||||||
return doNothing, nil, nil
|
return doNothing, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if cm.punchy.GetTargetEverything() {
|
// We aren't receiving traffic but we are sending it. The outbound
|
||||||
// This is similar to the old punchy behavior with a slight optimization.
|
// traffic itself refreshes the primary remote's NAT state; this
|
||||||
// We aren't receiving traffic but we are sending it, punch on all known
|
// fans out to non-primary remotes, but only if target_all_remotes
|
||||||
// ips in case we need to re-prime NAT state
|
// is configured.
|
||||||
cm.sendPunch(hostinfo)
|
cm.punchy.SendPunchToAll(hostinfo)
|
||||||
}
|
|
||||||
|
|
||||||
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
|
if cm.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
hostinfo.logger(cm.l).Debug("Tunnel status",
|
hostinfo.logger(cm.l).Debug("Tunnel status",
|
||||||
@@ -512,31 +507,6 @@ func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostI
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
|
|
||||||
if !cm.punchy.GetPunch() {
|
|
||||||
// Punching is disabled
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if cm.intf.lightHouse.IsAnyLighthouseAddr(hostinfo.vpnAddrs) {
|
|
||||||
// Do not punch to lighthouses, we assume our lighthouse update interval is good enough.
|
|
||||||
// In the event the update interval is not sufficient to maintain NAT state then a publicly available lighthouse
|
|
||||||
// would lose the ability to notify us and punchy.respond would become unreliable.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if cm.punchy.GetTargetEverything() {
|
|
||||||
hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
|
|
||||||
cm.metricsTxPunchy.Inc(1)
|
|
||||||
cm.intf.outside.WriteTo([]byte{1}, addr)
|
|
||||||
})
|
|
||||||
|
|
||||||
} else if hostinfo.remote.IsValid() {
|
|
||||||
cm.metricsTxPunchy.Inc(1)
|
|
||||||
cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
|
||||||
cs := cm.intf.pki.getCertState()
|
cs := cm.intf.pki.getCertState()
|
||||||
curCrt := hostinfo.ConnectionState.myCert
|
curCrt := hostinfo.ConnectionState.myCert
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
|
|||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
conf := config.NewC(test.NewLogger())
|
conf := config.NewC(test.NewLogger())
|
||||||
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
|
punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil)
|
||||||
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
||||||
nc.intf = ifce
|
nc.intf = ifce
|
||||||
p := []byte("")
|
p := []byte("")
|
||||||
@@ -146,7 +146,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
|
|||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
conf := config.NewC(test.NewLogger())
|
conf := config.NewC(test.NewLogger())
|
||||||
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
|
punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil)
|
||||||
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
||||||
nc.intf = ifce
|
nc.intf = ifce
|
||||||
p := []byte("")
|
p := []byte("")
|
||||||
@@ -233,7 +233,7 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
|
|||||||
conf.Settings["tunnels"] = map[string]any{
|
conf.Settings["tunnels"] = map[string]any{
|
||||||
"drop_inactive": true,
|
"drop_inactive": true,
|
||||||
}
|
}
|
||||||
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
|
punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil)
|
||||||
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
||||||
assert.True(t, nc.dropInactive.Load())
|
assert.True(t, nc.dropInactive.Load())
|
||||||
nc.intf = ifce
|
nc.intf = ifce
|
||||||
@@ -358,7 +358,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
|
|||||||
|
|
||||||
// Create manager
|
// Create manager
|
||||||
conf := config.NewC(test.NewLogger())
|
conf := config.NewC(test.NewLogger())
|
||||||
punchy := NewPunchyFromConfig(test.NewLogger(), conf)
|
punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil)
|
||||||
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy)
|
||||||
nc.intf = ifce
|
nc.intf = ifce
|
||||||
ifce.connectionManager = nc
|
ifce.connectionManager = nc
|
||||||
|
|||||||
@@ -18,14 +18,10 @@ import (
|
|||||||
// retry mechanism gives the wg.Wait()-driven goroutines a moment to drain
|
// retry mechanism gives the wg.Wait()-driven goroutines a moment to drain
|
||||||
// before failing the assertion.
|
// before failing the assertion.
|
||||||
//
|
//
|
||||||
// IgnoreCurrent is necessary in the parallelized suite: other tests can
|
// Intentionally NOT t.Parallel()'d: concurrent tests would have their own
|
||||||
// leave goroutines mid-shutdown when this one runs (Stop is async, the
|
// goroutines running and trip the assertion.
|
||||||
// wg.Wait() drain is not blocking on test return). We're checking that
|
|
||||||
// *this* test's setup tears down cleanly, not that the whole suite is
|
|
||||||
// idle at this moment. Intentionally NOT t.Parallel()'d for the same
|
|
||||||
// reason — concurrent test goroutines would always show up.
|
|
||||||
func TestNoGoroutineLeaks(t *testing.T) {
|
func TestNoGoroutineLeaks(t *testing.T) {
|
||||||
defer goleak.VerifyNone(t, goleak.IgnoreCurrent())
|
defer goleak.VerifyNone(t)
|
||||||
|
|
||||||
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
|
||||||
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
|
||||||
|
|||||||
@@ -163,17 +163,21 @@ listen:
|
|||||||
|
|
||||||
punchy:
|
punchy:
|
||||||
# Continues to punch inbound/outbound at a regular interval to avoid expiration of firewall nat mappings
|
# Continues to punch inbound/outbound at a regular interval to avoid expiration of firewall nat mappings
|
||||||
|
# This setting is reloadable.
|
||||||
punch: true
|
punch: true
|
||||||
|
|
||||||
# respond means that a node you are trying to reach will connect back out to you if your hole punching fails
|
# respond means that a node you are trying to reach will connect back out to you if your hole punching fails
|
||||||
# this is extremely useful if one node is behind a difficult nat, such as a symmetric NAT
|
# this is extremely useful if one node is behind a difficult nat, such as a symmetric NAT
|
||||||
# Default is false
|
# Default is false
|
||||||
|
# This setting is reloadable.
|
||||||
#respond: true
|
#respond: true
|
||||||
|
|
||||||
# delays a punch response for misbehaving NATs, default is 1 second.
|
# delays a punch response for misbehaving NATs, default is 1 second.
|
||||||
|
# This setting is reloadable.
|
||||||
#delay: 1s
|
#delay: 1s
|
||||||
|
|
||||||
# set the delay before attempting punchy.respond. Default is 5 seconds. respond must be true to take effect.
|
# set the delay before attempting punchy.respond. Default is 5 seconds. respond must be true to take effect.
|
||||||
|
# This setting is reloadable.
|
||||||
#respond_delay: 5s
|
#respond_delay: 5s
|
||||||
|
|
||||||
# Cipher allows you to choose between the available ciphers for your network. Options are chachapoly or aes
|
# Cipher allows you to choose between the available ciphers for your network. Options are chachapoly or aes
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/rcrowley/go-metrics"
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/slackhq/nebula/header"
|
"github.com/slackhq/nebula/header"
|
||||||
@@ -35,7 +34,6 @@ type LightHouse struct {
|
|||||||
|
|
||||||
myVpnNetworks []netip.Prefix
|
myVpnNetworks []netip.Prefix
|
||||||
myVpnNetworksTable *bart.Lite
|
myVpnNetworksTable *bart.Lite
|
||||||
punchConn udp.Conn
|
|
||||||
punchy *Punchy
|
punchy *Punchy
|
||||||
|
|
||||||
// Local cache of answers from light houses
|
// Local cache of answers from light houses
|
||||||
@@ -76,7 +74,6 @@ type LightHouse struct {
|
|||||||
calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote
|
calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote
|
||||||
|
|
||||||
metrics *MessageMetrics
|
metrics *MessageMetrics
|
||||||
metricHolepunchTx metrics.Counter
|
|
||||||
l *slog.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,7 +102,6 @@ func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, c
|
|||||||
myVpnNetworksTable: cs.myVpnNetworksTable,
|
myVpnNetworksTable: cs.myVpnNetworksTable,
|
||||||
addrMap: make(map[netip.Addr]*RemoteList),
|
addrMap: make(map[netip.Addr]*RemoteList),
|
||||||
nebulaPort: nebulaPort,
|
nebulaPort: nebulaPort,
|
||||||
punchConn: pc,
|
|
||||||
punchy: p,
|
punchy: p,
|
||||||
updateTrigger: make(chan struct{}, 1),
|
updateTrigger: make(chan struct{}, 1),
|
||||||
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
|
queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)),
|
||||||
@@ -118,9 +114,6 @@ func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, c
|
|||||||
|
|
||||||
if c.GetBool("stats.lighthouse_metrics", false) {
|
if c.GetBool("stats.lighthouse_metrics", false) {
|
||||||
h.metrics = newLighthouseMetrics()
|
h.metrics = newLighthouseMetrics()
|
||||||
h.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil)
|
|
||||||
} else {
|
|
||||||
h.metricHolepunchTx = metrics.NilCounter{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err := h.reload(c, true)
|
err := h.reload(c, true)
|
||||||
@@ -1406,58 +1399,25 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
empty := []byte{0}
|
|
||||||
punch := func(vpnPeer netip.AddrPort, logVpnAddr netip.Addr) {
|
|
||||||
if !vpnPeer.IsValid() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
time.Sleep(lhh.lh.punchy.GetDelay())
|
|
||||||
lhh.lh.metricHolepunchTx.Inc(1)
|
|
||||||
lhh.lh.punchConn.WriteTo(empty, vpnPeer)
|
|
||||||
}()
|
|
||||||
|
|
||||||
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
|
||||||
lhh.l.Debug("Punching",
|
|
||||||
"vpnPeer", vpnPeer,
|
|
||||||
"logVpnAddr", logVpnAddr,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
remoteAllowList := lhh.lh.GetRemoteAllowList()
|
remoteAllowList := lhh.lh.GetRemoteAllowList()
|
||||||
for _, a := range n.Details.V4AddrPorts {
|
for _, a := range n.Details.V4AddrPorts {
|
||||||
b := protoV4AddrPortToNetAddrPort(a)
|
b := protoV4AddrPortToNetAddrPort(a)
|
||||||
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
|
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
|
||||||
punch(b, detailsVpnAddr)
|
lhh.lh.punchy.Schedule(b, detailsVpnAddr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, a := range n.Details.V6AddrPorts {
|
for _, a := range n.Details.V6AddrPorts {
|
||||||
b := protoV6AddrPortToNetAddrPort(a)
|
b := protoV6AddrPortToNetAddrPort(a)
|
||||||
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
|
if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) {
|
||||||
punch(b, detailsVpnAddr)
|
lhh.lh.punchy.Schedule(b, detailsVpnAddr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// This sends a nebula test packet to the host trying to contact us. In the case
|
// This sends a nebula test packet to the host trying to contact us. In the case
|
||||||
// of a double nat or other difficult scenario, this may help establish
|
// of a double nat or other difficult scenario, this may help establish
|
||||||
// a tunnel.
|
// a tunnel. ScheduleRespond is a no-op when punchy.respond is disabled.
|
||||||
if lhh.lh.punchy.GetRespond() {
|
lhh.lh.punchy.ScheduleRespond(detailsVpnAddr)
|
||||||
go func() {
|
|
||||||
time.Sleep(lhh.lh.punchy.GetRespondDelay())
|
|
||||||
if lhh.l.Enabled(context.Background(), slog.LevelDebug) {
|
|
||||||
lhh.l.Debug("Sending a nebula test packet",
|
|
||||||
"vpnAddr", detailsVpnAddr,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
//NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine
|
|
||||||
// for each punchBack packet. We should move this into a timerwheel or a single goroutine
|
|
||||||
// managed by a channel.
|
|
||||||
w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func protoAddrToNetAddr(addr *Addr) netip.Addr {
|
func protoAddrToNetAddr(addr *Addr) netip.Addr {
|
||||||
|
|||||||
6
main.go
6
main.go
@@ -55,7 +55,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev
|
|||||||
}
|
}
|
||||||
l.Info("Firewall started", "firewallHashes", fw.GetRuleHashes())
|
l.Info("Firewall started", "firewallHashes", fw.GetRuleHashes())
|
||||||
|
|
||||||
ssh, err := sshd.NewSSHServer(l.With("subsystem", "sshd"))
|
ssh, err := sshd.NewSSHServer(ctx, l.With("subsystem", "sshd"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
|
return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err)
|
||||||
}
|
}
|
||||||
@@ -170,7 +170,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev
|
|||||||
}
|
}
|
||||||
|
|
||||||
hostMap := NewHostMapFromConfig(l, c)
|
hostMap := NewHostMapFromConfig(l, c)
|
||||||
punchy := NewPunchyFromConfig(l, c)
|
punchy := NewPunchyFromConfig(l, c, udpConns[0])
|
||||||
connManager := newConnectionManagerFromConfig(l, c, hostMap, punchy)
|
connManager := newConnectionManagerFromConfig(l, c, hostMap, punchy)
|
||||||
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
|
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -240,6 +240,8 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev
|
|||||||
|
|
||||||
handshakeManager.f = ifce
|
handshakeManager.f = ifce
|
||||||
go handshakeManager.Run(ctx)
|
go handshakeManager.Run(ctx)
|
||||||
|
|
||||||
|
punchy.Start(ctx, ifce, hostMap, lightHouse)
|
||||||
}
|
}
|
||||||
|
|
||||||
stats, err := newStatsServerFromConfig(ctx, l, c, buildVersion, configTest)
|
stats, err := newStatsServerFromConfig(ctx, l, c, buildVersion, configTest)
|
||||||
|
|||||||
191
punchy.go
191
punchy.go
@@ -1,24 +1,70 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"net/netip"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/rcrowley/go-metrics"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
|
"github.com/slackhq/nebula/header"
|
||||||
|
"github.com/slackhq/nebula/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// holepunchQueueSize buffers the channel that pending holepunchJobs land on after their delay timer fires.
|
||||||
|
const holepunchQueueSize = 64
|
||||||
|
|
||||||
|
// holepunchJob is one scheduled item delivered to the worker goroutine.
|
||||||
|
// - target valid -> send a UDP punch to target. vpnAddr, if set, is the peer's vpn addr carried for log context.
|
||||||
|
// - target invalid, vpnAddr valid -> send an encrypted test packet to vpnAddr (a "punchback").
|
||||||
|
type holepunchJob struct {
|
||||||
|
target netip.AddrPort
|
||||||
|
vpnAddr netip.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// lighthouseChecker is the slice of LightHouse that Punchy actually needs.
|
||||||
|
// Defined here so Punchy doesn't take a *LightHouse dependency (LightHouse
|
||||||
|
// already holds a *Punchy, and the bidirectional pointer reference is awkward
|
||||||
|
// even within the same package). Tests can also substitute a fake.
|
||||||
|
type lighthouseChecker interface {
|
||||||
|
IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool
|
||||||
|
}
|
||||||
|
|
||||||
type Punchy struct {
|
type Punchy struct {
|
||||||
punch atomic.Bool
|
punch atomic.Bool
|
||||||
respond atomic.Bool
|
respond atomic.Bool
|
||||||
delay atomic.Int64
|
delay atomic.Int64
|
||||||
respondDelay atomic.Int64
|
respondDelay atomic.Int64
|
||||||
punchEverything atomic.Bool
|
punchEverything atomic.Bool
|
||||||
|
|
||||||
|
sched *Scheduler[holepunchJob]
|
||||||
|
punchConn udp.Conn
|
||||||
|
metricHolepunchTx metrics.Counter
|
||||||
|
metricPunchyTx metrics.Counter
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
ifce EncWriter
|
||||||
|
hm *HostMap
|
||||||
|
lh lighthouseChecker
|
||||||
|
|
||||||
l *slog.Logger
|
l *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy {
|
func NewPunchyFromConfig(l *slog.Logger, c *config.C, punchConn udp.Conn) *Punchy {
|
||||||
p := &Punchy{l: l}
|
p := &Punchy{
|
||||||
|
l: l,
|
||||||
|
punchConn: punchConn,
|
||||||
|
sched: NewScheduler[holepunchJob](holepunchQueueSize),
|
||||||
|
metricPunchyTx: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.GetBool("stats.lighthouse_metrics", false) {
|
||||||
|
p.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil)
|
||||||
|
} else {
|
||||||
|
p.metricHolepunchTx = metrics.NilCounter{}
|
||||||
|
}
|
||||||
|
|
||||||
p.reload(c, true)
|
p.reload(c, true)
|
||||||
c.RegisterReloadCallback(func(c *config.C) {
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
@@ -29,7 +75,7 @@ func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Punchy) reload(c *config.C, initial bool) {
|
func (p *Punchy) reload(c *config.C, initial bool) {
|
||||||
if initial {
|
if initial || c.HasChanged("punchy.punch") || c.HasChanged("punchy") {
|
||||||
var yes bool
|
var yes bool
|
||||||
if c.IsSet("punchy.punch") {
|
if c.IsSet("punchy.punch") {
|
||||||
yes = c.GetBool("punchy.punch", false)
|
yes = c.GetBool("punchy.punch", false)
|
||||||
@@ -38,16 +84,15 @@ func (p *Punchy) reload(c *config.C, initial bool) {
|
|||||||
yes = c.GetBool("punchy", false)
|
yes = c.GetBool("punchy", false)
|
||||||
}
|
}
|
||||||
|
|
||||||
p.punch.Store(yes)
|
old := p.punch.Swap(yes)
|
||||||
if yes {
|
switch {
|
||||||
|
case initial && yes:
|
||||||
p.l.Info("punchy enabled")
|
p.l.Info("punchy enabled")
|
||||||
} else {
|
case initial:
|
||||||
p.l.Info("punchy disabled")
|
p.l.Info("punchy disabled")
|
||||||
|
case old != yes:
|
||||||
|
p.l.Info("punchy.punch changed", "punch", yes)
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if c.HasChanged("punchy.punch") || c.HasChanged("punchy") {
|
|
||||||
//TODO: it should be relatively easy to support this, just need to be able to cancel the goroutine and boot it up from here
|
|
||||||
p.l.Warn("Changing punchy.punch with reload is not supported, ignoring.")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if initial || c.HasChanged("punchy.respond") || c.HasChanged("punch_back") {
|
if initial || c.HasChanged("punchy.respond") || c.HasChanged("punch_back") {
|
||||||
@@ -59,52 +104,132 @@ func (p *Punchy) reload(c *config.C, initial bool) {
|
|||||||
yes = c.GetBool("punch_back", false)
|
yes = c.GetBool("punch_back", false)
|
||||||
}
|
}
|
||||||
|
|
||||||
p.respond.Store(yes)
|
old := p.respond.Swap(yes)
|
||||||
|
if !initial && old != yes {
|
||||||
if !initial {
|
p.l.Info("punchy.respond changed", "respond", yes)
|
||||||
p.l.Info("punchy.respond changed", "respond", p.GetRespond())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//NOTE: this will not apply to any in progress operations, only the next one
|
//NOTE: this will not apply to any in progress operations, only the next one
|
||||||
if initial || c.HasChanged("punchy.delay") {
|
if initial || c.HasChanged("punchy.delay") {
|
||||||
p.delay.Store((int64)(c.GetDuration("punchy.delay", time.Second)))
|
newDelay := int64(c.GetDuration("punchy.delay", time.Second))
|
||||||
if !initial {
|
old := p.delay.Swap(newDelay)
|
||||||
p.l.Info("punchy.delay changed", "delay", p.GetDelay())
|
if !initial && old != newDelay {
|
||||||
|
p.l.Info("punchy.delay changed", "delay", time.Duration(newDelay))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if initial || c.HasChanged("punchy.target_all_remotes") {
|
if initial || c.HasChanged("punchy.target_all_remotes") {
|
||||||
p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false))
|
yes := c.GetBool("punchy.target_all_remotes", false)
|
||||||
if !initial {
|
old := p.punchEverything.Swap(yes)
|
||||||
p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", p.GetTargetEverything())
|
if !initial && old != yes {
|
||||||
|
p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", yes)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if initial || c.HasChanged("punchy.respond_delay") {
|
if initial || c.HasChanged("punchy.respond_delay") {
|
||||||
p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second)))
|
newDelay := int64(c.GetDuration("punchy.respond_delay", 5*time.Second))
|
||||||
if !initial {
|
old := p.respondDelay.Swap(newDelay)
|
||||||
p.l.Info("punchy.respond_delay changed", "respond_delay", p.GetRespondDelay())
|
if !initial && old != newDelay {
|
||||||
|
p.l.Info("punchy.respond_delay changed", "respond_delay", time.Duration(newDelay))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Punchy) GetPunch() bool {
|
// Schedule queues a punch packet to target, to be sent after the configured delay.
|
||||||
return p.punch.Load()
|
// vpnAddr is the peer's vpn addr, used for log context when the packet actually fires.
|
||||||
|
// No-op if target is not a valid AddrPort or if Start has not yet been called. Safe to call from any goroutine.
|
||||||
|
func (p *Punchy) Schedule(target netip.AddrPort, vpnAddr netip.Addr) {
|
||||||
|
if !target.IsValid() || p.ctx == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.scheduleJob(holepunchJob{target: target, vpnAddr: vpnAddr}, time.Duration(p.delay.Load()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Punchy) GetRespond() bool {
|
// ScheduleRespond queues a punchback test packet to vpnAddr after the configured respond delay,
|
||||||
return p.respond.Load()
|
// gated on punchy.respond. No-op when respond is disabled or before Start has been called.
|
||||||
|
func (p *Punchy) ScheduleRespond(vpnAddr netip.Addr) {
|
||||||
|
if !p.respond.Load() || p.ctx == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.scheduleJob(holepunchJob{vpnAddr: vpnAddr}, time.Duration(p.respondDelay.Load()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Punchy) GetDelay() time.Duration {
|
// scheduleJob delegates to the pooled Scheduler.
|
||||||
return (time.Duration)(p.delay.Load())
|
// The callback observes p.ctx so a job that becomes due after Stop is dropped instead of queued.
|
||||||
|
func (p *Punchy) scheduleJob(job holepunchJob, delay time.Duration) {
|
||||||
|
p.sched.Schedule(p.ctx, job, delay)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Punchy) GetRespondDelay() time.Duration {
|
// SendPunch sends an immediate keepalive punch for an idle hostinfo.
|
||||||
return (time.Duration)(p.respondDelay.Load())
|
// The configured punchy.target_all_remotes mode picks the targets. Gated on punchy.punch and the lighthouse-skip rule
|
||||||
|
// (lighthouses don't get keepalive punches because the regular update interval keeps their NAT state warm).
|
||||||
|
func (p *Punchy) SendPunch(hostinfo *HostInfo) {
|
||||||
|
if !p.punch.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if p.lh.IsAnyLighthouseAddr(hostinfo.vpnAddrs) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.punchEverything.Load() {
|
||||||
|
p.sendPunchToAllRemotes(hostinfo)
|
||||||
|
} else if hostinfo.remote.IsValid() {
|
||||||
|
p.metricPunchyTx.Inc(1)
|
||||||
|
p.punchConn.WriteTo([]byte{1}, hostinfo.remote)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Punchy) GetTargetEverything() bool {
|
// SendPunchToAll punches every known remote for hostinfo, but only when punchy.target_all_remotes is enabled.
|
||||||
return p.punchEverything.Load()
|
// The connection manager calls this during outbound-only traffic: the outbound traffic itself keeps the primary's
|
||||||
|
// NAT state warm, but non-primary remotes need separate refresh, so we fan out to all of them (the redundant
|
||||||
|
// primary punch is harmless). Gated on punchy.punch and the lighthouse-skip rule.
|
||||||
|
func (p *Punchy) SendPunchToAll(hostinfo *HostInfo) {
|
||||||
|
if !p.punchEverything.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !p.punch.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if p.lh.IsAnyLighthouseAddr(hostinfo.vpnAddrs) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.sendPunchToAllRemotes(hostinfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Punchy) sendPunchToAllRemotes(hostinfo *HostInfo) {
|
||||||
|
hostinfo.remotes.ForEach(p.hm.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
|
||||||
|
p.metricPunchyTx.Inc(1)
|
||||||
|
p.punchConn.WriteTo([]byte{1}, addr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start wires the runtime dependencies and spawns the scheduler worker.
|
||||||
|
func (p *Punchy) Start(ctx context.Context, ifce EncWriter, hm *HostMap, lh lighthouseChecker) {
|
||||||
|
p.ctx = ctx
|
||||||
|
p.ifce = ifce
|
||||||
|
p.hm = hm
|
||||||
|
p.lh = lh
|
||||||
|
|
||||||
|
nb := make([]byte, 12, 12)
|
||||||
|
out := make([]byte, mtu)
|
||||||
|
empty := []byte{0}
|
||||||
|
|
||||||
|
go p.sched.Run(ctx, func(job holepunchJob) {
|
||||||
|
switch {
|
||||||
|
case job.target.IsValid():
|
||||||
|
if p.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
p.l.Debug("Punching", "target", job.target, "vpnAddr", job.vpnAddr)
|
||||||
|
}
|
||||||
|
p.metricHolepunchTx.Inc(1)
|
||||||
|
p.punchConn.WriteTo(empty, job.target)
|
||||||
|
case job.vpnAddr.IsValid():
|
||||||
|
// A nebula test packet to the host trying to contact us.
|
||||||
|
// In the case of a double nat or other difficult scenario, this may help establish a tunnel.
|
||||||
|
if p.l.Enabled(context.Background(), slog.LevelDebug) {
|
||||||
|
p.l.Debug("Sending a nebula test packet", "vpnAddr", job.vpnAddr)
|
||||||
|
}
|
||||||
|
p.ifce.SendMessageToVpnAddr(header.Test, header.TestRequest, job.vpnAddr, []byte(""), nb, out)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,42 +17,42 @@ func TestNewPunchyFromConfig(t *testing.T) {
|
|||||||
c := config.NewC(l)
|
c := config.NewC(l)
|
||||||
|
|
||||||
// Test defaults
|
// Test defaults
|
||||||
p := NewPunchyFromConfig(test.NewLogger(), c)
|
p := NewPunchyFromConfig(test.NewLogger(), c, nil)
|
||||||
assert.False(t, p.GetPunch())
|
assert.False(t, p.punch.Load())
|
||||||
assert.False(t, p.GetRespond())
|
assert.False(t, p.respond.Load())
|
||||||
assert.Equal(t, time.Second, p.GetDelay())
|
assert.Equal(t, time.Second, time.Duration(p.delay.Load()))
|
||||||
assert.Equal(t, 5*time.Second, p.GetRespondDelay())
|
assert.Equal(t, 5*time.Second, time.Duration(p.respondDelay.Load()))
|
||||||
|
|
||||||
// punchy deprecation
|
// punchy deprecation
|
||||||
c.Settings["punchy"] = true
|
c.Settings["punchy"] = true
|
||||||
p = NewPunchyFromConfig(test.NewLogger(), c)
|
p = NewPunchyFromConfig(test.NewLogger(), c, nil)
|
||||||
assert.True(t, p.GetPunch())
|
assert.True(t, p.punch.Load())
|
||||||
|
|
||||||
// punchy.punch
|
// punchy.punch
|
||||||
c.Settings["punchy"] = map[string]any{"punch": true}
|
c.Settings["punchy"] = map[string]any{"punch": true}
|
||||||
p = NewPunchyFromConfig(test.NewLogger(), c)
|
p = NewPunchyFromConfig(test.NewLogger(), c, nil)
|
||||||
assert.True(t, p.GetPunch())
|
assert.True(t, p.punch.Load())
|
||||||
|
|
||||||
// punch_back deprecation
|
// punch_back deprecation
|
||||||
c.Settings["punch_back"] = true
|
c.Settings["punch_back"] = true
|
||||||
p = NewPunchyFromConfig(test.NewLogger(), c)
|
p = NewPunchyFromConfig(test.NewLogger(), c, nil)
|
||||||
assert.True(t, p.GetRespond())
|
assert.True(t, p.respond.Load())
|
||||||
|
|
||||||
// punchy.respond
|
// punchy.respond
|
||||||
c.Settings["punchy"] = map[string]any{"respond": true}
|
c.Settings["punchy"] = map[string]any{"respond": true}
|
||||||
c.Settings["punch_back"] = false
|
c.Settings["punch_back"] = false
|
||||||
p = NewPunchyFromConfig(test.NewLogger(), c)
|
p = NewPunchyFromConfig(test.NewLogger(), c, nil)
|
||||||
assert.True(t, p.GetRespond())
|
assert.True(t, p.respond.Load())
|
||||||
|
|
||||||
// punchy.delay
|
// punchy.delay
|
||||||
c.Settings["punchy"] = map[string]any{"delay": "1m"}
|
c.Settings["punchy"] = map[string]any{"delay": "1m"}
|
||||||
p = NewPunchyFromConfig(test.NewLogger(), c)
|
p = NewPunchyFromConfig(test.NewLogger(), c, nil)
|
||||||
assert.Equal(t, time.Minute, p.GetDelay())
|
assert.Equal(t, time.Minute, time.Duration(p.delay.Load()))
|
||||||
|
|
||||||
// punchy.respond_delay
|
// punchy.respond_delay
|
||||||
c.Settings["punchy"] = map[string]any{"respond_delay": "1m"}
|
c.Settings["punchy"] = map[string]any{"respond_delay": "1m"}
|
||||||
p = NewPunchyFromConfig(test.NewLogger(), c)
|
p = NewPunchyFromConfig(test.NewLogger(), c, nil)
|
||||||
assert.Equal(t, time.Minute, p.GetRespondDelay())
|
assert.Equal(t, time.Minute, time.Duration(p.respondDelay.Load()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPunchy_reload(t *testing.T) {
|
func TestPunchy_reload(t *testing.T) {
|
||||||
@@ -61,35 +61,34 @@ func TestPunchy_reload(t *testing.T) {
|
|||||||
delay, _ := time.ParseDuration("1m")
|
delay, _ := time.ParseDuration("1m")
|
||||||
require.NoError(t, c.LoadString(`
|
require.NoError(t, c.LoadString(`
|
||||||
punchy:
|
punchy:
|
||||||
|
punch: false
|
||||||
delay: 1m
|
delay: 1m
|
||||||
respond: false
|
respond: false
|
||||||
`))
|
`))
|
||||||
p := NewPunchyFromConfig(test.NewLogger(), c)
|
p := NewPunchyFromConfig(test.NewLogger(), c, nil)
|
||||||
assert.Equal(t, delay, p.GetDelay())
|
assert.False(t, p.punch.Load())
|
||||||
assert.False(t, p.GetRespond())
|
assert.Equal(t, delay, time.Duration(p.delay.Load()))
|
||||||
|
assert.False(t, p.respond.Load())
|
||||||
|
|
||||||
newDelay, _ := time.ParseDuration("10m")
|
newDelay, _ := time.ParseDuration("10m")
|
||||||
require.NoError(t, c.ReloadConfigString(`
|
require.NoError(t, c.ReloadConfigString(`
|
||||||
punchy:
|
punchy:
|
||||||
|
punch: true
|
||||||
delay: 10m
|
delay: 10m
|
||||||
respond: true
|
respond: true
|
||||||
`))
|
`))
|
||||||
p.reload(c, false)
|
p.reload(c, false)
|
||||||
assert.Equal(t, newDelay, p.GetDelay())
|
assert.True(t, p.punch.Load())
|
||||||
assert.True(t, p.GetRespond())
|
assert.Equal(t, newDelay, time.Duration(p.delay.Load()))
|
||||||
|
assert.True(t, p.respond.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// The tests below pin the shape of each log line Punchy produces so changes
|
// The tests below pin the shape of each log line Punchy produces so changes
|
||||||
// cannot silently break whatever operators are grepping for. The assertions
|
// cannot silently break whatever operators are grepping for. The assertions
|
||||||
// are on the structured message + attrs (e.g. "punchy.respond changed" with
|
// are on the structured message + attrs (e.g. "punchy.respond changed" with
|
||||||
// a respond=true field) rather than a formatted string.
|
// a respond=true field) rather than a formatted string. Tests filter by
|
||||||
//
|
// message rather than asserting total entry counts so unrelated info lines
|
||||||
// Punchy.reload also emits a spurious "Changing punchy.punch with reload is
|
// are tolerated without being locked into the format.
|
||||||
// not supported" warning whenever any key under punchy changes, because of
|
|
||||||
// the c.HasChanged("punchy") fallback kept for the deprecated top-level
|
|
||||||
// punchy form. The tests filter by message rather than asserting total
|
|
||||||
// entry counts so that warning is tolerated without being locked into
|
|
||||||
// the format.
|
|
||||||
|
|
||||||
type capturedEntry struct {
|
type capturedEntry struct {
|
||||||
Level slog.Level
|
Level slog.Level
|
||||||
@@ -145,7 +144,7 @@ func TestPunchy_LogFormat_InitialEnabled(t *testing.T) {
|
|||||||
c := config.NewC(test.NewLogger())
|
c := config.NewC(test.NewLogger())
|
||||||
require.NoError(t, c.LoadString(`punchy: {punch: true}`))
|
require.NoError(t, c.LoadString(`punchy: {punch: true}`))
|
||||||
|
|
||||||
NewPunchyFromConfig(l, c)
|
NewPunchyFromConfig(l, c, nil)
|
||||||
|
|
||||||
entry := findEntry(t, hook.entries, "punchy enabled")
|
entry := findEntry(t, hook.entries, "punchy enabled")
|
||||||
assert.Equal(t, slog.LevelInfo, entry.Level)
|
assert.Equal(t, slog.LevelInfo, entry.Level)
|
||||||
@@ -157,32 +156,32 @@ func TestPunchy_LogFormat_InitialDisabled(t *testing.T) {
|
|||||||
c := config.NewC(test.NewLogger())
|
c := config.NewC(test.NewLogger())
|
||||||
require.NoError(t, c.LoadString(`punchy: {punch: false}`))
|
require.NoError(t, c.LoadString(`punchy: {punch: false}`))
|
||||||
|
|
||||||
NewPunchyFromConfig(l, c)
|
NewPunchyFromConfig(l, c, nil)
|
||||||
|
|
||||||
entry := findEntry(t, hook.entries, "punchy disabled")
|
entry := findEntry(t, hook.entries, "punchy disabled")
|
||||||
assert.Equal(t, slog.LevelInfo, entry.Level)
|
assert.Equal(t, slog.LevelInfo, entry.Level)
|
||||||
assert.Empty(t, entry.Attrs)
|
assert.Empty(t, entry.Attrs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPunchy_LogFormat_ReloadPunchUnsupported(t *testing.T) {
|
func TestPunchy_LogFormat_ReloadPunch(t *testing.T) {
|
||||||
l, hook := newCapturingPunchyLogger(t)
|
l, hook := newCapturingPunchyLogger(t)
|
||||||
c := config.NewC(test.NewLogger())
|
c := config.NewC(test.NewLogger())
|
||||||
require.NoError(t, c.LoadString(`punchy: {punch: false}`))
|
require.NoError(t, c.LoadString(`punchy: {punch: false}`))
|
||||||
NewPunchyFromConfig(l, c)
|
NewPunchyFromConfig(l, c, nil)
|
||||||
hook.entries = nil
|
hook.entries = nil
|
||||||
|
|
||||||
require.NoError(t, c.ReloadConfigString(`punchy: {punch: true}`))
|
require.NoError(t, c.ReloadConfigString(`punchy: {punch: true}`))
|
||||||
|
|
||||||
entry := findEntry(t, hook.entries, "Changing punchy.punch with reload is not supported, ignoring.")
|
entry := findEntry(t, hook.entries, "punchy.punch changed")
|
||||||
assert.Equal(t, slog.LevelWarn, entry.Level)
|
assert.Equal(t, slog.LevelInfo, entry.Level)
|
||||||
assert.Empty(t, entry.Attrs)
|
assert.Equal(t, map[string]any{"punch": true}, entry.Attrs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPunchy_LogFormat_ReloadRespond(t *testing.T) {
|
func TestPunchy_LogFormat_ReloadRespond(t *testing.T) {
|
||||||
l, hook := newCapturingPunchyLogger(t)
|
l, hook := newCapturingPunchyLogger(t)
|
||||||
c := config.NewC(test.NewLogger())
|
c := config.NewC(test.NewLogger())
|
||||||
require.NoError(t, c.LoadString(`punchy: {respond: false}`))
|
require.NoError(t, c.LoadString(`punchy: {respond: false}`))
|
||||||
NewPunchyFromConfig(l, c)
|
NewPunchyFromConfig(l, c, nil)
|
||||||
hook.entries = nil
|
hook.entries = nil
|
||||||
|
|
||||||
require.NoError(t, c.ReloadConfigString(`punchy: {respond: true}`))
|
require.NoError(t, c.ReloadConfigString(`punchy: {respond: true}`))
|
||||||
@@ -196,7 +195,7 @@ func TestPunchy_LogFormat_ReloadDelay(t *testing.T) {
|
|||||||
l, hook := newCapturingPunchyLogger(t)
|
l, hook := newCapturingPunchyLogger(t)
|
||||||
c := config.NewC(test.NewLogger())
|
c := config.NewC(test.NewLogger())
|
||||||
require.NoError(t, c.LoadString(`punchy: {delay: 1s}`))
|
require.NoError(t, c.LoadString(`punchy: {delay: 1s}`))
|
||||||
NewPunchyFromConfig(l, c)
|
NewPunchyFromConfig(l, c, nil)
|
||||||
hook.entries = nil
|
hook.entries = nil
|
||||||
|
|
||||||
require.NoError(t, c.ReloadConfigString(`punchy: {delay: 10s}`))
|
require.NoError(t, c.ReloadConfigString(`punchy: {delay: 10s}`))
|
||||||
@@ -210,7 +209,7 @@ func TestPunchy_LogFormat_ReloadTargetAllRemotes(t *testing.T) {
|
|||||||
l, hook := newCapturingPunchyLogger(t)
|
l, hook := newCapturingPunchyLogger(t)
|
||||||
c := config.NewC(test.NewLogger())
|
c := config.NewC(test.NewLogger())
|
||||||
require.NoError(t, c.LoadString(`punchy: {target_all_remotes: false}`))
|
require.NoError(t, c.LoadString(`punchy: {target_all_remotes: false}`))
|
||||||
NewPunchyFromConfig(l, c)
|
NewPunchyFromConfig(l, c, nil)
|
||||||
hook.entries = nil
|
hook.entries = nil
|
||||||
|
|
||||||
require.NoError(t, c.ReloadConfigString(`punchy: {target_all_remotes: true}`))
|
require.NoError(t, c.ReloadConfigString(`punchy: {target_all_remotes: true}`))
|
||||||
@@ -224,7 +223,7 @@ func TestPunchy_LogFormat_ReloadRespondDelay(t *testing.T) {
|
|||||||
l, hook := newCapturingPunchyLogger(t)
|
l, hook := newCapturingPunchyLogger(t)
|
||||||
c := config.NewC(test.NewLogger())
|
c := config.NewC(test.NewLogger())
|
||||||
require.NoError(t, c.LoadString(`punchy: {respond_delay: 5s}`))
|
require.NoError(t, c.LoadString(`punchy: {respond_delay: 5s}`))
|
||||||
NewPunchyFromConfig(l, c)
|
NewPunchyFromConfig(l, c, nil)
|
||||||
hook.entries = nil
|
hook.entries = nil
|
||||||
|
|
||||||
require.NoError(t, c.ReloadConfigString(`punchy: {respond_delay: 15s}`))
|
require.NoError(t, c.ReloadConfigString(`punchy: {respond_delay: 15s}`))
|
||||||
|
|||||||
84
scheduler.go
Normal file
84
scheduler.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
package nebula
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Scheduler is an allocation-conscious dispatch primitive for delayed work.
|
||||||
|
// Pending items are handed to time.AfterFunc, and ready items land on a worker
|
||||||
|
// channel for centralized dispatch in fire-time order.
|
||||||
|
//
|
||||||
|
// Pick a Scheduler when fire timing matters (exact deadlines, no bucketing) or when the scheduling
|
||||||
|
// rate is uneven enough that idle CPU matters. Each fire is a runtime-spawned goroutine running the callback before
|
||||||
|
// delivering to the worker, which is fine at sparse rates but adds up at line rate.
|
||||||
|
//
|
||||||
|
// Pick a TimerWheel when scheduling is high-rate and uniform: its O(1) insert, internal item cache,
|
||||||
|
// and bucket-batched dispatch are cheaper at scale.
|
||||||
|
// The caller drives the tick loop (Advance/Purge) and pays for fires at bucket boundaries rather than exact deadlines.
|
||||||
|
type Scheduler[T any] struct {
|
||||||
|
queue chan T
|
||||||
|
pool sync.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
type schedItem[T any] struct {
|
||||||
|
val T
|
||||||
|
ctx context.Context
|
||||||
|
s *Scheduler[T]
|
||||||
|
timer *time.Timer
|
||||||
|
fire func()
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewScheduler builds a Scheduler whose worker channel is sized to queueSize.
|
||||||
|
// The buffer absorbs bursts of timers firing close together without
|
||||||
|
// blocking the runtime's callback goroutines on the worker.
|
||||||
|
func NewScheduler[T any](queueSize int) *Scheduler[T] {
|
||||||
|
s := &Scheduler[T]{
|
||||||
|
queue: make(chan T, queueSize),
|
||||||
|
}
|
||||||
|
s.pool.New = func() any {
|
||||||
|
si := &schedItem[T]{s: s}
|
||||||
|
// fire is allocated exactly once per pool-resident item.
|
||||||
|
// The closure captures only `si`, which stays stable for the item's lifetime.
|
||||||
|
si.fire = func() {
|
||||||
|
select {
|
||||||
|
case si.s.queue <- si.val:
|
||||||
|
case <-si.ctx.Done():
|
||||||
|
}
|
||||||
|
var zero T
|
||||||
|
si.val = zero
|
||||||
|
si.ctx = nil
|
||||||
|
si.s.pool.Put(si)
|
||||||
|
}
|
||||||
|
return si
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Schedule arranges item to be delivered to the worker after delay.
|
||||||
|
// The runtime's timer heap handles the wait, so the scheduler itself burns no CPU while idle.
|
||||||
|
// The callback observes ctx: if ctx is cancelled before the timer fires, the item is dropped instead of queued.
|
||||||
|
func (s *Scheduler[T]) Schedule(ctx context.Context, item T, delay time.Duration) {
|
||||||
|
si := s.pool.Get().(*schedItem[T])
|
||||||
|
si.val = item
|
||||||
|
si.ctx = ctx
|
||||||
|
if si.timer == nil {
|
||||||
|
si.timer = time.AfterFunc(delay, si.fire)
|
||||||
|
} else {
|
||||||
|
si.timer.Reset(delay)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run drains the worker queue, calling fn for each item. Returns when ctx is cancelled.
|
||||||
|
// Tests that want deterministic timing should drive the queue directly rather than going through Schedule + Run.
|
||||||
|
func (s *Scheduler[T]) Run(ctx context.Context, fn func(T)) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case item := <-s.queue:
|
||||||
|
fn(item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
79
scheduler_test.go
Normal file
79
scheduler_test.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package nebula
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestScheduler_PooledReuse(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
s := NewScheduler[int](16)
|
||||||
|
delivered := make(chan int, 256)
|
||||||
|
go s.Run(ctx, func(item int) { delivered <- item })
|
||||||
|
|
||||||
|
const N = 100
|
||||||
|
for i := 0; i < N; i++ {
|
||||||
|
s.Schedule(ctx, i, time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
deadline := time.After(2 * time.Second)
|
||||||
|
got := 0
|
||||||
|
for got < N {
|
||||||
|
select {
|
||||||
|
case <-delivered:
|
||||||
|
got++
|
||||||
|
case <-deadline:
|
||||||
|
t.Fatalf("only %d/%d items delivered", got, N)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkScheduler_Schedule reports allocations per Schedule call.
|
||||||
|
// In steady state the Scheduler's sync.Pool means we should see zero allocs per op once the pool warms up.
|
||||||
|
func BenchmarkScheduler_Schedule(b *testing.B) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
s := NewScheduler[int](b.N)
|
||||||
|
go s.Run(ctx, func(int) {})
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
s.Schedule(ctx, i, time.Microsecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkBareAfterFunc is the comparison baseline.
|
||||||
|
// What we'd pay per Schedule if Punchy called time.AfterFunc directly without the pooled Scheduler.
|
||||||
|
// Allocates a *time.Timer plus a closure each call.
|
||||||
|
func BenchmarkBareAfterFunc(b *testing.B) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
queue := make(chan int, b.N)
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-queue:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
i := i
|
||||||
|
time.AfterFunc(time.Microsecond, func() {
|
||||||
|
select {
|
||||||
|
case queue <- i:
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -32,10 +32,12 @@ type SSHServer struct {
|
|||||||
cancel func()
|
cancel func()
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen
|
// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen.
|
||||||
func NewSSHServer(l *slog.Logger) (*SSHServer, error) {
|
// The ssh server's context is parented off the supplied ctx so cancelling it
|
||||||
|
// (e.g. on Control.Stop) tears down active sessions and closes the listener.
|
||||||
|
func NewSSHServer(ctx context.Context, l *slog.Logger) (*SSHServer, error) {
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
s := &SSHServer{
|
s := &SSHServer{
|
||||||
trustedKeys: make(map[string]map[string]bool),
|
trustedKeys: make(map[string]map[string]bool),
|
||||||
l: l,
|
l: l,
|
||||||
@@ -153,6 +155,10 @@ func (s *SSHServer) RegisterCommand(c *Command) {
|
|||||||
|
|
||||||
// Run begins listening and accepting connections
|
// Run begins listening and accepting connections
|
||||||
func (s *SSHServer) Run(addr string) error {
|
func (s *SSHServer) Run(addr string) error {
|
||||||
|
if s.ctx.Err() != nil {
|
||||||
|
return s.ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
s.listener, err = net.Listen("tcp", addr)
|
s.listener, err = net.Listen("tcp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -161,8 +167,21 @@ func (s *SSHServer) Run(addr string) error {
|
|||||||
|
|
||||||
s.l.Info("SSH server is listening", "sshListener", addr)
|
s.l.Info("SSH server is listening", "sshListener", addr)
|
||||||
|
|
||||||
|
// Per-invocation watcher: cancellation of the parent context (e.g.
|
||||||
|
// Control.Stop) closes the listener so Accept unblocks and run returns.
|
||||||
|
// Closing `done` on exit keeps the watcher from outliving this Run call.
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
s.Stop()
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// Run loops until there is an error
|
// Run loops until there is an error
|
||||||
s.run()
|
s.run()
|
||||||
|
close(done)
|
||||||
s.closeSessions()
|
s.closeSessions()
|
||||||
|
|
||||||
s.l.Info("SSH server stopped listening")
|
s.l.Info("SSH server stopped listening")
|
||||||
|
|||||||
17
timeout.go
17
timeout.go
@@ -8,6 +8,23 @@ import (
|
|||||||
// How many timer objects should be cached
|
// How many timer objects should be cached
|
||||||
const timerCacheMax = 50000
|
const timerCacheMax = 50000
|
||||||
|
|
||||||
|
// TimerWheel is a hashed timing wheel: a fixed slot array indexed by (now + delay) % wheelLen,
|
||||||
|
// with each slot a singly linked list of items due in that bucket.
|
||||||
|
// Adds are O(1), Purges return items in arrival-within-slot order, and an internal cache of TimeoutItems
|
||||||
|
// keeps steady-state inserts allocation-free.
|
||||||
|
//
|
||||||
|
// The TimerWheel does not handle concurrency or lifecycle on its own.
|
||||||
|
// Callers drive Advance/Purge from their own ticker loop, take their own locks (or use LockingTimerWheel),
|
||||||
|
// and decide whether to keep ticking when the wheel is empty.
|
||||||
|
//
|
||||||
|
// Pick a TimerWheel when scheduling is high-rate and uniform: line-rate conntrack inserts,
|
||||||
|
// per-tunnel traffic checks at fixed intervals. O(1) insert plus the item cache means the hot path doesn't allocate.
|
||||||
|
// Items added in the same tick are dispatched together when that slot rotates current,
|
||||||
|
// which amortizes the cost of waking the worker.
|
||||||
|
//
|
||||||
|
// Pick a Scheduler when delay precision matters or scheduling is sparse or uneven.
|
||||||
|
// The wheel rounds requested timeouts up to its tick resolution and clamps anything beyond its wheel duration;
|
||||||
|
// both are silent in this implementation.
|
||||||
type TimerWheel[T any] struct {
|
type TimerWheel[T any] struct {
|
||||||
// Current tick
|
// Current tick
|
||||||
current int
|
current int
|
||||||
|
|||||||
Reference in New Issue
Block a user