From 213dd46588d516f0151d0cc54e16a4cd042f9ba4 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Wed, 6 May 2026 16:21:16 -0500 Subject: [PATCH 01/27] Stop leaking goroutines past Control.Stop, consolidate punching in Punchy (#1708) --- connection_manager.go | 54 +++-------- connection_manager_test.go | 8 +- e2e/leak_test.go | 10 +- examples/config.yml | 4 + lighthouse.go | 52 ++-------- main.go | 6 +- punchy.go | 193 ++++++++++++++++++++++++++++++------- punchy_test.go | 81 ++++++++-------- scheduler.go | 84 ++++++++++++++++ scheduler_test.go | 79 +++++++++++++++ sshd/server.go | 25 ++++- timeout.go | 17 ++++ 12 files changed, 434 insertions(+), 179 deletions(-) create mode 100644 scheduler.go create mode 100644 scheduler_test.go diff --git a/connection_manager.go b/connection_manager.go index e7fc04cd..ee6d1eaf 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -11,7 +11,6 @@ import ( "sync/atomic" "time" - "github.com/rcrowley/go-metrics" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" @@ -45,19 +44,16 @@ type connectionManager struct { inactivityTimeout atomic.Int64 dropInactive atomic.Bool - metricsTxPunchy metrics.Counter - l *slog.Logger } func newConnectionManagerFromConfig(l *slog.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager { cm := &connectionManager{ - hostMap: hm, - l: l, - punchy: p, - relayUsed: make(map[uint32]struct{}), - relayUsedLock: &sync.RWMutex{}, - metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil), + hostMap: hm, + l: l, + punchy: p, + relayUsed: make(map[uint32]struct{}), + relayUsedLock: &sync.RWMutex{}, } cm.reload(c, true) @@ -369,7 +365,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim if !outTraffic { // Send a punch packet to keep the NAT state alive - cm.sendPunch(hostinfo) + cm.punchy.SendPunch(hostinfo) } return decision, hostinfo, primary @@ -400,17 +396,16 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim // If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel. // Just maintain NAT state if configured to do so. - cm.sendPunch(hostinfo) + cm.punchy.SendPunch(hostinfo) cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval) return doNothing, nil, nil } - if cm.punchy.GetTargetEverything() { - // This is similar to the old punchy behavior with a slight optimization. - // We aren't receiving traffic but we are sending it, punch on all known - // ips in case we need to re-prime NAT state - cm.sendPunch(hostinfo) - } + // We aren't receiving traffic but we are sending it. The outbound + // traffic itself refreshes the primary remote's NAT state; this + // fans out to non-primary remotes, but only if target_all_remotes + // is configured. + cm.punchy.SendPunchToAll(hostinfo) if cm.l.Enabled(context.Background(), slog.LevelDebug) { hostinfo.logger(cm.l).Debug("Tunnel status", @@ -512,31 +507,6 @@ func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostI } } -func (cm *connectionManager) sendPunch(hostinfo *HostInfo) { - if !cm.punchy.GetPunch() { - // Punching is disabled - return - } - - if cm.intf.lightHouse.IsAnyLighthouseAddr(hostinfo.vpnAddrs) { - // Do not punch to lighthouses, we assume our lighthouse update interval is good enough. - // In the event the update interval is not sufficient to maintain NAT state then a publicly available lighthouse - // would lose the ability to notify us and punchy.respond would become unreliable. - return - } - - if cm.punchy.GetTargetEverything() { - hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) { - cm.metricsTxPunchy.Inc(1) - cm.intf.outside.WriteTo([]byte{1}, addr) - }) - - } else if hostinfo.remote.IsValid() { - cm.metricsTxPunchy.Inc(1) - cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote) - } -} - func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) { cs := cm.intf.pki.getCertState() curCrt := hostinfo.ConnectionState.myCert diff --git a/connection_manager_test.go b/connection_manager_test.go index 7dc08a45..e167e5f2 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -64,7 +64,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { // Create manager conf := config.NewC(test.NewLogger()) - punchy := NewPunchyFromConfig(test.NewLogger(), conf) + punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil) nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc.intf = ifce p := []byte("") @@ -146,7 +146,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { // Create manager conf := config.NewC(test.NewLogger()) - punchy := NewPunchyFromConfig(test.NewLogger(), conf) + punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil) nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc.intf = ifce p := []byte("") @@ -233,7 +233,7 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) { conf.Settings["tunnels"] = map[string]any{ "drop_inactive": true, } - punchy := NewPunchyFromConfig(test.NewLogger(), conf) + punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil) nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) assert.True(t, nc.dropInactive.Load()) nc.intf = ifce @@ -358,7 +358,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { // Create manager conf := config.NewC(test.NewLogger()) - punchy := NewPunchyFromConfig(test.NewLogger(), conf) + punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil) nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc.intf = ifce ifce.connectionManager = nc diff --git a/e2e/leak_test.go b/e2e/leak_test.go index ffb024fe..576d67a8 100644 --- a/e2e/leak_test.go +++ b/e2e/leak_test.go @@ -18,14 +18,10 @@ import ( // retry mechanism gives the wg.Wait()-driven goroutines a moment to drain // before failing the assertion. // -// IgnoreCurrent is necessary in the parallelized suite: other tests can -// leave goroutines mid-shutdown when this one runs (Stop is async, the -// wg.Wait() drain is not blocking on test return). We're checking that -// *this* test's setup tears down cleanly, not that the whole suite is -// idle at this moment. Intentionally NOT t.Parallel()'d for the same -// reason — concurrent test goroutines would always show up. +// Intentionally NOT t.Parallel()'d: concurrent tests would have their own +// goroutines running and trip the assertion. func TestNoGoroutineLeaks(t *testing.T) { - defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + defer goleak.VerifyNone(t) ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) diff --git a/examples/config.yml b/examples/config.yml index f5752ae4..ac4810e6 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -163,17 +163,21 @@ listen: punchy: # Continues to punch inbound/outbound at a regular interval to avoid expiration of firewall nat mappings + # This setting is reloadable. punch: true # respond means that a node you are trying to reach will connect back out to you if your hole punching fails # this is extremely useful if one node is behind a difficult nat, such as a symmetric NAT # Default is false + # This setting is reloadable. #respond: true # delays a punch response for misbehaving NATs, default is 1 second. + # This setting is reloadable. #delay: 1s # set the delay before attempting punchy.respond. Default is 5 seconds. respond must be true to take effect. + # This setting is reloadable. #respond_delay: 5s # Cipher allows you to choose between the available ciphers for your network. Options are chachapoly or aes diff --git a/lighthouse.go b/lighthouse.go index 6034e68c..1a136a1b 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -15,7 +15,6 @@ import ( "time" "github.com/gaissmai/bart" - "github.com/rcrowley/go-metrics" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" @@ -35,7 +34,6 @@ type LightHouse struct { myVpnNetworks []netip.Prefix myVpnNetworksTable *bart.Lite - punchConn udp.Conn punchy *Punchy // Local cache of answers from light houses @@ -75,9 +73,8 @@ type LightHouse struct { calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote - metrics *MessageMetrics - metricHolepunchTx metrics.Counter - l *slog.Logger + metrics *MessageMetrics + l *slog.Logger } // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object @@ -105,7 +102,6 @@ func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, c myVpnNetworksTable: cs.myVpnNetworksTable, addrMap: make(map[netip.Addr]*RemoteList), nebulaPort: nebulaPort, - punchConn: pc, punchy: p, updateTrigger: make(chan struct{}, 1), queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), @@ -118,9 +114,6 @@ func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, c if c.GetBool("stats.lighthouse_metrics", false) { h.metrics = newLighthouseMetrics() - h.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil) - } else { - h.metricHolepunchTx = metrics.NilCounter{} } err := h.reload(c, true) @@ -1406,58 +1399,25 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn return } - empty := []byte{0} - punch := func(vpnPeer netip.AddrPort, logVpnAddr netip.Addr) { - if !vpnPeer.IsValid() { - return - } - - go func() { - time.Sleep(lhh.lh.punchy.GetDelay()) - lhh.lh.metricHolepunchTx.Inc(1) - lhh.lh.punchConn.WriteTo(empty, vpnPeer) - }() - - if lhh.l.Enabled(context.Background(), slog.LevelDebug) { - lhh.l.Debug("Punching", - "vpnPeer", vpnPeer, - "logVpnAddr", logVpnAddr, - ) - } - } - remoteAllowList := lhh.lh.GetRemoteAllowList() for _, a := range n.Details.V4AddrPorts { b := protoV4AddrPortToNetAddrPort(a) if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) { - punch(b, detailsVpnAddr) + lhh.lh.punchy.Schedule(b, detailsVpnAddr) } } for _, a := range n.Details.V6AddrPorts { b := protoV6AddrPortToNetAddrPort(a) if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) { - punch(b, detailsVpnAddr) + lhh.lh.punchy.Schedule(b, detailsVpnAddr) } } // This sends a nebula test packet to the host trying to contact us. In the case // of a double nat or other difficult scenario, this may help establish - // a tunnel. - if lhh.lh.punchy.GetRespond() { - go func() { - time.Sleep(lhh.lh.punchy.GetRespondDelay()) - if lhh.l.Enabled(context.Background(), slog.LevelDebug) { - lhh.l.Debug("Sending a nebula test packet", - "vpnAddr", detailsVpnAddr, - ) - } - //NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine - // for each punchBack packet. We should move this into a timerwheel or a single goroutine - // managed by a channel. - w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) - }() - } + // a tunnel. ScheduleRespond is a no-op when punchy.respond is disabled. + lhh.lh.punchy.ScheduleRespond(detailsVpnAddr) } func protoAddrToNetAddr(addr *Addr) netip.Addr { diff --git a/main.go b/main.go index d5e5dcc8..37aa24d1 100644 --- a/main.go +++ b/main.go @@ -55,7 +55,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev } l.Info("Firewall started", "firewallHashes", fw.GetRuleHashes()) - ssh, err := sshd.NewSSHServer(l.With("subsystem", "sshd")) + ssh, err := sshd.NewSSHServer(ctx, l.With("subsystem", "sshd")) if err != nil { return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err) } @@ -170,7 +170,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev } hostMap := NewHostMapFromConfig(l, c) - punchy := NewPunchyFromConfig(l, c) + punchy := NewPunchyFromConfig(l, c, udpConns[0]) connManager := newConnectionManagerFromConfig(l, c, hostMap, punchy) lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy) if err != nil { @@ -240,6 +240,8 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev handshakeManager.f = ifce go handshakeManager.Run(ctx) + + punchy.Start(ctx, ifce, hostMap, lightHouse) } stats, err := newStatsServerFromConfig(ctx, l, c, buildVersion, configTest) diff --git a/punchy.go b/punchy.go index 6ecf4f85..38a0e1ca 100644 --- a/punchy.go +++ b/punchy.go @@ -1,24 +1,70 @@ package nebula import ( + "context" "log/slog" + "net/netip" "sync/atomic" "time" + "github.com/rcrowley/go-metrics" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/udp" ) +// holepunchQueueSize buffers the channel that pending holepunchJobs land on after their delay timer fires. +const holepunchQueueSize = 64 + +// holepunchJob is one scheduled item delivered to the worker goroutine. +// - target valid -> send a UDP punch to target. vpnAddr, if set, is the peer's vpn addr carried for log context. +// - target invalid, vpnAddr valid -> send an encrypted test packet to vpnAddr (a "punchback"). +type holepunchJob struct { + target netip.AddrPort + vpnAddr netip.Addr +} + +// lighthouseChecker is the slice of LightHouse that Punchy actually needs. +// Defined here so Punchy doesn't take a *LightHouse dependency (LightHouse +// already holds a *Punchy, and the bidirectional pointer reference is awkward +// even within the same package). Tests can also substitute a fake. +type lighthouseChecker interface { + IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool +} + type Punchy struct { punch atomic.Bool respond atomic.Bool delay atomic.Int64 respondDelay atomic.Int64 punchEverything atomic.Bool - l *slog.Logger + + sched *Scheduler[holepunchJob] + punchConn udp.Conn + metricHolepunchTx metrics.Counter + metricPunchyTx metrics.Counter + + ctx context.Context + ifce EncWriter + hm *HostMap + lh lighthouseChecker + + l *slog.Logger } -func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy { - p := &Punchy{l: l} +func NewPunchyFromConfig(l *slog.Logger, c *config.C, punchConn udp.Conn) *Punchy { + p := &Punchy{ + l: l, + punchConn: punchConn, + sched: NewScheduler[holepunchJob](holepunchQueueSize), + metricPunchyTx: metrics.GetOrRegisterCounter("messages.tx.punchy", nil), + } + + if c.GetBool("stats.lighthouse_metrics", false) { + p.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil) + } else { + p.metricHolepunchTx = metrics.NilCounter{} + } p.reload(c, true) c.RegisterReloadCallback(func(c *config.C) { @@ -29,7 +75,7 @@ func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy { } func (p *Punchy) reload(c *config.C, initial bool) { - if initial { + if initial || c.HasChanged("punchy.punch") || c.HasChanged("punchy") { var yes bool if c.IsSet("punchy.punch") { yes = c.GetBool("punchy.punch", false) @@ -38,16 +84,15 @@ func (p *Punchy) reload(c *config.C, initial bool) { yes = c.GetBool("punchy", false) } - p.punch.Store(yes) - if yes { + old := p.punch.Swap(yes) + switch { + case initial && yes: p.l.Info("punchy enabled") - } else { + case initial: p.l.Info("punchy disabled") + case old != yes: + p.l.Info("punchy.punch changed", "punch", yes) } - - } else if c.HasChanged("punchy.punch") || c.HasChanged("punchy") { - //TODO: it should be relatively easy to support this, just need to be able to cancel the goroutine and boot it up from here - p.l.Warn("Changing punchy.punch with reload is not supported, ignoring.") } if initial || c.HasChanged("punchy.respond") || c.HasChanged("punch_back") { @@ -59,52 +104,132 @@ func (p *Punchy) reload(c *config.C, initial bool) { yes = c.GetBool("punch_back", false) } - p.respond.Store(yes) - - if !initial { - p.l.Info("punchy.respond changed", "respond", p.GetRespond()) + old := p.respond.Swap(yes) + if !initial && old != yes { + p.l.Info("punchy.respond changed", "respond", yes) } } //NOTE: this will not apply to any in progress operations, only the next one if initial || c.HasChanged("punchy.delay") { - p.delay.Store((int64)(c.GetDuration("punchy.delay", time.Second))) - if !initial { - p.l.Info("punchy.delay changed", "delay", p.GetDelay()) + newDelay := int64(c.GetDuration("punchy.delay", time.Second)) + old := p.delay.Swap(newDelay) + if !initial && old != newDelay { + p.l.Info("punchy.delay changed", "delay", time.Duration(newDelay)) } } if initial || c.HasChanged("punchy.target_all_remotes") { - p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false)) - if !initial { - p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", p.GetTargetEverything()) + yes := c.GetBool("punchy.target_all_remotes", false) + old := p.punchEverything.Swap(yes) + if !initial && old != yes { + p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", yes) } } if initial || c.HasChanged("punchy.respond_delay") { - p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second))) - if !initial { - p.l.Info("punchy.respond_delay changed", "respond_delay", p.GetRespondDelay()) + newDelay := int64(c.GetDuration("punchy.respond_delay", 5*time.Second)) + old := p.respondDelay.Swap(newDelay) + if !initial && old != newDelay { + p.l.Info("punchy.respond_delay changed", "respond_delay", time.Duration(newDelay)) } } } -func (p *Punchy) GetPunch() bool { - return p.punch.Load() +// Schedule queues a punch packet to target, to be sent after the configured delay. +// vpnAddr is the peer's vpn addr, used for log context when the packet actually fires. +// No-op if target is not a valid AddrPort or if Start has not yet been called. Safe to call from any goroutine. +func (p *Punchy) Schedule(target netip.AddrPort, vpnAddr netip.Addr) { + if !target.IsValid() || p.ctx == nil { + return + } + p.scheduleJob(holepunchJob{target: target, vpnAddr: vpnAddr}, time.Duration(p.delay.Load())) } -func (p *Punchy) GetRespond() bool { - return p.respond.Load() +// ScheduleRespond queues a punchback test packet to vpnAddr after the configured respond delay, +// gated on punchy.respond. No-op when respond is disabled or before Start has been called. +func (p *Punchy) ScheduleRespond(vpnAddr netip.Addr) { + if !p.respond.Load() || p.ctx == nil { + return + } + p.scheduleJob(holepunchJob{vpnAddr: vpnAddr}, time.Duration(p.respondDelay.Load())) } -func (p *Punchy) GetDelay() time.Duration { - return (time.Duration)(p.delay.Load()) +// scheduleJob delegates to the pooled Scheduler. +// The callback observes p.ctx so a job that becomes due after Stop is dropped instead of queued. +func (p *Punchy) scheduleJob(job holepunchJob, delay time.Duration) { + p.sched.Schedule(p.ctx, job, delay) } -func (p *Punchy) GetRespondDelay() time.Duration { - return (time.Duration)(p.respondDelay.Load()) +// SendPunch sends an immediate keepalive punch for an idle hostinfo. +// The configured punchy.target_all_remotes mode picks the targets. Gated on punchy.punch and the lighthouse-skip rule +// (lighthouses don't get keepalive punches because the regular update interval keeps their NAT state warm). +func (p *Punchy) SendPunch(hostinfo *HostInfo) { + if !p.punch.Load() { + return + } + if p.lh.IsAnyLighthouseAddr(hostinfo.vpnAddrs) { + return + } + + if p.punchEverything.Load() { + p.sendPunchToAllRemotes(hostinfo) + } else if hostinfo.remote.IsValid() { + p.metricPunchyTx.Inc(1) + p.punchConn.WriteTo([]byte{1}, hostinfo.remote) + } } -func (p *Punchy) GetTargetEverything() bool { - return p.punchEverything.Load() +// SendPunchToAll punches every known remote for hostinfo, but only when punchy.target_all_remotes is enabled. +// The connection manager calls this during outbound-only traffic: the outbound traffic itself keeps the primary's +// NAT state warm, but non-primary remotes need separate refresh, so we fan out to all of them (the redundant +// primary punch is harmless). Gated on punchy.punch and the lighthouse-skip rule. +func (p *Punchy) SendPunchToAll(hostinfo *HostInfo) { + if !p.punchEverything.Load() { + return + } + if !p.punch.Load() { + return + } + if p.lh.IsAnyLighthouseAddr(hostinfo.vpnAddrs) { + return + } + p.sendPunchToAllRemotes(hostinfo) +} + +func (p *Punchy) sendPunchToAllRemotes(hostinfo *HostInfo) { + hostinfo.remotes.ForEach(p.hm.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) { + p.metricPunchyTx.Inc(1) + p.punchConn.WriteTo([]byte{1}, addr) + }) +} + +// Start wires the runtime dependencies and spawns the scheduler worker. +func (p *Punchy) Start(ctx context.Context, ifce EncWriter, hm *HostMap, lh lighthouseChecker) { + p.ctx = ctx + p.ifce = ifce + p.hm = hm + p.lh = lh + + nb := make([]byte, 12, 12) + out := make([]byte, mtu) + empty := []byte{0} + + go p.sched.Run(ctx, func(job holepunchJob) { + switch { + case job.target.IsValid(): + if p.l.Enabled(context.Background(), slog.LevelDebug) { + p.l.Debug("Punching", "target", job.target, "vpnAddr", job.vpnAddr) + } + p.metricHolepunchTx.Inc(1) + p.punchConn.WriteTo(empty, job.target) + case job.vpnAddr.IsValid(): + // A nebula test packet to the host trying to contact us. + // In the case of a double nat or other difficult scenario, this may help establish a tunnel. + if p.l.Enabled(context.Background(), slog.LevelDebug) { + p.l.Debug("Sending a nebula test packet", "vpnAddr", job.vpnAddr) + } + p.ifce.SendMessageToVpnAddr(header.Test, header.TestRequest, job.vpnAddr, []byte(""), nb, out) + } + }) } diff --git a/punchy_test.go b/punchy_test.go index cbf9b17b..e56f3eff 100644 --- a/punchy_test.go +++ b/punchy_test.go @@ -17,42 +17,42 @@ func TestNewPunchyFromConfig(t *testing.T) { c := config.NewC(l) // Test defaults - p := NewPunchyFromConfig(test.NewLogger(), c) - assert.False(t, p.GetPunch()) - assert.False(t, p.GetRespond()) - assert.Equal(t, time.Second, p.GetDelay()) - assert.Equal(t, 5*time.Second, p.GetRespondDelay()) + p := NewPunchyFromConfig(test.NewLogger(), c, nil) + assert.False(t, p.punch.Load()) + assert.False(t, p.respond.Load()) + assert.Equal(t, time.Second, time.Duration(p.delay.Load())) + assert.Equal(t, 5*time.Second, time.Duration(p.respondDelay.Load())) // punchy deprecation c.Settings["punchy"] = true - p = NewPunchyFromConfig(test.NewLogger(), c) - assert.True(t, p.GetPunch()) + p = NewPunchyFromConfig(test.NewLogger(), c, nil) + assert.True(t, p.punch.Load()) // punchy.punch c.Settings["punchy"] = map[string]any{"punch": true} - p = NewPunchyFromConfig(test.NewLogger(), c) - assert.True(t, p.GetPunch()) + p = NewPunchyFromConfig(test.NewLogger(), c, nil) + assert.True(t, p.punch.Load()) // punch_back deprecation c.Settings["punch_back"] = true - p = NewPunchyFromConfig(test.NewLogger(), c) - assert.True(t, p.GetRespond()) + p = NewPunchyFromConfig(test.NewLogger(), c, nil) + assert.True(t, p.respond.Load()) // punchy.respond c.Settings["punchy"] = map[string]any{"respond": true} c.Settings["punch_back"] = false - p = NewPunchyFromConfig(test.NewLogger(), c) - assert.True(t, p.GetRespond()) + p = NewPunchyFromConfig(test.NewLogger(), c, nil) + assert.True(t, p.respond.Load()) // punchy.delay c.Settings["punchy"] = map[string]any{"delay": "1m"} - p = NewPunchyFromConfig(test.NewLogger(), c) - assert.Equal(t, time.Minute, p.GetDelay()) + p = NewPunchyFromConfig(test.NewLogger(), c, nil) + assert.Equal(t, time.Minute, time.Duration(p.delay.Load())) // punchy.respond_delay c.Settings["punchy"] = map[string]any{"respond_delay": "1m"} - p = NewPunchyFromConfig(test.NewLogger(), c) - assert.Equal(t, time.Minute, p.GetRespondDelay()) + p = NewPunchyFromConfig(test.NewLogger(), c, nil) + assert.Equal(t, time.Minute, time.Duration(p.respondDelay.Load())) } func TestPunchy_reload(t *testing.T) { @@ -61,35 +61,34 @@ func TestPunchy_reload(t *testing.T) { delay, _ := time.ParseDuration("1m") require.NoError(t, c.LoadString(` punchy: + punch: false delay: 1m respond: false `)) - p := NewPunchyFromConfig(test.NewLogger(), c) - assert.Equal(t, delay, p.GetDelay()) - assert.False(t, p.GetRespond()) + p := NewPunchyFromConfig(test.NewLogger(), c, nil) + assert.False(t, p.punch.Load()) + assert.Equal(t, delay, time.Duration(p.delay.Load())) + assert.False(t, p.respond.Load()) newDelay, _ := time.ParseDuration("10m") require.NoError(t, c.ReloadConfigString(` punchy: + punch: true delay: 10m respond: true `)) p.reload(c, false) - assert.Equal(t, newDelay, p.GetDelay()) - assert.True(t, p.GetRespond()) + assert.True(t, p.punch.Load()) + assert.Equal(t, newDelay, time.Duration(p.delay.Load())) + assert.True(t, p.respond.Load()) } // The tests below pin the shape of each log line Punchy produces so changes // cannot silently break whatever operators are grepping for. The assertions // are on the structured message + attrs (e.g. "punchy.respond changed" with -// a respond=true field) rather than a formatted string. -// -// Punchy.reload also emits a spurious "Changing punchy.punch with reload is -// not supported" warning whenever any key under punchy changes, because of -// the c.HasChanged("punchy") fallback kept for the deprecated top-level -// punchy form. The tests filter by message rather than asserting total -// entry counts so that warning is tolerated without being locked into -// the format. +// a respond=true field) rather than a formatted string. Tests filter by +// message rather than asserting total entry counts so unrelated info lines +// are tolerated without being locked into the format. type capturedEntry struct { Level slog.Level @@ -145,7 +144,7 @@ func TestPunchy_LogFormat_InitialEnabled(t *testing.T) { c := config.NewC(test.NewLogger()) require.NoError(t, c.LoadString(`punchy: {punch: true}`)) - NewPunchyFromConfig(l, c) + NewPunchyFromConfig(l, c, nil) entry := findEntry(t, hook.entries, "punchy enabled") assert.Equal(t, slog.LevelInfo, entry.Level) @@ -157,32 +156,32 @@ func TestPunchy_LogFormat_InitialDisabled(t *testing.T) { c := config.NewC(test.NewLogger()) require.NoError(t, c.LoadString(`punchy: {punch: false}`)) - NewPunchyFromConfig(l, c) + NewPunchyFromConfig(l, c, nil) entry := findEntry(t, hook.entries, "punchy disabled") assert.Equal(t, slog.LevelInfo, entry.Level) assert.Empty(t, entry.Attrs) } -func TestPunchy_LogFormat_ReloadPunchUnsupported(t *testing.T) { +func TestPunchy_LogFormat_ReloadPunch(t *testing.T) { l, hook := newCapturingPunchyLogger(t) c := config.NewC(test.NewLogger()) require.NoError(t, c.LoadString(`punchy: {punch: false}`)) - NewPunchyFromConfig(l, c) + NewPunchyFromConfig(l, c, nil) hook.entries = nil require.NoError(t, c.ReloadConfigString(`punchy: {punch: true}`)) - entry := findEntry(t, hook.entries, "Changing punchy.punch with reload is not supported, ignoring.") - assert.Equal(t, slog.LevelWarn, entry.Level) - assert.Empty(t, entry.Attrs) + entry := findEntry(t, hook.entries, "punchy.punch changed") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Equal(t, map[string]any{"punch": true}, entry.Attrs) } func TestPunchy_LogFormat_ReloadRespond(t *testing.T) { l, hook := newCapturingPunchyLogger(t) c := config.NewC(test.NewLogger()) require.NoError(t, c.LoadString(`punchy: {respond: false}`)) - NewPunchyFromConfig(l, c) + NewPunchyFromConfig(l, c, nil) hook.entries = nil require.NoError(t, c.ReloadConfigString(`punchy: {respond: true}`)) @@ -196,7 +195,7 @@ func TestPunchy_LogFormat_ReloadDelay(t *testing.T) { l, hook := newCapturingPunchyLogger(t) c := config.NewC(test.NewLogger()) require.NoError(t, c.LoadString(`punchy: {delay: 1s}`)) - NewPunchyFromConfig(l, c) + NewPunchyFromConfig(l, c, nil) hook.entries = nil require.NoError(t, c.ReloadConfigString(`punchy: {delay: 10s}`)) @@ -210,7 +209,7 @@ func TestPunchy_LogFormat_ReloadTargetAllRemotes(t *testing.T) { l, hook := newCapturingPunchyLogger(t) c := config.NewC(test.NewLogger()) require.NoError(t, c.LoadString(`punchy: {target_all_remotes: false}`)) - NewPunchyFromConfig(l, c) + NewPunchyFromConfig(l, c, nil) hook.entries = nil require.NoError(t, c.ReloadConfigString(`punchy: {target_all_remotes: true}`)) @@ -224,7 +223,7 @@ func TestPunchy_LogFormat_ReloadRespondDelay(t *testing.T) { l, hook := newCapturingPunchyLogger(t) c := config.NewC(test.NewLogger()) require.NoError(t, c.LoadString(`punchy: {respond_delay: 5s}`)) - NewPunchyFromConfig(l, c) + NewPunchyFromConfig(l, c, nil) hook.entries = nil require.NoError(t, c.ReloadConfigString(`punchy: {respond_delay: 15s}`)) diff --git a/scheduler.go b/scheduler.go new file mode 100644 index 00000000..7733204a --- /dev/null +++ b/scheduler.go @@ -0,0 +1,84 @@ +package nebula + +import ( + "context" + "sync" + "time" +) + +// Scheduler is an allocation-conscious dispatch primitive for delayed work. +// Pending items are handed to time.AfterFunc, and ready items land on a worker +// channel for centralized dispatch in fire-time order. +// +// Pick a Scheduler when fire timing matters (exact deadlines, no bucketing) or when the scheduling +// rate is uneven enough that idle CPU matters. Each fire is a runtime-spawned goroutine running the callback before +// delivering to the worker, which is fine at sparse rates but adds up at line rate. +// +// Pick a TimerWheel when scheduling is high-rate and uniform: its O(1) insert, internal item cache, +// and bucket-batched dispatch are cheaper at scale. +// The caller drives the tick loop (Advance/Purge) and pays for fires at bucket boundaries rather than exact deadlines. +type Scheduler[T any] struct { + queue chan T + pool sync.Pool +} + +type schedItem[T any] struct { + val T + ctx context.Context + s *Scheduler[T] + timer *time.Timer + fire func() +} + +// NewScheduler builds a Scheduler whose worker channel is sized to queueSize. +// The buffer absorbs bursts of timers firing close together without +// blocking the runtime's callback goroutines on the worker. +func NewScheduler[T any](queueSize int) *Scheduler[T] { + s := &Scheduler[T]{ + queue: make(chan T, queueSize), + } + s.pool.New = func() any { + si := &schedItem[T]{s: s} + // fire is allocated exactly once per pool-resident item. + // The closure captures only `si`, which stays stable for the item's lifetime. + si.fire = func() { + select { + case si.s.queue <- si.val: + case <-si.ctx.Done(): + } + var zero T + si.val = zero + si.ctx = nil + si.s.pool.Put(si) + } + return si + } + return s +} + +// Schedule arranges item to be delivered to the worker after delay. +// The runtime's timer heap handles the wait, so the scheduler itself burns no CPU while idle. +// The callback observes ctx: if ctx is cancelled before the timer fires, the item is dropped instead of queued. +func (s *Scheduler[T]) Schedule(ctx context.Context, item T, delay time.Duration) { + si := s.pool.Get().(*schedItem[T]) + si.val = item + si.ctx = ctx + if si.timer == nil { + si.timer = time.AfterFunc(delay, si.fire) + } else { + si.timer.Reset(delay) + } +} + +// Run drains the worker queue, calling fn for each item. Returns when ctx is cancelled. +// Tests that want deterministic timing should drive the queue directly rather than going through Schedule + Run. +func (s *Scheduler[T]) Run(ctx context.Context, fn func(T)) { + for { + select { + case <-ctx.Done(): + return + case item := <-s.queue: + fn(item) + } + } +} diff --git a/scheduler_test.go b/scheduler_test.go new file mode 100644 index 00000000..085d523c --- /dev/null +++ b/scheduler_test.go @@ -0,0 +1,79 @@ +package nebula + +import ( + "context" + "testing" + "time" +) + +func TestScheduler_PooledReuse(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := NewScheduler[int](16) + delivered := make(chan int, 256) + go s.Run(ctx, func(item int) { delivered <- item }) + + const N = 100 + for i := 0; i < N; i++ { + s.Schedule(ctx, i, time.Millisecond) + } + + deadline := time.After(2 * time.Second) + got := 0 + for got < N { + select { + case <-delivered: + got++ + case <-deadline: + t.Fatalf("only %d/%d items delivered", got, N) + } + } +} + +// BenchmarkScheduler_Schedule reports allocations per Schedule call. +// In steady state the Scheduler's sync.Pool means we should see zero allocs per op once the pool warms up. +func BenchmarkScheduler_Schedule(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := NewScheduler[int](b.N) + go s.Run(ctx, func(int) {}) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Schedule(ctx, i, time.Microsecond) + } +} + +// BenchmarkBareAfterFunc is the comparison baseline. +// What we'd pay per Schedule if Punchy called time.AfterFunc directly without the pooled Scheduler. +// Allocates a *time.Timer plus a closure each call. +func BenchmarkBareAfterFunc(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + queue := make(chan int, b.N) + go func() { + for { + select { + case <-ctx.Done(): + return + case <-queue: + } + } + }() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + i := i + time.AfterFunc(time.Microsecond, func() { + select { + case queue <- i: + case <-ctx.Done(): + } + }) + } +} diff --git a/sshd/server.go b/sshd/server.go index 38886e53..ff954bf5 100644 --- a/sshd/server.go +++ b/sshd/server.go @@ -32,10 +32,12 @@ type SSHServer struct { cancel func() } -// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen -func NewSSHServer(l *slog.Logger) (*SSHServer, error) { +// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen. +// The ssh server's context is parented off the supplied ctx so cancelling it +// (e.g. on Control.Stop) tears down active sessions and closes the listener. +func NewSSHServer(ctx context.Context, l *slog.Logger) (*SSHServer, error) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(ctx) s := &SSHServer{ trustedKeys: make(map[string]map[string]bool), l: l, @@ -153,6 +155,10 @@ func (s *SSHServer) RegisterCommand(c *Command) { // Run begins listening and accepting connections func (s *SSHServer) Run(addr string) error { + if s.ctx.Err() != nil { + return s.ctx.Err() + } + var err error s.listener, err = net.Listen("tcp", addr) if err != nil { @@ -161,8 +167,21 @@ func (s *SSHServer) Run(addr string) error { s.l.Info("SSH server is listening", "sshListener", addr) + // Per-invocation watcher: cancellation of the parent context (e.g. + // Control.Stop) closes the listener so Accept unblocks and run returns. + // Closing `done` on exit keeps the watcher from outliving this Run call. + done := make(chan struct{}) + go func() { + select { + case <-s.ctx.Done(): + s.Stop() + case <-done: + } + }() + // Run loops until there is an error s.run() + close(done) s.closeSessions() s.l.Info("SSH server stopped listening") diff --git a/timeout.go b/timeout.go index c1b4c398..96bf688b 100644 --- a/timeout.go +++ b/timeout.go @@ -8,6 +8,23 @@ import ( // How many timer objects should be cached const timerCacheMax = 50000 +// TimerWheel is a hashed timing wheel: a fixed slot array indexed by (now + delay) % wheelLen, +// with each slot a singly linked list of items due in that bucket. +// Adds are O(1), Purges return items in arrival-within-slot order, and an internal cache of TimeoutItems +// keeps steady-state inserts allocation-free. +// +// The TimerWheel does not handle concurrency or lifecycle on its own. +// Callers drive Advance/Purge from their own ticker loop, take their own locks (or use LockingTimerWheel), +// and decide whether to keep ticking when the wheel is empty. +// +// Pick a TimerWheel when scheduling is high-rate and uniform: line-rate conntrack inserts, +// per-tunnel traffic checks at fixed intervals. O(1) insert plus the item cache means the hot path doesn't allocate. +// Items added in the same tick are dispatched together when that slot rotates current, +// which amortizes the cost of waking the worker. +// +// Pick a Scheduler when delay precision matters or scheduling is sparse or uneven. +// The wheel rounds requested timeouts up to its tick resolution and clamps anything beyond its wheel duration; +// both are silent in this implementation. type TimerWheel[T any] struct { // Current tick current int From a82a8dc547dca7e0f4e30c4d6f6adaaa124babbc Mon Sep 17 00:00:00 2001 From: Jack Doan Date: Wed, 6 May 2026 17:00:07 -0500 Subject: [PATCH 02/27] don't panic on bad ed25519 key lengths (#1601) * don't panic on bad ed25519 key lengths * don't allow mismatched curves * add test --- cert/ca_pool.go | 4 ++++ cert/ca_pool_test.go | 28 ++++++++++++++++++++++++++++ cert/cert_v1.go | 3 +++ cert/cert_v2.go | 3 +++ cert/errors.go | 1 + 5 files changed, 39 insertions(+) diff --git a/cert/ca_pool.go b/cert/ca_pool.go index 792f8e66..966f78e3 100644 --- a/cert/ca_pool.go +++ b/cert/ca_pool.go @@ -217,6 +217,10 @@ func (ncp *CAPool) verify(c Certificate, now time.Time, certFp string, signerFp return nil, err } + if signer.Certificate.Curve() != c.Curve() { + return nil, ErrCurveMismatch + } + if signer.Certificate.Expired(now) { return nil, ErrRootExpired } diff --git a/cert/ca_pool_test.go b/cert/ca_pool_test.go index ab173228..c246e770 100644 --- a/cert/ca_pool_test.go +++ b/cert/ca_pool_test.go @@ -654,3 +654,31 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) { _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) } + +func TestCertificateV2_CurveMismatch(t *testing.T) { + caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") + caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) + + caPem, err := ca.MarshalPEM() + require.NoError(t, err) + + caPool := NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + require.NoError(t, err) + assert.Empty(t, b) + + // ip is outside the network + cIp1 := mustParsePrefixUnmapped("10.0.0.1/24") + c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1}, nil, []string{"test"}) + + fp, _ := c.Fingerprint() + _, err = caPool.verify(c, time.Now(), fp, c.Issuer()) + require.NoError(t, err) + // + c2 := c.(*certificateV2) + c2.curve = Curve_CURVE25519 + fp, _ = c.Fingerprint() + _, err = caPool.verify(c, time.Now(), fp, c.Issuer()) + require.Error(t, err) +} diff --git a/cert/cert_v1.go b/cert/cert_v1.go index c32f409a..4df30032 100644 --- a/cert/cert_v1.go +++ b/cert/cert_v1.go @@ -112,6 +112,9 @@ func (c *certificateV1) CheckSignature(key []byte) bool { } switch c.details.curve { case Curve_CURVE25519: + if len(key) != ed25519.PublicKeySize { + return false //avoids a panic internal to ed25519 + } return ed25519.Verify(key, b, c.signature) case Curve_P256: pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key) diff --git a/cert/cert_v2.go b/cert/cert_v2.go index 4648c496..c2b43a69 100644 --- a/cert/cert_v2.go +++ b/cert/cert_v2.go @@ -151,6 +151,9 @@ func (c *certificateV2) CheckSignature(key []byte) bool { switch c.curve { case Curve_CURVE25519: + if len(key) != ed25519.PublicKeySize { + return false //avoids a panic internal to ed25519 + } return ed25519.Verify(key, b, c.signature) case Curve_P256: pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key) diff --git a/cert/errors.go b/cert/errors.go index 8c480a14..596cfe19 100644 --- a/cert/errors.go +++ b/cert/errors.go @@ -22,6 +22,7 @@ var ( ErrCaNotFound = errors.New("could not find ca for the certificate") ErrUnknownVersion = errors.New("certificate version unrecognized") ErrCertPubkeyPresent = errors.New("certificate has unexpected pubkey present") + ErrCurveMismatch = errors.New("certificate curve does not match CA") ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block") ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner") From eaf756ea6c90d97790f29503fb5e687a251ca8fb Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 6 May 2026 17:31:48 -0500 Subject: [PATCH 03/27] Bump Apple-Actions/import-codesign-certs from 6 to 7 (#1697) --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index a5e8d397..b911bd52 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -75,7 +75,7 @@ jobs: - name: Import certificates if: env.HAS_SIGNING_CREDS == 'true' - uses: Apple-Actions/import-codesign-certs@v6 + uses: Apple-Actions/import-codesign-certs@v7 with: p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }} p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }} From 76e82a5256f55f47107327cd6710536169483e16 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 6 May 2026 17:32:21 -0500 Subject: [PATCH 04/27] Bump golang.org/x/net (#1664) Bumps the golang-x-dependencies group with 1 update in the / directory: [golang.org/x/net](https://github.com/golang/net). Updates `golang.org/x/net` from 0.52.0 to 0.53.0 - [Commits](https://github.com/golang/net/compare/v0.52.0...v0.53.0) --- updated-dependencies: - dependency-name: golang.org/x/net dependency-version: 0.53.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 24d901c5..bfbc987f 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,7 @@ require ( go.yaml.in/yaml/v3 v3.0.4 golang.org/x/crypto v0.50.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 - golang.org/x/net v0.52.0 + golang.org/x/net v0.53.0 golang.org/x/sync v0.20.0 golang.org/x/sys v0.43.0 golang.org/x/term v0.42.0 diff --git a/go.sum b/go.sum index aad164c7..10116c5b 100644 --- a/go.sum +++ b/go.sum @@ -182,8 +182,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= -golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= +golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= From dd2ac5d6550a37745f9047d8a230482c1bc8ad18 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 6 May 2026 17:32:45 -0500 Subject: [PATCH 05/27] Bump docker/login-action from 3 to 4 (#1628) Bumps [docker/login-action](https://github.com/docker/login-action) from 3 to 4. - [Release notes](https://github.com/docker/login-action/releases) - [Commits](https://github.com/docker/login-action/compare/v3...v4) --- updated-dependencies: - dependency-name: docker/login-action dependency-version: '4' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index b911bd52..8d4b62bc 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -135,7 +135,7 @@ jobs: - name: Login to Docker Hub if: ${{ env.HAS_DOCKER_CREDS == 'true' }} - uses: docker/login-action@v3 + uses: docker/login-action@v4 with: username: ${{ vars.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} From dd3a7ad03c488860e39060a728959817f162112a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 6 May 2026 17:33:16 -0500 Subject: [PATCH 06/27] Bump docker/setup-buildx-action from 3 to 4 (#1627) Bumps [docker/setup-buildx-action](https://github.com/docker/setup-buildx-action) from 3 to 4. - [Release notes](https://github.com/docker/setup-buildx-action/releases) - [Commits](https://github.com/docker/setup-buildx-action/compare/v3...v4) --- updated-dependencies: - dependency-name: docker/setup-buildx-action dependency-version: '4' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8d4b62bc..e323a2ca 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -142,7 +142,7 @@ jobs: - name: Set up Docker Buildx if: ${{ env.HAS_DOCKER_CREDS == 'true' }} - uses: docker/setup-buildx-action@v3 + uses: docker/setup-buildx-action@v4 - name: Build and push images if: ${{ env.HAS_DOCKER_CREDS == 'true' }} From 23c67bd8d820d48f16892a94f77f747bb5b358c7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 6 May 2026 17:33:47 -0500 Subject: [PATCH 07/27] Bump actions/upload-artifact from 6 to 7 (#1618) Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 6 to 7. - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](https://github.com/actions/upload-artifact/compare/v6...v7) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-version: '7' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/release.yml | 6 +++--- .github/workflows/test.yml | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e323a2ca..e934d436 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -24,7 +24,7 @@ jobs: mv build/*.tar.gz release - name: Upload artifacts - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@v7 with: name: linux-latest path: release @@ -55,7 +55,7 @@ jobs: mv dist\windows\wintun build\dist\windows\ - name: Upload artifacts - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@v7 with: name: windows-latest path: build @@ -104,7 +104,7 @@ jobs: fi - name: Upload artifacts - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@v7 with: name: darwin-latest path: ./release/* diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index aeaea294..009c22a9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -45,7 +45,7 @@ jobs: - name: Build test mobile run: make build-test-mobile - - uses: actions/upload-artifact@v6 + - uses: actions/upload-artifact@v7 with: name: e2e packet flow linux-latest path: e2e/mermaid/linux-latest @@ -125,7 +125,7 @@ jobs: - name: End 2 end run: make e2evv - - uses: actions/upload-artifact@v6 + - uses: actions/upload-artifact@v7 with: name: e2e packet flow ${{ matrix.os }} path: e2e/mermaid/${{ matrix.os }} From 83809a599a1414b57e715fe241d7204487eb9a9f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 6 May 2026 17:34:06 -0500 Subject: [PATCH 08/27] Bump actions/download-artifact from 7 to 8 (#1617) Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 7 to 8. - [Release notes](https://github.com/actions/download-artifact/releases) - [Commits](https://github.com/actions/download-artifact/compare/v7...v8) --- updated-dependencies: - dependency-name: actions/download-artifact dependency-version: '8' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/release.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e934d436..356ae363 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -128,7 +128,7 @@ jobs: - name: Download artifacts if: ${{ env.HAS_DOCKER_CREDS == 'true' }} - uses: actions/download-artifact@v7 + uses: actions/download-artifact@v8 with: name: linux-latest path: artifacts @@ -163,7 +163,7 @@ jobs: - uses: actions/checkout@v6 - name: Download artifacts - uses: actions/download-artifact@v7 + uses: actions/download-artifact@v8 with: path: artifacts From cba9ea5b1fb10fd7a7a00ce6a2adb7cf2f14fbc2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 6 May 2026 17:36:07 -0500 Subject: [PATCH 09/27] Bump github.com/gaissmai/bart from 0.26.0 to 0.26.1 (#1604) Bumps [github.com/gaissmai/bart](https://github.com/gaissmai/bart) from 0.26.0 to 0.26.1. - [Release notes](https://github.com/gaissmai/bart/releases) - [Commits](https://github.com/gaissmai/bart/compare/v0.26.0...v0.26.1) --- updated-dependencies: - dependency-name: github.com/gaissmai/bart dependency-version: 0.26.1 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index bfbc987f..84728201 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/armon/go-radix v1.0.0 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/flynn/noise v1.1.0 - github.com/gaissmai/bart v0.26.0 + github.com/gaissmai/bart v0.26.1 github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.4 diff --git a/go.sum b/go.sum index 10116c5b..3b0b87df 100644 --- a/go.sum +++ b/go.sum @@ -26,8 +26,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= -github.com/gaissmai/bart v0.26.0 h1:xOZ57E9hJLBiQaSyeZa9wgWhGuzfGACgqp4BE77OkO0= -github.com/gaissmai/bart v0.26.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c= +github.com/gaissmai/bart v0.26.1 h1:+w4rnLGNlA2GDVn382Tfe3jOsK5vOr5n4KmigJ9lbTo= +github.com/gaissmai/bart v0.26.1/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= From 5f920fdd7d5af2510516ef3e6dbd9543de8019ae Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Wed, 6 May 2026 17:37:03 -0500 Subject: [PATCH 10/27] Remove the global noiseEndianness var (#1707) --- connection_state.go | 9 +- handshake/machine.go | 2 + noise.go | 73 --------------- noiseutil/aesgcm.go | 53 +++++++++++ noiseutil/chachapoly.go | 52 +++++++++++ noiseutil/cipher_state.go | 40 ++++++++ noiseutil/cipher_state_test.go | 166 +++++++++++++++++++++++++++++++++ pki.go | 8 +- 8 files changed, 321 insertions(+), 82 deletions(-) delete mode 100644 noise.go create mode 100644 noiseutil/aesgcm.go create mode 100644 noiseutil/chachapoly.go create mode 100644 noiseutil/cipher_state.go create mode 100644 noiseutil/cipher_state_test.go diff --git a/connection_state.go b/connection_state.go index 47e23b5a..0ae2d9be 100644 --- a/connection_state.go +++ b/connection_state.go @@ -7,13 +7,14 @@ import ( "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/handshake" + "github.com/slackhq/nebula/noiseutil" ) const ReplayWindow = 1024 type ConnectionState struct { - eKey *NebulaCipherState - dKey *NebulaCipherState + eKey noiseutil.CipherState + dKey noiseutil.CipherState myCert cert.Certificate peerCert *cert.CachedCertificate initiator bool @@ -31,8 +32,8 @@ func newConnectionStateFromResult(r *handshake.Result) *ConnectionState { myCert: r.MyCert, initiator: r.Initiator, peerCert: r.RemoteCert, - eKey: NewNebulaCipherState(r.EKey), - dKey: NewNebulaCipherState(r.DKey), + eKey: noiseutil.NewCipherState(r.EKey, r.Cipher), + dKey: noiseutil.NewCipherState(r.DKey, r.Cipher), window: NewBits(ReplayWindow), } ci.messageCounter.Add(r.MessageIndex) diff --git a/handshake/machine.go b/handshake/machine.go index 25ed3a5a..737358dc 100644 --- a/handshake/machine.go +++ b/handshake/machine.go @@ -31,6 +31,7 @@ type CertVerifier func(cert.Certificate) (*cert.CachedCertificate, error) type Result struct { EKey *noise.CipherState DKey *noise.CipherState + Cipher noise.CipherFunc // identifies which post-handshake CipherState the data plane should wrap EKey/DKey in MyCert cert.Certificate RemoteCert *cert.CachedCertificate RemoteIndex uint32 @@ -105,6 +106,7 @@ func NewMachine( myVersion: version, result: &Result{ Initiator: initiator, + Cipher: cred.cipherSuite, }, }, nil } diff --git a/noise.go b/noise.go deleted file mode 100644 index 0491da17..00000000 --- a/noise.go +++ /dev/null @@ -1,73 +0,0 @@ -package nebula - -import ( - "crypto/cipher" - "encoding/binary" - "errors" - - "github.com/flynn/noise" -) - -type endianness interface { - PutUint64(b []byte, v uint64) -} - -var noiseEndianness endianness = binary.BigEndian - -type NebulaCipherState struct { - c cipher.AEAD -} - -func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState { - x := s.Cipher() - return &NebulaCipherState{c: x.(cipher.AEAD)} -} - -// EncryptDanger encrypts and authenticates a given payload. -// -// out is a destination slice to hold the output of the EncryptDanger operation. -// - ad is additional data, which will be authenticated and appended to out, but not encrypted. -// - plaintext is encrypted, authenticated and appended to out. -// - n is a nonce value which must never be re-used with this key. -// - nb is a buffer used for temporary storage in the implementation of this call, which should -// be re-used by callers to minimize garbage collection. -func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) { - if s != nil { - // TODO: Is this okay now that we have made messageCounter atomic? - // Alternative may be to split the counter space into ranges - //if n <= s.n { - // return nil, errors.New("CRITICAL: a duplicate counter value was used") - //} - //s.n = n - nb[0] = 0 - nb[1] = 0 - nb[2] = 0 - nb[3] = 0 - noiseEndianness.PutUint64(nb[4:], n) - out = s.c.Seal(out, nb, plaintext, ad) - //l.Debugf("Encryption: outlen: %d, nonce: %d, ad: %s, plainlen %d", len(out), n, ad, len(plaintext)) - return out, nil - } else { - return nil, errors.New("no cipher state available to encrypt") - } -} - -func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) { - if s != nil { - nb[0] = 0 - nb[1] = 0 - nb[2] = 0 - nb[3] = 0 - noiseEndianness.PutUint64(nb[4:], n) - return s.c.Open(out, nb, ciphertext, ad) - } else { - return []byte{}, nil - } -} - -func (s *NebulaCipherState) Overhead() int { - if s != nil { - return s.c.Overhead() - } - return 0 -} diff --git a/noiseutil/aesgcm.go b/noiseutil/aesgcm.go new file mode 100644 index 00000000..dcbd5693 --- /dev/null +++ b/noiseutil/aesgcm.go @@ -0,0 +1,53 @@ +package noiseutil + +import ( + "crypto/cipher" + "encoding/binary" + "errors" + + "github.com/flynn/noise" +) + +// CipherStateAESGCM is the data-plane wrapper for the AES-GCM AEAD cipher. +// AES-GCM uses big-endian nonce encoding per the Noise spec. +type CipherStateAESGCM struct { + c cipher.AEAD +} + +// NewCipherStateAESGCM extracts the underlying AEAD from the post-handshake noise.CipherState. +// The caller is responsible for ensuring the noise cipher is actually AES-GCM, +// otherwise the type assertion still succeeds but the nonce endianness will be wrong on the wire. +func NewCipherStateAESGCM(s *noise.CipherState) *CipherStateAESGCM { + return &CipherStateAESGCM{c: s.Cipher().(cipher.AEAD)} +} + +func (s *CipherStateAESGCM) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) { + if s == nil { + return nil, errors.New("no cipher state available to encrypt") + } + nb[0] = 0 + nb[1] = 0 + nb[2] = 0 + nb[3] = 0 + binary.BigEndian.PutUint64(nb[4:], n) + return s.c.Seal(out, nb, plaintext, ad), nil +} + +func (s *CipherStateAESGCM) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) { + if s == nil { + return []byte{}, nil + } + nb[0] = 0 + nb[1] = 0 + nb[2] = 0 + nb[3] = 0 + binary.BigEndian.PutUint64(nb[4:], n) + return s.c.Open(out, nb, ciphertext, ad) +} + +func (s *CipherStateAESGCM) Overhead() int { + if s == nil { + return 0 + } + return s.c.Overhead() +} diff --git a/noiseutil/chachapoly.go b/noiseutil/chachapoly.go new file mode 100644 index 00000000..31ab3bfe --- /dev/null +++ b/noiseutil/chachapoly.go @@ -0,0 +1,52 @@ +package noiseutil + +import ( + "crypto/cipher" + "encoding/binary" + "errors" + + "github.com/flynn/noise" +) + +// CipherStateChaChaPoly is the data-plane wrapper for the ChaCha20-Poly1305 AEAD cipher. +// ChaCha20-Poly1305 uses little-endian nonce encoding per the Noise spec. +type CipherStateChaChaPoly struct { + c cipher.AEAD +} + +// NewCipherStateChaChaPoly extracts the underlying AEAD from the post-handshake noise.CipherState. +// The caller is responsible for ensuring the noise cipher is actually ChaCha20-Poly1305. +func NewCipherStateChaChaPoly(s *noise.CipherState) *CipherStateChaChaPoly { + return &CipherStateChaChaPoly{c: s.Cipher().(cipher.AEAD)} +} + +func (s *CipherStateChaChaPoly) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) { + if s == nil { + return nil, errors.New("no cipher state available to encrypt") + } + nb[0] = 0 + nb[1] = 0 + nb[2] = 0 + nb[3] = 0 + binary.LittleEndian.PutUint64(nb[4:], n) + return s.c.Seal(out, nb, plaintext, ad), nil +} + +func (s *CipherStateChaChaPoly) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) { + if s == nil { + return []byte{}, nil + } + nb[0] = 0 + nb[1] = 0 + nb[2] = 0 + nb[3] = 0 + binary.LittleEndian.PutUint64(nb[4:], n) + return s.c.Open(out, nb, ciphertext, ad) +} + +func (s *CipherStateChaChaPoly) Overhead() int { + if s == nil { + return 0 + } + return s.c.Overhead() +} diff --git a/noiseutil/cipher_state.go b/noiseutil/cipher_state.go new file mode 100644 index 00000000..bb316385 --- /dev/null +++ b/noiseutil/cipher_state.go @@ -0,0 +1,40 @@ +package noiseutil + +import ( + "fmt" + + "github.com/flynn/noise" +) + +// CipherState is the post-handshake AEAD cipher used for the data plane. +// Each supported cipher has its own concrete implementation in this package with the nonce endianness hardcoded, +// so the encrypt/decrypt fast path avoids interface dispatch on the byte order. +type CipherState interface { + // EncryptDanger encrypts and authenticates a given payload. + // + // out is a destination slice to hold the output of the EncryptDanger operation. + // - ad is additional data, which will be authenticated and appended to out, but not encrypted. + // - plaintext is encrypted, authenticated and appended to out. + // - n is a nonce value which must never be re-used with this key. + // - nb is a scratch buffer used to assemble the nonce. + EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) + + // DecryptDanger authenticates and decrypts a given payload, with the same argument shape as EncryptDanger. + DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) + + // Overhead returns the AEAD tag size, or 0 if the receiver is nil. + Overhead() int +} + +// NewCipherState wraps the post-handshake noise.CipherState in the per-cipher type that matches cipherFunc. +// cipherFunc must be the same cipher used to build the noise CipherSuite that produced s. +func NewCipherState(s *noise.CipherState, cipherFunc noise.CipherFunc) CipherState { + switch cipherFunc.CipherName() { + case CipherAESGCM.CipherName(): + return NewCipherStateAESGCM(s) + case noise.CipherChaChaPoly.CipherName(): + return NewCipherStateChaChaPoly(s) + default: + panic(fmt.Sprintf("noiseutil: unsupported cipher %q", cipherFunc.CipherName())) + } +} diff --git a/noiseutil/cipher_state_test.go b/noiseutil/cipher_state_test.go new file mode 100644 index 00000000..a4df01e9 --- /dev/null +++ b/noiseutil/cipher_state_test.go @@ -0,0 +1,166 @@ +package noiseutil + +import ( + "testing" + + "github.com/flynn/noise" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCipherStateAESGCMRoundtrip(t *testing.T) { + enc, dec := buildCipherStates(t, CipherAESGCM) + roundtrip(t, NewCipherStateAESGCM(enc), NewCipherStateAESGCM(dec)) +} + +func TestCipherStateChaChaPolyRoundtrip(t *testing.T) { + enc, dec := buildCipherStates(t, noise.CipherChaChaPoly) + roundtrip(t, NewCipherStateChaChaPoly(enc), NewCipherStateChaChaPoly(dec)) +} + +func TestNewCipherStateDispatch(t *testing.T) { + encA, _ := buildCipherStates(t, CipherAESGCM) + encC, _ := buildCipherStates(t, noise.CipherChaChaPoly) + + assert.IsType(t, &CipherStateAESGCM{}, NewCipherState(encA, CipherAESGCM)) + assert.IsType(t, &CipherStateChaChaPoly{}, NewCipherState(encC, noise.CipherChaChaPoly)) +} + +func TestNewCipherStateUnsupportedPanics(t *testing.T) { + enc, _ := buildCipherStates(t, CipherAESGCM) + assert.Panics(t, func() { + NewCipherState(enc, fakeCipher{}) + }) +} + +type fakeCipher struct{} + +func (fakeCipher) Cipher(k [32]byte) noise.Cipher { return nil } +func (fakeCipher) CipherName() string { return "Fake" } + +// buildCipherStates runs an in-memory NN handshake with the requested cipher +// to produce a pair of post-handshake CipherStates that share keys. +func buildCipherStates(t *testing.T, c noise.CipherFunc) (*noise.CipherState, *noise.CipherState) { + t.Helper() + suite := noise.NewCipherSuite(noise.DH25519, c, noise.HashSHA256) + cfg := noise.Config{CipherSuite: suite, Pattern: noise.HandshakeNN} + cfg.Initiator = true + hsI, err := noise.NewHandshakeState(cfg) + require.NoError(t, err) + cfg.Initiator = false + hsR, err := noise.NewHandshakeState(cfg) + require.NoError(t, err) + + msg, _, _, err := hsI.WriteMessage(nil, nil) + require.NoError(t, err) + _, _, _, err = hsR.ReadMessage(nil, msg) + require.NoError(t, err) + + msg, dR, _, err := hsR.WriteMessage(nil, nil) + require.NoError(t, err) + _, eI, _, err := hsI.ReadMessage(nil, msg) + require.NoError(t, err) + require.NotNil(t, eI) + require.NotNil(t, dR) + + // noise returns (cs1, cs2) where cs1 is the initiator->responder cipher. + return eI, dR +} + +func roundtrip(t *testing.T, enc, dec CipherState) { + t.Helper() + plaintext := []byte("nebula cipher state roundtrip") + ad := []byte("aad") + nb := make([]byte, 12) + + ct, err := enc.EncryptDanger(nil, ad, plaintext, 1, nb) + require.NoError(t, err) + assert.NotEqual(t, plaintext, ct) + + pt, err := dec.DecryptDanger(nil, ad, ct, 1, nb) + require.NoError(t, err) + assert.Equal(t, plaintext, pt) + + // Wrong nonce must fail authentication. + _, err = dec.DecryptDanger(nil, ad, ct, 2, nb) + require.Error(t, err) + + assert.Equal(t, enc.Overhead(), dec.Overhead()) + assert.Equal(t, 16, enc.Overhead()) +} + +func BenchmarkCipherStateEncryptAESGCM(b *testing.B) { + enc, _ := buildCipherStatesB(b, CipherAESGCM) + benchEncryptCipherState(b, NewCipherState(enc, CipherAESGCM)) +} + +func BenchmarkCipherStateEncryptChaChaPoly(b *testing.B) { + enc, _ := buildCipherStatesB(b, noise.CipherChaChaPoly) + benchEncryptCipherState(b, NewCipherState(enc, noise.CipherChaChaPoly)) +} + +func benchEncryptCipherState(b *testing.B, cs CipherState) { + plaintext := make([]byte, 1280) + ad := make([]byte, 16) + nb := make([]byte, 12) + out := make([]byte, 0, len(plaintext)+cs.Overhead()) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + var err error + out, err = cs.EncryptDanger(out[:0], ad, plaintext, uint64(i+1), nb) + if err != nil { + b.Fatal(err) + } + } +} + +func buildCipherStatesB(b *testing.B, c noise.CipherFunc) (*noise.CipherState, *noise.CipherState) { + b.Helper() + suite := noise.NewCipherSuite(noise.DH25519, c, noise.HashSHA256) + cfg := noise.Config{CipherSuite: suite, Pattern: noise.HandshakeNN} + cfg.Initiator = true + hsI, err := noise.NewHandshakeState(cfg) + if err != nil { + b.Fatal(err) + } + cfg.Initiator = false + hsR, err := noise.NewHandshakeState(cfg) + if err != nil { + b.Fatal(err) + } + msg, _, _, err := hsI.WriteMessage(nil, nil) + if err != nil { + b.Fatal(err) + } + if _, _, _, err := hsR.ReadMessage(nil, msg); err != nil { + b.Fatal(err) + } + msg, dR, _, err := hsR.WriteMessage(nil, nil) + if err != nil { + b.Fatal(err) + } + _, eI, _, err := hsI.ReadMessage(nil, msg) + if err != nil { + b.Fatal(err) + } + return eI, dR +} + +func TestCipherStateNilSafety(t *testing.T) { + var aes *CipherStateAESGCM + _, err := aes.EncryptDanger(nil, nil, nil, 0, make([]byte, 12)) + require.Error(t, err) + out, err := aes.DecryptDanger(nil, nil, nil, 0, make([]byte, 12)) + require.NoError(t, err) + assert.Empty(t, out) + assert.Equal(t, 0, aes.Overhead()) + + var cc *CipherStateChaChaPoly + _, err = cc.EncryptDanger(nil, nil, nil, 0, make([]byte, 12)) + require.Error(t, err) + out, err = cc.DecryptDanger(nil, nil, nil, 0, make([]byte, 12)) + require.NoError(t, err) + assert.Empty(t, out) + assert.Equal(t, 0, cc.Overhead()) +} diff --git a/pki.go b/pki.go index acc80486..1bef5106 100644 --- a/pki.go +++ b/pki.go @@ -99,12 +99,10 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { var currentState *CertState if initial { cipher = c.GetString("cipher", "aes") - //TODO: this sucks and we should make it not a global switch cipher { - case "aes": - noiseEndianness = binary.BigEndian - case "chachapoly": - noiseEndianness = binary.LittleEndian + case "aes", "chachapoly": + // Each post-handshake CipherState in noiseutil hardcodes its own + // nonce endianness now, so there's nothing to set up here. default: return util.NewContextualError( "unknown cipher", From 1ada3d4dd98659a425fb0196b2e34eead36f9914 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Thu, 7 May 2026 10:30:29 -0500 Subject: [PATCH 11/27] Use DefinedNets fancy new netbsd10 vagrant box for smokes (#1711) --- .github/workflows/smoke-extra.yml | 48 ++++++++++++------- .../smoke/vagrant-netbsd-amd64/Vagrantfile | 2 +- 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/.github/workflows/smoke-extra.yml b/.github/workflows/smoke-extra.yml index 3734db75..cca7678b 100644 --- a/.github/workflows/smoke-extra.yml +++ b/.github/workflows/smoke-extra.yml @@ -14,10 +14,18 @@ on: - 'go.sum' jobs: - smoke-extra: + smoke-extra-libvirt: if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra') - name: Run extra smoke tests + name: ${{ matrix.target }} runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + target: + - freebsd-amd64 + - openbsd-amd64 + - netbsd-amd64 + - linux-amd64-ipv6disable env: VAGRANT_DEFAULT_PROVIDER: libvirt steps: @@ -40,28 +48,36 @@ jobs: sudo chmod 666 /var/run/libvirt/libvirt-sock vagrant plugin install vagrant-libvirt - - name: freebsd-amd64 - run: make smoke-vagrant/freebsd-amd64 + - name: ${{ matrix.target }} + run: make smoke-vagrant/${{ matrix.target }} - - name: openbsd-amd64 - run: make smoke-vagrant/openbsd-amd64 + timeout-minutes: 30 - - name: netbsd-amd64 - run: make smoke-vagrant/netbsd-amd64 + # linux-386 needs VirtualBox, which conflicts with KVM/libvirt -- isolated job. + smoke-extra-virtualbox: + if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra') + name: linux-386 + runs-on: ubuntu-latest + env: + VAGRANT_DEFAULT_PROVIDER: virtualbox + steps: - - name: linux-amd64-ipv6disable - run: make smoke-vagrant/linux-amd64-ipv6disable + - uses: actions/checkout@v6 - # linux-386 runs last because it requires disabling KVM to use VirtualBox, - # which prevents libvirt (used by the other tests) from working after this point. - - name: install virtualbox for i386 test + - uses: actions/setup-go@v6 + with: + go-version: '1.25' + check-latest: true + + - name: add hashicorp source + run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list + + - name: install vagrant and virtualbox run: | - sudo apt-get install -y virtualbox + sudo apt-get update && sudo apt-get install -y vagrant virtualbox sudo rmmod kvm_amd kvm_intel kvm 2>/dev/null || true - name: linux-386 - env: - VAGRANT_DEFAULT_PROVIDER: virtualbox run: make smoke-vagrant/linux-386 timeout-minutes: 30 diff --git a/.github/workflows/smoke/vagrant-netbsd-amd64/Vagrantfile b/.github/workflows/smoke/vagrant-netbsd-amd64/Vagrantfile index 14ba2ce1..a3fa7ec2 100644 --- a/.github/workflows/smoke/vagrant-netbsd-amd64/Vagrantfile +++ b/.github/workflows/smoke/vagrant-netbsd-amd64/Vagrantfile @@ -1,7 +1,7 @@ # -*- mode: ruby -*- # vi: set ft=ruby : Vagrant.configure("2") do |config| - config.vm.box = "generic/netbsd9" + config.vm.box = "DefinedNet/netbsd10" config.vm.synced_folder "../build", "/nebula", type: "rsync" end From c82db210ef7a31940412044b4cad0e372ea23658 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Thu, 7 May 2026 11:30:26 -0500 Subject: [PATCH 12/27] Change windows unsafe routes to link routes, fix sshd reload bug (#1709) --- e2e/sshd_test.go | 125 +++++++++++++++++++++++++++++++++++++++++ overlay/tun_windows.go | 16 ++++-- sshd/server.go | 56 +++++++++--------- 3 files changed, 162 insertions(+), 35 deletions(-) create mode 100644 e2e/sshd_test.go diff --git a/e2e/sshd_test.go b/e2e/sshd_test.go new file mode 100644 index 00000000..e91f1bd0 --- /dev/null +++ b/e2e/sshd_test.go @@ -0,0 +1,125 @@ +//go:build e2e_testing +// +build e2e_testing + +package e2e + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/pem" + "net" + "strings" + "testing" + "time" + + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" +) + +func TestSSHDLifecycle(t *testing.T) { + // TestSSHDLifecycle exercises the in-process sshd through several config reloads and a Control.Stop. + ca, _, caKey, _ := cert_test.NewTestCaCert( + cert.Version1, cert.Curve_CURVE25519, + time.Now(), time.Now().Add(10*time.Minute), + nil, nil, []string{}, + ) + + hostKeyPEM := generateSSHHostKey(t) + clientSigner, clientAuthKey := generateSSHClientKey(t) + sshdAddr := allocLoopbackPort(t) + + overrides := m{ + "sshd": m{ + "enabled": true, + "listen": sshdAddr, + "host_key": hostKeyPEM, + "authorized_users": []m{{ + "user": "tester", + "keys": []string{clientAuthKey}, + }}, + }, + } + control, _, _, _ := newSimpleServer(cert.Version1, ca, caKey, "sshd-test", "10.222.0.1/24", overrides) + control.Start() + t.Cleanup(func() { control.Stop() }) + + // sshd binds in a goroutine after Start returns; wait for it. + require.Eventually(t, func() bool { return canDial(sshdAddr) }, 2*time.Second, 25*time.Millisecond, + "sshd never started listening") + + for i := 1; i <= 3; i++ { + out := sshExecReload(t, sshdAddr, clientSigner) + assert.Contains(t, out, "Reloading config", "reload cycle %d", i) + require.Eventually(t, func() bool { return canDial(sshdAddr) }, 2*time.Second, 25*time.Millisecond, + "sshd not listening after reload cycle %d", i) + } + + control.Stop() + require.Eventually(t, func() bool { return !canDial(sshdAddr) }, 2*time.Second, 25*time.Millisecond, + "sshd still listening after Control.Stop") +} + +func canDial(addr string) bool { + c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond) + if err != nil { + return false + } + _ = c.Close() + return true +} + +// allocLoopbackPort grabs an unused TCP port on 127.0.0.1, closes it, and returns the address. There +// is a small race between releasing the port and the sshd reclaiming it; in practice the OS keeps the +// port available long enough for the test to bind it. +func allocLoopbackPort(t *testing.T) string { + t.Helper() + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + addr := l.Addr().String() + require.NoError(t, l.Close()) + return addr +} + +func generateSSHHostKey(t *testing.T) string { + t.Helper() + _, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + block, err := ssh.MarshalPrivateKey(priv, "nebula-e2e-host") + require.NoError(t, err) + return string(pem.EncodeToMemory(block)) +} + +func generateSSHClientKey(t *testing.T) (ssh.Signer, string) { + t.Helper() + _, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + signer, err := ssh.NewSignerFromKey(priv) + require.NoError(t, err) + auth := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(signer.PublicKey()))) + return signer, auth +} + +func sshExecReload(t *testing.T, addr string, signer ssh.Signer) string { + t.Helper() + cfg := &ssh.ClientConfig{ + User: "tester", + Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 2 * time.Second, + } + client, err := ssh.Dial("tcp", addr, cfg) + require.NoError(t, err) + defer client.Close() + + sess, err := client.NewSession() + require.NoError(t, err) + defer sess.Close() + + // reload tears the channel down before sending exit-status, so Output returns an error on the + // channel close. The output buffer still has whatever the reload callback wrote before that. + out, _ := sess.Output("reload") + return string(out) +} diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 680dddb3..14c8d499 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -156,11 +156,8 @@ func (t *winTun) addRoutes(logErrors bool) error { continue } - // Add our unsafe route - // Windows does not support multipath routes natively, so we install only a single route. - // This is not a problem as traffic will always be sent to Nebula which handles the multipath routing internally. - // In effect this provides multipath routing support to windows supporting loadbalancing and redundancy. - err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric)) + // Add our unsafe route as an on-link route to the nebula tun device. + err := luid.AddRoute(r.Cidr, unspecifiedNextHop(r.Cidr), uint32(r.Metric)) if err != nil { retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) if logErrors { @@ -206,7 +203,7 @@ func (t *winTun) removeRoutes(routes []Route) error { } // See comment on luid.AddRoute - err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr()) + err := luid.DeleteRoute(r.Cidr, unspecifiedNextHop(r.Cidr)) if err != nil { t.l.Error("Failed to remove route", "error", err, "route", r) } else { @@ -261,6 +258,13 @@ func (t *winTun) Close() error { return t.tun.Close() } +func unspecifiedNextHop(p netip.Prefix) netip.Addr { + if p.Addr().Is4() { + return netip.IPv4Unspecified() + } + return netip.IPv6Unspecified() +} + func generateGUIDByDeviceName(name string) (*windows.GUID, error) { // GUID is 128 bit hash := crypto.MD5.New() diff --git a/sshd/server.go b/sshd/server.go index ff954bf5..86c52961 100644 --- a/sshd/server.go +++ b/sshd/server.go @@ -27,23 +27,20 @@ type SSHServer struct { commands *radix.Tree listener net.Listener - // Call the cancel() function to stop all active sessions - ctx context.Context - cancel func() + // ctx parents per-Run contexts. Cancelling it (e.g. via Control.Stop) tears the server down even + // across reloads, since each Run derives a fresh child rather than reusing this one directly. + ctx context.Context } // NewSSHServer creates a new ssh server rigged with default commands and prepares to listen. // The ssh server's context is parented off the supplied ctx so cancelling it // (e.g. on Control.Stop) tears down active sessions and closes the listener. func NewSSHServer(ctx context.Context, l *slog.Logger) (*SSHServer, error) { - - ctx, cancel := context.WithCancel(ctx) s := &SSHServer{ trustedKeys: make(map[string]map[string]bool), l: l, commands: radix.New(), ctx: ctx, - cancel: cancel, } cc := ssh.CertChecker{ @@ -153,45 +150,51 @@ func (s *SSHServer) RegisterCommand(c *Command) { s.commands.Insert(c.Name, c) } -// Run begins listening and accepting connections +// Run begins listening and accepting connections. Each invocation derives a fresh per-Run context +// from the constructor-supplied ctx so a Stop+Run sequence (used by config reload) starts clean +// rather than carrying a permanently-cancelled context across runs. func (s *SSHServer) Run(addr string) error { if s.ctx.Err() != nil { return s.ctx.Err() } - var err error - s.listener, err = net.Listen("tcp", addr) + listener, err := net.Listen("tcp", addr) if err != nil { return err } + // s.listener is the public handle Stop uses to interrupt the active run; listener (the local) is what + // this run owns. They start equal but a fast reload may overwrite s.listener with the next run's + // listener before this run's watcher fires, so each run must close its own listener via the local + // reference. + s.listener = listener - s.l.Info("SSH server is listening", "sshListener", addr) + runCtx, cancel := context.WithCancel(s.ctx) + defer cancel() - // Per-invocation watcher: cancellation of the parent context (e.g. - // Control.Stop) closes the listener so Accept unblocks and run returns. - // Closing `done` on exit keeps the watcher from outliving this Run call. - done := make(chan struct{}) + // Close the listener when this run's context is cancelled. That can come from the parent + // (Control.Stop), from Run returning normally (defer cancel above), or transitively when a sibling + // run cancels through Stop closing the listener. net.Listener.Close is idempotent so a duplicate + // close from Stop is benign. go func() { - select { - case <-s.ctx.Done(): - s.Stop() - case <-done: + <-runCtx.Done() + if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + s.l.Warn("Failed to close the sshd listener", "error", err) } }() + s.l.Info("SSH server is listening", "sshListener", addr) + // Run loops until there is an error - s.run() - close(done) - s.closeSessions() + s.run(runCtx, listener) s.l.Info("SSH server stopped listening") // We don't return an error because run logs for us return nil } -func (s *SSHServer) run() { +func (s *SSHServer) run(ctx context.Context, listener net.Listener) { for { - c, err := s.listener.Accept() + c, err := listener.Accept() if err != nil { if !errors.Is(err, net.ErrClosed) { s.l.Warn("Error in listener, shutting down", "error", err) @@ -203,7 +206,7 @@ func (s *SSHServer) run() { // Ensure that a bad client doesn't hurt us by checking for the parent context // cancellation before calling NewServerConn, and forcing the socket to close when // the context is cancelled. - sessionContext, sessionCancel := context.WithCancel(s.ctx) + sessionContext, sessionCancel := context.WithCancel(ctx) go func() { <-sessionContext.Done() c.Close() @@ -246,14 +249,9 @@ func (s *SSHServer) run() { } func (s *SSHServer) Stop() { - // Close the listener, this will cause all session to terminate as well, see SSHServer.Run if s.listener != nil { if err := s.listener.Close(); err != nil { s.l.Warn("Failed to close the sshd listener", "error", err) } } } - -func (s *SSHServer) closeSessions() { - s.cancel() -} From 696903d6d91be3751a576779916cb7d5701140f2 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Thu, 7 May 2026 20:17:38 -0500 Subject: [PATCH 13/27] Add a way to set the network type on windows + tests (#1710) --- .github/workflows/smoke-extra.yml | 49 +++ .github/workflows/smoke/smoke-windows.ps1 | 272 ++++++++++++++++ examples/config.yml | 26 ++ overlay/network_category_windows.go | 358 ++++++++++++++++++++ overlay/network_category_windows_test.go | 109 +++++++ overlay/tun_bypass_windows.go | 23 ++ overlay/tun_bypass_windows_386.go | 11 + overlay/tun_windows.go | 54 +++- udp/udp_android.go | 3 +- udp/udp_bsd.go | 3 +- udp/udp_bypass_windows.go | 57 ++++ udp/udp_bypass_windows_386.go | 11 + udp/udp_netbsd.go | 3 +- udp/udp_windows.go | 13 +- wfp/wfp_windows.go | 377 ++++++++++++++++++++++ 15 files changed, 1349 insertions(+), 20 deletions(-) create mode 100644 .github/workflows/smoke/smoke-windows.ps1 create mode 100644 overlay/network_category_windows.go create mode 100644 overlay/network_category_windows_test.go create mode 100644 overlay/tun_bypass_windows.go create mode 100644 overlay/tun_bypass_windows_386.go create mode 100644 udp/udp_bypass_windows.go create mode 100644 udp/udp_bypass_windows_386.go create mode 100644 wfp/wfp_windows.go diff --git a/.github/workflows/smoke-extra.yml b/.github/workflows/smoke-extra.yml index cca7678b..e0428e9c 100644 --- a/.github/workflows/smoke-extra.yml +++ b/.github/workflows/smoke-extra.yml @@ -81,3 +81,52 @@ jobs: run: make smoke-vagrant/linux-386 timeout-minutes: 30 + + smoke-windows: + if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra') + name: Run windows smoke test + runs-on: windows-latest + steps: + + - uses: actions/checkout@v6 + + - uses: actions/setup-go@v6 + with: + go-version: '1.25' + check-latest: true + + # WSL2 + Ubuntu so the smoke can run a real linux peer with its own + # netns. iputils-ping is needed for the in-WSL ping check. WSL1 has no + # real kernel and would lack /dev/net/tun, so we have to force WSL2. + - uses: Vampire/setup-wsl@v3 + with: + distribution: Ubuntu-24.04 + additional-packages: iputils-ping iproute2 + + # Vampire/setup-wsl provisions WSL1 even when the WSL2 platform is present. + # Convert the distro to WSL2 explicitly before we try to use /dev/net/tun. + - name: convert distro to WSL2 + shell: pwsh + run: | + wsl --set-version Ubuntu-24.04 2 + wsl --shutdown + wsl --list --verbose + + - name: build windows nebula + run: make bin-windows + + - name: build linux nebula for WSL + shell: bash + env: + GOOS: linux + GOARCH: amd64 + run: | + mkdir -p build/linux-amd64 + go build -o build/linux-amd64/nebula ./cmd/nebula + + - name: run smoke-windows + shell: pwsh + working-directory: ./.github/workflows/smoke + run: ./smoke-windows.ps1 + + timeout-minutes: 15 diff --git a/.github/workflows/smoke/smoke-windows.ps1 b/.github/workflows/smoke/smoke-windows.ps1 new file mode 100644 index 00000000..0436598d --- /dev/null +++ b/.github/workflows/smoke/smoke-windows.ps1 @@ -0,0 +1,272 @@ +#!/usr/bin/env pwsh +# Windows smoke test for the nebula tun + UDP + NLM code paths. +# +# Topology: +# - lighthouse runs natively on the Windows host (wintun + windows UDP) +# - peer runs inside WSL2 (Linux build of nebula, /dev/net/tun) +# +# WSL2 gives us a real netns boundary so the loopback fast-path on Windows +# does not short-circuit the overlay -- when WSL pings the lighthouse VPN IP, +# Linux has no idea that IP is local to the Windows host, so the packet is +# forced through nebula. Same in reverse. + +$ErrorActionPreference = 'Stop' + +# wsl.exe emits UTF-16 LE by default which PowerShell reads as bytes, mangling +# every captured string. WSL_UTF8 makes wsl.exe emit UTF-8 instead. +$env:WSL_UTF8 = '1' + +$RepoRoot = Resolve-Path "$PSScriptRoot\..\..\.." +$Nebula = Join-Path $RepoRoot 'nebula.exe' +$NebulaCert = Join-Path $RepoRoot 'nebula-cert.exe' +$NebulaLinux = Join-Path $RepoRoot 'build\linux-amd64\nebula' + +if (-not (Test-Path $Nebula)) { throw "missing $Nebula; run 'make bin-windows' first" } +if (-not (Test-Path $NebulaCert)) { throw "missing $NebulaCert; run 'make bin-windows' first" } +if (-not (Test-Path $NebulaLinux)) { throw "missing $NebulaLinux; build the linux nebula first" } + +# Matches the distro installed by Vampire/setup-wsl in smoke-extra.yml. +$Distro = 'Ubuntu-24.04' +$listed = (wsl --list --quiet 2>$null) -join "`n" +if ($listed -notmatch [regex]::Escape($Distro)) { + throw "WSL distro $Distro not registered. Got: $listed" +} +Write-Host "Using WSL distro: $Distro" + +# Windows host as seen from inside WSL: WSL's default-route gateway. We extract +# it with a regex rather than awk fields so PowerShell does not eat any '$N' +# tokens, and tabs/double-spaces in `ip route` output do not confuse a cut. +$ipCmd = 'ip route show default | grep -oE "([0-9]+\.){3}[0-9]+" | head -1' +$WindowsIp = (wsl -d $Distro -- bash -c $ipCmd).Trim() +if (-not $WindowsIp) { throw "could not determine Windows host IP from WSL" } +Write-Host "Windows host IP from WSL: $WindowsIp" + +$WorkDir = Join-Path $env:TEMP 'nebula-smoke-windows' +if (Test-Path $WorkDir) { Remove-Item -Recurse -Force $WorkDir } +New-Item -ItemType Directory -Path $WorkDir | Out-Null + +$WslDir = '/tmp/nebula-smoke' +wsl -d $Distro -- bash -c "rm -rf $WslDir && mkdir -p $WslDir" | Out-Null + +$DevName = 'nebula-smoke' +$Ip1 = '192.168.241.1' +$Ip2 = '192.168.241.2' +$Port = 4242 + +& $NebulaCert ca -name 'smoke-ca' -out-crt "$WorkDir\ca.crt" -out-key "$WorkDir\ca.key" +if ($LASTEXITCODE -ne 0) { throw "nebula-cert ca failed (exit $LASTEXITCODE)" } + +& $NebulaCert sign -name 'lighthouse' -networks "$Ip1/24" -ca-crt "$WorkDir\ca.crt" -ca-key "$WorkDir\ca.key" -out-crt "$WorkDir\lighthouse.crt" -out-key "$WorkDir\lighthouse.key" +if ($LASTEXITCODE -ne 0) { throw "nebula-cert sign lighthouse failed (exit $LASTEXITCODE)" } + +& $NebulaCert sign -name 'peer' -networks "$Ip2/24" -ca-crt "$WorkDir\ca.crt" -ca-key "$WorkDir\ca.key" -out-crt "$WorkDir\peer.crt" -out-key "$WorkDir\peer.key" +if ($LASTEXITCODE -ne 0) { throw "nebula-cert sign peer failed (exit $LASTEXITCODE)" } + +# Windows lighthouse config. +@" +pki: + ca: $WorkDir\ca.crt + cert: $WorkDir\lighthouse.crt + key: $WorkDir\lighthouse.key +static_host_map: {} +lighthouse: + am_lighthouse: true + interval: 60 + hosts: [] +listen: + host: 0.0.0.0 + port: $Port +tun: + disabled: false + dev: $DevName + drop_local_broadcast: false + drop_multicast: false + tx_queue: 500 + mtu: 1300 + network_category: private +logging: + level: info + format: text +firewall: + outbound_action: drop + inbound_action: drop + conntrack: + tcp_timeout: 12m + udp_timeout: 3m + default_timeout: 10m + outbound: + - port: any + proto: any + host: any + inbound: + - port: any + proto: any + host: any +"@ | Out-File -FilePath "$WorkDir\lighthouse.yml" -Encoding utf8 + +# WSL peer config (paths are POSIX, deliberately). +@" +pki: + ca: $WslDir/ca.crt + cert: $WslDir/peer.crt + key: $WslDir/peer.key +static_host_map: + "${Ip1}": ["${WindowsIp}:$Port"] +lighthouse: + am_lighthouse: false + interval: 60 + hosts: + - "${Ip1}" +listen: + host: 0.0.0.0 + port: 0 +tun: + disabled: false + dev: nebula1 + drop_local_broadcast: false + drop_multicast: false + tx_queue: 500 + mtu: 1300 +logging: + level: info + format: text +firewall: + outbound_action: drop + inbound_action: drop + conntrack: + tcp_timeout: 12m + udp_timeout: 3m + default_timeout: 10m + outbound: + - port: any + proto: any + host: any + inbound: + - port: any + proto: any + host: any +"@ | Out-File -FilePath "$WorkDir\peer.yml" -Encoding utf8 + +# Stage WSL artifacts. Convert Windows paths to WSL paths ourselves rather than +# calling `wslpath`, because PowerShell's argument-passing to external EXEs +# strips backslashes from path arguments in ways that are hard to escape around. +function ConvertTo-WslPath { + param([string]$WindowsPath) + if ($WindowsPath -notmatch '^([A-Za-z]):\\(.*)$') { + throw "cannot convert path to WSL: $WindowsPath" + } + return "/mnt/$($matches[1].ToLower())/$($matches[2].Replace('\','/'))" +} + +$WslWorkDir = ConvertTo-WslPath $WorkDir +$WslNebulaPath = ConvertTo-WslPath $NebulaLinux +wsl -d $Distro -- bash -c "cp '$WslWorkDir/ca.crt' '$WslWorkDir/peer.crt' '$WslWorkDir/peer.key' '$WslWorkDir/peer.yml' $WslDir/ && cp '$WslNebulaPath' $WslDir/nebula && chmod +x $WslDir/nebula" + +# Make sure WSL has tun support and /dev/net/tun is usable before starting +# nebula. Diagnostics first so a fail here points at the real problem (e.g. +# WSL1 distros do not have a real kernel and will not have tun). +Write-Host '=== WSL diagnostic ===' +wsl --version 2>&1 | Out-Host +wsl --list --verbose 2>&1 | Out-Host +wsl -d $Distro -u root -- uname -a | Out-Host +wsl -d $Distro -u root -- bash -c "modprobe tun 2>&1 || true; mkdir -p /dev/net; [ -c /dev/net/tun ] || mknod /dev/net/tun c 10 200; chmod 600 /dev/net/tun; ls -l /dev/net/tun" +if ($LASTEXITCODE -ne 0) { throw "failed to prepare /dev/net/tun in WSL (TUN support missing?)" } + +# Deliberately no New-NetFirewallRule calls here -- nebula's windows_bypass_wdf +# feature is supposed to install WFP permit filters that let inbound traffic +# through Windows Defender Firewall on its own. If this smoke regresses, that +# feature regressed. + +$lhOut = Join-Path $WorkDir 'lighthouse.out.log' +$lhErr = Join-Path $WorkDir 'lighthouse.err.log' +$lhProc = Start-Process -FilePath $Nebula -ArgumentList @('-config', "$WorkDir\lighthouse.yml") ` + -PassThru -NoNewWindow ` + -RedirectStandardOutput $lhOut ` + -RedirectStandardError $lhErr + +# Run nebula in WSL as root with no sudo + no shell wrapper. PowerShell's +# Start-Process arg quoting mangles `bash -c "..."` strings that contain +# spaces/redirections, so we skip bash entirely and let Start-Process do the +# stdout/stderr capture itself. +$peerOut = Join-Path $WorkDir 'peer.out.log' +$peerErr = Join-Path $WorkDir 'peer.err.log' +$peerProc = Start-Process -FilePath 'wsl' ` + -ArgumentList @('-d', $Distro, '-u', 'root', '--', "$WslDir/nebula", '-config', "$WslDir/peer.yml") ` + -PassThru -NoNewWindow ` + -RedirectStandardOutput $peerOut ` + -RedirectStandardError $peerErr + +function Wait-Until { + param([scriptblock]$Predicate, [int]$TimeoutSec, [string]$What) + $deadline = (Get-Date).AddSeconds($TimeoutSec) + while ((Get-Date) -lt $deadline) { + if (& $Predicate) { return } + Start-Sleep -Milliseconds 500 + } + throw "timed out waiting for: $What" +} + +try { + Wait-Until -TimeoutSec 30 -What "windows wintun adapter $DevName with NetworkCategory=Private" -Predicate { + if ($lhProc.HasExited) { throw "lighthouse exited (code $($lhProc.ExitCode)) before tun was ready" } + $p = Get-NetConnectionProfile -InterfaceAlias $DevName -ErrorAction SilentlyContinue + $p -and ("$($p.NetworkCategory)" -ieq 'Private') + } + Write-Host "OK: $DevName NetworkCategory=Private" + + Wait-Until -TimeoutSec 30 -What "WSL nebula1 with $Ip2" -Predicate { + if ($peerProc.HasExited) { throw "peer exited (code $($peerProc.ExitCode)) before tun was ready" } + $r = wsl -d $Distro -u root -- bash -c "ip -o addr show nebula1 2>/dev/null | grep -q 'inet $Ip2' && echo yes" + ("$r").Trim() -eq 'yes' + } + Write-Host "OK: WSL nebula1 has $Ip2" + + Wait-Until -TimeoutSec 30 -What "ping from WSL peer to windows lighthouse ($Ip1)" -Predicate { + if ($peerProc.HasExited) { throw "peer exited (code $($peerProc.ExitCode)) before ping succeeded" } + $r = wsl -d $Distro -u root -- bash -c "ping -c1 -W1 $Ip1 >/dev/null 2>&1 && echo OK" + ("$r").Trim() -eq 'OK' + } + Write-Host "OK: WSL peer -> windows lighthouse" + + Wait-Until -TimeoutSec 30 -What "ping from windows lighthouse to WSL peer ($Ip2)" -Predicate { + $null = & ping.exe -n 1 -w 1000 $Ip2 + $LASTEXITCODE -eq 0 + } + Write-Host "OK: windows lighthouse -> WSL peer" + + Write-Host '' + Write-Host 'All smoke checks passed.' +} +catch { + Write-Host '' + Write-Host '=== lighthouse stdout ===' + Get-Content $lhOut -ErrorAction SilentlyContinue | Out-Host + Write-Host '=== lighthouse stderr ===' + Get-Content $lhErr -ErrorAction SilentlyContinue | Out-Host + Write-Host '=== peer stdout ===' + Get-Content $peerOut -ErrorAction SilentlyContinue | Out-Host + Write-Host '=== peer stderr ===' + Get-Content $peerErr -ErrorAction SilentlyContinue | Out-Host + Write-Host '=== nebula WFP filters ===' + # Dump nebula-installed filters so we can verify they got registered with + # the conditions we expect. + $wfpDump = Join-Path $WorkDir 'wfp.xml' + netsh wfp show filters file=$wfpDump 2>&1 | Out-Null + if (Test-Path $wfpDump) { + Select-String -Path $wfpDump -Pattern 'Nebula' -Context 0,80 -ErrorAction SilentlyContinue | Out-Host + } + throw +} +finally { + if (-not $lhProc.HasExited) { + Stop-Process -Id $lhProc.Id -Force -ErrorAction SilentlyContinue + $lhProc.WaitForExit(5000) | Out-Null + } + wsl -d $Distro -u root -- bash -c "pkill -f $WslDir/nebula 2>/dev/null; true" | Out-Null + # pkill returns 1 when no match and wsl propagates that; the smoke is done + # so we don't want it to leak into the script's exit code. + $global:LASTEXITCODE = 0 + if ($peerProc -and -not $peerProc.HasExited) { + Stop-Process -Id $peerProc.Id -Force -ErrorAction SilentlyContinue + } +} diff --git a/examples/config.yml b/examples/config.yml index ac4810e6..6c7fb489 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -138,6 +138,14 @@ listen: # max, net.core.rmem_max and net.core.wmem_max #read_buffer: 10485760 #write_buffer: 10485760 + + # On Windows only + # When true, Nebula installs a WFP (Windows Filtering Platform) PERMIT filter scoped to UDP at the listener port. + # WFP sits below Windows Defender Firewall, so this lets peer handshakes reach Nebula's outside socket regardless + # of WDF's inbound rules. + # Default true; set to false to leave WDF in charge of inbound decisions on the listener port. Not reloadable. + #windows_bypass_wdf: true + # By default, Nebula replies to packets it has no tunnel for with a "recv_error" packet. This packet helps speed up reconnection # in the case that Nebula on either side did not shut down cleanly. This response can be abused as a way to discover if Nebula is running # on a host though. This option lets you configure if you want to send "recv_error" packets always, never, or only to private network remotes. @@ -286,6 +294,24 @@ tun: # metric: 100 # install: true + # On Windows only, sets the network category of the nebula interface. Without this, Windows often + # leaves the network as "Unidentified" and treats it as Public, which makes the host firewall more + # restrictive than you usually want for an overlay between trusted peers. Valid values: + # private - treat the nebula network as a private/trusted network (default) + # public - treat it as a public/untrusted network + # domain - treat it as a domain-authenticated network + # unset - leave whatever Windows decided alone + # Not reloadable. + #network_category: private + + # On Windows only + # When true, Nebula installs a WFP (Windows Filtering Platform) PERMIT filter scoped to the nebula adapter LUID. + # WFP sits below Windows Defender Firewall, so this lets inbound traffic through regardless of WDF rules. + # Filters are auto-removed when the adapter goes away. + # See listen.windows_bypass_wdf for the matching control over inbound to nebula's outside UDP listener. + # Default true; set to false to leave WDF in charge of inbound decisions on the nebula interface. Not reloadable. + #windows_bypass_wdf: true + # On linux only, set to true to manage unsafe routes directly on the system route table with gateway routes instead of # in nebula configuration files. Default false, not reloadable. #use_system_route_table: false diff --git a/overlay/network_category_windows.go b/overlay/network_category_windows.go new file mode 100644 index 00000000..cbf87f00 --- /dev/null +++ b/overlay/network_category_windows.go @@ -0,0 +1,358 @@ +//go:build !e2e_testing +// +build !e2e_testing + +package overlay + +import ( + "errors" + "fmt" + "log/slog" + "runtime" + "strings" + "syscall" + "time" + "unsafe" + + "golang.org/x/sys/windows" +) + +// networkCategory mirrors NLM_NETWORK_CATEGORY from netlistmgr.h. +type networkCategory int32 + +const ( + networkCategoryPublic networkCategory = 0 + networkCategoryPrivate networkCategory = 1 + networkCategoryDomainAuthenticated networkCategory = 2 +) + +func (c networkCategory) String() string { + switch c { + case networkCategoryPublic: + return "public" + case networkCategoryPrivate: + return "private" + case networkCategoryDomainAuthenticated: + return "domain" + } + return fmt.Sprintf("unknown(%d)", c) +} + +// parseNetworkCategory accepts the user-supplied tun.network_category. A +// second return of false means "leave the category alone". +func parseNetworkCategory(s string) (networkCategory, bool, error) { + switch strings.ToLower(strings.TrimSpace(s)) { + case "", "unset": + return 0, false, nil + case "public": + return networkCategoryPublic, true, nil + case "private": + return networkCategoryPrivate, true, nil + case "domain", "domainauthenticated": + return networkCategoryDomainAuthenticated, true, nil + } + return 0, false, fmt.Errorf("unknown tun.network_category %q (expected public, private, domain, or unset)", s) +} + +// CLSID_NetworkListManager {DCB00C01-570F-4A9B-8D69-199FDBA5723B} +var clsidNetworkListManager = windows.GUID{ + Data1: 0xDCB00C01, Data2: 0x570F, Data3: 0x4A9B, + Data4: [8]byte{0x8D, 0x69, 0x19, 0x9F, 0xDB, 0xA5, 0x72, 0x3B}, +} + +// IID_INetworkListManager {DCB00000-570F-4A9B-8D69-199FDBA5723B} +var iidINetworkListManager = windows.GUID{ + Data1: 0xDCB00000, Data2: 0x570F, Data3: 0x4A9B, + Data4: [8]byte{0x8D, 0x69, 0x19, 0x9F, 0xDB, 0xA5, 0x72, 0x3B}, +} + +// x/sys/windows doesn't expose CoCreateInstance, so we bind it ourselves. +var procCoCreateInstance = windows.NewLazySystemDLL("ole32.dll").NewProc("CoCreateInstance") + +const clsCtxAll = windows.CLSCTX_INPROC_SERVER | windows.CLSCTX_INPROC_HANDLER | + windows.CLSCTX_LOCAL_SERVER | windows.CLSCTX_REMOTE_SERVER + +const ( + hrSFALSE = 0x00000001 + hrRPCEChangedMode = 0x80010106 +) + +type hresult uint32 + +func (h hresult) failed() bool { return int32(h) < 0 } +func (h hresult) String() string { + return fmt.Sprintf("HRESULT 0x%08x", uint32(h)) +} + +var errAdapterNotFound = errors.New("adapter not present in network connections enumeration") + +// Vtable layouts. Slot order must match the declaration order in netlistmgr.h. +// All NLM interfaces here derive from IDispatch, which derives from IUnknown. + +type iUnknownVtbl struct { + QueryInterface uintptr + AddRef uintptr + Release uintptr +} + +type iDispatchVtbl struct { + iUnknownVtbl + GetTypeInfoCount uintptr + GetTypeInfo uintptr + GetIDsOfNames uintptr + Invoke uintptr +} + +type iNetworkListManagerVtbl struct { + iDispatchVtbl + GetNetworks uintptr + GetNetwork uintptr + GetNetworkConnections uintptr + GetNetworkConnection uintptr + IsConnectedToInternet uintptr + IsConnected uintptr + GetConnectivity uintptr +} + +type iNetworkListManager struct{ Vtbl *iNetworkListManagerVtbl } + +func (n *iNetworkListManager) Release() { + syscall.SyscallN(n.Vtbl.Release, uintptr(unsafe.Pointer(n))) +} + +func (n *iNetworkListManager) GetNetworkConnections() (*iEnumNetworkConnections, error) { + var enum *iEnumNetworkConnections + r1, _, _ := syscall.SyscallN(n.Vtbl.GetNetworkConnections, + uintptr(unsafe.Pointer(n)), uintptr(unsafe.Pointer(&enum)), + ) + if hr := hresult(r1); hr.failed() { + return nil, fmt.Errorf("INetworkListManager.GetNetworkConnections: %s", hr) + } + return enum, nil +} + +type iEnumNetworkConnectionsVtbl struct { + iDispatchVtbl + NewEnum uintptr + Next uintptr + Skip uintptr + Reset uintptr + Clone uintptr +} + +type iEnumNetworkConnections struct{ Vtbl *iEnumNetworkConnectionsVtbl } + +func (e *iEnumNetworkConnections) Release() { + syscall.SyscallN(e.Vtbl.Release, uintptr(unsafe.Pointer(e))) +} + +// Next returns the next connection, or (nil, nil) at the end of the enumeration. +func (e *iEnumNetworkConnections) Next() (*iNetworkConnection, error) { + var conn *iNetworkConnection + var fetched uint32 + r1, _, _ := syscall.SyscallN(e.Vtbl.Next, + uintptr(unsafe.Pointer(e)), 1, + uintptr(unsafe.Pointer(&conn)), uintptr(unsafe.Pointer(&fetched)), + ) + if hr := hresult(r1); hr.failed() { + return nil, fmt.Errorf("IEnumNetworkConnections.Next: %s", hr) + } + if fetched == 0 { + return nil, nil + } + return conn, nil +} + +type iNetworkConnectionVtbl struct { + iDispatchVtbl + GetNetwork uintptr + IsConnectedToInternet uintptr + IsConnected uintptr + GetConnectivity uintptr + GetConnectionId uintptr + GetAdapterId uintptr + GetDomainType uintptr +} + +type iNetworkConnection struct{ Vtbl *iNetworkConnectionVtbl } + +func (c *iNetworkConnection) Release() { + syscall.SyscallN(c.Vtbl.Release, uintptr(unsafe.Pointer(c))) +} + +func (c *iNetworkConnection) GetAdapterId() (windows.GUID, error) { + var g windows.GUID + r1, _, _ := syscall.SyscallN(c.Vtbl.GetAdapterId, + uintptr(unsafe.Pointer(c)), uintptr(unsafe.Pointer(&g)), + ) + if hr := hresult(r1); hr.failed() { + return windows.GUID{}, fmt.Errorf("INetworkConnection.GetAdapterId: %s", hr) + } + return g, nil +} + +func (c *iNetworkConnection) GetNetwork() (*iNetwork, error) { + var net *iNetwork + r1, _, _ := syscall.SyscallN(c.Vtbl.GetNetwork, + uintptr(unsafe.Pointer(c)), uintptr(unsafe.Pointer(&net)), + ) + if hr := hresult(r1); hr.failed() { + return nil, fmt.Errorf("INetworkConnection.GetNetwork: %s", hr) + } + return net, nil +} + +type iNetworkVtbl struct { + iDispatchVtbl + GetName uintptr + SetName uintptr + GetDescription uintptr + SetDescription uintptr + GetNetworkId uintptr + GetDomainType uintptr + GetNetworkConnections uintptr + GetTimeCreatedAndConnected uintptr + IsConnectedToInternet uintptr + IsConnected uintptr + GetConnectivity uintptr + GetCategory uintptr + SetCategory uintptr +} + +type iNetwork struct{ Vtbl *iNetworkVtbl } + +func (n *iNetwork) Release() { + syscall.SyscallN(n.Vtbl.Release, uintptr(unsafe.Pointer(n))) +} + +func (n *iNetwork) GetCategory() (networkCategory, error) { + var c networkCategory + r1, _, _ := syscall.SyscallN(n.Vtbl.GetCategory, + uintptr(unsafe.Pointer(n)), uintptr(unsafe.Pointer(&c)), + ) + if hr := hresult(r1); hr.failed() { + return 0, fmt.Errorf("INetwork.GetCategory: %s", hr) + } + return c, nil +} + +func (n *iNetwork) SetCategory(c networkCategory) error { + r1, _, _ := syscall.SyscallN(n.Vtbl.SetCategory, + uintptr(unsafe.Pointer(n)), uintptr(int32(c)), + ) + if hr := hresult(r1); hr.failed() { + return fmt.Errorf("INetwork.SetCategory: %s", hr) + } + return nil +} + +// coInit initializes COM for the current OS thread. The returned function must +// be deferred to balance a successful init. RPC_E_CHANGED_MODE means COM is +// already initialized in a different mode on this thread, which is still fine +// for our calls but we must not Uninitialize in that case. +func coInit() (func(), error) { + err := windows.CoInitializeEx(0, windows.COINIT_MULTITHREADED) + if err == nil { + return windows.CoUninitialize, nil + } + if e, ok := err.(syscall.Errno); ok { + switch uint32(e) { + case hrSFALSE: + return windows.CoUninitialize, nil + case hrRPCEChangedMode: + return func() {}, nil + } + } + return nil, fmt.Errorf("CoInitializeEx: %w", err) +} + +func createNetworkListManager() (*iNetworkListManager, error) { + var nlm *iNetworkListManager + r1, _, _ := procCoCreateInstance.Call( + uintptr(unsafe.Pointer(&clsidNetworkListManager)), + 0, + uintptr(clsCtxAll), + uintptr(unsafe.Pointer(&iidINetworkListManager)), + uintptr(unsafe.Pointer(&nlm)), + ) + if hr := hresult(r1); hr.failed() { + return nil, fmt.Errorf("CoCreateInstance(NetworkListManager): %s", hr) + } + return nlm, nil +} + +// setNetworkCategory locates the network connection bound to adapterGUID and +// sets the category of its parent network. Returns errAdapterNotFound if the +// adapter is not yet visible in the NLM enumeration. +func setNetworkCategory(adapterGUID windows.GUID, cat networkCategory) error { + deinit, err := coInit() + if err != nil { + return err + } + defer deinit() + + nlm, err := createNetworkListManager() + if err != nil { + return err + } + defer nlm.Release() + + enum, err := nlm.GetNetworkConnections() + if err != nil { + return err + } + defer enum.Release() + + for { + conn, err := enum.Next() + if err != nil { + return err + } + if conn == nil { + return errAdapterNotFound + } + + guid, err := conn.GetAdapterId() + if err != nil || guid != adapterGUID { + conn.Release() + continue + } + + net, err := conn.GetNetwork() + conn.Release() + if err != nil { + return err + } + err = net.SetCategory(cat) + net.Release() + return err + } +} + +// applyNetworkCategory polls until the wintun adapter shows up in the NLM +// enumeration, then sets the category. Intended to run in its own goroutine. +func applyNetworkCategory(l *slog.Logger, adapterGUID windows.GUID, cat networkCategory) { + // COM Init/Uninit must be paired on the same OS thread. + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + const ( + attempts = 30 + interval = 500 * time.Millisecond + ) + for i := 0; i < attempts; i++ { + err := setNetworkCategory(adapterGUID, cat) + if err == nil { + l.Info("Set Windows network category", "category", cat.String()) + return + } + if !errors.Is(err, errAdapterNotFound) { + l.Warn("Failed to set Windows network category", "error", err, "category", cat.String()) + return + } + time.Sleep(interval) + } + l.Warn("Gave up waiting for adapter to appear in NLM enumeration; network category not set", + "category", cat.String(), + "waited", time.Duration(attempts)*interval, + ) +} diff --git a/overlay/network_category_windows_test.go b/overlay/network_category_windows_test.go new file mode 100644 index 00000000..c679f8c4 --- /dev/null +++ b/overlay/network_category_windows_test.go @@ -0,0 +1,109 @@ +//go:build !e2e_testing +// +build !e2e_testing + +package overlay + +import ( + "testing" +) + +func Test_parseNetworkCategory(t *testing.T) { + cases := []struct { + in string + wantCat networkCategory + wantApply bool + wantErr bool + }{ + {"", 0, false, false}, + {"unset", 0, false, false}, + {" UNSET ", 0, false, false}, + {"private", networkCategoryPrivate, true, false}, + {"Private", networkCategoryPrivate, true, false}, + {" PRIVATE ", networkCategoryPrivate, true, false}, + {"public", networkCategoryPublic, true, false}, + {"PUBLIC", networkCategoryPublic, true, false}, + {"domain", networkCategoryDomainAuthenticated, true, false}, + {"DomainAuthenticated", networkCategoryDomainAuthenticated, true, false}, + {"garbage", 0, false, true}, + {"privates", 0, false, true}, + } + for _, tc := range cases { + cat, apply, err := parseNetworkCategory(tc.in) + if (err != nil) != tc.wantErr { + t.Errorf("parseNetworkCategory(%q) err=%v, wantErr=%v", tc.in, err, tc.wantErr) + continue + } + if cat != tc.wantCat || apply != tc.wantApply { + t.Errorf("parseNetworkCategory(%q) = (%v, %v), want (%v, %v)", tc.in, cat, apply, tc.wantCat, tc.wantApply) + } + } +} + +// Test_NLM_round_trip exercises every COM call path used by setNetworkCategory +// without mutating the host's network state. It validates the CLSID/IID +// constants and every vtable index by enumerating connections, fetching the +// adapter id and parent network, reading the current category, and writing it +// back unchanged. +// +// Requires Windows but does not require admin or the wintun driver. Skips if +// no network connections are available (unlikely outside of an isolated +// container). +func Test_NLM_round_trip(t *testing.T) { + deinit, err := coInit() + if err != nil { + t.Fatalf("coInit: %v", err) + } + defer deinit() + + nlm, err := createNetworkListManager() + if err != nil { + t.Fatalf("createNetworkListManager: %v", err) + } + defer nlm.Release() + + enum, err := nlm.GetNetworkConnections() + if err != nil { + t.Fatalf("GetNetworkConnections: %v", err) + } + defer enum.Release() + + saw := 0 + for { + conn, err := enum.Next() + if err != nil { + t.Fatalf("EnumNetworkConnections.Next: %v", err) + } + if conn == nil { + break + } + saw++ + + if _, err := conn.GetAdapterId(); err != nil { + conn.Release() + t.Fatalf("INetworkConnection.GetAdapterId: %v", err) + } + + net, err := conn.GetNetwork() + conn.Release() + if err != nil { + t.Fatalf("INetworkConnection.GetNetwork: %v", err) + } + + cat, err := net.GetCategory() + if err != nil { + net.Release() + t.Fatalf("INetwork.GetCategory: %v", err) + } + // Set to the current value so the host's NLM state is unchanged but + // SetCategory's vtable slot is still validated end-to-end. + if err := net.SetCategory(cat); err != nil { + net.Release() + t.Fatalf("INetwork.SetCategory(%v): %v", cat, err) + } + net.Release() + } + + if saw == 0 { + t.Skip("no NLM network connections available; skipping round-trip") + } +} diff --git a/overlay/tun_bypass_windows.go b/overlay/tun_bypass_windows.go new file mode 100644 index 00000000..1f62373c --- /dev/null +++ b/overlay/tun_bypass_windows.go @@ -0,0 +1,23 @@ +//go:build (amd64 || arm64) && !e2e_testing +// +build amd64 arm64 +// +build !e2e_testing + +package overlay + +import ( + "log/slog" + + "github.com/slackhq/nebula/wfp" +) + +// installInterfaceBypass installs a WFP PERMIT filter scoped to the wintun interface LUID so inbound traffic on the +// nebula adapter bypasses Windows Defender Firewall. +func installInterfaceBypass(l *slog.Logger, luid uint64) closer { + s, err := wfp.PermitInterface(luid) + if err != nil { + l.Warn("Failed to install WFP bypass filters on nebula interface", "error", err) + return nil + } + l.Info("Installed WFP filters bypassing Windows Defender Firewall on nebula interface") + return s +} diff --git a/overlay/tun_bypass_windows_386.go b/overlay/tun_bypass_windows_386.go new file mode 100644 index 00000000..366430b0 --- /dev/null +++ b/overlay/tun_bypass_windows_386.go @@ -0,0 +1,11 @@ +//go:build !e2e_testing +// +build !e2e_testing + +package overlay + +import "log/slog" + +// installInterfaceBypass is a no-op on windows-386 because we don't currently build for it. +func installInterfaceBypass(_ *slog.Logger, _ uint64) closer { + return nil +} diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 14c8d499..cf01615f 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -25,15 +25,24 @@ import ( "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) +type closer interface { + Close() +} + const tunGUIDLabel = "Fixed Nebula Windows GUID v1" type winTun struct { - Device string - vpnNetworks []netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *slog.Logger + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + guid windows.GUID + networkCategory networkCategory + setCategory bool + bypassWDF bool + wdfBypass closer + l *slog.Logger tun *wintun.NativeTun } @@ -54,11 +63,20 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*w return nil, fmt.Errorf("generate GUID failed: %w", err) } + cat, setCat, err := parseNetworkCategory(c.GetString("tun.network_category", "private")) + if err != nil { + return nil, err + } + t := &winTun{ - Device: deviceName, - vpnNetworks: vpnNetworks, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, + Device: deviceName, + vpnNetworks: vpnNetworks, + MTU: c.GetInt("tun.mtu", DefaultMTU), + guid: *guid, + networkCategory: cat, + setCategory: setCat, + bypassWDF: c.GetBool("tun.windows_bypass_wdf", true), + l: l, } err = t.reload(c, true) @@ -142,6 +160,17 @@ func (t *winTun) Activate() error { return err } + if t.setCategory { + // The wintun adapter takes a moment to register with the Network List + // Manager, so we apply the category in the background and retry until + // it shows up. + go applyNetworkCategory(t.l, t.guid, t.networkCategory) + } + + if t.bypassWDF { + t.wdfBypass = installInterfaceBypass(t.l, uint64(t.tun.LUID())) + } + return nil } @@ -255,6 +284,11 @@ func (t *winTun) Close() error { _ = luid.FlushDNS(windows.AF_INET) _ = luid.FlushDNS(windows.AF_INET6) + if t.wdfBypass != nil { + t.wdfBypass.Close() + t.wdfBypass = nil + } + return t.tun.Close() } diff --git a/udp/udp_android.go b/udp/udp_android.go index 3fc68003..213ab422 100644 --- a/udp/udp_android.go +++ b/udp/udp_android.go @@ -5,12 +5,11 @@ package udp import ( "fmt" + "log/slog" "net" "net/netip" "syscall" - "log/slog" - "golang.org/x/sys/unix" ) diff --git a/udp/udp_bsd.go b/udp/udp_bsd.go index c42a3c18..31ae9c5a 100644 --- a/udp/udp_bsd.go +++ b/udp/udp_bsd.go @@ -8,12 +8,11 @@ package udp import ( "fmt" + "log/slog" "net" "net/netip" "syscall" - "log/slog" - "golang.org/x/sys/unix" ) diff --git a/udp/udp_bypass_windows.go b/udp/udp_bypass_windows.go new file mode 100644 index 00000000..b8b06b1e --- /dev/null +++ b/udp/udp_bypass_windows.go @@ -0,0 +1,57 @@ +//go:build (amd64 || arm64) && !e2e_testing +// +build amd64 arm64 +// +build !e2e_testing + +package udp + +import ( + "log/slog" + "sync" + + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/wfp" +) + +// wrapWithWDFBypass wraps a Conn so that the first ReloadConfig consults listen.windows_bypass_wdf +// and installs a WFP PERMIT filter for the listener's bound UDP port. The session is released when Close runs. +func wrapWithWDFBypass(l *slog.Logger, conn Conn) Conn { + return &bypassConn{Conn: conn, l: l} +} + +type bypassConn struct { + Conn + + l *slog.Logger + installOnce sync.Once + session *wfp.Session +} + +func (b *bypassConn) ReloadConfig(c *config.C) { + b.installOnce.Do(func() { + if !c.GetBool("listen.windows_bypass_wdf", true) { + return + } + addr, err := b.Conn.LocalAddr() + if err != nil { + b.l.Warn("Failed to query listener port for WFP bypass", "error", err) + return + } + s, err := wfp.PermitUDPPort(addr.Port()) + if err != nil { + b.l.Warn("Failed to install WFP bypass filters for listener", "error", err) + return + } + b.l.Info("Installed WFP filters bypassing Windows Defender Firewall on UDP listener port", + "port", addr.Port()) + b.session = s + }) + b.Conn.ReloadConfig(c) +} + +func (b *bypassConn) Close() error { + if b.session != nil { + b.session.Close() + b.session = nil + } + return b.Conn.Close() +} diff --git a/udp/udp_bypass_windows_386.go b/udp/udp_bypass_windows_386.go new file mode 100644 index 00000000..fa5a6eec --- /dev/null +++ b/udp/udp_bypass_windows_386.go @@ -0,0 +1,11 @@ +//go:build !e2e_testing +// +build !e2e_testing + +package udp + +import "log/slog" + +// wrapWithWDFBypass is a no-op on windows-386 since we don't currently build for it. +func wrapWithWDFBypass(_ *slog.Logger, conn Conn) Conn { + return conn +} diff --git a/udp/udp_netbsd.go b/udp/udp_netbsd.go index 4b2de75a..b0c81393 100644 --- a/udp/udp_netbsd.go +++ b/udp/udp_netbsd.go @@ -7,12 +7,11 @@ package udp import ( "fmt" + "log/slog" "net" "net/netip" "syscall" - "log/slog" - "golang.org/x/sys/unix" ) diff --git a/udp/udp_windows.go b/udp/udp_windows.go index 7969f7e8..1f34f0bc 100644 --- a/udp/udp_windows.go +++ b/udp/udp_windows.go @@ -19,13 +19,18 @@ func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) return nil, fmt.Errorf("multiple udp listeners not supported on windows") } + var conn Conn rc, err := NewRIOListener(l, ip, port) if err == nil { - return rc, nil + conn = rc + } else { + l.Error("Falling back to standard udp sockets", "error", err) + conn, err = NewGenericListener(l, ip, port, multi, batch) + if err != nil { + return nil, err + } } - - l.Error("Falling back to standard udp sockets", "error", err) - return NewGenericListener(l, ip, port, multi, batch) + return wrapWithWDFBypass(l, conn), nil } func NewListenConfig(multi bool) net.ListenConfig { diff --git a/wfp/wfp_windows.go b/wfp/wfp_windows.go new file mode 100644 index 00000000..22aa0565 --- /dev/null +++ b/wfp/wfp_windows.go @@ -0,0 +1,377 @@ +//go:build (amd64 || arm64) && !e2e_testing +// +build amd64 arm64 +// +build !e2e_testing + +// Package wfp installs Windows Filtering Platform (WFP) PERMIT filters in a dynamic, session-scoped sublayer. +// Because WFP sits below Windows Defender Firewall, a high-weight permit at FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4/V6 lets +// the matching inbound traffic through regardless of WDF rules. +// +// Each Session owns its own engine handle. When the handle closes, every dynamic object added during the session +// is auto-deleted by Windows, so there are no orphaned filters. +// +// Type definitions and constants are derived from the wireguard-windows firewall package (MIT). +// Only the subset we exercise is reproduced. +package wfp + +import ( + "fmt" + "unsafe" + + "golang.org/x/sys/windows" +) + +// FWPM layer GUIDs (fwpmu.h). +// +// FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 = e1cd9fe7-f4b5-4273-96c0-592e487b8650 +// FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V6 = a3b42c97-9f04-4672-b87e-cee9c483257f +var ( + fwpmLayerAleAuthRecvAcceptV4 = windows.GUID{ + Data1: 0xe1cd9fe7, Data2: 0xf4b5, Data3: 0x4273, + Data4: [8]byte{0x96, 0xc0, 0x59, 0x2e, 0x48, 0x7b, 0x86, 0x50}, + } + fwpmLayerAleAuthRecvAcceptV6 = windows.GUID{ + Data1: 0xa3b42c97, Data2: 0x9f04, Data3: 0x4672, + Data4: [8]byte{0xb8, 0x7e, 0xce, 0xe9, 0xc4, 0x83, 0x25, 0x7f}, + } +) + +// FWPM_CONDITION_IP_LOCAL_INTERFACE = 4cd62a49-59c3-4969-b7f3-bda5d32890a4 +var fwpmConditionIPLocalInterface = windows.GUID{ + Data1: 0x4cd62a49, Data2: 0x59c3, Data3: 0x4969, + Data4: [8]byte{0xb7, 0xf3, 0xbd, 0xa5, 0xd3, 0x28, 0x90, 0xa4}, +} + +// FWPM_CONDITION_IP_PROTOCOL = 3971ef2b-623e-4f9a-8cb1-6e79b806b9a7 +var fwpmConditionIPProtocol = windows.GUID{ + Data1: 0x3971ef2b, Data2: 0x623e, Data3: 0x4f9a, + Data4: [8]byte{0x8c, 0xb1, 0x6e, 0x79, 0xb8, 0x06, 0xb9, 0xa7}, +} + +// FWPM_CONDITION_IP_LOCAL_PORT = 0c1ba1af-5765-453f-af22-a8f791ac775b +var fwpmConditionIPLocalPort = windows.GUID{ + Data1: 0x0c1ba1af, Data2: 0x5765, Data3: 0x453f, + Data4: [8]byte{0xaf, 0x22, 0xa8, 0xf7, 0x91, 0xac, 0x77, 0x5b}, +} + +// IPPROTO_UDP from in.h. +const ipprotoUDP uint8 = 17 + +// FWP_ACTION_TYPE values (fwptypes.h). PERMIT is terminating. +const fwpActionPermit uint32 = 0x00001002 // 0x2 | FWP_ACTION_FLAG_TERMINATING(0x1000) + +// FWP_DATA_TYPE values we use. +const ( + fwpEmpty uint32 = 0 + fwpUint8 uint32 = 1 + fwpUint16 uint32 = 2 + fwpUint64 uint32 = 4 +) + +// FWP_MATCH_TYPE values. +const fwpMatchEqual uint32 = 0 + +// FWPM_SESSION flags. +const fwpmSessionFlagDynamic uint32 = 0x1 + +// FWPM_FILTER_FLAG_CLEAR_ACTION_RIGHT prevents lower-priority filters in other sublayers, +// notably Windows Defender Firewall's MPSSVC_WF sublayer, which shares our 0xFFFF weight from overriding this PERMIT. +// Without it, a default WDF block at the same sublayer weight can still win arbitration. +const fwpmFilterFlagClearActionRight uint32 = 0x8 + +// RPC authentication. +// RPC_C_AUTHN_WINNT works on workgroup machines with no domain context +// RPC_C_AUTHN_DEFAULT falls back through a chain that can land on something WFP doesn't accept on a fresh box. +const rpcCAuthnWinNT uint32 = 10 + +// fwpByteBlob (FWP_BYTE_BLOB). 16 bytes on 64-bit. +type fwpByteBlob struct { + size uint32 + _ uint32 // padding + data *uint8 +} + +// fwpValue0 / FWP_CONDITION_VALUE0 layout. 16 bytes on 64-bit. +// The union is pointer-sized; types <= 32 bits (UINT8/16/32, INT8/16/32, float) live inline in the low bytes +// of `value`, while UINT64/INT64/double and aggregate types are stored *by pointer*, even on 64-bit, where the +// union member is declared as UINT64*. So when populating an FWP_UINT64 condition, pass +// uintptr(unsafe.Pointer(&luidVar)) instead of the LUID inline. +type fwpValue0 struct { + type_ uint32 + _ uint32 // padding before union to 8-byte alignment + value uintptr +} + +// fwpmDisplayData0 / FWPM_DISPLAY_DATA0. 16 bytes on 64-bit. +type fwpmDisplayData0 struct { + name *uint16 + description *uint16 +} + +// fwpmAction0 / FWPM_ACTION0. 20 bytes; no leading padding because actionType +// is uint32 and GUID's first field is uint32. +type fwpmAction0 struct { + actionType uint32 + filterType windows.GUID +} + +// fwpmFilterCondition0. 40 bytes on 64-bit. +type fwpmFilterCondition0 struct { + fieldKey windows.GUID // 16 + matchType uint32 // 4 + _ uint32 // 4 padding + conditionValue fwpValue0 // 16 +} + +// fwpmFilter0. 200 bytes on 64-bit. +type fwpmFilter0 struct { + filterKey windows.GUID + displayData fwpmDisplayData0 + flags uint32 + _ uint32 // padding before *GUID + providerKey *windows.GUID + providerData fwpByteBlob + layerKey windows.GUID + subLayerKey windows.GUID + weight fwpValue0 + numFilterConditions uint32 + _ uint32 // padding before pointer + filterCondition *fwpmFilterCondition0 + action fwpmAction0 + _ [4]byte // layout correction + providerContextKey windows.GUID + reserved *windows.GUID + filterID uint64 + effectiveWeight fwpValue0 +} + +// fwpmSublayer0. 72 bytes on 64-bit. +type fwpmSublayer0 struct { + subLayerKey windows.GUID + displayData fwpmDisplayData0 + flags uint32 + _ uint32 // padding before *GUID + providerKey *windows.GUID + providerData fwpByteBlob + weight uint16 + _ [6]byte // padding to 72 bytes +} + +// fwpmSession0. 72 bytes on 64-bit. +type fwpmSession0 struct { + sessionKey windows.GUID + displayData fwpmDisplayData0 + flags uint32 + txnWaitTimeoutInMSec uint32 + processId uint32 + _ uint32 // padding before *SID + sid *windows.SID + username *uint16 + kernelMode uint8 + _ [7]byte // tail padding +} + +// fwpuclnt.dll bindings. Only the calls we use. +var ( + modFwpuclnt = windows.NewLazySystemDLL("fwpuclnt.dll") + procFwpmEngineOpen0 = modFwpuclnt.NewProc("FwpmEngineOpen0") + procFwpmEngineClose0 = modFwpuclnt.NewProc("FwpmEngineClose0") + procFwpmSubLayerAdd0 = modFwpuclnt.NewProc("FwpmSubLayerAdd0") + procFwpmFilterAdd0 = modFwpuclnt.NewProc("FwpmFilterAdd0") +) + +// Session holds the WFP engine handle for a single bypass operation. The handle owns a dynamic session: +// when it is closed, every WFP object added during the session (sublayer + filters) is automatically deleted by +// Windows. That gives us correct cleanup even if the host process is killed hard between Permit* and Close. +type Session struct { + engine uintptr +} + +// Close releases the engine handle. Windows deletes every dynamic object (sublayer + filters) the session installed. +// Safe to call on a nil receiver. +func (s *Session) Close() { + if s == nil || s.engine == 0 { + return + } + procFwpmEngineClose0.Call(s.engine) + s.engine = 0 +} + +// PermitInterface installs PERMIT filters at FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 and _V6 scoped to the given network +// interface LUID. Inbound traffic on that interface bypasses Windows Defender Firewall. +func PermitInterface(luid uint64) (*Session, error) { + s, sublayerKey, err := newSession() + if err != nil { + return nil, err + } + + if err := addInterfaceFilter(s.engine, sublayerKey, fwpmLayerAleAuthRecvAcceptV4, luid); err != nil { + s.Close() + return nil, fmt.Errorf("add v4 filter: %w", err) + } + if err := addInterfaceFilter(s.engine, sublayerKey, fwpmLayerAleAuthRecvAcceptV6, luid); err != nil { + s.Close() + return nil, fmt.Errorf("add v6 filter: %w", err) + } + return s, nil +} + +// PermitUDPPort installs PERMIT filters at FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 and _V6 scoped to UDP traffic with the +// given local port. Inbound UDP to that port on any interface bypasses Windows Defender Firewall. +func PermitUDPPort(port uint16) (*Session, error) { + s, sublayerKey, err := newSession() + if err != nil { + return nil, err + } + + if err := addUDPPortFilter(s.engine, sublayerKey, fwpmLayerAleAuthRecvAcceptV4, port); err != nil { + s.Close() + return nil, fmt.Errorf("add v4 filter: %w", err) + } + if err := addUDPPortFilter(s.engine, sublayerKey, fwpmLayerAleAuthRecvAcceptV6, port); err != nil { + s.Close() + return nil, fmt.Errorf("add v6 filter: %w", err) + } + return s, nil +} + +func newSession() (*Session, windows.GUID, error) { + engine, err := openDynamicEngine() + if err != nil { + return nil, windows.GUID{}, err + } + sublayerKey, err := registerSublayer(engine) + if err != nil { + procFwpmEngineClose0.Call(engine) + return nil, windows.GUID{}, err + } + return &Session{engine: engine}, sublayerKey, nil +} + +func openDynamicEngine() (uintptr, error) { + session := fwpmSession0{flags: fwpmSessionFlagDynamic} + var engine uintptr + r1, _, _ := procFwpmEngineOpen0.Call( + 0, // serverName == NULL (local) + uintptr(rpcCAuthnWinNT), + 0, // authIdentity == NULL + uintptr(unsafe.Pointer(&session)), + uintptr(unsafe.Pointer(&engine)), + ) + if r1 != 0 { + return 0, fmt.Errorf("FwpmEngineOpen0: 0x%x", r1) + } + return engine, nil +} + +// registerSublayer adds a session-scoped sublayer with a freshly generated GUID, weight 0xFFFF so its filters arbitrate +// above WDF's default sublayer. The sublayer is dynamic (no PERSISTENT flag) and goes away when the engine handle closes. +func registerSublayer(engine uintptr) (windows.GUID, error) { + key, err := windows.GenerateGUID() + if err != nil { + return windows.GUID{}, fmt.Errorf("GenerateGUID for sublayer: %w", err) + } + + name, _ := windows.UTF16PtrFromString("Nebula WDF bypass sublayer") + desc, _ := windows.UTF16PtrFromString("Permit filters bypassing Windows Defender Firewall") + sl := fwpmSublayer0{ + subLayerKey: key, + displayData: fwpmDisplayData0{name: name, description: desc}, + weight: 0xFFFF, + } + r1, _, _ := procFwpmSubLayerAdd0.Call( + engine, + uintptr(unsafe.Pointer(&sl)), + 0, // sd == NULL + ) + if r1 != 0 { + return windows.GUID{}, fmt.Errorf("FwpmSubLayerAdd0: 0x%x", r1) + } + return key, nil +} + +func addInterfaceFilter(engine uintptr, sublayerKey, layer windows.GUID, luid uint64) error { + name, _ := windows.UTF16PtrFromString("Nebula allow interface inbound") + desc, _ := windows.UTF16PtrFromString("Permits inbound traffic on a nebula interface") + + // luid must remain addressable through the syscall -- FWP_UINT64 is stored + // by pointer in the FWP_VALUE0 union. + cond := fwpmFilterCondition0{ + fieldKey: fwpmConditionIPLocalInterface, + matchType: fwpMatchEqual, + conditionValue: fwpValue0{ + type_: fwpUint64, + value: uintptr(unsafe.Pointer(&luid)), + }, + } + + filter := fwpmFilter0{ + // filterKey left zero: WFP assigns one when the filter is added. + displayData: fwpmDisplayData0{name: name, description: desc}, + flags: fwpmFilterFlagClearActionRight, + layerKey: layer, + subLayerKey: sublayerKey, + weight: fwpValue0{type_: fwpUint8, value: uintptr(15)}, + numFilterConditions: 1, + filterCondition: &cond, + action: fwpmAction0{actionType: fwpActionPermit}, + } + + r1, _, _ := procFwpmFilterAdd0.Call( + engine, + uintptr(unsafe.Pointer(&filter)), + 0, // sd == NULL + 0, // id == NULL + ) + if r1 != 0 { + return fmt.Errorf("FwpmFilterAdd0: 0x%x", r1) + } + return nil +} + +// addUDPPortFilter installs a PERMIT filter that matches (IP_PROTOCOL == UDP) AND (IP_LOCAL_PORT == port). +// FWP_UINT8 and FWP_UINT16 are <= 32 bits so they live inline in the FWP_VALUE0 union. +func addUDPPortFilter(engine uintptr, sublayerKey, layer windows.GUID, port uint16) error { + name, _ := windows.UTF16PtrFromString("Nebula allow UDP port inbound") + desc, _ := windows.UTF16PtrFromString("Permits inbound UDP to a nebula listener port") + + conds := [2]fwpmFilterCondition0{ + { + fieldKey: fwpmConditionIPProtocol, + matchType: fwpMatchEqual, + conditionValue: fwpValue0{ + type_: fwpUint8, + value: uintptr(ipprotoUDP), + }, + }, + { + fieldKey: fwpmConditionIPLocalPort, + matchType: fwpMatchEqual, + conditionValue: fwpValue0{ + type_: fwpUint16, + value: uintptr(port), + }, + }, + } + + filter := fwpmFilter0{ + displayData: fwpmDisplayData0{name: name, description: desc}, + flags: fwpmFilterFlagClearActionRight, + layerKey: layer, + subLayerKey: sublayerKey, + weight: fwpValue0{type_: fwpUint8, value: uintptr(15)}, + numFilterConditions: 2, + filterCondition: &conds[0], + action: fwpmAction0{actionType: fwpActionPermit}, + } + + r1, _, _ := procFwpmFilterAdd0.Call( + engine, + uintptr(unsafe.Pointer(&filter)), + 0, + 0, + ) + if r1 != 0 { + return fmt.Errorf("FwpmFilterAdd0: 0x%x", r1) + } + return nil +} From 398d67e2da34573801545c9403d86d8460a2c8a5 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 8 May 2026 14:43:19 -0500 Subject: [PATCH 14/27] Windows code signing (#1718) --- .github/actions/code-sign/action.yml | 113 +++++++++++++++++++++++++++ .github/workflows/release.yml | 10 +++ 2 files changed, 123 insertions(+) create mode 100644 .github/actions/code-sign/action.yml diff --git a/.github/actions/code-sign/action.yml b/.github/actions/code-sign/action.yml new file mode 100644 index 00000000..bfa1a9ec --- /dev/null +++ b/.github/actions/code-sign/action.yml @@ -0,0 +1,113 @@ +name: Code-sign Windows binaries +description: > + Sign every .exe under a given path in place via the DefinedNet code-signer + Lambda. If `role` or `bucket` is empty, logs a notice and skips signing so + forks and dev branches without AWS access still produce usable builds. + +inputs: + path: + description: "Directory whose .exe files should be signed in place" + required: true + role: + description: "IAM role ARN to assume via OIDC; empty disables signing" + required: false + default: "" + bucket: + description: "S3 staging bucket the code-signer Lambda reads from; empty disables signing" + required: false + default: "" + region: + description: "AWS region for the role and Lambda" + required: false + default: "us-east-2" + function-name: + description: "Code-signer Lambda function name" + required: false + default: "code-signer" + key-prefix: + description: "S3 key prefix the caller is authorized to write under" + required: false + default: "code-signing/slackhq/nebula" + +runs: + using: composite + steps: + - name: Skip notice + if: inputs.role == '' || inputs.bucket == '' + shell: sh + run: echo "::notice::code-signer role or bucket not set; skipping code signing." + + - name: Configure AWS credentials + if: inputs.role != '' && inputs.bucket != '' + uses: aws-actions/configure-aws-credentials@v6 + with: + role-to-assume: ${{ inputs.role }} + aws-region: ${{ inputs.region }} + # Default is 12 retries to ride out IAM trust-policy propagation; once + # the role is stable we want a real misconfiguration to fail fast. + retry-max-attempts: 5 + + - name: Sign .exe files + if: inputs.role != '' && inputs.bucket != '' + shell: sh + env: + SIGN_PATH: ${{ inputs.path }} + BUCKET: ${{ inputs.bucket }} + FUNCTION_NAME: ${{ inputs.function-name }} + KEY_PREFIX: ${{ inputs.key-prefix }} + run: | + set -eu + RUN="${GITHUB_RUN_ID}-${GITHUB_RUN_ATTEMPT}" + + find "$SIGN_PATH" -name '*.exe' -print | while read -r path + do + rel=${path#"$SIGN_PATH"/} + file=$(basename "$path") + name=${file%.exe} + prefix="${KEY_PREFIX}/${RUN}" + src="${prefix}/unsigned/${rel}" + dst="${prefix}/signed/${rel}" + + echo "::group::Sign ${rel}" + echo "Uploading unsigned to s3://${BUCKET}/${src}" + aws s3 cp --no-progress "$path" "s3://${BUCKET}/${src}" >/dev/null + + echo "Invoking ${FUNCTION_NAME} Lambda" + payload=$(jq -nc \ + --arg s "$src" \ + --arg d "$dst" \ + --arg p "$name" \ + '{source_key: $s, dest_key: $d, program_name: $p}') + meta=$(aws lambda invoke \ + --function-name "$FUNCTION_NAME" \ + --cli-binary-format raw-in-base64-out \ + --payload "$payload" \ + --output json \ + /tmp/sign-resp.json) + if echo "$meta" | jq -e '.FunctionError != null' >/dev/null + then + echo "::endgroup::" + echo "::error::code-signer Lambda failed for ${rel}" + cat /tmp/sign-resp.json >&2 + exit 1 + fi + + echo "Downloading signed back to ${path}" + aws s3 cp --no-progress "s3://${BUCKET}/${dst}" "$path" >/dev/null + + aws s3 rm "s3://${BUCKET}/${src}" >/dev/null 2>&1 || true + aws s3 rm "s3://${BUCKET}/${dst}" >/dev/null 2>&1 || true + + # Sanity-check the bytes we got back actually carry an Authenticode + # signature that this machine can validate end to end. + status=$(powershell -NoProfile -Command "(Get-AuthenticodeSignature -FilePath '$path').Status" | tr -d '\r') + if [ "$status" != "Valid" ] + then + echo "::endgroup::" + echo "::error::${rel} signature status: ${status} (expected Valid)" + exit 1 + fi + + echo "Signed ${rel} (sha256=$(jq -r '.sha256' /tmp/sign-resp.json), status=${status})" + echo "::endgroup::" + done diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 356ae363..e4ca2933 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -32,6 +32,9 @@ jobs: build-windows: name: Build Windows runs-on: windows-latest + permissions: + id-token: write + contents: read steps: - uses: actions/checkout@v6 @@ -54,6 +57,13 @@ jobs: mkdir build\dist\windows mv dist\windows\wintun build\dist\windows\ + - name: Code-sign + uses: ./.github/actions/code-sign + with: + path: build + role: ${{ secrets.DEFINED_CODE_SIGNER_ROLE }} + bucket: ${{ secrets.DEFINED_CODE_SIGNER_BUCKET }} + - name: Upload artifacts uses: actions/upload-artifact@v7 with: From 110ea8f45c11d50ef2954db6f4fa9cadf1331f69 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 15 May 2026 14:14:32 -0400 Subject: [PATCH 15/27] Bump the golang-x-dependencies group with 4 updates (#1721) Bumps the golang-x-dependencies group with 4 updates: [golang.org/x/crypto](https://github.com/golang/crypto), [golang.org/x/net](https://github.com/golang/net), [golang.org/x/sys](https://github.com/golang/sys) and [golang.org/x/term](https://github.com/golang/term). Updates `golang.org/x/crypto` from 0.50.0 to 0.51.0 - [Commits](https://github.com/golang/crypto/compare/v0.50.0...v0.51.0) Updates `golang.org/x/net` from 0.53.0 to 0.54.0 - [Commits](https://github.com/golang/net/compare/v0.53.0...v0.54.0) Updates `golang.org/x/sys` from 0.43.0 to 0.44.0 - [Commits](https://github.com/golang/sys/compare/v0.43.0...v0.44.0) Updates `golang.org/x/term` from 0.42.0 to 0.43.0 - [Commits](https://github.com/golang/term/compare/v0.42.0...v0.43.0) --- updated-dependencies: - dependency-name: golang.org/x/crypto dependency-version: 0.51.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/net dependency-version: 0.54.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/sys dependency-version: 0.44.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/term dependency-version: 0.43.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 8 ++++---- go.sum | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/go.mod b/go.mod index 84728201..ee51151f 100644 --- a/go.mod +++ b/go.mod @@ -24,12 +24,12 @@ require ( github.com/vishvananda/netlink v1.3.1 go.uber.org/goleak v1.3.0 go.yaml.in/yaml/v3 v3.0.4 - golang.org/x/crypto v0.50.0 + golang.org/x/crypto v0.51.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 - golang.org/x/net v0.53.0 + golang.org/x/net v0.54.0 golang.org/x/sync v0.20.0 - golang.org/x/sys v0.43.0 - golang.org/x/term v0.42.0 + golang.org/x/sys v0.44.0 + golang.org/x/term v0.43.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.6.1 diff --git a/go.sum b/go.sum index 3b0b87df..5640bd46 100644 --- a/go.sum +++ b/go.sum @@ -162,8 +162,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= -golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= +golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI= +golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= @@ -182,8 +182,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= -golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= +golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w= +golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -208,11 +208,11 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= -golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ= +golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY= -golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY= +golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4= +golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= From 6c7ebb08759ddfea0c628bcc7c3069d379edee1b Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 15 May 2026 15:35:49 -0500 Subject: [PATCH 16/27] Reset static host list addresses on change (#1713) --- lighthouse.go | 25 +++++++-- lighthouse_test.go | 126 ++++++++++++++++++++++++++++++++++++++++++++ remote_list.go | 25 +++++++++ remote_list_test.go | 89 +++++++++++++++++++++++++++++++ 4 files changed, 261 insertions(+), 4 deletions(-) diff --git a/lighthouse.go b/lighthouse.go index 1a136a1b..d23e84b8 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -272,16 +272,18 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { //NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config if initial || c.HasChanged("static_host_map") || c.HasChanged("static_map.cadence") || c.HasChanged("static_map.network") || c.HasChanged("static_map.lookup_timeout") { // Clean up. Entries still in the static_host_map will be re-built. - // Entries no longer present must have their (possible) background DNS goroutines stopped. - if existingStaticList := lh.staticList.Load(); existingStaticList != nil { + ourselves := lh.myVpnNetworks[0].Addr() + oldStaticList := lh.staticList.Load() + if oldStaticList != nil { lh.RLock() - for staticVpnAddr := range *existingStaticList { + for staticVpnAddr := range *oldStaticList { if am, ok := lh.addrMap[staticVpnAddr]; ok && am != nil { - am.hr.Cancel() + am.ResetForOwner(ourselves) } } lh.RUnlock() } + // Build a new list based on current config. staticList := make(map[netip.Addr]struct{}) err := lh.loadStaticMap(c, staticList) @@ -289,6 +291,21 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { return err } + // For entries removed from static_host_map, stop the DNS goroutine and drop the cached addrs. + // All addrs must come from the lighthouses now that it's no longer a static host. + if oldStaticList != nil { + lh.RLock() + for staticVpnAddr := range *oldStaticList { + if _, stillStatic := staticList[staticVpnAddr]; stillStatic { + continue + } + if am, ok := lh.addrMap[staticVpnAddr]; ok && am != nil { + am.ClearHostnameResults() + } + } + lh.RUnlock() + } + lh.staticList.Store(&staticList) if !initial { if c.HasChanged("static_host_map") { diff --git a/lighthouse_test.go b/lighthouse_test.go index c57c44ec..81c883ff 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -303,6 +303,132 @@ func TestLighthouse_reload(t *testing.T) { require.NoError(t, err) } +// TestLighthouse_reloadStaticHostMap verifies that reloading static_host_map applies the new +// config rather than appending to it. See issue #718. +func TestLighthouse_reloadStaticHostMap(t *testing.T) { + l := test.NewLogger() + c := config.NewC(l) + c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true} + c.Settings["listen"] = map[string]any{"port": 4242} + c.Settings["static_host_map"] = map[string]any{ + "10.128.0.2": []any{"1.1.1.1:4242"}, + } + + myVpnNet := netip.MustParsePrefix("10.128.0.1/24") + nt := new(bart.Lite) + nt.Insert(myVpnNet) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } + + lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil) + require.NoError(t, err) + + staticHost := netip.MustParseAddr("10.128.0.2") + otherHost := netip.MustParseAddr("10.128.0.3") + + // Capture the RemoteList pointer up front; an in-flight handshake would hold the same one + // on hostinfo.remotes, so it must reflect every reload below. + pinned := lh.Query(staticHost) + require.NotNil(t, pinned) + assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:4242")}, pinned.CopyAddrs([]netip.Prefix{})) + + // Replace the remote address. The new address should be the only entry. + nc := map[string]any{ + "static_host_map": map[string]any{ + "10.128.0.2": []any{"2.2.2.2:4242"}, + }, + } + rc, err := yaml.Marshal(nc) + require.NoError(t, err) + require.NoError(t, c.ReloadConfigString(string(rc))) + + rl := lh.Query(staticHost) + require.NotNil(t, rl) + assert.Same(t, pinned, rl, "RemoteList pointer must stay stable so in-flight handshakes pick up the change") + assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("2.2.2.2:4242")}, rl.CopyAddrs([]netip.Prefix{})) + + // Reload back to the original IP. Mirrors the round-trip in issue #718 step 6-8 where + // the buggy reload produced [1.1.1.1, 2.2.2.2, 1.1.1.1] instead of [1.1.1.1]. + nc = map[string]any{ + "static_host_map": map[string]any{ + "10.128.0.2": []any{"1.1.1.1:4242"}, + }, + } + rc, err = yaml.Marshal(nc) + require.NoError(t, err) + require.NoError(t, c.ReloadConfigString(string(rc))) + + rl = lh.Query(staticHost) + require.NotNil(t, rl) + assert.Same(t, pinned, rl) + assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:4242")}, rl.CopyAddrs([]netip.Prefix{})) + + // Reload with the same config. An unchanged entry must not duplicate. + require.NoError(t, c.ReloadConfigString(string(rc))) + + rl = lh.Query(staticHost) + require.NotNil(t, rl) + assert.Same(t, pinned, rl) + assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:4242")}, rl.CopyAddrs([]netip.Prefix{})) + + // Switch back to 2.2.2.2 so the rest of the test continues against a known address. + nc = map[string]any{ + "static_host_map": map[string]any{ + "10.128.0.2": []any{"2.2.2.2:4242"}, + }, + } + rc, err = yaml.Marshal(nc) + require.NoError(t, err) + require.NoError(t, c.ReloadConfigString(string(rc))) + + // Add a second host alongside the first. Both should be present, neither duplicated. + nc = map[string]any{ + "static_host_map": map[string]any{ + "10.128.0.2": []any{"2.2.2.2:4242"}, + "10.128.0.3": []any{"3.3.3.3:4242"}, + }, + } + rc, err = yaml.Marshal(nc) + require.NoError(t, err) + require.NoError(t, c.ReloadConfigString(string(rc))) + + rl = lh.Query(staticHost) + require.NotNil(t, rl) + assert.Same(t, pinned, rl, "adding a sibling entry must not displace the existing RemoteList") + assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("2.2.2.2:4242")}, rl.CopyAddrs([]netip.Prefix{})) + + rl = lh.Query(otherHost) + require.NotNil(t, rl) + assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("3.3.3.3:4242")}, rl.CopyAddrs([]netip.Prefix{})) + + // Drop the first host entirely. The vpnAddr is no longer marked static, our owner + // contribution is cleared, but the addrMap entry stays in place so non-static cache + // data (from lighthouse queries) on the same RemoteList isn't lost. In-flight handshakes + // that already had the pointer see an empty address list rather than retrying stale ones. + nc = map[string]any{ + "static_host_map": map[string]any{ + "10.128.0.3": []any{"3.3.3.3:4242"}, + }, + } + rc, err = yaml.Marshal(nc) + require.NoError(t, err) + require.NoError(t, c.ReloadConfigString(string(rc))) + + _, isStatic := lh.GetStaticHostList()[staticHost] + assert.False(t, isStatic) + + rl = lh.Query(staticHost) + require.NotNil(t, rl) + assert.Same(t, pinned, rl) + assert.Empty(t, rl.CopyAddrs([]netip.Prefix{})) + + rl = lh.Query(otherHost) + require.NotNil(t, rl) + assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("3.3.3.3:4242")}, rl.CopyAddrs([]netip.Prefix{})) +} + func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply { req := &NebulaMeta{ Type: NebulaMeta_HostQuery, diff --git a/remote_list.go b/remote_list.go index 7b95de87..ef6eb794 100644 --- a/remote_list.go +++ b/remote_list.go @@ -239,6 +239,31 @@ func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) { r.hr = hr } +// ResetForOwner zeros the reported address slices for the given owner and marks the addrs list dirty. +// Any pending hostname resolution will be canceled. +func (r *RemoteList) ResetForOwner(ownerVpnAddr netip.Addr) { + r.Lock() + defer r.Unlock() + r.hr.Cancel() + if c, ok := r.cache[ownerVpnAddr]; ok { + if c.v4 != nil { + c.v4.reported = c.v4.reported[:0] + } + if c.v6 != nil { + c.v6.reported = c.v6.reported[:0] + } + } + r.shouldRebuild = true +} + +// ClearHostnameResults cancels the in-flight DNS resolver goroutine (if any) and drops the resolved IP cache. +func (r *RemoteList) ClearHostnameResults() { + r.Lock() + defer r.Unlock() + r.unlockedSetHostnamesResults(nil) + r.shouldRebuild = true +} + // Len locks and reports the size of the deduplicated address list // The deduplication work may need to occur here, so you must pass preferredRanges func (r *RemoteList) Len(preferredRanges []netip.Prefix) int { diff --git a/remote_list_test.go b/remote_list_test.go index 0caf86a4..0b5b7d5d 100644 --- a/remote_list_test.go +++ b/remote_list_test.go @@ -6,8 +6,22 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +// trackedHostnameResults builds a *hostnamesResults with a known cancel function and a +// pre-populated ips map so tests can assert cancellation and verify previously-resolved +// IPs survive a cancel without spinning up a real DNS resolver. +func trackedHostnameResults(cancelFn func(), addrs ...string) *hostnamesResults { + hr := &hostnamesResults{cancelFn: cancelFn} + ips := map[netip.AddrPort]struct{}{} + for _, a := range addrs { + ips[netip.MustParseAddrPort(a)] = struct{}{} + } + hr.ips.Store(&ips) + return hr +} + func TestRemoteList_Rebuild(t *testing.T) { rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil) rl.unlockedSetV4( @@ -112,6 +126,81 @@ func TestRemoteList_Rebuild(t *testing.T) { assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String()) } +func TestRemoteList_ResetForOwner(t *testing.T) { + ourselves := netip.MustParseAddr("10.0.0.1") + otherOwner := netip.MustParseAddr("10.0.0.2") + vpnAddr := netip.MustParseAddr("10.0.0.99") + + rl := NewRemoteList([]netip.Addr{vpnAddr}, nil) + rl.unlockedSetV4(ourselves, vpnAddr, + []*V4AddrPort{newIp4AndPortFromString("1.1.1.1:4242")}, + func(netip.Addr, *V4AddrPort) bool { return true }, + ) + rl.unlockedSetV6(ourselves, vpnAddr, + []*V6AddrPort{newIp6AndPortFromString("[1::1]:4242")}, + func(netip.Addr, *V6AddrPort) bool { return true }, + ) + rl.unlockedSetV4(otherOwner, vpnAddr, + []*V4AddrPort{newIp4AndPortFromString("2.2.2.2:4242")}, + func(netip.Addr, *V4AddrPort) bool { return true }, + ) + + canceled := 0 + hr := trackedHostnameResults(func() { canceled++ }, "3.3.3.3:4242") + rl.Lock() + rl.unlockedSetHostnamesResults(hr) + rl.Unlock() + + rl.ResetForOwner(ourselves) + + rl.RLock() + defer rl.RUnlock() + assert.Empty(t, rl.cache[ourselves].v4.reported, "our v4 reported should be cleared") + assert.Empty(t, rl.cache[ourselves].v6.reported, "our v6 reported should be cleared") + assert.Len(t, rl.cache[otherOwner].v4.reported, 1, "other owner's contribution must be preserved") + assert.Equal(t, "2.2.2.2:4242", protoV4AddrPortToNetAddrPort(rl.cache[otherOwner].v4.reported[0]).String()) + assert.Equal(t, 1, canceled, "DNS resolution goroutine should be canceled") + assert.Same(t, hr, rl.hr, "hostnamesResults must be preserved so DNS-resolved IPs keep feeding addrs until replaced") + assert.NotEmpty(t, rl.hr.GetAddrs(), "previously-resolved IPs should still be readable after cancel") + assert.True(t, rl.shouldRebuild, "shouldRebuild must be set so the next Rebuild recomputes addrs") +} + +func TestRemoteList_ResetForOwner_NoEntry(t *testing.T) { + // An owner with no cache entry must not panic; shouldRebuild is still set and any + // existing hostnamesResults is canceled. + rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("10.0.0.99")}, nil) + canceled := 0 + rl.Lock() + rl.unlockedSetHostnamesResults(trackedHostnameResults(func() { canceled++ }, "3.3.3.3:4242")) + rl.Unlock() + + rl.ResetForOwner(netip.MustParseAddr("10.0.0.1")) + + rl.RLock() + defer rl.RUnlock() + assert.Equal(t, 1, canceled) + assert.True(t, rl.shouldRebuild) +} + +func TestRemoteList_ClearHostnameResults(t *testing.T) { + rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("10.0.0.99")}, nil) + + canceled := 0 + hr := trackedHostnameResults(func() { canceled++ }, "3.3.3.3:4242") + rl.Lock() + rl.unlockedSetHostnamesResults(hr) + rl.Unlock() + require.NotEmpty(t, hr.GetAddrs(), "hostnamesResults should have its fastrack IPs populated") + + rl.ClearHostnameResults() + + rl.RLock() + defer rl.RUnlock() + assert.Equal(t, 1, canceled, "DNS resolution goroutine should be canceled") + assert.Nil(t, rl.hr, "hostnamesResults should be dropped") + assert.True(t, rl.shouldRebuild) +} + func BenchmarkFullRebuild(b *testing.B) { rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil) rl.unlockedSetV4( From 3c121e7ab1b9f0369d72a78c60edacd8a7cf6b2f Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 15 May 2026 15:36:08 -0500 Subject: [PATCH 17/27] Allow for `-` to stand in for stdin/out (#1714) --- cmd/nebula-cert/ca.go | 39 ++++++-- cmd/nebula-cert/ca_test.go | 79 ++++++++++++++-- cmd/nebula-cert/keygen.go | 15 ++- cmd/nebula-cert/keygen_test.go | 41 ++++++++ cmd/nebula-cert/passwords.go | 4 +- cmd/nebula-cert/print.go | 31 ++++-- cmd/nebula-cert/print_test.go | 39 ++++++++ cmd/nebula-cert/sign.go | 62 ++++++++---- cmd/nebula-cert/sign_test.go | 120 +++++++++++++++++++++-- cmd/nebula-cert/stdio.go | 117 +++++++++++++++++++++++ cmd/nebula-cert/stdio_test.go | 167 +++++++++++++++++++++++++++++++++ cmd/nebula-cert/verify.go | 17 +++- cmd/nebula-cert/verify_test.go | 44 +++++++++ 13 files changed, 718 insertions(+), 57 deletions(-) create mode 100644 cmd/nebula-cert/stdio.go create mode 100644 cmd/nebula-cert/stdio_test.go diff --git a/cmd/nebula-cert/ca.go b/cmd/nebula-cert/ca.go index cd9b82f9..3145f445 100644 --- a/cmd/nebula-cert/ca.go +++ b/cmd/nebula-cert/ca.go @@ -97,6 +97,19 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error if err = mustFlagString("out-key", cf.outKeyPath); err != nil { return err } + } else { + // out-key is meaningless under PKCS#11 because the private key never + // leaves the HSM; reject it so we never silently accept or claim a + // stdout slot for it. + outKeySet := false + cf.set.Visit(func(f *flag.Flag) { + if f.Name == "out-key" { + outKeySet = true + } + }) + if outKeySet { + return newHelpErrorf("cannot set -out-key with -pkcs11") + } } if err := mustFlagString("out-crt", cf.outCertPath); err != nil { return err @@ -171,12 +184,21 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error } } + var claims ioClaims + if err := reserveOutputs(&claims, + "out-key", *cf.outKeyPath, + "out-crt", *cf.outCertPath, + "out-qr", *cf.outQRPath, + ); err != nil { + return err + } + var passphrase []byte if !isP11 && *cf.encryption { passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE")) if len(passphrase) == 0 { for i := 0; i < 5; i++ { - out.Write([]byte("Enter passphrase: ")) + errOut.Write([]byte("Enter passphrase: ")) passphrase, err = pr.ReadPassword() if err == ErrNoTerminal { @@ -261,14 +283,16 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error Curve: curve, } - if !isP11 { + if !isP11 && !isStdio(*cf.outKeyPath) { if _, err := os.Stat(*cf.outKeyPath); err == nil { return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath) } } - if _, err := os.Stat(*cf.outCertPath); err == nil { - return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath) + if !isStdio(*cf.outCertPath) { + if _, err := os.Stat(*cf.outCertPath); err == nil { + return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath) + } } var c cert.Certificate @@ -294,7 +318,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error b = cert.MarshalSigningPrivateKeyToPEM(curve, rawPriv) } - err = os.WriteFile(*cf.outKeyPath, b, 0600) + err = writeOutput(*cf.outKeyPath, b, 0600, out) if err != nil { return fmt.Errorf("error while writing out-key: %s", err) } @@ -305,7 +329,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error return fmt.Errorf("error while marshalling certificate: %s", err) } - err = os.WriteFile(*cf.outCertPath, b, 0600) + err = writeOutput(*cf.outCertPath, b, 0600, out) if err != nil { return fmt.Errorf("error while writing out-crt: %s", err) } @@ -316,7 +340,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error return fmt.Errorf("error while generating qr code: %s", err) } - err = os.WriteFile(*cf.outQRPath, b, 0600) + err = writeOutput(*cf.outQRPath, b, 0600, out) if err != nil { return fmt.Errorf("error while writing out-qr: %s", err) } @@ -332,6 +356,7 @@ func caSummary() string { func caHelp(out io.Writer) { cf := newCaFlags() out.Write([]byte("Usage of " + os.Args[0] + " " + caSummary() + "\n")) + out.Write([]byte(stdioHelpText)) cf.set.SetOutput(out) cf.set.PrintDefaults() } diff --git a/cmd/nebula-cert/ca_test.go b/cmd/nebula-cert/ca_test.go index 779d3a2d..ce0113b6 100644 --- a/cmd/nebula-cert/ca_test.go +++ b/cmd/nebula-cert/ca_test.go @@ -27,6 +27,7 @@ func Test_caHelp(t *testing.T) { assert.Equal( t, "Usage of "+os.Args[0]+" ca : create a self signed certificate authority\n"+ + " Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+ " -argon-iterations uint\n"+ " \tOptional: Argon2 iterations parameter used for encrypted private key passphrase (default 1)\n"+ " -argon-memory uint\n"+ @@ -84,7 +85,7 @@ func Test_ca(t *testing.T) { err: nil, } - pwPromptOb := "Enter passphrase: " + pwPromptEB := "Enter passphrase: " // required args assertHelpError(t, ca( @@ -168,8 +169,8 @@ func Test_ca(t *testing.T) { eb.Reset() args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} require.NoError(t, ca(args, ob, eb, testpw)) - assert.Equal(t, pwPromptOb, ob.String()) - assert.Empty(t, eb.String()) + assert.Empty(t, ob.String()) + assert.Equal(t, pwPromptEB, eb.String()) // test encrypted key with passphrase environment variable os.Remove(keyF.Name()) @@ -207,8 +208,8 @@ func Test_ca(t *testing.T) { eb.Reset() args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} require.Error(t, ca(args, ob, eb, errpw)) - assert.Equal(t, pwPromptOb, ob.String()) - assert.Empty(t, eb.String()) + assert.Empty(t, ob.String()) + assert.Equal(t, pwPromptEB, eb.String()) // test when user fails to enter a password os.Remove(keyF.Name()) @@ -217,8 +218,8 @@ func Test_ca(t *testing.T) { eb.Reset() args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext") - assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up - assert.Empty(t, eb.String()) + assert.Empty(t, ob.String()) + assert.Equal(t, strings.Repeat(pwPromptEB, 5), eb.String()) // prompts 5 times before giving up // create valid cert/key for overwrite tests os.Remove(keyF.Name()) @@ -247,3 +248,67 @@ func Test_ca(t *testing.T) { os.Remove(keyF.Name()) } + +func Test_ca_stdio(t *testing.T) { + nopw := &StubPasswordReader{} + + keyF, err := os.CreateTemp("", "ca.key") + require.NoError(t, err) + os.Remove(keyF.Name()) + defer os.Remove(keyF.Name()) + + crtF, err := os.CreateTemp("", "ca.crt") + require.NoError(t, err) + os.Remove(crtF.Name()) + defer os.Remove(crtF.Name()) + + // out-crt on stdout, out-key on disk + ob := &bytes.Buffer{} + eb := &bytes.Buffer{} + require.NoError(t, ca([]string{"-name", "test-ca", "-duration", "1h", "-out-crt", "-", "-out-key", keyF.Name()}, ob, eb, nopw)) + assert.Empty(t, eb.String()) + c, _, err := cert.UnmarshalCertificateFromPEM(ob.Bytes()) + require.NoError(t, err) + assert.True(t, c.IsCA()) + assert.Equal(t, "test-ca", c.Name()) + + // out-key on stdout, out-crt on disk + os.Remove(keyF.Name()) + ob.Reset() + eb.Reset() + require.NoError(t, ca([]string{"-name", "test-ca", "-duration", "1h", "-out-crt", crtF.Name(), "-out-key", "-"}, ob, eb, nopw)) + assert.Empty(t, eb.String()) + _, _, curve, err := cert.UnmarshalSigningPrivateKeyFromPEM(ob.Bytes()) + require.NoError(t, err) + assert.Equal(t, cert.Curve_CURVE25519, curve) + + // dual stdout is rejected up front + os.Remove(crtF.Name()) + ob.Reset() + eb.Reset() + require.EqualError(t, + ca([]string{"-name", "test-ca", "-duration", "1h", "-out-crt", "-", "-out-key", "-"}, ob, eb, nopw), + `-out-key and -out-crt both set to "-", only one output may write to stdout`) + assert.Empty(t, ob.String()) + + // an output conflict combined with -encrypt must error BEFORE prompting + // for a passphrase; pr would record any read attempt + tracker := &trackingPasswordReader{} + ob.Reset() + eb.Reset() + require.EqualError(t, + ca([]string{"-name", "test-ca", "-duration", "1h", "-encrypt", "-out-crt", "-", "-out-key", "-"}, ob, eb, tracker), + `-out-key and -out-crt both set to "-", only one output may write to stdout`) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + assert.Zero(t, tracker.calls, "passphrase prompt should not have been called") +} + +type trackingPasswordReader struct { + calls int +} + +func (pr *trackingPasswordReader) ReadPassword() ([]byte, error) { + pr.calls++ + return []byte(""), nil +} diff --git a/cmd/nebula-cert/keygen.go b/cmd/nebula-cert/keygen.go index 496f84c2..dea6c4af 100644 --- a/cmd/nebula-cert/keygen.go +++ b/cmd/nebula-cert/keygen.go @@ -42,6 +42,8 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error { if err = mustFlagString("out-key", cf.outKeyPath); err != nil { return err } + } else if *cf.outKeyPath != "" { + return newHelpErrorf("cannot set -out-key with -pkcs11") } if err = mustFlagString("out-pub", cf.outPubPath); err != nil { return err @@ -69,6 +71,14 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error { } } + var claims ioClaims + if err := reserveOutputs(&claims, + "out-key", *cf.outKeyPath, + "out-pub", *cf.outPubPath, + ); err != nil { + return err + } + if isP11 { p11Client, err := pkclient.FromUrl(*cf.p11url) if err != nil { @@ -82,12 +92,12 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("error while getting public key: %w", err) } } else { - err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600) + err = writeOutput(*cf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600, out) if err != nil { return fmt.Errorf("error while writing out-key: %s", err) } } - err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKeyToPEM(curve, pub), 0600) + err = writeOutput(*cf.outPubPath, cert.MarshalPublicKeyToPEM(curve, pub), 0600, out) if err != nil { return fmt.Errorf("error while writing out-pub: %s", err) } @@ -102,6 +112,7 @@ func keygenSummary() string { func keygenHelp(out io.Writer) { cf := newKeygenFlags() _, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n")) + _, _ = out.Write([]byte(stdioHelpText)) cf.set.SetOutput(out) cf.set.PrintDefaults() } diff --git a/cmd/nebula-cert/keygen_test.go b/cmd/nebula-cert/keygen_test.go index 95d9893e..98c4c456 100644 --- a/cmd/nebula-cert/keygen_test.go +++ b/cmd/nebula-cert/keygen_test.go @@ -20,6 +20,7 @@ func Test_keygenHelp(t *testing.T) { assert.Equal( t, "Usage of "+os.Args[0]+" keygen : create a public/private key pair. the public key can be passed to `nebula-cert sign`\n"+ + " Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+ " -curve string\n"+ " \tECDH Curve (25519, P256) (default \"25519\")\n"+ " -out-key string\n"+ @@ -93,3 +94,43 @@ func Test_keygen(t *testing.T) { require.NoError(t, err) assert.Len(t, lPub, 32) } + +func Test_keygen_stdio(t *testing.T) { + keyF, err := os.CreateTemp("", "test.key") + require.NoError(t, err) + os.Remove(keyF.Name()) + defer os.Remove(keyF.Name()) + + pubF, err := os.CreateTemp("", "test.pub") + require.NoError(t, err) + os.Remove(pubF.Name()) + defer os.Remove(pubF.Name()) + + // out-pub on stdout, out-key on disk + ob := &bytes.Buffer{} + eb := &bytes.Buffer{} + require.NoError(t, keygen([]string{"-out-pub", "-", "-out-key", keyF.Name()}, ob, eb)) + assert.Empty(t, eb.String()) + lPub, _, curve, err := cert.UnmarshalPublicKeyFromPEM(ob.Bytes()) + require.NoError(t, err) + assert.Equal(t, cert.Curve_CURVE25519, curve) + assert.Len(t, lPub, 32) + + // out-key on stdout, out-pub on disk + os.Remove(keyF.Name()) + ob.Reset() + eb.Reset() + require.NoError(t, keygen([]string{"-out-pub", pubF.Name(), "-out-key", "-"}, ob, eb)) + assert.Empty(t, eb.String()) + lKey, _, curve, err := cert.UnmarshalPrivateKeyFromPEM(ob.Bytes()) + require.NoError(t, err) + assert.Equal(t, cert.Curve_CURVE25519, curve) + assert.Len(t, lKey, 32) + + // both on stdout is a conflict caught up front + ob.Reset() + eb.Reset() + require.EqualError(t, keygen([]string{"-out-pub", "-", "-out-key", "-"}, ob, eb), + `-out-key and -out-pub both set to "-", only one output may write to stdout`) + assert.Empty(t, ob.String()) +} diff --git a/cmd/nebula-cert/passwords.go b/cmd/nebula-cert/passwords.go index 8129560e..0aa2115d 100644 --- a/cmd/nebula-cert/passwords.go +++ b/cmd/nebula-cert/passwords.go @@ -22,7 +22,9 @@ func (pr StdinPasswordReader) ReadPassword() ([]byte, error) { } password, err := term.ReadPassword(int(os.Stdin.Fd())) - fmt.Println() + // Terminal echo is off while reading, so the user's Enter key does not + // produce a visible newline. Emit one on stderr to match the prompt. + fmt.Fprintln(os.Stderr) return password, err } diff --git a/cmd/nebula-cert/print.go b/cmd/nebula-cert/print.go index 30e0965b..3ba0571e 100644 --- a/cmd/nebula-cert/print.go +++ b/cmd/nebula-cert/print.go @@ -40,11 +40,23 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { return err } - rawCert, err := os.ReadFile(*pf.path) + var claims ioClaims + if err := reserveInputs(&claims, "path", *pf.path); err != nil { + return err + } + if err := reserveOutputs(&claims, "out-qr", *pf.outQRPath); err != nil { + return err + } + + rawCert, err := readInput("path", *pf.path, &claims) if err != nil { return fmt.Errorf("unable to read cert; %s", err) } + // When the QR is going to stdout, suppress the human-readable text/json + // output so the binary stream is not contaminated. + qrToStdout := isStdio(*pf.outQRPath) + var c cert.Certificate var qrBytes []byte part := 0 @@ -57,11 +69,13 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("error while unmarshaling cert: %s", err) } - if *pf.json { - jsonCerts = append(jsonCerts, c) - } else { - _, _ = out.Write([]byte(c.String())) - _, _ = out.Write([]byte("\n")) + if !qrToStdout { + if *pf.json { + jsonCerts = append(jsonCerts, c) + } else { + _, _ = out.Write([]byte(c.String())) + _, _ = out.Write([]byte("\n")) + } } if *pf.outQRPath != "" { @@ -79,7 +93,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { part++ } - if *pf.json { + if *pf.json && !qrToStdout { b, _ := json.Marshal(jsonCerts) _, _ = out.Write(b) _, _ = out.Write([]byte("\n")) @@ -91,7 +105,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("error while generating qr code: %s", err) } - err = os.WriteFile(*pf.outQRPath, b, 0600) + err = writeOutput(*pf.outQRPath, b, 0600, out) if err != nil { return fmt.Errorf("error while writing out-qr: %s", err) } @@ -107,6 +121,7 @@ func printSummary() string { func printHelp(out io.Writer) { pf := newPrintFlags() out.Write([]byte("Usage of " + os.Args[0] + " " + printSummary() + "\n")) + out.Write([]byte(stdioHelpText)) pf.set.SetOutput(out) pf.set.PrintDefaults() } diff --git a/cmd/nebula-cert/print_test.go b/cmd/nebula-cert/print_test.go index 221ab778..8d5d31be 100644 --- a/cmd/nebula-cert/print_test.go +++ b/cmd/nebula-cert/print_test.go @@ -25,6 +25,7 @@ func Test_printHelp(t *testing.T) { assert.Equal( t, "Usage of "+os.Args[0]+" print : prints details about a certificate\n"+ + " Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+ " -json\n"+ " \tOptional: outputs certificates in json format\n"+ " -out-qr string\n"+ @@ -178,6 +179,44 @@ func Test_printCert(t *testing.T) { ob.String(), ) assert.Empty(t, eb.String()) + + // read cert from stdin + ob.Reset() + eb.Reset() + withStdin(t, bytes.NewReader(p)) + err = printCert([]string{"-json", "-path", "-"}, ob, eb) + require.NoError(t, err) + assert.Equal( + t, + `[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}] +`, + ob.String(), + ) + assert.Empty(t, eb.String()) + + // -out-qr - sends only the PNG to stdout, suppressing the cert dump + ob.Reset() + eb.Reset() + withStdin(t, bytes.NewReader(p)) + err = printCert([]string{"-path", "-", "-out-qr", "-"}, ob, eb) + require.NoError(t, err) + assert.Empty(t, eb.String()) + stdout := ob.Bytes() + require.NotEmpty(t, stdout) + // PNG magic, no PEM/JSON noise prepended + assert.Equal(t, []byte{0x89, 'P', 'N', 'G', 0x0d, 0x0a, 0x1a, 0x0a}, stdout[:8]) + assert.NotContains(t, string(stdout), "NebulaCertificate") + assert.NotContains(t, string(stdout), `"details"`) + + // json + out-qr - still suppresses json + ob.Reset() + eb.Reset() + withStdin(t, bytes.NewReader(p)) + err = printCert([]string{"-json", "-path", "-", "-out-qr", "-"}, ob, eb) + require.NoError(t, err) + assert.Empty(t, eb.String()) + assert.Equal(t, []byte{0x89, 'P', 'N', 'G'}, ob.Bytes()[:4]) + assert.NotContains(t, ob.String(), `"details"`) } // NewTestCaCert will generate a CA cert diff --git a/cmd/nebula-cert/sign.go b/cmd/nebula-cert/sign.go index 561138ca..9b57c4fe 100644 --- a/cmd/nebula-cert/sign.go +++ b/cmd/nebula-cert/sign.go @@ -85,6 +85,9 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" { return newHelpErrorf("cannot set both -in-pub and -out-key") } + if isP11 && *sf.outKeyPath != "" { + return newHelpErrorf("cannot set -out-key with -pkcs11") + } var v4Networks []netip.Prefix var v6Networks []netip.Prefix @@ -102,13 +105,35 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2) } + if *sf.outKeyPath == "" { + *sf.outKeyPath = *sf.name + ".key" + } + if *sf.outCertPath == "" { + *sf.outCertPath = *sf.name + ".crt" + } + + var claims ioClaims + if err := reserveInputs(&claims, + "ca-key", *sf.caKeyPath, + "ca-crt", *sf.caCertPath, + "in-pub", *sf.inPubPath, + ); err != nil { + return err + } + if err := reserveOutputs(&claims, + "out-key", *sf.outKeyPath, + "out-crt", *sf.outCertPath, + "out-qr", *sf.outQRPath, + ); err != nil { + return err + } + var curve cert.Curve var caKey []byte if !isP11 { var rawCAKey []byte - rawCAKey, err := os.ReadFile(*sf.caKeyPath) - + rawCAKey, err = readInput("ca-key", *sf.caKeyPath, &claims) if err != nil { return fmt.Errorf("error while reading ca-key: %s", err) } @@ -121,7 +146,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) if len(passphrase) == 0 { // ask for a passphrase until we get one for i := 0; i < 5; i++ { - out.Write([]byte("Enter passphrase: ")) + errOut.Write([]byte("Enter passphrase: ")) passphrase, err = pr.ReadPassword() if errors.Is(err, ErrNoTerminal) { @@ -147,7 +172,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) } } - rawCACert, err := os.ReadFile(*sf.caCertPath) + rawCACert, err := readInput("ca-crt", *sf.caCertPath, &claims) if err != nil { return fmt.Errorf("error while reading ca-crt: %s", err) } @@ -245,7 +270,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) if *sf.inPubPath != "" { var pubCurve cert.Curve - rawPub, err := os.ReadFile(*sf.inPubPath) + rawPub, err := readInput("in-pub", *sf.inPubPath, &claims) if err != nil { return fmt.Errorf("error while reading in-pub: %s", err) } @@ -266,16 +291,10 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) pub, rawPriv = newKeypair(curve) } - if *sf.outKeyPath == "" { - *sf.outKeyPath = *sf.name + ".key" - } - - if *sf.outCertPath == "" { - *sf.outCertPath = *sf.name + ".crt" - } - - if _, err := os.Stat(*sf.outCertPath); err == nil { - return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath) + if !isStdio(*sf.outCertPath) { + if _, err := os.Stat(*sf.outCertPath); err == nil { + return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath) + } } var crts []cert.Certificate @@ -360,11 +379,13 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) } if !isP11 && *sf.inPubPath == "" { - if _, err := os.Stat(*sf.outKeyPath); err == nil { - return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath) + if !isStdio(*sf.outKeyPath) { + if _, err := os.Stat(*sf.outKeyPath); err == nil { + return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath) + } } - err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600) + err = writeOutput(*sf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600, out) if err != nil { return fmt.Errorf("error while writing out-key: %s", err) } @@ -379,7 +400,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) b = append(b, sb...) } - err = os.WriteFile(*sf.outCertPath, b, 0600) + err = writeOutput(*sf.outCertPath, b, 0600, out) if err != nil { return fmt.Errorf("error while writing out-crt: %s", err) } @@ -390,7 +411,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return fmt.Errorf("error while generating qr code: %s", err) } - err = os.WriteFile(*sf.outQRPath, b, 0600) + err = writeOutput(*sf.outQRPath, b, 0600, out) if err != nil { return fmt.Errorf("error while writing out-qr: %s", err) } @@ -440,6 +461,7 @@ func signSummary() string { func signHelp(out io.Writer) { sf := newSignFlags() out.Write([]byte("Usage of " + os.Args[0] + " " + signSummary() + "\n")) + out.Write([]byte(stdioHelpText)) sf.set.SetOutput(out) sf.set.PrintDefaults() } diff --git a/cmd/nebula-cert/sign_test.go b/cmd/nebula-cert/sign_test.go index f5f8cbb0..64d5c7d9 100644 --- a/cmd/nebula-cert/sign_test.go +++ b/cmd/nebula-cert/sign_test.go @@ -27,6 +27,7 @@ func Test_signHelp(t *testing.T) { assert.Equal( t, "Usage of "+os.Args[0]+" sign : create and sign a certificate\n"+ + " Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+ " -ca-crt string\n"+ " \tOptional: path to the signing CA cert (default \"ca.crt\")\n"+ " -ca-key string\n"+ @@ -376,15 +377,18 @@ func Test_signCert(t *testing.T) { // test with the proper password args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} require.NoError(t, signCert(args, ob, eb, testpw)) - assert.Equal(t, "Enter passphrase: ", ob.String()) - assert.Empty(t, eb.String()) + assert.Empty(t, ob.String()) + assert.Equal(t, "Enter passphrase: ", eb.String()) // test with the proper password in the environment os.Remove(crtF.Name()) os.Remove(keyF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase)) + ob.Reset() + eb.Reset() require.NoError(t, signCert(args, ob, eb, testpw)) + assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) os.Setenv("NEBULA_CA_PASSPHRASE", "") @@ -395,8 +399,8 @@ func Test_signCert(t *testing.T) { testpw.password = []byte("invalid password") args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} require.Error(t, signCert(args, ob, eb, testpw)) - assert.Equal(t, "Enter passphrase: ", ob.String()) - assert.Empty(t, eb.String()) + assert.Empty(t, ob.String()) + assert.Equal(t, "Enter passphrase: ", eb.String()) // test with the wrong password in environment ob.Reset() @@ -416,8 +420,8 @@ func Test_signCert(t *testing.T) { args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} require.Error(t, signCert(args, ob, eb, nopw)) // normally the user hitting enter on the prompt would add newlines between these - assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String()) - assert.Empty(t, eb.String()) + assert.Empty(t, ob.String()) + assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", eb.String()) // test an error condition ob.Reset() @@ -425,6 +429,106 @@ func Test_signCert(t *testing.T) { args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} require.Error(t, signCert(args, ob, eb, errpw)) - assert.Equal(t, "Enter passphrase: ", ob.String()) - assert.Empty(t, eb.String()) + assert.Empty(t, ob.String()) + assert.Equal(t, "Enter passphrase: ", eb.String()) +} + +func Test_signCert_stdio(t *testing.T) { + nopw := &StubPasswordReader{ + password: []byte(""), + err: nil, + } + + caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader) + rawCAKey := cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv) + + ca, _ := NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil) + rawCACrt, _ := ca.MarshalPEM() + + caCrtF, err := os.CreateTemp("", "sign-cert.crt") + require.NoError(t, err) + defer os.Remove(caCrtF.Name()) + caCrtF.Write(rawCACrt) + + caKeyF, err := os.CreateTemp("", "sign-cert.key") + require.NoError(t, err) + defer os.Remove(caKeyF.Name()) + caKeyF.Write(rawCAKey) + + keyF, err := os.CreateTemp("", "sign.key") + require.NoError(t, err) + os.Remove(keyF.Name()) + defer os.Remove(keyF.Name()) + + // ca-key on stdin, cert to stdout + withStdin(t, bytes.NewReader(rawCAKey)) + ob := &bytes.Buffer{} + eb := &bytes.Buffer{} + args := []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", "-", "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", "-", "-out-key", keyF.Name(), "-duration", "100m"} + require.NoError(t, signCert(args, ob, eb, nopw)) + assert.Empty(t, eb.String()) + + lCrt, _, err := cert.UnmarshalCertificateFromPEM(ob.Bytes()) + require.NoError(t, err) + assert.Equal(t, "stdin-test", lCrt.Name()) + assert.True(t, lCrt.CheckSignature(caPub)) + + // two flags reading from stdin should error before any read attempt; + // otherwise an interactive shell would hang on io.ReadAll + stdinIn := bytes.NewReader(rawCAKey) + withStdin(t, stdinIn) + ob.Reset() + eb.Reset() + args = []string{"-version", "1", "-ca-crt", "-", "-ca-key", "-", "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + require.EqualError(t, signCert(args, ob, eb, nopw), + `-ca-key and -ca-crt both set to "-", only one input may read from stdin`) + assert.Equal(t, len(rawCAKey), stdinIn.Len(), "stdin should be untouched when conflict is caught up front") + + // two flags writing to stdout should error before any output is written + // AND before stdin is consumed + stdinR := bytes.NewReader(rawCAKey) + withStdin(t, stdinR) + ob.Reset() + eb.Reset() + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", "-", "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", "-", "-out-key", "-", "-duration", "100m"} + require.EqualError(t, signCert(args, ob, eb, nopw), + `-out-key and -out-crt both set to "-", only one output may write to stdout`) + assert.Empty(t, ob.String()) + // stdin should be untouched because the conflict was caught up front + assert.Equal(t, len(rawCAKey), stdinR.Len()) + + // out-key on stdout, cert on disk + keyF2, err := os.CreateTemp("", "sign.key") + require.NoError(t, err) + os.Remove(keyF2.Name()) + defer os.Remove(keyF2.Name()) + crtF, err := os.CreateTemp("", "sign.crt") + require.NoError(t, err) + os.Remove(crtF.Name()) + defer os.Remove(crtF.Name()) + + ob.Reset() + eb.Reset() + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", "-", "-duration", "100m"} + require.NoError(t, signCert(args, ob, eb, nopw)) + assert.Empty(t, eb.String()) + _, _, curve, err := cert.UnmarshalPrivateKeyFromPEM(ob.Bytes()) + require.NoError(t, err) + assert.Equal(t, cert.Curve_CURVE25519, curve) + + // in-pub on stdin (caller already has a keypair, only the cert is generated) + inPub, _ := x25519Keypair() + rawInPub := cert.MarshalPublicKeyToPEM(cert.Curve_CURVE25519, inPub) + + withStdin(t, bytes.NewReader(rawInPub)) + os.Remove(crtF.Name()) + ob.Reset() + eb.Reset() + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "in-pub-test", "-ip", "1.1.1.1/24", "-in-pub", "-", "-out-crt", "-", "-duration", "100m"} + require.NoError(t, signCert(args, ob, eb, nopw)) + assert.Empty(t, eb.String()) + stdinCrt, _, err := cert.UnmarshalCertificateFromPEM(ob.Bytes()) + require.NoError(t, err) + assert.Equal(t, "in-pub-test", stdinCrt.Name()) + assert.Equal(t, inPub, stdinCrt.PublicKey()) } diff --git a/cmd/nebula-cert/stdio.go b/cmd/nebula-cert/stdio.go new file mode 100644 index 00000000..3f71d52f --- /dev/null +++ b/cmd/nebula-cert/stdio.go @@ -0,0 +1,117 @@ +package main + +import ( + "fmt" + "io" + "os" +) + +// stdioPath is the special path value that selects stdin (for inputs) or +// stdout (for outputs) instead of a file on disk. +const stdioPath = "-" + +// stdioHelpText is rendered just under the Usage line of each subcommand +// help so the - convention is documented once instead of on every flag. +const stdioHelpText = " Pass \"-\" to any path flag to read from stdin or write to stdout.\n" + +// stdinReader is the source used when an input flag is set to "-". +// It is a package level var so tests can swap in a deterministic reader. +// Tests that mutate stdinReader cannot run with t.Parallel(). +var stdinReader io.Reader = os.Stdin + +// ioClaims tracks which flags have claimed stdin and stdout during a single +// command invocation so we can refuse a second flag asking for the same +// stream. +type ioClaims struct { + in string + out string +} + +func (c *ioClaims) claimIn(flagName string) error { + if c.in != "" && c.in != flagName { + return fmt.Errorf("-%s and -%s both set to %q, only one input may read from stdin", c.in, flagName, stdioPath) + } + c.in = flagName + return nil +} + +func (c *ioClaims) claimOut(flagName string) error { + if c.out != "" && c.out != flagName { + return fmt.Errorf("-%s and -%s both set to %q, only one output may write to stdout", c.out, flagName, stdioPath) + } + c.out = flagName + return nil +} + +// reserveInputs walks alternating (flagName, path) pairs and claims stdin +// for any path equal to stdioPath. It must be called before any input is +// read so a conflict can be reported immediately instead of blocking on +// io.ReadAll while waiting for input that will never arrive. +func reserveInputs(claims *ioClaims, pairs ...string) error { + return reserveStdio(claims, "reserveInputs", (*ioClaims).claimIn, pairs) +} + +// reserveOutputs walks alternating (flagName, path) pairs and claims stdout +// for any path equal to stdioPath. It must be called before any output is +// written so a conflict cannot leave one stream half written before the +// second flag fails. +func reserveOutputs(claims *ioClaims, pairs ...string) error { + return reserveStdio(claims, "reserveOutputs", (*ioClaims).claimOut, pairs) +} + +func reserveStdio(claims *ioClaims, who string, claim func(*ioClaims, string) error, pairs []string) error { + if len(pairs)%2 != 0 { + panic(who + " requires alternating name, path pairs") + } + for i := 0; i < len(pairs); i += 2 { + name, path := pairs[i], pairs[i+1] + if path != stdioPath { + continue + } + if err := claim(claims, name); err != nil { + return err + } + } + return nil +} + +// readInput returns the bytes referenced by path, reading from stdin when +// path is stdioPath. +func readInput(flagName, path string, claims *ioClaims) ([]byte, error) { + if path == stdioPath { + if err := claims.claimIn(flagName); err != nil { + return nil, err + } + return io.ReadAll(stdinReader) + } + return os.ReadFile(path) +} + +// openInput returns a reader for path. When path is stdioPath the returned +// reader wraps stdin and Close is a no-op. +func openInput(flagName, path string, claims *ioClaims) (io.ReadCloser, error) { + if path == stdioPath { + if err := claims.claimIn(flagName); err != nil { + return nil, err + } + return io.NopCloser(stdinReader), nil + } + return os.Open(path) +} + +// writeOutput writes data to path, or to stdout when path is stdioPath. perm +// is only used for file output. The caller must have already claimed stdout +// via reserveOutputs before invoking with stdioPath. +func writeOutput(path string, data []byte, perm os.FileMode, stdout io.Writer) error { + if path == stdioPath { + _, err := stdout.Write(data) + return err + } + return os.WriteFile(path, data, perm) +} + +// isStdio reports whether path is the stdio sentinel and so should skip +// existence checks like "refuse to overwrite". +func isStdio(path string) bool { + return path == stdioPath +} diff --git a/cmd/nebula-cert/stdio_test.go b/cmd/nebula-cert/stdio_test.go new file mode 100644 index 00000000..dc87a597 --- /dev/null +++ b/cmd/nebula-cert/stdio_test.go @@ -0,0 +1,167 @@ +package main + +import ( + "bytes" + "io" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// withStdin temporarily replaces stdinReader for the duration of t. +func withStdin(t *testing.T, r io.Reader) { + t.Helper() + prev := stdinReader + stdinReader = r + t.Cleanup(func() { stdinReader = prev }) +} + +func Test_readInput_stdin(t *testing.T) { + withStdin(t, bytes.NewBufferString("hello")) + var claims ioClaims + + got, err := readInput("path", "-", &claims) + require.NoError(t, err) + assert.Equal(t, []byte("hello"), got) + assert.Equal(t, "path", claims.in) +} + +func Test_readInput_file(t *testing.T) { + dir := t.TempDir() + p := filepath.Join(dir, "f") + require.NoError(t, os.WriteFile(p, []byte("file"), 0600)) + var claims ioClaims + + got, err := readInput("path", p, &claims) + require.NoError(t, err) + assert.Equal(t, []byte("file"), got) + assert.Empty(t, claims.in) +} + +func Test_readInput_doubleStdinErrors(t *testing.T) { + withStdin(t, bytes.NewBufferString("hello")) + var claims ioClaims + + _, err := readInput("ca-key", "-", &claims) + require.NoError(t, err) + + _, err = readInput("ca-crt", "-", &claims) + require.EqualError(t, err, `-ca-key and -ca-crt both set to "-", only one input may read from stdin`) +} + +func Test_openInput_stdin(t *testing.T) { + withStdin(t, bytes.NewBufferString("hi")) + var claims ioClaims + + r, err := openInput("ca", "-", &claims) + require.NoError(t, err) + defer r.Close() + b, err := io.ReadAll(r) + require.NoError(t, err) + assert.Equal(t, []byte("hi"), b) +} + +func Test_openInput_doubleStdinErrors(t *testing.T) { + withStdin(t, bytes.NewBufferString("hi")) + var claims ioClaims + + r, err := openInput("ca", "-", &claims) + require.NoError(t, err) + r.Close() + + _, err = openInput("crt", "-", &claims) + require.EqualError(t, err, `-ca and -crt both set to "-", only one input may read from stdin`) +} + +func Test_writeOutput_stdout(t *testing.T) { + out := &bytes.Buffer{} + + err := writeOutput("-", []byte("payload"), 0600, out) + require.NoError(t, err) + assert.Equal(t, "payload", out.String()) +} + +func Test_writeOutput_file(t *testing.T) { + dir := t.TempDir() + p := filepath.Join(dir, "f") + out := &bytes.Buffer{} + + err := writeOutput(p, []byte("payload"), 0600, out) + require.NoError(t, err) + assert.Empty(t, out.String()) + got, err := os.ReadFile(p) + require.NoError(t, err) + assert.Equal(t, []byte("payload"), got) +} + +func Test_reserveOutputs_noConflict(t *testing.T) { + var claims ioClaims + require.NoError(t, reserveOutputs(&claims, + "out-key", "/tmp/key", + "out-crt", "-", + "out-qr", "", + )) + assert.Equal(t, "out-crt", claims.out) +} + +func Test_reserveOutputs_conflict(t *testing.T) { + var claims ioClaims + err := reserveOutputs(&claims, + "out-key", "-", + "out-crt", "-", + ) + require.EqualError(t, err, `-out-key and -out-crt both set to "-", only one output may write to stdout`) +} + +func Test_reserveOutputs_panicsOnOddPairs(t *testing.T) { + defer func() { + r := recover() + require.NotNil(t, r) + }() + var claims ioClaims + _ = reserveOutputs(&claims, "out-key") +} + +func Test_reserveInputs_noConflict(t *testing.T) { + var claims ioClaims + require.NoError(t, reserveInputs(&claims, + "ca-key", "/tmp/ca.key", + "ca-crt", "-", + "in-pub", "", + )) + assert.Equal(t, "ca-crt", claims.in) +} + +func Test_reserveInputs_conflict(t *testing.T) { + var claims ioClaims + err := reserveInputs(&claims, + "ca-key", "-", + "ca-crt", "-", + ) + require.EqualError(t, err, `-ca-key and -ca-crt both set to "-", only one input may read from stdin`) +} + +func Test_claimIn_idempotent(t *testing.T) { + // pre-claim then a lazy re-claim of the same flag should be a no-op + var claims ioClaims + require.NoError(t, claims.claimIn("ca-key")) + require.NoError(t, claims.claimIn("ca-key")) + assert.Equal(t, "ca-key", claims.in) +} + +func Test_claimOut_idempotent(t *testing.T) { + var claims ioClaims + require.NoError(t, claims.claimOut("out-crt")) + require.NoError(t, claims.claimOut("out-crt")) + assert.Equal(t, "out-crt", claims.out) +} + +func Test_isStdio(t *testing.T) { + assert.True(t, isStdio("-")) + assert.False(t, isStdio("")) + assert.False(t, isStdio("./-")) + assert.False(t, isStdio("foo")) +} diff --git a/cmd/nebula-cert/verify.go b/cmd/nebula-cert/verify.go index 36258dd8..76d3dbe6 100644 --- a/cmd/nebula-cert/verify.go +++ b/cmd/nebula-cert/verify.go @@ -39,18 +39,26 @@ func verify(args []string, out io.Writer, errOut io.Writer) error { return err } - caFile, err := os.Open(*vf.caPath) + var claims ioClaims + if err := reserveInputs(&claims, + "ca", *vf.caPath, + "crt", *vf.certPath, + ); err != nil { + return err + } + + caReader, err := openInput("ca", *vf.caPath, &claims) if err != nil { return fmt.Errorf("error while reading ca: %w", err) } - defer caFile.Close() + defer caReader.Close() - caPool, err := cert.NewCAPoolFromPEMReader(caFile) + caPool, err := cert.NewCAPoolFromPEMReader(caReader) if err != nil && !errors.Is(err, cert.ErrExpired) { return fmt.Errorf("error while adding ca cert to pool: %w", err) } - rawCert, err := os.ReadFile(*vf.certPath) + rawCert, err := readInput("crt", *vf.certPath, &claims) if err != nil { return fmt.Errorf("unable to read crt: %w", err) } @@ -85,6 +93,7 @@ func verifySummary() string { func verifyHelp(out io.Writer) { vf := newVerifyFlags() _, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n")) + _, _ = out.Write([]byte(stdioHelpText)) vf.set.SetOutput(out) vf.set.PrintDefaults() } diff --git a/cmd/nebula-cert/verify_test.go b/cmd/nebula-cert/verify_test.go index 1aa5e8e6..aa089d0e 100644 --- a/cmd/nebula-cert/verify_test.go +++ b/cmd/nebula-cert/verify_test.go @@ -23,6 +23,7 @@ func Test_verifyHelp(t *testing.T) { assert.Equal( t, "Usage of "+os.Args[0]+" verify : verifies a certificate isn't expired and was signed by a trusted authority.\n"+ + " Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+ " -ca string\n"+ " \tRequired: path to a file containing one or more ca certificates\n"+ " -crt string\n"+ @@ -122,3 +123,46 @@ func Test_verify(t *testing.T) { assert.Empty(t, eb.String()) require.NoError(t, err) } + +func Test_verify_stdio(t *testing.T) { + ob := &bytes.Buffer{} + eb := &bytes.Buffer{} + + caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader) + ca, _ := NewTestCaCert("test-ca", caPub, caPriv, time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour*2), nil, nil, nil) + caPEM, _ := ca.MarshalPEM() + + crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) + crtPEM, _ := crt.MarshalPEM() + + caFile, err := os.CreateTemp("", "verify-ca") + require.NoError(t, err) + defer os.Remove(caFile.Name()) + caFile.Write(caPEM) + + // crt on stdin, ca on disk + withStdin(t, bytes.NewReader(crtPEM)) + require.NoError(t, verify([]string{"-ca", caFile.Name(), "-crt", "-"}, ob, eb)) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + + // ca on stdin, crt on disk + certFile, err := os.CreateTemp("", "verify-cert") + require.NoError(t, err) + defer os.Remove(certFile.Name()) + certFile.Write(crtPEM) + + withStdin(t, bytes.NewReader(caPEM)) + ob.Reset() + eb.Reset() + require.NoError(t, verify([]string{"-ca", "-", "-crt", certFile.Name()}, ob, eb)) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + + // both flags on stdin should error + withStdin(t, bytes.NewReader(caPEM)) + ob.Reset() + eb.Reset() + require.EqualError(t, verify([]string{"-ca", "-", "-crt", "-"}, ob, eb), + `-ca and -crt both set to "-", only one input may read from stdin`) +} From 99c5854e5c87525a1ebfca3be233e9e25ef2b573 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 15 May 2026 15:36:26 -0500 Subject: [PATCH 18/27] Prime some critical stats before the first scrape (#1715) --- interface.go | 38 ++++++++++++++---------- interface_test.go | 73 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 15 deletions(-) create mode 100644 interface_test.go diff --git a/interface.go b/interface.go index 5fedcdd3..32f5c2a6 100644 --- a/interface.go +++ b/interface.go @@ -491,26 +491,34 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { certInitiatingVersion := metrics.GetOrRegisterGauge("certificate.initiating_version", nil) certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil) + emit := func() { + f.firewall.EmitStats() + f.handshakeManager.EmitStats() + udpStats() + + certState := f.pki.getCertState() + defaultCrt := certState.GetDefaultCertificate() + certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second)) + certInitiatingVersion.Update(int64(defaultCrt.Version())) + + // Report the max certificate version we are capable of using + if certState.v2Cert != nil { + certMaxVersion.Update(int64(certState.v2Cert.Version())) + } else { + certMaxVersion.Update(int64(certState.v1Cert.Version())) + } + } + + // Prime gauges so a Prometheus scrape that lands before the first tick + // sees real values instead of the zero defaults (issue #907). + emit() + for { select { case <-ctx.Done(): return case <-ticker.C: - f.firewall.EmitStats() - f.handshakeManager.EmitStats() - udpStats() - - certState := f.pki.getCertState() - defaultCrt := certState.GetDefaultCertificate() - certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second)) - certInitiatingVersion.Update(int64(defaultCrt.Version())) - - // Report the max certificate version we are capable of using - if certState.v2Cert != nil { - certMaxVersion.Update(int64(certState.v2Cert.Version())) - } else { - certMaxVersion.Update(int64(certState.v1Cert.Version())) - } + emit() } } } diff --git a/interface_test.go b/interface_test.go new file mode 100644 index 00000000..b0a9d025 --- /dev/null +++ b/interface_test.go @@ -0,0 +1,73 @@ +//go:build linux || darwin + +package nebula + +import ( + "context" + "net/netip" + "testing" + "time" + + "github.com/rcrowley/go-metrics" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/firewall" + "github.com/slackhq/nebula/overlay/overlaytest" + "github.com/slackhq/nebula/test" + "github.com/slackhq/nebula/udp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test_emitStats_primesGauges covers issue #907: a Prometheus scrape that +// landed before the first ticker fire used to read 0 for the cert gauges. +// emitStats now primes the gauges before entering the ticker loop. We assert +// the gauge is zero before the first call and non-zero after. +func Test_emitStats_primesGauges(t *testing.T) { + defer metrics.DefaultRegistry.UnregisterAll() + + l := test.NewLogger() + hostMap := newHostMap(l) + preferredRanges := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")} + hostMap.preferredRanges.Store(&preferredRanges) + + notAfter := time.Now().Add(time.Hour) + cs := &CertState{ + initiatingVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1, notAfter: notAfter}, + v1Credential: nil, + } + + lh := newTestLighthouse() + ifce := &Interface{ + hostMap: hostMap, + inside: &overlaytest.NoopTun{}, + outside: &udp.NoopConn{}, + firewall: &Firewall{Conntrack: &FirewallConntrack{Conns: map[firewall.Packet]*conn{}}}, + lightHouse: lh, + pki: &PKI{}, + handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), + l: l, + // On linux, udp.NewUDPStatsEmitter indexes writers[0] and asserts to + // *udp.StdConn. A zero value works: getMemInfo sees a nil rawConn, + // returns an error, and the emitter falls through to a no-op. + writers: []udp.Conn{&udp.StdConn{}}, + } + ifce.pki.cs.Store(cs) + + ttlGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil) + require.Zero(t, ttlGauge.Value(), "gauge should be zero before emitStats runs") + + // Pre-cancel the context so emitStats returns after priming the gauges + // without ever reading from ticker.C. The one hour interval is just a + // belt-and-suspenders, the test does not expect the ticker to fire. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + ifce.emitStats(ctx, time.Hour) + + ttl := ttlGauge.Value() + assert.Positive(t, ttl, "ttl gauge should be primed by emitStats before its first tick") + assert.LessOrEqual(t, ttl, int64(3600)) + assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.initiating_version", nil).Value()) + assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.max_version", nil).Value()) +} From 625f58b84adc778895b20a3dd74b2e2190c83132 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 15 May 2026 15:36:44 -0500 Subject: [PATCH 19/27] Record my local details in the dns server if enabled (#1716) --- dns_server.go | 118 +++++++++++++++++++++++++++++++++++---------- dns_server_test.go | 89 ++++++++++++++++++++++++++++++++++ main.go | 2 +- 3 files changed, 182 insertions(+), 27 deletions(-) diff --git a/dns_server.go b/dns_server.go index ff1369ab..a80630b5 100644 --- a/dns_server.go +++ b/dns_server.go @@ -11,19 +11,21 @@ import ( "sync" "sync/atomic" - "github.com/gaissmai/bart" "github.com/miekg/dns" "github.com/slackhq/nebula/config" ) type dnsServer struct { sync.RWMutex - l *slog.Logger - ctx context.Context - dnsMap4 map[string]netip.Addr - dnsMap6 map[string]netip.Addr - hostMap *HostMap - myVpnAddrsTable *bart.Lite + l *slog.Logger + ctx context.Context + dnsMap4 map[string]netip.Addr + dnsMap6 map[string]netip.Addr + hostMap *HostMap + pki *PKI + + // selfHost is the cached FQDN we last seeded for ourselves + selfHost string mux *dns.ServeMux @@ -55,14 +57,14 @@ type dnsServer struct { // they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel // watcher that tears the listener down on nebula shutdown. The returned // pointer is always non-nil, even on error. -func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) { +func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, pki *PKI, hostMap *HostMap, c *config.C) (*dnsServer, error) { ds := &dnsServer{ - l: l, - ctx: ctx, - dnsMap4: make(map[string]netip.Addr), - dnsMap6: make(map[string]netip.Addr), - hostMap: hostMap, - myVpnAddrsTable: cs.myVpnAddrsTable, + l: l, + ctx: ctx, + dnsMap4: make(map[string]netip.Addr), + dnsMap6: make(map[string]netip.Addr), + hostMap: hostMap, + pki: pki, } ds.mux = dns.NewServeMux() ds.mux.HandleFunc(".", ds.handleDnsRequest) @@ -76,6 +78,7 @@ func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState, if err := ds.reload(c, true); err != nil { return ds, err } + ds.seedSelf() return ds, nil } @@ -113,7 +116,7 @@ func (d *dnsServer) reload(c *config.C, initial bool) error { d.Stop() } // Drop any records that accumulated while enabled; a later re-enable - // will repopulate from fresh handshakes. + // will repopulate from fresh handshakes and a fresh seedSelf. d.clearRecords() return nil } @@ -121,17 +124,14 @@ func (d *dnsServer) reload(c *config.C, initial bool) error { if running == nil { // Was disabled (or never started); bring it up now. go d.Start() - return nil + } else if !sameAddr { + d.shutdownServer(running, runningStarted, "reload") + // Old Start goroutine has now exited; bring up a fresh listener on the new address. + go d.Start() } - if sameAddr { - return nil - } - - d.shutdownServer(running, runningStarted, "reload") - // Old Start goroutine has now exited; bring up a fresh listener on the - // new address. - go d.Start() + // Refresh the self entry every enabled reload so cert renewals that change our name or VPN addresses are picked up. + d.seedSelf() return nil } @@ -249,6 +249,20 @@ func (d *dnsServer) QueryCert(data string) string { return "" } + // The hostmap only ever contains peers we have handshaked with, so it never carries an entry for ourselves. + // Answer self lookups straight from the local cert state. + if cs := d.certState(); cs != nil && cs.myVpnAddrsTable != nil && cs.myVpnAddrsTable.Contains(ip) { + c := cs.GetDefaultCertificate() + if c == nil { + return "" + } + b, err := c.MarshalJSON() + if err != nil { + return "" + } + return string(b) + } + hostinfo := d.hostMap.QueryVpnAddr(ip) if hostinfo == nil { return "" @@ -266,12 +280,60 @@ func (d *dnsServer) QueryCert(data string) string { return string(b) } -// clearRecords drops all DNS records. +// clearRecords drops all DNS records, including the self entry. func (d *dnsServer) clearRecords() { d.Lock() defer d.Unlock() clear(d.dnsMap4) clear(d.dnsMap6) + d.selfHost = "" +} + +// seedSelf inserts (or refreshes) a record for our own cert name pointing at our VPN addresses, +// so a single-lighthouse network can resolve the lighthouse's own hostname without the two-process workaround. +func (d *dnsServer) seedSelf() { + if !d.enabled.Load() { + return + } + cs := d.certState() + if cs == nil { + return + } + c := cs.GetDefaultCertificate() + if c == nil { + return + } + newHost := strings.ToLower(c.Name()) + "." + + d.Lock() + defer d.Unlock() + if d.selfHost != "" && d.selfHost != newHost { + delete(d.dnsMap4, d.selfHost) + delete(d.dnsMap6, d.selfHost) + } + d.selfHost = newHost + delete(d.dnsMap4, newHost) + delete(d.dnsMap6, newHost) + haveV4, haveV6 := false, false + for _, addr := range cs.myVpnAddrs { + if addr.Is4() && !haveV4 { + d.dnsMap4[newHost] = addr + haveV4 = true + } else if addr.Is6() && !haveV6 { + d.dnsMap6[newHost] = addr + haveV6 = true + } + if haveV4 && haveV6 { + break + } + } +} + +func (d *dnsServer) certState() *CertState { + if d.pki == nil { + return nil + } + return d.pki.getCertState() } // Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host` @@ -309,8 +371,12 @@ func (d *dnsServer) isSelfNebulaOrLocalhost(addr string) bool { return true } + cs := d.certState() + if cs == nil || cs.myVpnAddrsTable == nil { + return false + } //if we found it in this table, it's good - return d.myVpnAddrsTable.Contains(b) + return cs.myVpnAddrsTable.Contains(b) } func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) { diff --git a/dns_server_test.go b/dns_server_test.go index dcea046c..58646937 100644 --- a/dns_server_test.go +++ b/dns_server_test.go @@ -9,7 +9,10 @@ import ( "testing" "time" + "github.com/gaissmai/bart" "github.com/miekg/dns" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -276,6 +279,92 @@ func TestDnsServer_Stop_beforeBind_doesNotHang(t *testing.T) { } } +// newTestPKI builds a minimal *PKI with a single v1 cert whose name and +// VPN addresses are caller-provided, suitable for exercising seedSelf and +// QueryCert self handling. +func newTestPKI(t *testing.T, name string, addrs []netip.Addr) *PKI { + t.Helper() + networks := make([]netip.Prefix, 0, len(addrs)) + for _, a := range addrs { + bits := 32 + if a.Is6() { + bits = 128 + } + networks = append(networks, netip.PrefixFrom(a, bits)) + } + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) + c, _, _, _ := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, ca, caKey, name, time.Time{}, time.Time{}, networks, nil, nil) + + addrsTable := new(bart.Lite) + for _, a := range addrs { + addrsTable.Insert(netip.PrefixFrom(a, a.BitLen())) + } + + cs := &CertState{ + v2Cert: c, + initiatingVersion: cert.Version2, + myVpnAddrs: addrs, + myVpnAddrsTable: addrsTable, + } + pki := &PKI{} + pki.cs.Store(cs) + return pki +} + +func TestDnsServer_seedSelf_addsOwnRecord(t *testing.T) { + ds, c := newTestDnsServer(t) + myV4 := netip.MustParseAddr("10.0.0.1") + myV6 := netip.MustParseAddr("fd00::1") + ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{myV4, myV6}) + setDnsConfig(c, "127.0.0.1", "0", true, true) + require.NoError(t, ds.reload(c, true)) + + ds.seedSelf() + got4, exists := ds.Query(dns.TypeA, "lighthouse.") + assert.True(t, exists) + assert.Equal(t, myV4, got4) + got6, exists := ds.Query(dns.TypeAAAA, "lighthouse.") + assert.True(t, exists) + assert.Equal(t, myV6, got6) +} + +func TestDnsServer_seedSelf_disabled_noOp(t *testing.T) { + ds, c := newTestDnsServer(t) + ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{netip.MustParseAddr("10.0.0.1")}) + setDnsConfig(c, "127.0.0.1", "0", true, false) + require.NoError(t, ds.reload(c, true)) + + ds.seedSelf() + _, exists := ds.Query(dns.TypeA, "lighthouse.") + assert.False(t, exists) +} + +func TestDnsServer_clearRecords_dropsSelfHost(t *testing.T) { + ds, c := newTestDnsServer(t) + ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{netip.MustParseAddr("10.0.0.1")}) + setDnsConfig(c, "127.0.0.1", "0", true, true) + require.NoError(t, ds.reload(c, true)) + ds.seedSelf() + require.NotEmpty(t, ds.selfHost) + + ds.clearRecords() + assert.Empty(t, ds.selfHost) + _, exists := ds.Query(dns.TypeA, "lighthouse.") + assert.False(t, exists) +} + +func TestDnsServer_QueryCert_returnsOwnCert(t *testing.T) { + ds, _ := newTestDnsServer(t) + myV4 := netip.MustParseAddr("10.0.0.1") + ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{myV4}) + + got := ds.QueryCert(myV4.String() + ".") + assert.NotEmpty(t, got, "TXT lookup of our own VPN address should return our cert") + + other := netip.MustParseAddr("10.0.0.99") + assert.Empty(t, ds.QueryCert(other.String()+"."), "unknown peer IP should return nothing") +} + func TestDnsServer_reload_disable_stopsRunningServer(t *testing.T) { port := freeUDPPort(t) ds, c := newTestDnsServer(t) diff --git a/main.go b/main.go index 37aa24d1..7d7a0f72 100644 --- a/main.go +++ b/main.go @@ -194,7 +194,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig) lightHouse.handshakeTrigger = handshakeManager.trigger - ds, err := newDnsServerFromConfig(ctx, l, pki.getCertState(), hostMap, c) + ds, err := newDnsServerFromConfig(ctx, l, pki, hostMap, c) if err != nil { l.Warn("Failed to start DNS responder", "error", err) } From ffd5249cf522a1dd582c707888776f5f54264d32 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 15 May 2026 15:37:01 -0500 Subject: [PATCH 20/27] Search for config.yaml/yml in both service and cli mode (#1717) --- cmd/nebula-service/main.go | 9 +++-- cmd/nebula-service/service.go | 17 ++------- cmd/nebula/main.go | 9 +++-- config/default.go | 29 +++++++++++++++ config/default_test.go | 67 +++++++++++++++++++++++++++++++++++ 5 files changed, 110 insertions(+), 21 deletions(-) create mode 100644 config/default.go create mode 100644 config/default_test.go diff --git a/cmd/nebula-service/main.go b/cmd/nebula-service/main.go index 19fb3a9f..724c0c6a 100644 --- a/cmd/nebula-service/main.go +++ b/cmd/nebula-service/main.go @@ -61,9 +61,12 @@ func main() { } if *configPath == "" { - fmt.Println("-config flag must be set") - flag.Usage() - os.Exit(1) + p, err := config.DefaultPath() + if err != nil { + fmt.Println(err) + os.Exit(1) + } + *configPath = p } c := config.NewC(l) diff --git a/cmd/nebula-service/service.go b/cmd/nebula-service/service.go index 6551ceb4..7c2b39c8 100644 --- a/cmd/nebula-service/service.go +++ b/cmd/nebula-service/service.go @@ -3,8 +3,6 @@ package main import ( "fmt" "log" - "os" - "path/filepath" "github.com/kardianos/service" "github.com/slackhq/nebula" @@ -57,24 +55,13 @@ func (p *program) Stop(s service.Service) error { return nil } -func fileExists(filename string) bool { - _, err := os.Stat(filename) - if os.IsNotExist(err) { - return false - } - return true -} - func doService(configPath *string, configTest *bool, build string, serviceFlag *string) error { if *configPath == "" { - ex, err := os.Executable() + p, err := config.DefaultPath() if err != nil { return err } - *configPath = filepath.Dir(ex) + "/config.yaml" - if !fileExists(*configPath) { - *configPath = filepath.Dir(ex) + "/config.yml" - } + *configPath = p } svcConfig := &service.Config{ diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index d7f0de93..219519c2 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -50,9 +50,12 @@ func main() { } if *configPath == "" { - fmt.Println("-config flag must be set") - flag.Usage() - os.Exit(1) + p, err := config.DefaultPath() + if err != nil { + fmt.Println(err) + os.Exit(1) + } + *configPath = p } l := logging.NewLogger(os.Stdout) diff --git a/config/default.go b/config/default.go new file mode 100644 index 00000000..9494c655 --- /dev/null +++ b/config/default.go @@ -0,0 +1,29 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" +) + +// DefaultPath returns a path to a config file alongside the running executable, preferring config.yaml over config.yml. +// If neither file exists an error is returned that names both paths checked. +func DefaultPath() (string, error) { + ex, err := os.Executable() + if err != nil { + return "", err + } + return defaultPathInDir(filepath.Dir(ex)) +} + +func defaultPathInDir(dir string) (string, error) { + yamlPath := filepath.Join(dir, "config.yaml") + if _, err := os.Stat(yamlPath); err == nil { + return yamlPath, nil + } + ymlPath := filepath.Join(dir, "config.yml") + if _, err := os.Stat(ymlPath); err == nil { + return ymlPath, nil + } + return "", fmt.Errorf("no default config found at %s or %s", yamlPath, ymlPath) +} diff --git a/config/default_test.go b/config/default_test.go new file mode 100644 index 00000000..a4d56f59 --- /dev/null +++ b/config/default_test.go @@ -0,0 +1,67 @@ +package config + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDefaultPathInDir(t *testing.T) { + t.Run("prefers config.yaml when both exist", func(t *testing.T) { + dir := t.TempDir() + want := filepath.Join(dir, "config.yaml") + other := filepath.Join(dir, "config.yml") + require.NoError(t, os.WriteFile(want, []byte("a: 1"), 0644)) + require.NoError(t, os.WriteFile(other, []byte("a: 2"), 0644)) + + got, err := defaultPathInDir(dir) + require.NoError(t, err) + assert.Equal(t, want, got) + }) + + t.Run("returns config.yaml when only it exists", func(t *testing.T) { + dir := t.TempDir() + want := filepath.Join(dir, "config.yaml") + require.NoError(t, os.WriteFile(want, []byte("a: 1"), 0644)) + + got, err := defaultPathInDir(dir) + require.NoError(t, err) + assert.Equal(t, want, got) + }) + + t.Run("falls back to config.yml when only it exists", func(t *testing.T) { + dir := t.TempDir() + want := filepath.Join(dir, "config.yml") + require.NoError(t, os.WriteFile(want, []byte("a: 1"), 0644)) + + got, err := defaultPathInDir(dir) + require.NoError(t, err) + assert.Equal(t, want, got) + }) + + t.Run("errors when neither exists and names both paths", func(t *testing.T) { + dir := t.TempDir() + got, err := defaultPathInDir(dir) + assert.Empty(t, got) + require.Error(t, err) + assert.Contains(t, err.Error(), filepath.Join(dir, "config.yaml")) + assert.Contains(t, err.Error(), filepath.Join(dir, "config.yml")) + }) +} + +func TestDefaultPath(t *testing.T) { + got, err := DefaultPath() + if err != nil { + ex, exErr := os.Executable() + require.NoError(t, exErr) + assert.Contains(t, err.Error(), filepath.Dir(ex)) + return + } + ex, err := os.Executable() + require.NoError(t, err) + assert.Equal(t, filepath.Dir(ex), filepath.Dir(got)) + assert.Contains(t, []string{"config.yaml", "config.yml"}, filepath.Base(got)) +} From 0d23377c6575bd716448920269f8142a789097ca Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Mon, 18 May 2026 11:10:30 -0500 Subject: [PATCH 21/27] Fix flakey cert tests (#1728) --- cert/helper_test.go | 14 ++++++++++---- cert_test/cert.go | 14 ++++++++++---- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/cert/helper_test.go b/cert/helper_test.go index 1b72a0ff..9becfa5c 100644 --- a/cert/helper_test.go +++ b/cert/helper_test.go @@ -13,6 +13,12 @@ import ( "golang.org/x/crypto/ed25519" ) +// testCertNow is the reference "now" used to derive default before/after times +// in NewTestCaCert and NewTestCert. Holding it fixed for the lifetime of the +// test binary keeps CA and leaf defaults aligned at the same second, so a leaf +// signed with default times can never expire after its CA on a rounding race. +var testCertNow = time.Now().Round(time.Second) + // NewTestCaCert will create a new ca certificate func NewTestCaCert(version Version, curve Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) { var err error @@ -34,10 +40,10 @@ func NewTestCaCert(version Version, curve Curve, before, after time.Time, networ } if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) + before = testCertNow.Add(time.Second * -60) } if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) + after = testCertNow.Add(time.Second * 60) } t := &TBSCertificate{ @@ -70,11 +76,11 @@ func NewTestCaCert(version Version, curve Curve, before, after time.Time, networ // Expiry times are defaulted if you do not pass them in func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) { if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) + before = testCertNow.Add(time.Second * -60) } if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) + after = testCertNow.Add(time.Second * 60) } if len(networks) == 0 { diff --git a/cert_test/cert.go b/cert_test/cert.go index c3759f12..4c440aff 100644 --- a/cert_test/cert.go +++ b/cert_test/cert.go @@ -14,6 +14,12 @@ import ( "golang.org/x/crypto/ed25519" ) +// testCertNow is the reference "now" used to derive default before/after times +// in NewTestCaCert and NewTestCert. Holding it fixed for the lifetime of the +// test binary keeps CA and leaf defaults aligned at the same second, so a leaf +// signed with default times can never expire after its CA on a rounding race. +var testCertNow = time.Now().Round(time.Second) + // NewTestCaCert will create a new ca certificate func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { var err error @@ -35,10 +41,10 @@ func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Ti } if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) + before = testCertNow.Add(time.Second * -60) } if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) + after = testCertNow.Add(time.Second * 60) } t := &cert.TBSCertificate{ @@ -71,11 +77,11 @@ func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Ti // Expiry times are defaulted if you do not pass them in func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) + before = testCertNow.Add(time.Second * -60) } if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) + after = testCertNow.Add(time.Second * 60) } var pub, priv []byte From 04dea41f7495d09c9ee3d7c03b1bae00adb25ba4 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Mon, 18 May 2026 11:25:34 -0500 Subject: [PATCH 22/27] Make firewall reload when unsafe networks in the cert changes (#1719) --- firewall.go | 30 ++++---- interface.go | 17 ++++- interface_emit_test.go | 73 ++++++++++++++++++++ interface_test.go | 151 +++++++++++++++++++++++++++-------------- 4 files changed, 201 insertions(+), 70 deletions(-) create mode 100644 interface_emit_test.go diff --git a/firewall.go b/firewall.go index adecbe81..904c71b2 100644 --- a/firewall.go +++ b/firewall.go @@ -58,8 +58,9 @@ type Firewall struct { routableNetworks *bart.Lite // assignedNetworks is a list of vpn networks assigned to us in the certificate. - assignedNetworks []netip.Prefix - hasUnsafeNetworks bool + assignedNetworks []netip.Prefix + // unsafeNetworks is the list of unsafe networks issued to us in the certificate + unsafeNetworks []netip.Prefix rules string rulesVersion uint16 @@ -158,10 +159,9 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur assignedNetworks = append(assignedNetworks, network) } - hasUnsafeNetworks := false - for _, n := range c.UnsafeNetworks() { + unsafeNetworks := c.UnsafeNetworks() + for _, n := range unsafeNetworks { routableNetworks.Insert(n) - hasUnsafeNetworks = true } return &Firewall{ @@ -169,15 +169,15 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur Conns: make(map[firewall.Packet]*conn), TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax), }, - InRules: newFirewallTable(), - OutRules: newFirewallTable(), - TCPTimeout: tcpTimeout, - UDPTimeout: UDPTimeout, - DefaultTimeout: defaultTimeout, - routableNetworks: routableNetworks, - assignedNetworks: assignedNetworks, - hasUnsafeNetworks: hasUnsafeNetworks, - l: l, + InRules: newFirewallTable(), + OutRules: newFirewallTable(), + TCPTimeout: tcpTimeout, + UDPTimeout: UDPTimeout, + DefaultTimeout: defaultTimeout, + routableNetworks: routableNetworks, + assignedNetworks: assignedNetworks, + unsafeNetworks: unsafeNetworks, + l: l, incomingMetrics: firewallMetrics{ droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil), @@ -897,7 +897,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error { } if localCidr == "" { - if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny { + if len(f.unsafeNetworks) == 0 || f.defaultLocalCIDRAny { flc.Any = true return nil } diff --git a/interface.go b/interface.go index 32f5c2a6..f96e431a 100644 --- a/interface.go +++ b/interface.go @@ -7,6 +7,7 @@ import ( "io" "log/slog" "net/netip" + "slices" "sync" "sync/atomic" "time" @@ -14,6 +15,7 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -375,13 +377,22 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) { } func (f *Interface) reloadFirewall(c *config.C) { - //TODO: need to trigger/detect if the certificate changed too - if c.HasChanged("firewall") == false { + cs := f.pki.getCertState() + curCert := cs.getCertificate(cert.Version2) + if curCert == nil { + curCert = cs.getCertificate(cert.Version1) + } + + // The firewall builds its routableNetworks set from the certificate's UnsafeNetworks at construction. + // Check to see if that set has changed, and if so, rebuild the firewall. + certUnsafeChanged := curCert != nil && !slices.Equal(curCert.UnsafeNetworks(), f.firewall.unsafeNetworks) + + if !c.HasChanged("firewall") && !certUnsafeChanged { f.l.Debug("No firewall config change detected") return } - fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c) + fw, err := NewFirewallFromConfig(f.l, cs, c) if err != nil { f.l.Error("Error while creating firewall during reload", "error", err) return diff --git a/interface_emit_test.go b/interface_emit_test.go new file mode 100644 index 00000000..b0a9d025 --- /dev/null +++ b/interface_emit_test.go @@ -0,0 +1,73 @@ +//go:build linux || darwin + +package nebula + +import ( + "context" + "net/netip" + "testing" + "time" + + "github.com/rcrowley/go-metrics" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/firewall" + "github.com/slackhq/nebula/overlay/overlaytest" + "github.com/slackhq/nebula/test" + "github.com/slackhq/nebula/udp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test_emitStats_primesGauges covers issue #907: a Prometheus scrape that +// landed before the first ticker fire used to read 0 for the cert gauges. +// emitStats now primes the gauges before entering the ticker loop. We assert +// the gauge is zero before the first call and non-zero after. +func Test_emitStats_primesGauges(t *testing.T) { + defer metrics.DefaultRegistry.UnregisterAll() + + l := test.NewLogger() + hostMap := newHostMap(l) + preferredRanges := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")} + hostMap.preferredRanges.Store(&preferredRanges) + + notAfter := time.Now().Add(time.Hour) + cs := &CertState{ + initiatingVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1, notAfter: notAfter}, + v1Credential: nil, + } + + lh := newTestLighthouse() + ifce := &Interface{ + hostMap: hostMap, + inside: &overlaytest.NoopTun{}, + outside: &udp.NoopConn{}, + firewall: &Firewall{Conntrack: &FirewallConntrack{Conns: map[firewall.Packet]*conn{}}}, + lightHouse: lh, + pki: &PKI{}, + handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), + l: l, + // On linux, udp.NewUDPStatsEmitter indexes writers[0] and asserts to + // *udp.StdConn. A zero value works: getMemInfo sees a nil rawConn, + // returns an error, and the emitter falls through to a no-op. + writers: []udp.Conn{&udp.StdConn{}}, + } + ifce.pki.cs.Store(cs) + + ttlGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil) + require.Zero(t, ttlGauge.Value(), "gauge should be zero before emitStats runs") + + // Pre-cancel the context so emitStats returns after priming the gauges + // without ever reading from ticker.C. The one hour interval is just a + // belt-and-suspenders, the test does not expect the ticker to fire. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + ifce.emitStats(ctx, time.Hour) + + ttl := ttlGauge.Value() + assert.Positive(t, ttl, "ttl gauge should be primed by emitStats before its first tick") + assert.LessOrEqual(t, ttl, int64(3600)) + assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.initiating_version", nil).Value()) + assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.max_version", nil).Value()) +} diff --git a/interface_test.go b/interface_test.go index b0a9d025..1b912bbb 100644 --- a/interface_test.go +++ b/interface_test.go @@ -1,73 +1,120 @@ -//go:build linux || darwin - package nebula import ( - "context" "net/netip" "testing" - "time" - "github.com/rcrowley/go-metrics" "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/overlay/overlaytest" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" - "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -// Test_emitStats_primesGauges covers issue #907: a Prometheus scrape that -// landed before the first ticker fire used to read 0 for the cert gauges. -// emitStats now primes the gauges before entering the ticker loop. We assert -// the gauge is zero before the first call and non-zero after. -func Test_emitStats_primesGauges(t *testing.T) { - defer metrics.DefaultRegistry.UnregisterAll() - +// TestReloadFirewall_CertUnsafeNetworksChanged verifies that reloadFirewall +// rebuilds the firewall when only the certificate's UnsafeNetworks have changed, +// even if the firewall section of the YAML has not. +func TestReloadFirewall_CertUnsafeNetworksChanged(t *testing.T) { l := test.NewLogger() - hostMap := newHostMap(l) - preferredRanges := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")} - hostMap.preferredRanges.Store(&preferredRanges) - notAfter := time.Now().Add(time.Hour) - cs := &CertState{ - initiatingVersion: cert.Version1, - privateKey: []byte{}, - v1Cert: &dummyCert{version: cert.Version1, notAfter: notAfter}, - v1Credential: nil, + vpnNet := netip.MustParsePrefix("10.0.0.1/24") + initialUnsafe := []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")} + + // dummyCert avoids dragging the real signing pipeline into a unit test. + c1 := &dummyCert{ + version: cert.Version2, + networks: []netip.Prefix{vpnNet}, + unsafeNetworks: initialUnsafe, + } + pki := &PKI{} + pki.cs.Store(&CertState{v2Cert: c1, initiatingVersion: cert.Version2}) + + rawYAML := `firewall: + outbound: + - port: any + proto: any + host: any + inbound: + - port: any + proto: any + host: any +` + cfg := config.NewC(l) + require.NoError(t, cfg.LoadString(rawYAML)) + + fw, err := NewFirewallFromConfig(l, pki.getCertState(), cfg) + require.NoError(t, err) + require.Equal(t, initialUnsafe, fw.unsafeNetworks) + + f := &Interface{ + pki: pki, + firewall: fw, + l: l, } - lh := newTestLighthouse() - ifce := &Interface{ - hostMap: hostMap, - inside: &overlaytest.NoopTun{}, - outside: &udp.NoopConn{}, - firewall: &Firewall{Conntrack: &FirewallConntrack{Conns: map[firewall.Packet]*conn{}}}, - lightHouse: lh, - pki: &PKI{}, - handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), - l: l, - // On linux, udp.NewUDPStatsEmitter indexes writers[0] and asserts to - // *udp.StdConn. A zero value works: getMemInfo sees a nil rawConn, - // returns an error, and the emitter falls through to a no-op. - writers: []udp.Conn{&udp.StdConn{}}, + // Swap the cert with a different UnsafeNetworks set. + newUnsafe := []netip.Prefix{ + netip.MustParsePrefix("198.51.100.0/24"), + netip.MustParsePrefix("203.0.113.0/24"), } - ifce.pki.cs.Store(cs) + c2 := &dummyCert{ + version: cert.Version2, + networks: []netip.Prefix{vpnNet}, + unsafeNetworks: newUnsafe, + } + pki.cs.Store(&CertState{v2Cert: c2, initiatingVersion: cert.Version2}) - ttlGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil) - require.Zero(t, ttlGauge.Value(), "gauge should be zero before emitStats runs") + // Reload with the same YAML so HasChanged("firewall") reports false. + require.NoError(t, cfg.ReloadConfigString(rawYAML)) + require.False(t, cfg.HasChanged("firewall")) - // Pre-cancel the context so emitStats returns after priming the gauges - // without ever reading from ticker.C. The one hour interval is just a - // belt-and-suspenders, the test does not expect the ticker to fire. - ctx, cancel := context.WithCancel(context.Background()) - cancel() - ifce.emitStats(ctx, time.Hour) + f.reloadFirewall(cfg) - ttl := ttlGauge.Value() - assert.Positive(t, ttl, "ttl gauge should be primed by emitStats before its first tick") - assert.LessOrEqual(t, ttl, int64(3600)) - assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.initiating_version", nil).Value()) - assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.max_version", nil).Value()) + assert.NotSame(t, fw, f.firewall, "firewall pointer should have been replaced") + assert.Equal(t, newUnsafe, f.firewall.unsafeNetworks) + assert.True(t, f.firewall.routableNetworks.Contains(netip.MustParseAddr("203.0.113.5"))) +} + +// TestReloadFirewall_NoChange verifies that reloadFirewall is a no-op when +// neither the firewall config nor the cert's UnsafeNetworks have changed. +func TestReloadFirewall_NoChange(t *testing.T) { + l := test.NewLogger() + + vpnNet := netip.MustParsePrefix("10.0.0.1/24") + unsafe := []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")} + + c1 := &dummyCert{ + version: cert.Version2, + networks: []netip.Prefix{vpnNet}, + unsafeNetworks: unsafe, + } + pki := &PKI{} + pki.cs.Store(&CertState{v2Cert: c1, initiatingVersion: cert.Version2}) + + rawYAML := `firewall: + outbound: + - port: any + proto: any + host: any + inbound: + - port: any + proto: any + host: any +` + cfg := config.NewC(l) + require.NoError(t, cfg.LoadString(rawYAML)) + + fw, err := NewFirewallFromConfig(l, pki.getCertState(), cfg) + require.NoError(t, err) + + f := &Interface{ + pki: pki, + firewall: fw, + l: l, + } + + require.NoError(t, cfg.ReloadConfigString(rawYAML)) + f.reloadFirewall(cfg) + + assert.Same(t, fw, f.firewall, "firewall should not have been replaced") } From 074a123a4bb51e6dba649f309c713eaab0af96c2 Mon Sep 17 00:00:00 2001 From: randomizedcoder <64496590+randomizedcoder@users.noreply.github.com> Date: Mon, 18 May 2026 10:23:10 -0700 Subject: [PATCH 23/27] Reject port numbers outside [0, 65535] in firewall rule parsing (#1724) --- firewall.go | 39 +++++++++++++++++++-------- firewall_test.go | 69 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 11 deletions(-) diff --git a/firewall.go b/firewall.go index 904c71b2..eb120fa6 100644 --- a/firewall.go +++ b/firewall.go @@ -1055,7 +1055,6 @@ func (r *rule) sanity() error { } func parsePort(s string) (int32, int32, error) { - var err error const notAPort int32 = -2 if s == "any" { return firewall.PortAny, firewall.PortAny, nil @@ -1064,11 +1063,11 @@ func parsePort(s string) (int32, int32, error) { return firewall.PortFragment, firewall.PortFragment, nil } if !strings.Contains(s, `-`) { - rPort, err := strconv.Atoi(s) + rPort, err := parsePortValue("", s) if err != nil { - return notAPort, notAPort, fmt.Errorf("was not a number; `%s`", s) + return notAPort, notAPort, err } - return int32(rPort), int32(rPort), nil + return rPort, rPort, nil } sPorts := strings.SplitN(s, `-`, 2) @@ -1079,22 +1078,40 @@ func parsePort(s string) (int32, int32, error) { return notAPort, notAPort, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s) } - rStartPort, err := strconv.Atoi(sPorts[0]) + startPort, err := parsePortValue("beginning range ", sPorts[0]) if err != nil { - return notAPort, notAPort, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0]) + return notAPort, notAPort, err } - rEndPort, err := strconv.Atoi(sPorts[1]) + endPort, err := parsePortValue("ending range ", sPorts[1]) if err != nil { - return notAPort, notAPort, fmt.Errorf("ending range was not a number; `%s`", sPorts[1]) + return notAPort, notAPort, err } - startPort := int32(rStartPort) - endPort := int32(rEndPort) - if startPort == firewall.PortAny { endPort = firewall.PortAny } return startPort, endPort, nil } + +// parsePortValue accepts a base-10 decimal in [0, 65535] and returns it +// widened to int32. Using strconv.ParseUint with bitSize 16 rejects +// negative input, out-of-range input (>65535), and any non-decimal byte +// by construction, so the int32 widening that follows is provably safe +// and cannot collide with firewall.PortAny (0) or firewall.PortFragment +// (-1) via integer truncation. +// +// prefix is prepended to both error messages so callers can disambiguate +// the single-port path (prefix="") from the range bounds (prefix="beginning +// range " / "ending range "), preserving the historical error strings. +func parsePortValue(prefix, s string) (int32, error) { + n, err := strconv.ParseUint(s, 10, 16) + if err == nil { + return int32(n), nil + } + if errors.Is(err, strconv.ErrRange) { + return 0, fmt.Errorf("%sout of range [0,65535]; `%s`", prefix, s) + } + return 0, fmt.Errorf("%swas not a number; `%s`", prefix, s) +} diff --git a/firewall_test.go b/firewall_test.go index 40b57477..9373f1fd 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -1029,6 +1029,75 @@ func Test_parsePort(t *testing.T) { require.NoError(t, err) } +// Test_parsePort_invalid covers inputs that must error. The named bug is +// that int32(strconv.Atoi("4294967296")) truncates to 0 == firewall.PortAny, +// silently turning a typo into a match-all-ports rule; the rest are +// representative syntax/range probes. +func Test_parsePort_invalid(t *testing.T) { + tests := []struct { + name string + input string + wantErrContains string + }{ + // Numeric overflow (the named bug + boundary). + {"named bug: 2^32 truncates to PortAny", "4294967296", "out of range"}, + {"just above max real port", "65536", "out of range"}, + + // Negatives route through the range branch and hit the empty-half + // guard; included as defense in depth so a future refactor cannot + // accidentally reach the int32 cast. + {"negative", "-1", "could not be parsed"}, + + // Syntax probes. + {"NUL between digits", "4\x002", "was not a number"}, + {"hex notation", "0x10", "was not a number"}, + {"scientific notation", "1e3", "was not a number"}, + {"leading whitespace", " 42", "was not a number"}, + {"fullwidth digits", "42", "was not a number"}, + + // Range branch. + {"range upper out of range", "1-65536", "ending range out of range"}, + {"range lower out of range", "65536-65537", "beginning range out of range"}, + {"range with negative upper", "1--1", "ending range was not a number"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, _, err := parsePort(tc.input) + require.Error(t, err, "input %q must error", tc.input) + require.ErrorContains(t, err, tc.wantErrContains) + }) + } +} + +// Test_parsePort_valid_boundaries locks in success cases at 0, 1, and 65535 +// so a future refactor cannot regress the boundaries. +func Test_parsePort_valid_boundaries(t *testing.T) { + tests := []struct { + name string + input string + wantStart int32 + wantEnd int32 + }{ + {"zero is PortAny", "0", 0, 0}, + {"min real port", "1", 1, 1}, + {"max real port", "65535", 65535, 65535}, + {"range zero to max forces end to zero", "0-65535", 0, 0}, + {"range max to max", "65535-65535", 65535, 65535}, + {"range one to max", "1-65535", 1, 65535}, + {"range with whitespace inside", " 1 - 2 ", 1, 2}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + s, e, err := parsePort(tc.input) + require.NoError(t, err) + assert.Equal(t, tc.wantStart, s, "start port") + assert.Equal(t, tc.wantEnd, e, "end port") + }) + } +} + func TestNewFirewallFromConfig(t *testing.T) { l := test.NewLogger() // Test a bad rule definition From 0c1ad9bb48e8e1c289d92299b75ee3e7ebeb5805 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Tue, 19 May 2026 08:35:04 -0500 Subject: [PATCH 24/27] Parallelize the tests a bit more (#1730) --- .github/workflows/gofmt.yml | 34 -------- .github/workflows/test.yml | 170 ++++++++++++++++++++---------------- Makefile | 43 ++++++++- 3 files changed, 136 insertions(+), 111 deletions(-) delete mode 100644 .github/workflows/gofmt.yml diff --git a/.github/workflows/gofmt.yml b/.github/workflows/gofmt.yml deleted file mode 100644 index 4d57c7b2..00000000 --- a/.github/workflows/gofmt.yml +++ /dev/null @@ -1,34 +0,0 @@ -name: gofmt -on: - push: - branches: - - master - pull_request: - paths: - - '.github/workflows/gofmt.yml' - - '**.go' -jobs: - - gofmt: - name: Run gofmt - runs-on: ubuntu-latest - steps: - - - uses: actions/checkout@v6 - - - uses: actions/setup-go@v6 - with: - go-version: '1.25' - check-latest: true - - - name: Install goimports - run: | - go install golang.org/x/tools/cmd/goimports@latest - - - name: gofmt - run: | - if [ "$(find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -l)" ] - then - find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -d - exit 1 - fi diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 009c22a9..2abb3740 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -13,8 +13,8 @@ on: - 'go.sum' jobs: - test-linux: - name: Build all and test on ubuntu-linux + static: + name: Static checks runs-on: ubuntu-latest steps: @@ -25,8 +25,16 @@ jobs: go-version: '1.25' check-latest: true - - name: Build - run: make all + - name: Install goimports + run: go install golang.org/x/tools/cmd/goimports@latest + + - name: gofmt + run: | + if [ "$(find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -l)" ] + then + find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -d + exit 1 + fi - name: Vet run: make vet @@ -36,66 +44,38 @@ jobs: with: version: v2.5 - - name: Test - run: make test - - - name: End 2 end - run: make e2evv - - - name: Build test mobile - run: make build-test-mobile - - - uses: actions/upload-artifact@v7 - with: - name: e2e packet flow linux-latest - path: e2e/mermaid/linux-latest - if-no-files-found: warn - - test-linux-boringcrypto: - name: Build and test on linux with boringcrypto - runs-on: ubuntu-latest - steps: - - - uses: actions/checkout@v6 - - - uses: actions/setup-go@v6 - with: - go-version: '1.25' - check-latest: true - - - name: Build - run: make bin-boringcrypto - - - name: Test - run: make test-boringcrypto - - - name: End 2 end - run: make e2e GOEXPERIMENT=boringcrypto CGO_ENABLED=1 TEST_ENV="TEST_LOGS=1" TEST_FLAGS="-v -ldflags -checklinkname=0" - - test-linux-pkcs11: - name: Build and test on linux with pkcs11 - runs-on: ubuntu-latest - steps: - - - uses: actions/checkout@v6 - - - uses: actions/setup-go@v6 - with: - go-version: '1.25' - check-latest: true - - - name: Build - run: make bin-pkcs11 - - - name: Test - run: make test-pkcs11 - test: - name: Build and test on ${{ matrix.os }} + name: Test ${{ matrix.name }} runs-on: ${{ matrix.os }} strategy: + fail-fast: false matrix: - os: [windows-latest, macos-latest] + include: + - name: linux + os: ubuntu-latest + build-cmd: go build ./cmd/nebula ./cmd/nebula-cert + test-cmd: make test + e2e-cmd: make e2evv + - name: linux-boringcrypto + os: ubuntu-latest + build-cmd: make bin-boringcrypto + test-cmd: make test-boringcrypto + e2e-cmd: make e2e GOEXPERIMENT=boringcrypto CGO_ENABLED=1 TEST_ENV="TEST_LOGS=1" TEST_FLAGS="-v -ldflags -checklinkname=0" + - name: linux-pkcs11 + os: ubuntu-latest + build-cmd: make bin-pkcs11 + test-cmd: make test-pkcs11 + e2e-cmd: '' + - name: macos + os: macos-latest + build-cmd: go build ./cmd/nebula ./cmd/nebula-cert + test-cmd: make test + e2e-cmd: make e2evv + - name: windows + os: windows-latest + build-cmd: go build ./cmd/nebula ./cmd/nebula-cert + test-cmd: make test + e2e-cmd: make e2evv steps: - uses: actions/checkout@v6 @@ -105,28 +85,66 @@ jobs: go-version: '1.25' check-latest: true - - name: Build nebula - run: go build ./cmd/nebula + - name: Build + run: ${{ matrix.build-cmd }} - - name: Build nebula-cert - run: go build ./cmd/nebula-cert - - - name: Vet - run: make vet - - - name: golangci-lint - uses: golangci/golangci-lint-action@v9 - with: - version: v2.5 + - name: Cross-build darwin-amd64 + if: matrix.name == 'macos' + run: GOARCH=amd64 go build -o /tmp/nebula-amd64 ./cmd/nebula && GOARCH=amd64 go build -o /tmp/nebula-cert-amd64 ./cmd/nebula-cert - name: Test - run: make test + run: ${{ matrix.test-cmd }} - name: End 2 end - run: make e2evv + if: matrix.e2e-cmd != '' + run: ${{ matrix.e2e-cmd }} - uses: actions/upload-artifact@v7 + if: matrix.e2e-cmd != '' && always() with: - name: e2e packet flow ${{ matrix.os }} - path: e2e/mermaid/${{ matrix.os }} + name: e2e packet flow ${{ matrix.name }} + path: e2e/mermaid/ if-no-files-found: warn + + cross-build: + name: Cross-build ${{ matrix.name }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - {name: linux-arm, make-target: all-cross-linux-arm} + - {name: linux-mips, make-target: all-cross-linux-mips} + - {name: linux-other, make-target: all-cross-linux-other} + - {name: freebsd, make-target: all-freebsd} + - {name: openbsd, make-target: all-openbsd} + - {name: netbsd, make-target: all-netbsd} + - {name: windows, make-target: all-cross-windows} + - {name: mobile, make-target: build-test-mobile} + steps: + + - uses: actions/checkout@v6 + + - uses: actions/setup-go@v6 + with: + go-version: '1.25' + check-latest: true + + - name: Build ${{ matrix.name }} + run: make -j"$(nproc)" ${{ matrix.make-target }} + + finish: + name: CI status + if: always() + needs: [static, test, cross-build] + runs-on: ubuntu-latest + steps: + + - name: Fail if any upstream job failed + if: contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') + run: | + echo "upstream results: ${{ toJSON(needs) }}" + exit 1 + + - name: All upstream jobs passed + run: echo "ok" diff --git a/Makefile b/Makefile index 0b199a5a..892c8eb0 100644 --- a/Makefile +++ b/Makefile @@ -60,6 +60,18 @@ ALL = $(ALL_LINUX) \ windows-amd64 \ windows-arm64 +# Cross-build shards used by .github/workflows/test.yml — same as ALL_* +# but with the arch that has a native CI runner removed, so the cross-build +# job is not duplicating coverage the native test jobs already give. +ALL_CROSS_LINUX = $(filter-out linux-amd64,$(ALL_LINUX)) + +# ALL_CROSS_LINUX further split into family sub-shards so each can run on +# its own CI runner in parallel. Union of the three must equal +# ALL_CROSS_LINUX; adding a new linux arch goes into the matching family. +ALL_CROSS_LINUX_ARM = linux-arm-5 linux-arm-6 linux-arm-7 linux-arm64 +ALL_CROSS_LINUX_MIPS = linux-mips linux-mipsle linux-mips64 linux-mips64le linux-mips-softfloat +ALL_CROSS_LINUX_OTHER = linux-386 linux-ppc64le linux-riscv64 linux-loong64 + e2e: $(TEST_ENV) go test -tags=e2e_testing -count=1 $(TEST_FLAGS) ./e2e @@ -82,6 +94,35 @@ DOCKER_BIN = build/linux-amd64/nebula build/linux-amd64/nebula-cert all: $(ALL:%=build/%/nebula) $(ALL:%=build/%/nebula-cert) +all-linux: $(ALL_LINUX:%=build/%/nebula) $(ALL_LINUX:%=build/%/nebula-cert) + +all-freebsd: $(ALL_FREEBSD:%=build/%/nebula) $(ALL_FREEBSD:%=build/%/nebula-cert) + +all-openbsd: $(ALL_OPENBSD:%=build/%/nebula) $(ALL_OPENBSD:%=build/%/nebula-cert) + +all-netbsd: $(ALL_NETBSD:%=build/%/nebula) $(ALL_NETBSD:%=build/%/nebula-cert) + +all-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert build/darwin-arm64/nebula build/darwin-arm64/nebula-cert + +all-windows: build/windows-amd64/nebula.exe build/windows-amd64/nebula-cert.exe build/windows-arm64/nebula.exe build/windows-arm64/nebula-cert.exe + +# CI cross-build shards. darwin-arm64 is covered by the native macos-latest +# job; windows-amd64 is covered by the native windows-latest job; both are +# omitted here to avoid building them a second time. darwin-amd64 stays in +# all-cross-darwin because intel mac is only a labeled/master-time native +# job, so PRs still need cross-build coverage for it. +all-cross-linux: $(ALL_CROSS_LINUX:%=build/%/nebula) $(ALL_CROSS_LINUX:%=build/%/nebula-cert) + +all-cross-linux-arm: $(ALL_CROSS_LINUX_ARM:%=build/%/nebula) $(ALL_CROSS_LINUX_ARM:%=build/%/nebula-cert) + +all-cross-linux-mips: $(ALL_CROSS_LINUX_MIPS:%=build/%/nebula) $(ALL_CROSS_LINUX_MIPS:%=build/%/nebula-cert) + +all-cross-linux-other: $(ALL_CROSS_LINUX_OTHER:%=build/%/nebula) $(ALL_CROSS_LINUX_OTHER:%=build/%/nebula-cert) + +all-cross-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert + +all-cross-windows: build/windows-arm64/nebula.exe build/windows-arm64/nebula-cert.exe + docker: docker/linux-$(shell go env GOARCH) release: $(ALL:%=build/nebula-%.tar.gz) @@ -236,5 +277,5 @@ smoke-vagrant/%: bin-docker build/%/nebula cd .github/workflows/smoke/ && ./smoke-vagrant.sh $* .FORCE: -.PHONY: bench bench-cpu bench-cpu-long bin build-test-mobile e2e e2ev e2evv e2evvv e2evvvv proto release service smoke-docker smoke-docker-race test test-cov-html smoke-vagrant/% +.PHONY: all all-linux all-freebsd all-openbsd all-netbsd all-darwin all-windows all-cross-linux all-cross-linux-arm all-cross-linux-mips all-cross-linux-other all-cross-darwin all-cross-windows bench bench-cpu bench-cpu-long bin build-test-mobile e2e e2ev e2evv e2evvv e2evvvv proto release service smoke-docker smoke-docker-race test test-cov-html smoke-vagrant/% .DEFAULT_GOAL := bin From 72bad1603a92373e1ae8da7b8fd95feb1efc9561 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 22 May 2026 08:53:50 -0500 Subject: [PATCH 25/27] Bump github.com/gaissmai/bart from 0.26.1 to 0.27.1 (#1732) Bumps [github.com/gaissmai/bart](https://github.com/gaissmai/bart) from 0.26.1 to 0.27.1. - [Release notes](https://github.com/gaissmai/bart/releases) - [Commits](https://github.com/gaissmai/bart/compare/v0.26.1...v0.27.1) --- updated-dependencies: - dependency-name: github.com/gaissmai/bart dependency-version: 0.27.1 dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index ee51151f..bd1c0c57 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/armon/go-radix v1.0.0 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/flynn/noise v1.1.0 - github.com/gaissmai/bart v0.26.1 + github.com/gaissmai/bart v0.27.1 github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.4 diff --git a/go.sum b/go.sum index 5640bd46..8ab36d34 100644 --- a/go.sum +++ b/go.sum @@ -26,8 +26,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= -github.com/gaissmai/bart v0.26.1 h1:+w4rnLGNlA2GDVn382Tfe3jOsK5vOr5n4KmigJ9lbTo= -github.com/gaissmai/bart v0.26.1/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c= +github.com/gaissmai/bart v0.27.1 h1:FysPzqETMJa8q9rNkLW5peT1hq25nLOz8ksHbSVoiAk= +github.com/gaissmai/bart v0.27.1/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= From 873f94f4655098e3df133ba8b9eb2633bb594fc9 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 22 May 2026 10:19:06 -0500 Subject: [PATCH 26/27] Reduce relay log spam (#1733) --- handshake_manager.go | 3 +- relay_manager.go | 55 ++++++++++++++++-------- relay_manager_test.go | 97 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 136 insertions(+), 19 deletions(-) create mode 100644 relay_manager_test.go diff --git a/handshake_manager.go b/handshake_manager.go index 87257028..d03814da 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -83,6 +83,7 @@ type HandshakeHostInfo struct { initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake? counter int64 // How many attempts have we made so far lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt + lastRelays []netip.Addr // Relays we attempted to use during the previous attempt packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes hostinfo *HostInfo @@ -323,7 +324,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered ) } - hm.f.relayManager.StartRelays(hm.f, vpnIp, hostinfo, stage0) + hm.f.relayManager.StartRelays(hm.f, vpnIp, hh, stage0) // If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add if !lighthouseTriggered { diff --git a/relay_manager.go b/relay_manager.go index 25e65871..1fd98963 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -7,6 +7,7 @@ import ( "fmt" "log/slog" "net/netip" + "slices" "sync/atomic" "github.com/slackhq/nebula/cert" @@ -57,14 +58,25 @@ func (rm *relayManager) GetUseRelays() bool { // For each candidate relay it either kicks off a handshake to the relay, sends a CreateRelayRequest, retransmits // one that may have been lost, or, once the relay is Established, forwards the in-progress // stage 0 handshake packet for vpnIp through it. -func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *HostInfo, stage0 []byte) { +func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hh *HandshakeHostInfo, stage0 []byte) { + hostinfo := hh.hostinfo if !rm.GetUseRelays() || len(hostinfo.remotes.relays) == 0 { + hh.lastRelays = nil return } - hostinfo.logger(rm.l).Info("Attempt to relay through hosts", "relays", hostinfo.remotes.relays) + relays := hostinfo.remotes.relays + listLevel := slog.LevelDebug + prior := hh.lastRelays + if !slices.Equal(relays, prior) { + listLevel = slog.LevelInfo + hh.lastRelays = slices.Clone(relays) + } + hl := hostinfo.logger(rm.l) + hl.Log(context.Background(), listLevel, "Attempt to relay through hosts", "relays", relays) + // Send a RelayRequest to all known Relay IP's - for _, relay := range hostinfo.remotes.relays { + for _, relay := range relays { // Don't relay through the host I'm trying to connect to if relay == vpnIp { continue @@ -75,12 +87,19 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho continue } + // Each relay's per-attempt log fires at Info on the first time we hit it and Debug after that. + level := slog.LevelInfo + if slices.Contains(prior, relay) { + level = slog.LevelDebug + } + relayHostInfo := rm.hostmap.QueryVpnAddr(relay) if relayHostInfo == nil || !relayHostInfo.remote.IsValid() { - hostinfo.logger(rm.l).Info("Establish tunnel to relay target", "relay", relay.String()) + hl.Log(context.Background(), level, "Establish tunnel to relay target", "relay", relay.String()) f.Handshake(relay) continue } + // Check the relay HostInfo to see if we already established a relay through existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp) if !ok { @@ -88,7 +107,7 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho if relayHostInfo.remote.IsValid() { idx, err := AddRelay(rm.l, relayHostInfo, rm.hostmap, vpnIp, nil, TerminalType, Requested) if err != nil { - hostinfo.logger(rm.l).Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err) + hl.Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err) } m := NebulaControl{ @@ -99,12 +118,12 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho switch relayHostInfo.GetCert().Certificate.Version() { case cert.Version1: if !f.myVpnAddrs[0].Is4() { - hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") + hl.Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") continue } if !vpnIp.Is4() { - hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") + hl.Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") continue } @@ -116,16 +135,16 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho m.RelayFromAddr = netAddrToProtoAddr(f.myVpnAddrs[0]) m.RelayToAddr = netAddrToProtoAddr(vpnIp) default: - hostinfo.logger(rm.l).Error("Unknown certificate version found while creating relay") + hl.Error("Unknown certificate version found while creating relay") continue } msg, err := m.Marshal() if err != nil { - hostinfo.logger(rm.l).Error("Failed to marshal Control message to create relay", "error", err) + hl.Error("Failed to marshal Control message to create relay", "error", err) } else { f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - rm.l.Info("send CreateRelayRequest", + rm.l.Log(context.Background(), level, "send CreateRelayRequest", "relayFrom", f.myVpnAddrs[0], "relayTo", vpnIp, "initiatorRelayIndex", idx, @@ -138,14 +157,14 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho switch existingRelay.State { case Established: - hostinfo.logger(rm.l).Info("Send handshake via relay", "relay", relay.String()) + hl.Log(context.Background(), level, "Send handshake via relay", "relay", relay.String()) f.SendVia(relayHostInfo, existingRelay, stage0, make([]byte, 12), make([]byte, mtu), false) case Disestablished: // Mark this relay as 'requested' relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested) fallthrough case Requested: - hostinfo.logger(rm.l).Info("Re-send CreateRelay request", "relay", relay.String()) + hl.Log(context.Background(), level, "Re-send CreateRelay request", "relay", relay.String()) // Re-send the CreateRelay request, in case the previous one was lost. m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, @@ -155,12 +174,12 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho switch relayHostInfo.GetCert().Certificate.Version() { case cert.Version1: if !f.myVpnAddrs[0].Is4() { - hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") + hl.Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") continue } if !vpnIp.Is4() { - hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") + hl.Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") continue } @@ -172,16 +191,16 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho m.RelayFromAddr = netAddrToProtoAddr(f.myVpnAddrs[0]) m.RelayToAddr = netAddrToProtoAddr(vpnIp) default: - hostinfo.logger(rm.l).Error("Unknown certificate version found while creating relay") + hl.Error("Unknown certificate version found while creating relay") continue } msg, err := m.Marshal() if err != nil { - hostinfo.logger(rm.l).Error("Failed to marshal Control message to create relay", "error", err) + hl.Error("Failed to marshal Control message to create relay", "error", err) } else { // This must send over the hostinfo, not over hm.Hosts[ip] f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - rm.l.Info("send CreateRelayRequest", + rm.l.Log(context.Background(), level, "send CreateRelayRequest", "relayFrom", f.myVpnAddrs[0], "relayTo", vpnIp, "initiatorRelayIndex", existingRelay.LocalIndex, @@ -192,7 +211,7 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho // PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case. fallthrough default: - hostinfo.logger(rm.l).Error("Relay unexpected state", + hl.Error("Relay unexpected state", "vpnIp", vpnIp, "state", existingRelay.State, "relay", relay, diff --git a/relay_manager_test.go b/relay_manager_test.go new file mode 100644 index 00000000..8da38940 --- /dev/null +++ b/relay_manager_test.go @@ -0,0 +1,97 @@ +package nebula + +import ( + "bytes" + "log/slog" + "net/netip" + "testing" + + "github.com/gaissmai/bart" + "github.com/slackhq/nebula/test" + "github.com/stretchr/testify/assert" +) + +// TestStartRelaysLogDedupe verifies that repeated attempts with the same relay set drop the log +// chatter to Debug, mirroring how the normal handshake retry loop quiets down once it's already +// announced its targets. +func TestStartRelaysLogDedupe(t *testing.T) { + vpnIp := netip.MustParseAddr("100.64.99.4") + otherRelay := netip.MustParseAddr("100.64.99.5") + + newHH := func() *HandshakeHostInfo { + // Use the target's own vpnIp as the "relay" so the loop body skips it without + // touching any sender-side state. That isolates the test to the level-selection + // behavior of the top-level "Attempt to relay through hosts" log. + hostinfo := &HostInfo{ + vpnAddrs: []netip.Addr{vpnIp}, + localIndexId: 1, + remotes: NewRemoteList([]netip.Addr{vpnIp}, nil), + } + hostinfo.remotes.relays = []netip.Addr{vpnIp} + return &HandshakeHostInfo{hostinfo: hostinfo} + } + + // Park any extra relay addresses we'll introduce mid-test in myVpnAddrsTable so the loop + // body always skips before touching f.Handshake (which would need a real handshakeManager). + addrTable := new(bart.Lite) + addrTable.Insert(netip.PrefixFrom(otherRelay, otherRelay.BitLen())) + f := &Interface{myVpnAddrsTable: addrTable} + + newRM := func(buf *bytes.Buffer) *relayManager { + l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug) + rm := &relayManager{l: l, hostmap: newHostMap(l)} + rm.useRelays.Store(true) + return rm + } + + const msg = `msg="Attempt to relay through hosts"` + + t.Run("first attempt logs at Info", func(t *testing.T) { + var buf bytes.Buffer + rm := newRM(&buf) + hh := newHH() + rm.StartRelays(f, vpnIp, hh, nil) + assert.Equal(t, []netip.Addr{vpnIp}, hh.lastRelays, "lastRelays should record the relay set we just attempted") + assert.Contains(t, buf.String(), "level=INFO "+msg, "expected Info level on first attempt") + }) + + t.Run("repeat attempt with same relays drops to Debug", func(t *testing.T) { + var buf bytes.Buffer + rm := newRM(&buf) + hh := newHH() + rm.StartRelays(f, vpnIp, hh, nil) + first := append([]netip.Addr(nil), hh.lastRelays...) + buf.Reset() + rm.StartRelays(f, vpnIp, hh, nil) + assert.Equal(t, first, hh.lastRelays) + assert.Contains(t, buf.String(), "level=DEBUG "+msg, "expected Debug level on identical retry") + assert.NotContains(t, buf.String(), "level=INFO "+msg, "Info should not fire on identical retry") + }) + + t.Run("changed relay list bumps back to Info", func(t *testing.T) { + var buf bytes.Buffer + rm := newRM(&buf) + hh := newHH() + rm.StartRelays(f, vpnIp, hh, nil) + buf.Reset() + + // The lighthouse handed us a new set this round. + hh.hostinfo.remotes.relays = []netip.Addr{vpnIp, otherRelay} + + rm.StartRelays(f, vpnIp, hh, nil) + assert.Equal(t, []netip.Addr{vpnIp, otherRelay}, hh.lastRelays) + assert.Contains(t, buf.String(), "level=INFO "+msg, "expected Info when the relay list changes") + }) + + t.Run("disabled relays clears lastRelays and emits no Attempt log", func(t *testing.T) { + var buf bytes.Buffer + rm := newRM(&buf) + rm.useRelays.Store(false) + hh := newHH() + hh.lastRelays = []netip.Addr{vpnIp} + + rm.StartRelays(f, vpnIp, hh, nil) + assert.Nil(t, hh.lastRelays, "with relays disabled lastRelays should be cleared") + assert.NotContains(t, buf.String(), msg, "should not log when we shortcut out") + }) +} From 3a95495c6355dffeb83607eeedcf5a96eb5d484f Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 22 May 2026 10:19:53 -0500 Subject: [PATCH 27/27] Fix duplicate log fields which slog duplicates (#1734) --- handshake_manager.go | 3 --- inside.go | 1 - outside.go | 5 ++--- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/handshake_manager.go b/handshake_manager.go index d03814da..e04886b5 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -218,7 +218,6 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered fields := []any{ "udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()), "initiatorIndex", hh.hostinfo.localIndexId, - "remoteIndex", hh.hostinfo.remoteIndexId, "durationNs", time.Since(hh.startTime).Nanoseconds(), } // hh.machine can be nil here if buildStage0Packet never succeeded @@ -466,7 +465,6 @@ func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex", - "remoteIndex", hostinfo.remoteIndexId, "collision", existingRemoteIndex.vpnAddrs, ) } @@ -489,7 +487,6 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex", - "remoteIndex", hostinfo.remoteIndexId, "collision", existingRemoteIndex.vpnAddrs, ) } diff --git a/inside.go b/inside.go index 68cb38ec..27a6f758 100644 --- a/inside.go +++ b/inside.go @@ -391,7 +391,6 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType "error", err, "udpAddr", remote, "counter", c, - "attemptedCounter", c, ) return } diff --git a/outside.go b/outside.go index 17013ed3..4c0c935e 100644 --- a/outside.go +++ b/outside.go @@ -194,8 +194,7 @@ func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing // its internal mapping. This should never happen. hostinfo.logger(f.l).Error("HostInfo missing remote relay index", - "vpnAddrs", hostinfo.vpnAddrs, - "remoteIndex", h.RemoteIndex, + "relayRemoteIndex", h.RemoteIndex, ) return } @@ -218,8 +217,8 @@ func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, if err != nil { hostinfo.logger(f.l).Info("Failed to find target host info by ip", "relayTo", relay.PeerAddr, + "relayFrom", hostinfo.vpnAddrs[0], "error", err, - "hostinfo.vpnAddrs", hostinfo.vpnAddrs, ) return }