From 33c2d7277c3a6f43b3ef63dc87f7e4754722ca29 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 1 May 2026 13:21:38 -0500 Subject: [PATCH 01/31] Reduce HandshakeManager complexity a little bit (#1701) --- handshake_manager.go | 144 +------------------------------------ main.go | 10 +-- relay_manager.go | 166 +++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 163 insertions(+), 157 deletions(-) diff --git a/handshake_manager.go b/handshake_manager.go index 9fc69ff4..87257028 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -23,7 +23,6 @@ const ( DefaultHandshakeTryInterval = time.Millisecond * 100 DefaultHandshakeRetries = 10 DefaultHandshakeTriggerBuffer = 64 - DefaultUseRelays = true // maxCachedPackets is how many unsent packets we'll buffer per pending // handshake before dropping further ones. @@ -43,7 +42,6 @@ var ( tryInterval: DefaultHandshakeTryInterval, retries: DefaultHandshakeRetries, triggerBuffer: DefaultHandshakeTriggerBuffer, - useRelays: DefaultUseRelays, } ) @@ -51,7 +49,6 @@ type HandshakeConfig struct { tryInterval time.Duration retries int64 triggerBuffer int - useRelays bool messageMetrics *MessageMetrics } @@ -326,146 +323,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered ) } - if hm.config.useRelays && len(hostinfo.remotes.relays) > 0 { - hostinfo.logger(hm.l).Info("Attempt to relay through hosts", "relays", hostinfo.remotes.relays) - // Send a RelayRequest to all known Relay IP's - for _, relay := range hostinfo.remotes.relays { - // Don't relay through the host I'm trying to connect to - if relay == vpnIp { - continue - } - - // Don't relay to myself - if hm.f.myVpnAddrsTable.Contains(relay) { - continue - } - - relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay) - if relayHostInfo == nil || !relayHostInfo.remote.IsValid() { - hostinfo.logger(hm.l).Info("Establish tunnel to relay target", "relay", relay.String()) - hm.f.Handshake(relay) - continue - } - // Check the relay HostInfo to see if we already established a relay through - existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp) - if !ok { - // No relays exist or requested yet. - if relayHostInfo.remote.IsValid() { - idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested) - if err != nil { - hostinfo.logger(hm.l).Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err) - } - - m := NebulaControl{ - Type: NebulaControl_CreateRelayRequest, - InitiatorRelayIndex: idx, - } - - switch relayHostInfo.GetCert().Certificate.Version() { - case cert.Version1: - if !hm.f.myVpnAddrs[0].Is4() { - hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") - continue - } - - if !vpnIp.Is4() { - hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") - continue - } - - b := hm.f.myVpnAddrs[0].As4() - m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) - b = vpnIp.As4() - m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) - case cert.Version2: - m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0]) - m.RelayToAddr = netAddrToProtoAddr(vpnIp) - default: - hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay") - continue - } - - msg, err := m.Marshal() - if err != nil { - hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err) - } else { - hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - hm.l.Info("send CreateRelayRequest", - "relayFrom", hm.f.myVpnAddrs[0], - "relayTo", vpnIp, - "initiatorRelayIndex", idx, - "relay", relay, - ) - } - } - continue - } - - switch existingRelay.State { - case Established: - hostinfo.logger(hm.l).Info("Send handshake via relay", "relay", relay.String()) - hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[handshakePacketStage0], make([]byte, 12), make([]byte, mtu), false) - case Disestablished: - // Mark this relay as 'requested' - relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested) - fallthrough - case Requested: - hostinfo.logger(hm.l).Info("Re-send CreateRelay request", "relay", relay.String()) - // Re-send the CreateRelay request, in case the previous one was lost. - m := NebulaControl{ - Type: NebulaControl_CreateRelayRequest, - InitiatorRelayIndex: existingRelay.LocalIndex, - } - - switch relayHostInfo.GetCert().Certificate.Version() { - case cert.Version1: - if !hm.f.myVpnAddrs[0].Is4() { - hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") - continue - } - - if !vpnIp.Is4() { - hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") - continue - } - - b := hm.f.myVpnAddrs[0].As4() - m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) - b = vpnIp.As4() - m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) - case cert.Version2: - m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0]) - m.RelayToAddr = netAddrToProtoAddr(vpnIp) - default: - hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay") - continue - } - msg, err := m.Marshal() - if err != nil { - hostinfo.logger(hm.l).Error("Failed to marshal Control message to create relay", "error", err) - } else { - // This must send over the hostinfo, not over hm.Hosts[ip] - hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - hm.l.Info("send CreateRelayRequest", - "relayFrom", hm.f.myVpnAddrs[0], - "relayTo", vpnIp, - "initiatorRelayIndex", existingRelay.LocalIndex, - "relay", relay, - ) - } - case PeerRequested: - // PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case. - fallthrough - default: - hostinfo.logger(hm.l).Error("Relay unexpected state", - "vpnIp", vpnIp, - "state", existingRelay.State, - "relay", relay, - ) - - } - } - } + hm.f.relayManager.StartRelays(hm.f, vpnIp, hostinfo, stage0) // If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add if !lighthouseTriggered { diff --git a/main.go b/main.go index eef13c97..d5e5dcc8 100644 --- a/main.go +++ b/main.go @@ -184,14 +184,10 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev messageMetrics = newMessageMetricsOnlyRecvError() } - useRelays := c.GetBool("relay.use_relays", DefaultUseRelays) && !c.GetBool("relay.am_relay", false) - handshakeConfig := HandshakeConfig{ - tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval), - retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)), - triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer), - useRelays: useRelays, - + tryInterval: c.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval), + retries: int64(c.GetInt("handshakes.retries", DefaultHandshakeRetries)), + triggerBuffer: c.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer), messageMetrics: messageMetrics, } diff --git a/relay_manager.go b/relay_manager.go index 919bb2b6..25e65871 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -15,9 +15,10 @@ import ( ) type relayManager struct { - l *slog.Logger - hostmap *HostMap - amRelay atomic.Bool + l *slog.Logger + hostmap *HostMap + amRelay atomic.Bool + useRelays atomic.Bool } func NewRelayManager(ctx context.Context, l *slog.Logger, hostmap *HostMap, c *config.C) *relayManager { @@ -36,8 +37,10 @@ func NewRelayManager(ctx context.Context, l *slog.Logger, hostmap *HostMap, c *c } func (rm *relayManager) reload(c *config.C, initial bool) error { - if initial || c.HasChanged("relay.am_relay") { - rm.setAmRelay(c.GetBool("relay.am_relay", false)) + if initial || c.HasChanged("relay.am_relay") || c.HasChanged("relay.use_relays") { + amRelay := c.GetBool("relay.am_relay", false) + rm.amRelay.Store(amRelay) + rm.useRelays.Store(c.GetBool("relay.use_relays", true) && !amRelay) } return nil } @@ -46,8 +49,157 @@ func (rm *relayManager) GetAmRelay() bool { return rm.amRelay.Load() } -func (rm *relayManager) setAmRelay(v bool) { - rm.amRelay.Store(v) +func (rm *relayManager) GetUseRelays() bool { + return rm.useRelays.Load() +} + +// StartRelays drives the relay-establishment side of an outbound handshake attempt. +// For each candidate relay it either kicks off a handshake to the relay, sends a CreateRelayRequest, retransmits +// one that may have been lost, or, once the relay is Established, forwards the in-progress +// stage 0 handshake packet for vpnIp through it. +func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *HostInfo, stage0 []byte) { + if !rm.GetUseRelays() || len(hostinfo.remotes.relays) == 0 { + return + } + + hostinfo.logger(rm.l).Info("Attempt to relay through hosts", "relays", hostinfo.remotes.relays) + // Send a RelayRequest to all known Relay IP's + for _, relay := range hostinfo.remotes.relays { + // Don't relay through the host I'm trying to connect to + if relay == vpnIp { + continue + } + + // Don't relay to myself + if f.myVpnAddrsTable.Contains(relay) { + continue + } + + relayHostInfo := rm.hostmap.QueryVpnAddr(relay) + if relayHostInfo == nil || !relayHostInfo.remote.IsValid() { + hostinfo.logger(rm.l).Info("Establish tunnel to relay target", "relay", relay.String()) + f.Handshake(relay) + continue + } + // Check the relay HostInfo to see if we already established a relay through + existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp) + if !ok { + // No relays exist or requested yet. + if relayHostInfo.remote.IsValid() { + idx, err := AddRelay(rm.l, relayHostInfo, rm.hostmap, vpnIp, nil, TerminalType, Requested) + if err != nil { + hostinfo.logger(rm.l).Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err) + } + + m := NebulaControl{ + Type: NebulaControl_CreateRelayRequest, + InitiatorRelayIndex: idx, + } + + switch relayHostInfo.GetCert().Certificate.Version() { + case cert.Version1: + if !f.myVpnAddrs[0].Is4() { + hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") + continue + } + + if !vpnIp.Is4() { + hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") + continue + } + + b := f.myVpnAddrs[0].As4() + m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = vpnIp.As4() + m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + case cert.Version2: + m.RelayFromAddr = netAddrToProtoAddr(f.myVpnAddrs[0]) + m.RelayToAddr = netAddrToProtoAddr(vpnIp) + default: + hostinfo.logger(rm.l).Error("Unknown certificate version found while creating relay") + continue + } + + msg, err := m.Marshal() + if err != nil { + hostinfo.logger(rm.l).Error("Failed to marshal Control message to create relay", "error", err) + } else { + f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) + rm.l.Info("send CreateRelayRequest", + "relayFrom", f.myVpnAddrs[0], + "relayTo", vpnIp, + "initiatorRelayIndex", idx, + "relay", relay, + ) + } + } + continue + } + + switch existingRelay.State { + case Established: + hostinfo.logger(rm.l).Info("Send handshake via relay", "relay", relay.String()) + f.SendVia(relayHostInfo, existingRelay, stage0, make([]byte, 12), make([]byte, mtu), false) + case Disestablished: + // Mark this relay as 'requested' + relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested) + fallthrough + case Requested: + hostinfo.logger(rm.l).Info("Re-send CreateRelay request", "relay", relay.String()) + // Re-send the CreateRelay request, in case the previous one was lost. + m := NebulaControl{ + Type: NebulaControl_CreateRelayRequest, + InitiatorRelayIndex: existingRelay.LocalIndex, + } + + switch relayHostInfo.GetCert().Certificate.Version() { + case cert.Version1: + if !f.myVpnAddrs[0].Is4() { + hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") + continue + } + + if !vpnIp.Is4() { + hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") + continue + } + + b := f.myVpnAddrs[0].As4() + m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = vpnIp.As4() + m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + case cert.Version2: + m.RelayFromAddr = netAddrToProtoAddr(f.myVpnAddrs[0]) + m.RelayToAddr = netAddrToProtoAddr(vpnIp) + default: + hostinfo.logger(rm.l).Error("Unknown certificate version found while creating relay") + continue + } + msg, err := m.Marshal() + if err != nil { + hostinfo.logger(rm.l).Error("Failed to marshal Control message to create relay", "error", err) + } else { + // This must send over the hostinfo, not over hm.Hosts[ip] + f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) + rm.l.Info("send CreateRelayRequest", + "relayFrom", f.myVpnAddrs[0], + "relayTo", vpnIp, + "initiatorRelayIndex", existingRelay.LocalIndex, + "relay", relay, + ) + } + case PeerRequested: + // PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case. + fallthrough + default: + hostinfo.logger(rm.l).Error("Relay unexpected state", + "vpnIp", vpnIp, + "state", existingRelay.State, + "relay", relay, + ) + + } + } } // AddRelay finds an available relay index on the hostmap, and associates the relay info with it. From b7e9939e921aab000699115330fb31f33c6449b9 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Mon, 4 May 2026 10:12:58 -0500 Subject: [PATCH 02/31] More stable e2e test harness, better for benchmarking (#1702) --- control_tester.go | 72 ++-------- e2e/handshake_manager_test.go | 24 ++-- e2e/handshakes_test.go | 76 ++++++----- e2e/helpers_test.go | 59 ++++++++- e2e/router/router.go | 242 ++++++++++++++++++++++++++-------- e2e/tunnels_test.go | 4 +- overlay/tun_tester.go | 54 +++++++- udp/udp_tester.go | 67 +++++++--- 8 files changed, 418 insertions(+), 180 deletions(-) diff --git a/control_tester.go b/control_tester.go index f927140b..728ac649 100644 --- a/control_tester.go +++ b/control_tester.go @@ -5,8 +5,6 @@ package nebula import ( "net/netip" - "github.com/google/gopacket" - "github.com/google/gopacket/layers" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" @@ -22,7 +20,9 @@ func (c *Control) WaitForType(msgType header.MessageType, subType header.Message panic(err) } pipeTo.InjectUDPPacket(p) - if h.Type == msgType && h.Subtype == subType { + match := h.Type == msgType && h.Subtype == subType + p.Release() + if match { return } } @@ -38,7 +38,9 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType, panic(err) } pipeTo.InjectUDPPacket(p) - if h.RemoteIndex == toIndex && h.Type == msgType && h.Subtype == subType { + match := h.RemoteIndex == toIndex && h.Type == msgType && h.Subtype == subType + p.Release() + if match { return } } @@ -90,65 +92,15 @@ func (c *Control) GetTunTxChan() <-chan []byte { return c.f.inside.(*overlay.TestTun).TxPackets } -// InjectUDPPacket will inject a packet into the udp side of nebula +// InjectUDPPacket injects a packet into the udp side. We copy internally so the caller keeps ownership of p. +// The copy comes from the freelist so steady-state alloc is zero. func (c *Control) InjectUDPPacket(p *udp.Packet) { - c.f.outside.(*udp.TesterConn).Send(p) + c.f.outside.(*udp.TesterConn).Send(p.Copy()) } -// InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol -func (c *Control) InjectTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) { - serialize := make([]gopacket.SerializableLayer, 0) - var netLayer gopacket.NetworkLayer - if toAddr.Is6() { - if !fromAddr.Is6() { - panic("Cant send ipv6 to ipv4") - } - ip := &layers.IPv6{ - Version: 6, - NextHeader: layers.IPProtocolUDP, - SrcIP: fromAddr.Unmap().AsSlice(), - DstIP: toAddr.Unmap().AsSlice(), - } - serialize = append(serialize, ip) - netLayer = ip - } else { - if !fromAddr.Is4() { - panic("Cant send ipv4 to ipv6") - } - - ip := &layers.IPv4{ - Version: 4, - TTL: 64, - Protocol: layers.IPProtocolUDP, - SrcIP: fromAddr.Unmap().AsSlice(), - DstIP: toAddr.Unmap().AsSlice(), - } - serialize = append(serialize, ip) - netLayer = ip - } - - udp := layers.UDP{ - SrcPort: layers.UDPPort(fromPort), - DstPort: layers.UDPPort(toPort), - } - err := udp.SetNetworkLayerForChecksum(netLayer) - if err != nil { - panic(err) - } - - buffer := gopacket.NewSerializeBuffer() - opt := gopacket.SerializeOptions{ - ComputeChecksums: true, - FixLengths: true, - } - - serialize = append(serialize, &udp, gopacket.Payload(data)) - err = gopacket.SerializeLayers(buffer, opt, serialize...) - if err != nil { - panic(err) - } - - c.f.inside.(*overlay.TestTun).Send(buffer.Bytes()) +// InjectTunPacket pushes an IP packet onto the tun interface. +func (c *Control) InjectTunPacket(packet []byte) { + c.f.inside.(*overlay.TestTun).Send(packet) } func (c *Control) GetVpnAddrs() []netip.Addr { diff --git a/e2e/handshake_manager_test.go b/e2e/handshake_manager_test.go index 1c6ebacc..b06564d1 100644 --- a/e2e/handshake_manager_test.go +++ b/e2e/handshake_manager_test.go @@ -47,7 +47,7 @@ func TestHandshakeRetransmitDuplicate(t *testing.T) { defer r.RenderFlow() t.Log("Trigger handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))) t.Log("Grab my msg1") msg1 := myControl.GetFromUDP(true) @@ -97,7 +97,7 @@ func TestHandshakeTruncatedPacketRecovery(t *testing.T) { defer r.RenderFlow() t.Log("Trigger handshake") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))) t.Log("Get msg1 and deliver to responder") msg1 := myControl.GetFromUDP(true) @@ -146,7 +146,7 @@ func TestHandshakeOrphanedMsg2Dropped(t *testing.T) { defer r.RenderFlow() t.Log("Complete a normal handshake") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))) r.RouteForAllUntilTxTun(theirControl) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) @@ -248,7 +248,7 @@ func TestHandshakeLateResponse(t *testing.T) { theirControl.Start() t.Log("Trigger handshake from me") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))) t.Log("Grab msg1 but don't deliver") msg1 := myControl.GetFromUDP(true) @@ -292,7 +292,7 @@ func TestHandshakeSelfConnectionRejected(t *testing.T) { myControl.Start() t.Log("Trigger handshake from me") - myControl.InjectTunUDPPacket(netip.MustParseAddr("10.128.0.2"), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) + myControl.InjectTunPacket(BuildTunUDPPacket(netip.MustParseAddr("10.128.0.2"), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))) msg1 := myControl.GetFromUDP(true) t.Log("Drain any handshake retransmits before injecting") @@ -375,7 +375,7 @@ func TestHandshakeRemoteAllowList(t *testing.T) { defer r.RenderFlow() t.Log("Trigger handshake from them") - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi")) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi"))) msg1 := theirControl.GetFromUDP(true) t.Log("Rewrite the source to a blocked IP and inject") @@ -426,7 +426,7 @@ func TestHandshakeAlreadySeenPreferredRemote(t *testing.T) { defer r.RenderFlow() t.Log("Complete a normal handshake via the router") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi"))) r.RouteForAllUntilTxTun(theirControl) assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) @@ -437,7 +437,7 @@ func TestHandshakeAlreadySeenPreferredRemote(t *testing.T) { originalRemote := hi.CurrentRemote t.Log("Re-trigger traffic to cause a new handshake attempt (ErrAlreadySeen)") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("roam")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("roam"))) r.RouteForAllUntilTxTun(theirControl) t.Log("Verify tunnel still works") @@ -475,8 +475,8 @@ func TestHandshakeWrongResponderPacketStore(t *testing.T) { evilControl.Start() t.Log("Send multiple packets to them (cached during handshake)") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet1")) - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet2")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet1"))) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("packet2"))) t.Log("Route until evil tunnel is closed") h := &header.H{} @@ -540,7 +540,7 @@ func TestHandshakeRelayComplete(t *testing.T) { theirControl.Start() t.Log("Trigger handshake via relay") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi via relay")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi via relay"))) p := r.RouteForAllUntilTxTun(theirControl) assertUdpPacket(t, []byte("Hi via relay"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) @@ -568,7 +568,7 @@ func TestHandshakeRelayComplete(t *testing.T) { } // NOTE: Relay V1 cert + IPv6 rejection is not tested here because -// InjectTunUDPPacket from a V4 node to a V6 address panics in the test +// BuildTunUDPPacket from a V4 node to a V6 address panics in the test // framework. The check is in handshake_manager.go handleOutbound relay // logic (lines ~304-313): if the relay host has a V1 cert and either // address is IPv6, the relay is skipped. diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 43fa72f2..d0b9543c 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -16,6 +16,7 @@ import ( "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/overlay" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -39,11 +40,22 @@ func BenchmarkHotPath(b *testing.B) { r.CancelFlowLogs() assertTunnel(b, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + // Pre-build the IP packet bytes once so the bench measures the data plane, + // not gopacket SerializeLayers overhead. + prebuilt := BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + + // EnableFanIn switches the router to a 0-alloc routing path. Required + // for hot-path benchmarks; would conflict with GetFromUDP-using tests. + r.EnableFanIn() + b.ResetTimer() for n := 0; n < b.N; n++ { - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) - _ = r.RouteForAllUntilTxTun(theirControl) + myControl.InjectTunPacket(prebuilt) + // Release the TUN-side bytes back to the harness freelist; the bench + // just confirms a packet arrived, the contents aren't inspected. + overlay.ReleaseTunBuf(r.RouteForAllUntilTxTun(theirControl)) } myControl.Stop() @@ -71,11 +83,15 @@ func BenchmarkHotPathRelay(b *testing.B) { theirControl.Start() assertTunnel(b, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) + + prebuilt := BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + r.EnableFanIn() + b.ResetTimer() for n := 0; n < b.N; n++ { - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) - _ = r.RouteForAllUntilTxTun(theirControl) + myControl.InjectTunPacket(prebuilt) + overlay.ReleaseTunBuf(r.RouteForAllUntilTxTun(theirControl)) } myControl.Stop() @@ -97,7 +113,7 @@ func TestGoodHandshake(t *testing.T) { theirControl.Start() t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) t.Log("Have them consume my stage 0 packet. They have a tunnel now") theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) @@ -191,7 +207,7 @@ func TestWrongResponderHandshake(t *testing.T) { evilControl.Start() t.Log("Start the handshake process, we will route until we see the evil tunnel closed") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) h := &header.H{} r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { @@ -273,7 +289,7 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) { evilControl.Start() t.Log("Start the handshake process, we will route until we see the evil tunnel closed") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) h := &header.H{} r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { @@ -352,8 +368,8 @@ func TestStage1Race(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake to start on both me and them") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))) t.Log("Get both stage 1 handshake packets") myHsForThem := myControl.GetFromUDP(true) @@ -430,7 +446,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) p := r.RouteForAllUntilTxTun(theirControl) assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) @@ -441,7 +457,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { myHostmap.Indexes = map[uint32]*nebula.HostInfo{} myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me again")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me again"))) p = r.RouteForAllUntilTxTun(theirControl) assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) @@ -480,7 +496,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) p := r.RouteForAllUntilTxTun(theirControl) assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) @@ -492,7 +508,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { theirHostmap.Indexes = map[uint32]*nebula.HostInfo{} theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them again")) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them again"))) p = r.RouteForAllUntilTxTun(myControl) assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("Derp hostmaps", myControl, theirControl) @@ -535,7 +551,7 @@ func TestRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") @@ -565,7 +581,7 @@ func TestRelaysDontCareAboutIps(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") @@ -595,14 +611,14 @@ func TestReestablishRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) t.Log("Ensure packet traversal from them to me via the relay") - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))) p = r.RouteForAllUntilTxTun(myControl) r.Log("Assert the tunnel works") @@ -617,7 +633,7 @@ func TestReestablishRelays(t *testing.T) { for curIndexes >= start { curIndexes = len(myControl.GetHostmap().Indexes) r.Logf("Wait for the dead index to go away:start=%v indexes, current=%v indexes", start, curIndexes) - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail"))) r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { return router.RouteAndExit @@ -634,7 +650,7 @@ func TestReestablishRelays(t *testing.T) { myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) p = r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") @@ -669,7 +685,7 @@ func TestReestablishRelays(t *testing.T) { t.Log("Assert the tunnel works the other way, too") for { t.Log("RouteForAllUntilTxTun") - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))) p = r.RouteForAllUntilTxTun(myControl) r.Log("Assert the tunnel works") @@ -739,8 +755,8 @@ func TestStage1RaceRelays(t *testing.T) { assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))) r.Log("Wait for a packet from them to me") p := r.RouteForAllUntilTxTun(myControl) @@ -787,8 +803,8 @@ func TestStage1RaceRelays2(t *testing.T) { assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))) //r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone) //r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone) @@ -852,7 +868,7 @@ func TestRehandshakingRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") @@ -957,7 +973,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") @@ -1259,8 +1275,8 @@ func TestRaceRegression(t *testing.T) { //them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089 t.Log("Start both handshakes") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them"))) t.Log("Get both stage 1") myStage1ForThem := myControl.GetFromUDP(true) @@ -1476,7 +1492,7 @@ func TestGoodHandshakeUnsafeDest(t *testing.T) { theirControl.Start() t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") - myControl.InjectTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(spookyDest, 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))) t.Log("Have them consume my stage 0 packet. They have a tunnel now") theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) @@ -1504,7 +1520,7 @@ func TestGoodHandshakeUnsafeDest(t *testing.T) { assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), spookyDest, 80, 80) //reply - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman")) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnIpNet[0].Addr(), 80, spookyDest, 80, []byte("Hi from the spookyman"))) //wait for reply theirControl.WaitForType(1, 0, myControl) theirCachedPacket := myControl.GetFromTun(true) diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index 381ae897..b555fbc4 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -294,12 +294,12 @@ func deadline(t *testing.T, seconds time.Duration) doneCb { func assertTunnel(t testing.TB, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) { // Send a packet from them to me - controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B")) + controlB.InjectTunPacket(BuildTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))) bPacket := r.RouteForAllUntilTxTun(controlA) assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80) // And once more from me to them - controlA.InjectTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A")) + controlA.InjectTunPacket(BuildTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A"))) aPacket := r.RouteForAllUntilTxTun(controlB) assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80) } @@ -408,3 +408,58 @@ func testLogLevelName() string { } return "info" } + +// BuildTunUDPPacket assembles an IP+UDP packet suitable for Control.InjectTunPacket. +// Using UDP here because it's a simpler protocol. +func BuildTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) []byte { + serialize := make([]gopacket.SerializableLayer, 0) + var netLayer gopacket.NetworkLayer + if toAddr.Is6() { + if !fromAddr.Is6() { + panic("Cant send ipv6 to ipv4") + } + ip := &layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolUDP, + SrcIP: fromAddr.Unmap().AsSlice(), + DstIP: toAddr.Unmap().AsSlice(), + } + serialize = append(serialize, ip) + netLayer = ip + } else { + if !fromAddr.Is4() { + panic("Cant send ipv4 to ipv6") + } + + ip := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + SrcIP: fromAddr.Unmap().AsSlice(), + DstIP: toAddr.Unmap().AsSlice(), + } + serialize = append(serialize, ip) + netLayer = ip + } + + udp := layers.UDP{ + SrcPort: layers.UDPPort(fromPort), + DstPort: layers.UDPPort(toPort), + } + if err := udp.SetNetworkLayerForChecksum(netLayer); err != nil { + panic(err) + } + + buffer := gopacket.NewSerializeBuffer() + opt := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + serialize = append(serialize, &udp, gopacket.Payload(data)) + if err := gopacket.SerializeLayers(buffer, opt, serialize...); err != nil { + panic(err) + } + + return buffer.Bytes() +} diff --git a/e2e/router/router.go b/e2e/router/router.go index c8264ab7..72012073 100644 --- a/e2e/router/router.go +++ b/e2e/router/router.go @@ -13,6 +13,7 @@ import ( "regexp" "sort" "sync" + "sync/atomic" "testing" "time" @@ -24,6 +25,19 @@ import ( "golang.org/x/exp/maps" ) +// outNatKey is the (from, to) pair used by outNat. Comparable struct, so it works as a map key without the +// allocation cost of a string-concat key. +type outNatKey struct { + from, to netip.AddrPort +} + +// fannedPacket pairs a UDP TX packet with its source control so the router can route it after popping from +// the fan-in channel. +type fannedPacket struct { + from *nebula.Control + pkt *udp.Packet +} + type R struct { // Simple map of the ip:port registered on a control to the control // Basically a router, right? @@ -34,12 +48,28 @@ type R struct { // A last used map, if an inbound packet hit the inNat map then // all return packets should use the same last used inbound address for the outbound sender - // map[from address + ":" + to address] => ip:port to rewrite in the udp packet to receiver - outNat map[string]netip.AddrPort + outNat map[outNatKey]netip.AddrPort // A map of vpn ip to the nebula control it belongs to vpnControls map[netip.Addr]*nebula.Control + // Cached select infrastructure for RouteForAllUntilTxTun. + // The controls map is immutable after NewR so the cases are good for the test lifetime. + // We only rebuild if a different receiver is asked. + selRecvCtl *nebula.Control + selCases []reflect.SelectCase + selCtls []*nebula.Control + + // Optional fan-in mode for hot-path benchmarks: one forwarder goroutine per control drains UDP TX into udpFanIn, + // so RouteForAllUntilTxTun can do a fixed 2-way native select instead of paying reflect.Select per call. + // Off by default (would otherwise interleave with tests that use GetFromUDP directly on the same control). + // Enabled by EnableFanIn. + udpFanIn chan fannedPacket + stopFanIn chan struct{} + fanInWG sync.WaitGroup + fanInMu sync.Mutex + fanInOn atomic.Bool + ignoreFlows []ignoreFlow flow []flowEntry @@ -119,7 +149,7 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { controls: make(map[netip.AddrPort]*nebula.Control), vpnControls: make(map[netip.Addr]*nebula.Control), inNat: make(map[netip.AddrPort]*nebula.Control), - outNat: make(map[string]netip.AddrPort), + outNat: make(map[outNatKey]netip.AddrPort), flow: []flowEntry{}, ignoreFlows: []ignoreFlow{}, fn: filepath.Join("mermaid", fmt.Sprintf("%s.md", t.Name())), @@ -153,8 +183,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { case <-ctx.Done(): return case <-clockSource.C: + r.Lock() r.renderHostmaps("clock tick") r.renderFlow() + r.Unlock() } } }() @@ -180,15 +212,21 @@ func (r *R) AddRoute(ip netip.Addr, port uint16, c *nebula.Control) { // RenderFlow renders the packet flow seen up until now and stops further automatic renders from happening. func (r *R) RenderFlow() { r.cancelRender() + r.Lock() + defer r.Unlock() r.renderFlow() } // CancelFlowLogs stops flow logs from being tracked and destroys any logs already collected func (r *R) CancelFlowLogs() { r.cancelRender() + r.Lock() r.flow = nil + r.Unlock() } +// renderFlow writes the flow log to disk. Caller must hold r.Lock. renderFlow reads r.flow / r.additionalGraphs and +// the *packet pointers stashed inside, all of which are mutated under the same lock by routing paths. func (r *R) renderFlow() { if r.flow == nil { return @@ -434,68 +472,157 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) [] panic("No control for udp tx " + a.String()) } fp := r.unlockedInjectFlow(sender, c, p, false) - c.InjectUDPPacket(p) + c.InjectUDPPacket(p) // copies internally; original is ours to release fp.WasReceived() r.Unlock() + p.Release() } } } -// RouteForAllUntilTxTun will route for everyone and return when a packet is seen on receivers tun -// If the router doesn't have the nebula controller for that address, we panic +// RouteForAllUntilTxTun will route for everyone and return when a packet is seen on the receiver's tun. +// If a control's UDP TX address can't be matched to a registered control, we panic. +// +// For allocation-sensitive callers (hot-path benchmarks, in particular relay +// benches with 3+ controls), call EnableFanIn() first. func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte { + if r.fanInOn.Load() { + return r.routeFanIn(receiver) + } + return r.routeReflect(receiver) +} + +// routeFanIn is the alloc-free path used when EnableFanIn is in effect. +func (r *R) routeFanIn(receiver *nebula.Control) []byte { + tunTx := receiver.GetTunTxChan() + for { + select { + case p := <-tunTx: + r.Lock() + if r.flow != nil { + np := udp.Packet{Data: make([]byte, len(p))} + copy(np.Data, p) + r.unlockedInjectFlow(receiver, receiver, &np, true) + } + r.Unlock() + return p + case fp := <-r.udpFanIn: + r.routeUDP(fp.from, fp.pkt) + } + } +} + +// routeReflect is the default reflect.Select-based path. Pays the boxing allocation per call but doesn't interfere +// with tests that pull packets directly from controls' UDP TX channels via GetFromUDP. +func (r *R) routeReflect(receiver *nebula.Control) []byte { + sc, cm := r.selectCasesFor(receiver) + for { + x, rx, _ := reflect.Select(sc) + if x == 0 { + p := rx.Interface().([]byte) + r.Lock() + if r.flow != nil { + np := udp.Packet{Data: make([]byte, len(p))} + copy(np.Data, p) + r.unlockedInjectFlow(cm[x], cm[x], &np, true) + } + r.Unlock() + return p + } + r.routeUDP(cm[x], rx.Interface().(*udp.Packet)) + } +} + +// EnableFanIn switches RouteForAllUntilTxTun to the alloc-free fan-in path. +// One forwarder goroutine per registered control drains UDP TX into a shared channel that RouteForAllUntilTxTun selects +// on alongside the receiver's TUN TX channel. +func (r *R) EnableFanIn() { + r.fanInMu.Lock() + defer r.fanInMu.Unlock() + if r.fanInOn.Load() { + return + } + r.udpFanIn = make(chan fannedPacket, 32) + r.stopFanIn = make(chan struct{}) + for _, c := range r.controls { + r.startFanInWorker(c) + } + r.fanInOn.Store(true) + r.t.Cleanup(r.stopFanInWorkers) +} + +// startFanInWorker spawns a goroutine that drains c's UDP TX into r.udpFanIn. +func (r *R) startFanInWorker(c *nebula.Control) { + r.fanInWG.Add(1) + udpTx := c.GetUDPTxChan() + go func() { + defer r.fanInWG.Done() + for { + select { + case <-r.stopFanIn: + return + case p := <-udpTx: + select { + case <-r.stopFanIn: + p.Release() + return + case r.udpFanIn <- fannedPacket{from: c, pkt: p}: + } + } + } + }() +} + +// stopFanInWorkers signals the fan-in goroutines to exit and waits for them. +func (r *R) stopFanInWorkers() { + r.fanInMu.Lock() + wasOn := r.fanInOn.Swap(false) + r.fanInMu.Unlock() + if !wasOn { + return + } + close(r.stopFanIn) + r.fanInWG.Wait() +} + +// routeUDP forwards a UDP TX packet from the named source control to the destination control derived from p.To, +// releasing the source packet after InjectUDPPacket has copied its bytes into a fresh pool slot. +func (r *R) routeUDP(from *nebula.Control, p *udp.Packet) { + r.Lock() + defer r.Unlock() + a := from.GetUDPAddr() + c := r.getControl(a, p.To, p) + if c == nil { + panic(fmt.Sprintf("No control for udp tx %s", p.To)) + } + fp := r.unlockedInjectFlow(from, c, p, false) + c.InjectUDPPacket(p) // copies internally; original is ours to release + fp.WasReceived() + p.Release() +} + +// selectCasesFor returns the SelectCase array used by routeReflect: one slot for the receiver's TUN TX channel followed +// by one per control's UDP TX channel. Cached for the test lifetime, only rebuilt if the receiver changes. +func (r *R) selectCasesFor(receiver *nebula.Control) ([]reflect.SelectCase, []*nebula.Control) { + r.Lock() + defer r.Unlock() + if r.selRecvCtl == receiver && r.selCases != nil { + return r.selCases, r.selCtls + } sc := make([]reflect.SelectCase, len(r.controls)+1) cm := make([]*nebula.Control, len(r.controls)+1) - - i := 0 - sc[i] = reflect.SelectCase{ - Dir: reflect.SelectRecv, - Chan: reflect.ValueOf(receiver.GetTunTxChan()), - Send: reflect.Value{}, - } - cm[i] = receiver - - i++ + sc[0] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(receiver.GetTunTxChan())} + cm[0] = receiver + i := 1 for _, c := range r.controls { - sc[i] = reflect.SelectCase{ - Dir: reflect.SelectRecv, - Chan: reflect.ValueOf(c.GetUDPTxChan()), - Send: reflect.Value{}, - } - + sc[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(c.GetUDPTxChan())} cm[i] = c i++ } - - for { - x, rx, _ := reflect.Select(sc) - r.Lock() - - if x == 0 { - // we are the tun tx, we can exit - p := rx.Interface().([]byte) - np := udp.Packet{Data: make([]byte, len(p))} - copy(np.Data, p) - - r.unlockedInjectFlow(cm[x], cm[x], &np, true) - r.Unlock() - return p - - } else { - // we are a udp tx, route and continue - p := rx.Interface().(*udp.Packet) - a := cm[x].GetUDPAddr() - c := r.getControl(a, p.To, p) - if c == nil { - r.Unlock() - panic(fmt.Sprintf("No control for udp tx %s", p.To)) - } - fp := r.unlockedInjectFlow(cm[x], c, p, false) - c.InjectUDPPacket(p) - fp.WasReceived() - } - r.Unlock() - } + r.selRecvCtl = receiver + r.selCases = sc + r.selCtls = cm + return sc, cm } // RouteExitFunc will call the whatDo func with each udp packet from sender. @@ -522,6 +649,7 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) { switch e { case ExitNow: r.Unlock() + p.Release() return case RouteAndExit: @@ -529,6 +657,7 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) { receiver.InjectUDPPacket(p) fp.WasReceived() r.Unlock() + p.Release() return case KeepRouting: @@ -541,6 +670,7 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) { } r.Unlock() + p.Release() } } @@ -641,6 +771,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) { switch e { case ExitNow: r.Unlock() + p.Release() return case RouteAndExit: @@ -648,6 +779,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) { receiver.InjectUDPPacket(p) fp.WasReceived() r.Unlock() + p.Release() return case KeepRouting: @@ -659,6 +791,7 @@ func (r *R) RouteForAllExitFunc(whatDo ExitFunc) { panic(fmt.Sprintf("Unknown exitFunc return: %v", e)) } r.Unlock() + p.Release() } } @@ -702,19 +835,20 @@ func (r *R) FlushAll() { } receiver.InjectUDPPacket(p) r.Unlock() + p.Release() } } // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change // This is an internal router function, the caller must hold the lock func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.Control { - if newAddr, ok := r.outNat[fromAddr.String()+":"+toAddr.String()]; ok { + if newAddr, ok := r.outNat[outNatKey{from: fromAddr, to: toAddr}]; ok { p.From = newAddr } c, ok := r.inNat[toAddr] if ok { - r.outNat[c.GetUDPAddr().String()+":"+fromAddr.String()] = toAddr + r.outNat[outNatKey{from: c.GetUDPAddr(), to: fromAddr}] = toAddr return c } diff --git a/e2e/tunnels_test.go b/e2e/tunnels_test.go index 63c655f3..697f25af 100644 --- a/e2e/tunnels_test.go +++ b/e2e/tunnels_test.go @@ -355,14 +355,14 @@ func TestCrossStackRelaysWork(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me")) + myControl.InjectTunPacket(BuildTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80) t.Log("reply?") - theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them")) + theirControl.InjectTunPacket(BuildTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))) p = r.RouteForAllUntilTxTun(myControl) assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80) diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index b2c2a0ea..8acd83f0 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -15,6 +15,7 @@ import ( "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/routing" + "github.com/slackhq/nebula/udp" ) type TestTun struct { @@ -54,9 +55,12 @@ func newTunFromFd(_ *config.C, _ *slog.Logger, _ int, _ []netip.Prefix) (*TestTu return nil, fmt.Errorf("newTunFromFd not supported") } -// Send will place a byte array onto the receive queue for nebula to consume +// Send will place a byte array onto the receive queue for nebula to consume. // These are unencrypted ip layer frames destined for another nebula node. -// packets should exit the udp side, capture them with udpConn.Get +// packets should exit the udp side, capture them with udpConn.Get. +// +// Send copies the input via the freelist, so the caller is free to mutate +// or reuse it after the call returns. func (t *TestTun) Send(packet []byte) { if t.closed.Load() { return @@ -65,7 +69,9 @@ func (t *TestTun) Send(packet []byte) { if t.l.Enabled(context.Background(), slog.LevelDebug) { t.l.Debug("Tun receiving injected packet", "dataLen", len(packet)) } - t.rxPackets <- packet + buf := acquireTunBuf(len(packet)) + copy(buf, packet) + t.rxPackets <- buf } // Get will pull an unencrypted ip layer frame from the transmit queue @@ -110,12 +116,44 @@ func (t *TestTun) Write(b []byte) (n int, err error) { return 0, io.ErrClosedPipe } - packet := make([]byte, len(b), len(b)) + packet := acquireTunBuf(len(b)) copy(packet, b) t.TxPackets <- packet return len(b), nil } +// ReleaseTunBuf returns a slice from TxPackets to the harness freelist, don't use the bytes after the call. +// Channel-backed instead of sync.Pool because putting a []byte in a sync.Pool escapes the slice header to heap. +func ReleaseTunBuf(b []byte) { + if b == nil { + return + } + select { + case tunBufFreelist <- b: + default: + // Freelist full; drop the buffer for the GC. + } +} + +// tunBufFreelist retains the backing arrays for TestTun.Write so steady-state allocation drops to zero once the +// freelist has saturated for the current MTU. +var tunBufFreelist = make(chan []byte, 64) + +func acquireTunBuf(n int) []byte { + var b []byte + select { + case b = <-tunBufFreelist: + default: + b = make([]byte, 0, udp.MTU) + } + if cap(b) < n { + b = make([]byte, n) + } else { + b = b[:n] + } + return b +} + func (t *TestTun) Close() error { if t.closed.CompareAndSwap(false, true) { close(t.rxPackets) @@ -129,8 +167,14 @@ func (t *TestTun) Read(b []byte) (int, error) { if !ok { return 0, os.ErrClosed } + n := len(p) copy(b, p) - return len(p), nil + // Send always pushes a freelist-acquired slice, return it once we've copied the bytes into the caller's buffer. + select { + case tunBufFreelist <- p: + default: + } + return n, nil } func (t *TestTun) SupportsMultiqueue() bool { diff --git a/udp/udp_tester.go b/udp/udp_tester.go index fcd0967c..f872e32a 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -21,17 +21,48 @@ type Packet struct { Data []byte } +// Copy returns a fresh *Packet (from the freelist) with a duplicate Data buffer. func (u *Packet) Copy() *Packet { - n := &Packet{ - To: u.To, - From: u.From, - Data: make([]byte, len(u.Data)), + n := acquirePacket() + n.To = u.To + n.From = u.From + if cap(n.Data) < len(u.Data) { + n.Data = make([]byte, len(u.Data)) + } else { + n.Data = n.Data[:len(u.Data)] } - copy(n.Data, u.Data) return n } +// Release returns p to the harness packet freelist. +// Callers that pull a *Packet from Get / TxPackets must Release when done. +// Channel-backed instead of sync.Pool because sync.Pool's per-P caches drain badly under cross-goroutine Get/Put, +// and putting a []byte in a Pool escapes the slice header to heap. +func (p *Packet) Release() { + if p == nil { + return + } + p.Data = p.Data[:0] + select { + case packetFreelist <- p: + default: + // Freelist full; drop the *Packet for the GC. + } +} + +// packetFreelist retains *Packet structs (and their backing Data arrays) so steady-state allocation drops to zero. +var packetFreelist = make(chan *Packet, 64) + +func acquirePacket() *Packet { + select { + case p := <-packetFreelist: + return p + default: + return &Packet{} + } +} + type TesterConn struct { Addr netip.AddrPort @@ -64,13 +95,15 @@ func NewListener(l *slog.Logger, ip netip.Addr, port int, _ bool, _ int) (Conn, // this is an encrypted packet or a handshake message in most cases // packets were transmitted from another nebula node, you can send them with Tun.Send func (u *TesterConn) Send(packet *Packet) { - h := &header.H{} - if err := h.Parse(packet.Data); err != nil { - panic(err) - } if u.l.Enabled(context.Background(), slog.LevelDebug) { + // Parse the header only under debug logging, otherwise the + // allocation would show up in every Send call. + var h header.H + if err := h.Parse(packet.Data); err != nil { + panic(err) + } u.l.Debug("UDP receiving injected packet", - "header", h, + "header", &h, "udpAddr", packet.From, "dataLen", len(packet.Data), ) @@ -107,15 +140,18 @@ func (u *TesterConn) Get(block bool) *Packet { //********************************************************************************************************************// func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { - p := &Packet{ - Data: make([]byte, len(b), len(b)), - From: u.Addr, - To: addr, + p := acquirePacket() + if cap(p.Data) < len(b) { + p.Data = make([]byte, len(b)) + } else { + p.Data = p.Data[:len(b)] } - copy(p.Data, b) + p.From = u.Addr + p.To = addr select { case <-u.done: + p.Release() return io.ErrClosedPipe case u.TxPackets <- p: return nil @@ -129,6 +165,7 @@ func (u *TesterConn) ListenOut(r EncReader) error { return os.ErrClosed case p := <-u.RxPackets: r(p.From, p.Data) + p.Release() } } } From ff91c37529509ffb26137bff4d4ded9eac9113a6 Mon Sep 17 00:00:00 2001 From: Jack Doan Date: Wed, 6 May 2026 10:22:26 -0500 Subject: [PATCH 03/31] switch Bits to a packed u64 (#1705) --- bits.go | 209 +++++++++++++++++++++----- bits_test.go | 407 +++++++++++++++++++++++++++++++++++---------------- 2 files changed, 452 insertions(+), 164 deletions(-) diff --git a/bits.go b/bits.go index 5c8f902b..15bafd87 100644 --- a/bits.go +++ b/bits.go @@ -2,24 +2,42 @@ package nebula import ( "context" + "fmt" "log/slog" + "math" + mathbits "math/bits" "github.com/rcrowley/go-metrics" ) +const bitsPerWord = 64 + +// Bits is a sliding-window anti-replay tracker. The window is stored as a +// circular bitmap packed into uint64 words (8x denser than a []bool), so a +// length-N window costs N/8 bytes. length must be a power of two. type Bits struct { length uint64 + lengthMask uint64 current uint64 - bits []bool + bits []uint64 lostCounter metrics.Counter dupeCounter metrics.Counter outOfWindowCounter metrics.Counter } -func NewBits(bits uint64) *Bits { +func NewBits(length uint64) *Bits { + if length == 0 || length&(length-1) != 0 { + panic(fmt.Sprintf("Bits length must be a power of two, got %d", length)) + } + + nWords := length / bitsPerWord + if nWords == 0 { + nWords = 1 + } b := &Bits{ - length: bits, - bits: make([]bool, bits, bits), + length: length, + lengthMask: length - 1, + bits: make([]uint64, nWords), current: 0, lostCounter: metrics.GetOrRegisterCounter("network.packets.lost", nil), dupeCounter: metrics.GetOrRegisterCounter("network.packets.duplicate", nil), @@ -27,71 +45,194 @@ func NewBits(bits uint64) *Bits { } // There is no counter value 0, mark it to avoid counting a lost packet later. - b.bits[0] = true - b.current = 0 + b.bits[0] = 1 return b } +func (b *Bits) get(i uint64) bool { + pos := i & b.lengthMask + //bit-shifting by 6 because i is a bit index, not a u64 index, and we need to find the u64 without bit in it + return b.bits[pos>>6]&(uint64(1)<<(pos&63)) != 0 +} + +func (b *Bits) set(i uint64) { + pos := i & b.lengthMask + b.bits[pos>>6] |= uint64(1) << (pos & 63) +} + +// clearRange clears `count` bits starting at circular position `startPos` +// (already masked to [0, length)) and returns how many of them were set +// before the clear. count must be in [1, length]. +func (b *Bits) clearRange(startPos, count uint64) uint64 { + wasSet := uint64(0) + if count >= b.length { + for _, w := range b.bits { + wasSet += uint64(mathbits.OnesCount64(w)) + } + clear(b.bits) + return wasSet + } + + pos := startPos + remaining := count + + // handle the potential partial word before pos becomes u64 aligned + word := pos >> 6 + bit := pos & 63 + take := uint64(64) - bit + if take > remaining { + take = remaining + } + if take > b.length-pos { + take = b.length - pos + } + var mask uint64 + if take == 64 { + mask = math.MaxUint64 + } else { + mask = ((uint64(1) << take) - 1) << bit + } + wasSet += uint64(mathbits.OnesCount64(b.bits[word] & mask)) + b.bits[word] &^= mask + remaining -= take + pos = (pos + take) & b.lengthMask + + // Clear whole words, keeping track of the number of set bits + for remaining >= 64 { + word = pos >> 6 + wasSet += uint64(mathbits.OnesCount64(b.bits[word])) + b.bits[word] = 0 + remaining -= 64 + pos = (pos + 64) & b.lengthMask + } + + // Clear the remaining partial word + if remaining > 0 { + word = pos >> 6 + mask = (uint64(1) << remaining) - 1 + wasSet += uint64(mathbits.OnesCount64(b.bits[word] & mask)) + b.bits[word] &^= mask + } + + return wasSet +} + +func (b *Bits) strictlyWithinWindow(i uint64) bool { + // Handle the case where the window hasn't slid yet. This avoids u64 underflow. + inWarmup := b.current < b.length + if i < b.length && inWarmup { + return true + } + + // Next, if the packet is in-window, see if we've seen it before + if i > b.current-b.length { + return true + } + return false //not within window! +} + +// Check returns true if i is within (or way out in front of) the window, and not a replay func (b *Bits) Check(l *slog.Logger, i uint64) bool { // If i is the next number, return true. if i > b.current { return true } - // If i is within the window, check if it's been set already. - if i > b.current-b.length || i < b.length && b.current < b.length { - return !b.bits[i%b.length] + if b.strictlyWithinWindow(i) { + return !b.get(i) } // Not within the window if l.Enabled(context.Background(), slog.LevelDebug) { - l.Debug("rejected a packet (top)", - "current", b.current, - "incoming", i, - ) + l.Debug("rejected a packet (top)", "current", b.current, "incoming", i) } return false } +// Update has three branches: +// - i == b.current+1: fast path; advance the cursor by one and lose-count +// the slot we just stomped (only past warmup; see the i > b.length guard +// below). +// - i > b.current+1: jump path; clear all slots between current and i +// (or up to a full window's worth, whichever is smaller) via clearRange, +// then mark i. Two arms here: a warmup arm that handles the very first +// window before the cursor has slid, and a steady-state arm that treats +// every cleared empty slot as a lost packet. +// - i <= b.current: in-window check for duplicates; out-of-window otherwise. +// +// NewBits seeds bits[0]=1 so counter 0 looks "received" — Update never +// clears that marker during warmup (clearRange skips position 0 when +// startPos=1), and once b.current >= b.length the marker is no longer +// consulted. The marker prevents a fictitious "lost" hit on the first real +// counter. func (b *Bits) Update(l *slog.Logger, i uint64) bool { - // If i is the next number, return true and update current. + // Fast path: i is the next expected counter. Split out so the function + // stays small and avoids paying for the slow paths' slog argument-build + // stack frame on every call. The bit read/test/write is inlined to + // touch the backing word once. if i == b.current+1 { - // Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter - // The very first window can only be tracked as lost once we are on the 2nd window or greater - if b.bits[i%b.length] == false && i > b.length { + pos := i & b.lengthMask + word := pos >> 6 + mask := uint64(1) << (pos & 63) + w := b.bits[word] + if i > b.length && w&mask == 0 { b.lostCounter.Inc(1) } - b.bits[i%b.length] = true + b.bits[word] = w | mask b.current = i return true } + return b.updateSlow(l, i) +} +// updateSlow handles jumps, in-window backfill, dupes, and out-of-window. +func (b *Bits) updateSlow(l *slog.Logger, i uint64) bool { // If i is a jump, adjust the window, record lost, update current, and return true if i > b.current { - lost := int64(0) - // Zero out the bits between the current and the new counter value, limited by the window size, - // since the window is shifting - for n := b.current + 1; n <= min(i, b.current+b.length); n++ { - if b.bits[n%b.length] == false && n > b.length { - lost++ + end := i + if end > b.current+b.length { + end = b.current + b.length + } + count := end - b.current + startPos := (b.current + 1) & b.lengthMask + + var lost int64 + if b.current >= b.length { + // Steady state: every cleared slot is past warmup, so any unset + // bit we evict is a lost packet from the previous cycle. + wasSet := b.clearRange(startPos, count) + lost = int64(count) - int64(wasSet) + } else { + // Warmup (the very first window). Some cleared slots represent + // packets <= length where eviction is not "lost" in the usual + // sense. This branch is taken at most once per connection so we + // don't bother optimizing it. + for n := b.current + 1; n <= end; n++ { + if !b.get(n) && n > b.length { + lost++ + } } - b.bits[n%b.length] = false + b.clearRange(startPos, count) } - // Only record any skipped packets as a result of the window moving further than the window length - // Any loss within the new window will be accounted for in future calls - lost += max(0, int64(i-b.current-b.length)) + // Anything past the new window can never be backfilled, so it's lost. + if i > b.current+b.length { + lost += int64(i - b.current - b.length) + } b.lostCounter.Inc(lost) - b.bits[i%b.length] = true + b.set(i) b.current = i return true } - // If i is within the current window but below the current counter, - // Check to see if it's a duplicate - if i > b.current-b.length || i < b.length && b.current < b.length { - if b.current == i || b.bits[i%b.length] == true { + // If i is within the current window but below the current counter, check to see if it's a duplicate + if b.strictlyWithinWindow(i) { + pos := i & b.lengthMask + word := pos >> 6 + mask := uint64(1) << (pos & 63) + w := b.bits[word] + if b.current == i || w&mask != 0 { if l.Enabled(context.Background(), slog.LevelDebug) { l.Debug("Receive window", "accepted", false, @@ -104,7 +245,7 @@ func (b *Bits) Update(l *slog.Logger, i uint64) bool { return false } - b.bits[i%b.length] = true + b.bits[word] = w | mask return true } diff --git a/bits_test.go b/bits_test.go index 3504cefa..da44c92a 100644 --- a/bits_test.go +++ b/bits_test.go @@ -7,61 +7,79 @@ import ( "github.com/stretchr/testify/assert" ) +// snapshot returns the bitmap as a []bool of length b.length, for readable +// test assertions against the now-packed []uint64 storage. +func (b *Bits) snapshot() []bool { + out := make([]bool, b.length) + for i := uint64(0); i < b.length; i++ { + out[i] = b.get(i) + } + return out +} + +func TestBitsRequiresPowerOfTwo(t *testing.T) { + assert.Panics(t, func() { NewBits(10) }) + assert.Panics(t, func() { NewBits(0) }) + assert.NotPanics(t, func() { NewBits(1) }) + assert.NotPanics(t, func() { NewBits(16) }) + assert.NotPanics(t, func() { NewBits(1024) }) + assert.NotPanics(t, func() { NewBits(16384) }) +} + func TestBits(t *testing.T) { l := test.NewLogger() - b := NewBits(10) - - // make sure it is the right size - assert.Len(t, b.bits, 10) + b := NewBits(16) + assert.EqualValues(t, 16, b.length) // This is initialized to zero - receive one. This should work. assert.True(t, b.Check(l, 1)) assert.True(t, b.Update(l, 1)) assert.EqualValues(t, 1, b.current) - g := []bool{true, true, false, false, false, false, false, false, false, false} - assert.Equal(t, g, b.bits) + g := []bool{true, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false} + assert.Equal(t, g, b.snapshot()) // Receive two assert.True(t, b.Check(l, 2)) assert.True(t, b.Update(l, 2)) assert.EqualValues(t, 2, b.current) - g = []bool{true, true, true, false, false, false, false, false, false, false} - assert.Equal(t, g, b.bits) + g = []bool{true, true, true, false, false, false, false, false, false, false, false, false, false, false, false, false} + assert.Equal(t, g, b.snapshot()) // Receive two again - it will fail assert.False(t, b.Check(l, 2)) assert.False(t, b.Update(l, 2)) assert.EqualValues(t, 2, b.current) - // Jump ahead to 15, which should clear everything and set the 6th element - assert.True(t, b.Check(l, 15)) - assert.True(t, b.Update(l, 15)) - assert.EqualValues(t, 15, b.current) - g = []bool{false, false, false, false, false, true, false, false, false, false} - assert.Equal(t, g, b.bits) + // Jump ahead to 25, which clears the window and sets slot 25%16 = 9. + assert.True(t, b.Check(l, 25)) + assert.True(t, b.Update(l, 25)) + assert.EqualValues(t, 25, b.current) + g = []bool{false, false, false, false, false, false, false, false, false, true, false, false, false, false, false, false} + assert.Equal(t, g, b.snapshot()) - // Mark 14, which is allowed because it is in the window - assert.True(t, b.Check(l, 14)) - assert.True(t, b.Update(l, 14)) - assert.EqualValues(t, 15, b.current) - g = []bool{false, false, false, false, true, true, false, false, false, false} - assert.Equal(t, g, b.bits) + // Mark 24, which is in window (current 25, length 16, window covers [10,25]). + assert.True(t, b.Check(l, 24)) + assert.True(t, b.Update(l, 24)) + assert.EqualValues(t, 25, b.current) + g = []bool{false, false, false, false, false, false, false, false, true, true, false, false, false, false, false, false} + assert.Equal(t, g, b.snapshot()) - // Mark 5, which is not allowed because it is not in the window + // Mark 5, not allowed because 5 <= current-length (25-16=9). assert.False(t, b.Check(l, 5)) assert.False(t, b.Update(l, 5)) - assert.EqualValues(t, 15, b.current) - g = []bool{false, false, false, false, true, true, false, false, false, false} - assert.Equal(t, g, b.bits) + assert.EqualValues(t, 25, b.current) + g = []bool{false, false, false, false, false, false, false, false, true, true, false, false, false, false, false, false} + assert.Equal(t, g, b.snapshot()) - // make sure we handle wrapping around once to the current position - b = NewBits(10) + // Make sure we handle wrapping around once to the same slot. With + // length=16, packets 1 and 17 share slot 1. + b = NewBits(16) assert.True(t, b.Update(l, 1)) - assert.True(t, b.Update(l, 11)) - assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false}, b.bits) + assert.True(t, b.Update(l, 17)) + assert.Equal(t, []bool{false, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false}, b.snapshot()) // Walk through a few windows in order - b = NewBits(10) + b = NewBits(16) for i := uint64(1); i <= 100; i++ { assert.True(t, b.Check(l, i), "Error while checking %v", i) assert.True(t, b.Update(l, i), "Error while updating %v", i) @@ -72,24 +90,31 @@ func TestBits(t *testing.T) { func TestBitsLargeJumps(t *testing.T) { l := test.NewLogger() - b := NewBits(10) + + // length=16. Update(55) from current=0: + // warmup, per-bit loop sees no n>16 with unset bits (slot 0 was set by + // NewBits and gets re-evaluated when n=16; n=16 is not strictly > 16), + // so the loop contributes 0. The jump exceeds the window so we record + // 55 - 0 - 16 = 39 packets fell out the back. + b := NewBits(16) b.lostCounter.Clear() + assert.True(t, b.Update(l, 55)) + assert.Equal(t, int64(39), b.lostCounter.Count()) - b = NewBits(10) - b.lostCounter.Clear() - assert.True(t, b.Update(l, 55)) // We saw packet 55 and can still track 45,46,47,48,49,50,51,52,53,54 - assert.Equal(t, int64(45), b.lostCounter.Count()) + // Update(100): clears 16 slots starting at slot 56%16=8. Only slot 7 (for + // packet 55) was set, so 16 - 1 = 15 evicted slots had unset bits. + // Plus 100 - 55 - 16 = 29 packets fell past the window. Total 44. + assert.True(t, b.Update(l, 100)) + assert.Equal(t, int64(39+44), b.lostCounter.Count()) - assert.True(t, b.Update(l, 100)) // We saw packet 55 and 100 and can still track 90,91,92,93,94,95,96,97,98,99 - assert.Equal(t, int64(89), b.lostCounter.Count()) - - assert.True(t, b.Update(l, 200)) // We saw packet 55, 100, and 200 and can still track 190,191,192,193,194,195,196,197,198,199 - assert.Equal(t, int64(188), b.lostCounter.Count()) + // Update(200): same shape: 16 - 1 = 15 evicted unset, plus 200 - 100 - 16 = 84 past window. Total 99. + assert.True(t, b.Update(l, 200)) + assert.Equal(t, int64(39+44+99), b.lostCounter.Count()) } func TestBitsDupeCounter(t *testing.T) { l := test.NewLogger() - b := NewBits(10) + b := NewBits(16) b.lostCounter.Clear() b.dupeCounter.Clear() b.outOfWindowCounter.Clear() @@ -114,120 +139,117 @@ func TestBitsDupeCounter(t *testing.T) { func TestBitsOutOfWindowCounter(t *testing.T) { l := test.NewLogger() - b := NewBits(10) + b := NewBits(16) b.lostCounter.Clear() b.dupeCounter.Clear() b.outOfWindowCounter.Clear() + // Jump to 20 (warmup branch + 4 past-window packets). assert.True(t, b.Update(l, 20)) assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) - assert.True(t, b.Update(l, 21)) - assert.True(t, b.Update(l, 22)) - assert.True(t, b.Update(l, 23)) - assert.True(t, b.Update(l, 24)) - assert.True(t, b.Update(l, 25)) - assert.True(t, b.Update(l, 26)) - assert.True(t, b.Update(l, 27)) - assert.True(t, b.Update(l, 28)) - assert.True(t, b.Update(l, 29)) + // 9 single-step advances, each evicts a slot whose bit was cleared during + // the jump above and whose value was never seen, so each contributes 1 + // to lostCounter. + for n := uint64(21); n <= 29; n++ { + assert.True(t, b.Update(l, n)) + } assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) + // 0 is below current-length (29-16=13) so it falls outside the window. assert.False(t, b.Update(l, 0)) assert.Equal(t, int64(1), b.outOfWindowCounter.Count()) - assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost + // 4 from the Update(20) jump + 9 from 21..29. + assert.Equal(t, int64(13), b.lostCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(1), b.outOfWindowCounter.Count()) } func TestBitsLostCounter(t *testing.T) { l := test.NewLogger() - b := NewBits(10) + b := NewBits(16) b.lostCounter.Clear() b.dupeCounter.Clear() b.outOfWindowCounter.Clear() - assert.True(t, b.Update(l, 20)) - assert.True(t, b.Update(l, 21)) - assert.True(t, b.Update(l, 22)) - assert.True(t, b.Update(l, 23)) - assert.True(t, b.Update(l, 24)) - assert.True(t, b.Update(l, 25)) - assert.True(t, b.Update(l, 26)) - assert.True(t, b.Update(l, 27)) - assert.True(t, b.Update(l, 28)) - assert.True(t, b.Update(l, 29)) - assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost + // Walk 20..29 like the original, just with a bigger window. Same + // reasoning as TestBitsOutOfWindowCounter: 4 past-window from Update(20), + // then 9 more from the unit advances. + for n := uint64(20); n <= 29; n++ { + assert.True(t, b.Update(l, n)) + } + assert.Equal(t, int64(13), b.lostCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) - b = NewBits(10) + b = NewBits(16) b.lostCounter.Clear() b.dupeCounter.Clear() b.outOfWindowCounter.Clear() - assert.True(t, b.Update(l, 9)) - assert.Equal(t, int64(0), b.lostCounter.Count()) - // 10 will set 0 index, 0 was already set, no lost packets - assert.True(t, b.Update(l, 10)) - assert.Equal(t, int64(0), b.lostCounter.Count()) - // 11 will set 1 index, 1 was missed, we should see 1 packet lost - assert.True(t, b.Update(l, 11)) - assert.Equal(t, int64(1), b.lostCounter.Count()) - // Now let's fill in the window, should end up with 8 lost packets - assert.True(t, b.Update(l, 12)) - assert.True(t, b.Update(l, 13)) - assert.True(t, b.Update(l, 14)) + // Update(15) clears the warmup window (no lost), sets slot 15. assert.True(t, b.Update(l, 15)) + assert.Equal(t, int64(0), b.lostCounter.Count()) + + // Update(16): slot 0 was already set (NewBits seeded it), and 16 is not + // strictly > length, so nothing is recorded as lost. assert.True(t, b.Update(l, 16)) + assert.Equal(t, int64(0), b.lostCounter.Count()) + + // Update(17): we jumped straight from 0 to 15, so slot 1 was cleared + // (and never re-set). 17 > 16 is past warmup, so packet 1 is recorded lost. assert.True(t, b.Update(l, 17)) - assert.True(t, b.Update(l, 18)) - assert.True(t, b.Update(l, 19)) - assert.Equal(t, int64(8), b.lostCounter.Count()) + assert.Equal(t, int64(1), b.lostCounter.Count()) - // Jump ahead by a window size - assert.True(t, b.Update(l, 29)) - assert.Equal(t, int64(8), b.lostCounter.Count()) - // Now lets walk ahead normally through the window, the missed packets should fill in - assert.True(t, b.Update(l, 30)) - assert.True(t, b.Update(l, 31)) - assert.True(t, b.Update(l, 32)) - assert.True(t, b.Update(l, 33)) - assert.True(t, b.Update(l, 34)) - assert.True(t, b.Update(l, 35)) - assert.True(t, b.Update(l, 36)) - assert.True(t, b.Update(l, 37)) - assert.True(t, b.Update(l, 38)) - // 39 packets tracked, 22 seen, 17 lost - assert.Equal(t, int64(17), b.lostCounter.Count()) + // Fill in 18..30 in single steps. Each i evicts slot i%16. Slots 2..14 + // were all cleared during Update(15), and we never re-set any of them, + // so each i in 18..30 is a fresh lost packet — 13 more. + for n := uint64(18); n <= 30; n++ { + assert.True(t, b.Update(l, n)) + } + assert.Equal(t, int64(14), b.lostCounter.Count()) - // Jump ahead by 2 windows, should have recording 1 full window missing - assert.True(t, b.Update(l, 58)) - assert.Equal(t, int64(27), b.lostCounter.Count()) - // Now lets walk ahead normally through the window, the missed packets should fill in from this window - assert.True(t, b.Update(l, 59)) - assert.True(t, b.Update(l, 60)) - assert.True(t, b.Update(l, 61)) - assert.True(t, b.Update(l, 62)) - assert.True(t, b.Update(l, 63)) - assert.True(t, b.Update(l, 64)) - assert.True(t, b.Update(l, 65)) - assert.True(t, b.Update(l, 66)) - assert.True(t, b.Update(l, 67)) - // 68 packets tracked, 32 seen, 36 missed - assert.Equal(t, int64(36), b.lostCounter.Count()) + // Jump ahead by exactly one window size. + assert.True(t, b.Update(l, 46)) + // end = min(46, 30+16) = 46, count = 16, all slots cleared. Before the + // jump every slot 0..15 had been set (Update(15), (16), (17), 18..30), + // so wasSet=16 and 46 == current+length means no past-window slack: + // lost contribution = 0. + assert.Equal(t, int64(14), b.lostCounter.Count()) + + // Walk 47..55. The Update(46) jump cleared every slot, so only slot 14 + // (for packet 46) is set when we start. Each subsequent unit step lands + // on a slot that was cleared and is past warmup, so it counts as lost. + // 9 more = 23. + for n := uint64(47); n <= 55; n++ { + assert.True(t, b.Update(l, n)) + } + assert.Equal(t, int64(23), b.lostCounter.Count()) + + // Jump ahead by two windows: clears the window plus past-window loss. + assert.True(t, b.Update(l, 87)) + // current=55, length=16. end = min(87, 71) = 71. count=16, all slots + // cleared. Slots set before the clear are slots 14,15,0..7 (10 total). + // Lost from clear = 16 - 10 = 6. Past window: 87 - 55 - 16 = 16. +22. + assert.Equal(t, int64(45), b.lostCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) } func TestBitsLostCounterIssue1(t *testing.T) { l := test.NewLogger() - b := NewBits(10) + b := NewBits(16) b.lostCounter.Clear() b.dupeCounter.Clear() b.outOfWindowCounter.Clear() + // Receive 4, backfill 1, then 9, 2, 3, 5, 6, 7 (skip 8), 10, 11, 14. + // Then jump to 25 — slot 25%16=9 is being evicted, but it had been set + // (we received packet 9), so no spurious lost increment. The original + // regression was about double-counting a missing packet when its slot + // got cleared on a jump. With the jump path now using clearRange's + // word-level wasSet count, the same semantics hold. assert.True(t, b.Update(l, 4)) assert.Equal(t, int64(0), b.lostCounter.Count()) assert.True(t, b.Update(l, 1)) @@ -244,7 +266,7 @@ func TestBitsLostCounterIssue1(t *testing.T) { assert.Equal(t, int64(0), b.lostCounter.Count()) assert.True(t, b.Update(l, 7)) assert.Equal(t, int64(0), b.lostCounter.Count()) - // assert.True(t, b.Update(l, 8)) + // Skip packet 8. assert.True(t, b.Update(l, 10)) assert.Equal(t, int64(0), b.lostCounter.Count()) assert.True(t, b.Update(l, 11)) @@ -252,9 +274,23 @@ func TestBitsLostCounterIssue1(t *testing.T) { assert.True(t, b.Update(l, 14)) assert.Equal(t, int64(0), b.lostCounter.Count()) - // Issue seems to be here, we reset missing packet 8 to false here and don't increment the lost counter - assert.True(t, b.Update(l, 19)) + + // Jump to 25. With length=16, slot 25%16=9 corresponds to packet 9 + // (which we DID receive), so its bit is set and no lost++ from that + // eviction. The trace below shows the only loss is packet 8. + assert.True(t, b.Update(l, 25)) + // current was 14, i=25. end=min(25,30)=25. count=11. startPos=15. + // steady? current=14<16, so warmup branch: per-bit n=15..25, count those + // with !get(n) AND n>16. n=17..25 are >16. Among slots 17%16=1..25%16=9 + // did we set slots 1..9 (packets 1..9)? Yes for all but slot 8 (packet 8 + // was skipped). n=24 maps to slot 8 which is FALSE → lost++. All other + // n in 17..25 map to slots that are set. n=16 is not strictly > 16. So + // lost = 1. assert.Equal(t, int64(1), b.lostCounter.Count()) + + // Fill in 12, 13, 15, 16. Each is below current=25 (in-window). 16 must + // recheck slot 0 — it was set by NewBits and then cleared by the + // Update(25) jump, so 16 backfills cleanly. assert.True(t, b.Update(l, 12)) assert.Equal(t, int64(1), b.lostCounter.Count()) assert.True(t, b.Update(l, 13)) @@ -263,29 +299,140 @@ func TestBitsLostCounterIssue1(t *testing.T) { assert.Equal(t, int64(1), b.lostCounter.Count()) assert.True(t, b.Update(l, 16)) assert.Equal(t, int64(1), b.lostCounter.Count()) - assert.True(t, b.Update(l, 17)) - assert.Equal(t, int64(1), b.lostCounter.Count()) - assert.True(t, b.Update(l, 18)) - assert.Equal(t, int64(1), b.lostCounter.Count()) - assert.True(t, b.Update(l, 20)) - assert.Equal(t, int64(1), b.lostCounter.Count()) - assert.True(t, b.Update(l, 21)) - // We missed packet 8 above + // We missed packet 8 above and that loss is still recorded once, never + // double-counted, never zeroed. assert.Equal(t, int64(1), b.lostCounter.Count()) assert.Equal(t, int64(0), b.dupeCounter.Count()) assert.Equal(t, int64(0), b.outOfWindowCounter.Count()) } -func BenchmarkBits(b *testing.B) { - z := NewBits(10) - for n := 0; n < b.N; n++ { - for i := range z.bits { - z.bits[i] = true - } - for i := range z.bits { - z.bits[i] = false - } +// TestBitsWarmupOvershoot exercises the jump path's warmup arm with an +// overshoot past one full window. NewBits leaves current=0 with only slot 0 +// "set" by the marker. Jumping straight to length+k must (a) clear every +// slot the jump straddles, (b) count only past-window slack (not the +// in-window slots, which never had a "lost" tenant during warmup), and +// (c) leave the cursor at the new counter so subsequent unit advances +// count from steady state. The marker bit at slot 0 is irrelevant once +// current >= length. +func TestBitsWarmupOvershoot(t *testing.T) { + l := test.NewLogger() + b := NewBits(16) + b.lostCounter.Clear() + // Jump from current=0 to i=20 (length=16, overshoot=4). + // Warmup arm: counts slots in [1..16] where bit unset and n>length. + // Only n=16 was unset and >length: but slot 16%16=0 is the marker, + // so b.get(16) reads bits[0]=1 and skips. Result: 0 lost from the loop. + // Past-window: i - current - length = 20 - 0 - 16 = 4 lost. + assert.True(t, b.Update(l, 20)) + assert.Equal(t, int64(4), b.lostCounter.Count()) + assert.Equal(t, uint64(20), b.current) + + // Steady state now (current=20 >= length=16). Unit advance to 21 + // stomps slot 21%16=5, which was cleared by the jump and not reset, + // so this is +1 lost. + assert.True(t, b.Update(l, 21)) + assert.Equal(t, int64(5), b.lostCounter.Count()) +} + +// TestBitsCheckAcrossWarmupBoundary pins the underflow trick in Check's +// in-window clause. While in warmup, b.current-b.length underflows uint64 +// to a huge value so the first OR-clause is always false; the second +// clause (i < length && current < length) carries the in-window check. +// Once current >= length the regimes flip cleanly. +func TestBitsCheckAcrossWarmupBoundary(t *testing.T) { + l := test.NewLogger() + b := NewBits(16) + + // Warmup: current=0. Check(0) must read the marker (set) and return false. + assert.False(t, b.Check(l, 0), "marker slot should look already-received") + // Warmup: any 0 < i < length is in-window and unset → accepted. + for i := uint64(1); i < 16; i++ { + assert.True(t, b.Check(l, i), "warmup in-window i=%d should be accepted", i) + } + // Warmup: i >= length but > current is "next number" so accepted. + assert.True(t, b.Check(l, 16)) + assert.True(t, b.Check(l, 1_000_000)) + + // Cross into steady state. + assert.True(t, b.Update(l, 100)) + // Now current=100, length=16. In-window range is [85..100]. + // 84 is just outside: the underflow clause activates; 84 > 100-16=84 is false. + // And the warmup clause is false (current >= length). So out of window. + assert.False(t, b.Check(l, 84)) + // 85 sits at the boundary. 85 > 84 is true → in window, unset → accept. + assert.True(t, b.Check(l, 85)) + // 100 is current itself; not strictly greater, in-window, but already set. + assert.False(t, b.Check(l, 100)) + // Way out: clearly out of window. + assert.False(t, b.Check(l, 50)) +} + +// TestBitsMarkerInvariant verifies the seeded bits[0]=1 marker behaves +// correctly across warmup and beyond. Update should never clear the marker +// during warmup (clearRange skips position 0 when startPos=1), and once +// current >= length the marker is no longer consulted by Check/Update on +// the live path — but it must still report counter 0 as a duplicate while +// we are in warmup. +func TestBitsMarkerInvariant(t *testing.T) { + l := test.NewLogger() + b := NewBits(8) + + // Counter 0 is the seeded marker; Check sees it as already received. + assert.False(t, b.Check(l, 0)) + // Update(0) at current=0 hits the duplicate branch. + b.dupeCounter.Clear() + assert.False(t, b.Update(l, 0)) + assert.Equal(t, int64(1), b.dupeCounter.Count()) + + // Walk forward through warmup; the marker must remain set. + for n := uint64(1); n <= 7; n++ { + assert.True(t, b.Update(l, n)) + } + // Position 0 (the marker) should still read as set because we never + // cleared it; Update(0) still looks like a duplicate. + assert.False(t, b.Check(l, 0)) + + // Cross into steady state with a unit advance to 8: pos=0, evicts the + // marker bit. The lost-counter guard (i > b.length) is false (8 == 8), + // so this advance does NOT charge a lost packet — exactly what the + // marker is there to prevent. + b.lostCounter.Clear() + assert.True(t, b.Update(l, 8)) + assert.Equal(t, int64(0), b.lostCounter.Count()) + // The slot at pos 0 is now occupied by counter 8. + assert.False(t, b.Check(l, 8)) +} + +// BenchmarkBitsUpdateInOrder is the steady-state hot path: each call is +// i == current+1. +func BenchmarkBitsUpdateInOrder(b *testing.B) { + l := test.NewLogger() + z := NewBits(16384) + for n := 0; n < b.N; n++ { + z.Update(l, uint64(n)+1) + } +} + +// BenchmarkBitsUpdateReorder simulates light reorder within the window: +// every other packet arrives one slot behind its predecessor (forces the +// in-window backfill branch). +func BenchmarkBitsUpdateReorder(b *testing.B) { + l := test.NewLogger() + z := NewBits(16384) + for n := 0; n < b.N; n++ { + base := uint64(n) * 2 + z.Update(l, base+2) + z.Update(l, base+1) + } +} + +// BenchmarkBitsUpdateLargeJumps stresses the clearRange word-level path. +func BenchmarkBitsUpdateLargeJumps(b *testing.B) { + l := test.NewLogger() + z := NewBits(16384) + for n := 0; n < b.N; n++ { + z.Update(l, uint64(n+1)*1000) } } From 4fb5cdb4faaa1c47ef0c8e59fb46641db707dca9 Mon Sep 17 00:00:00 2001 From: Wade Simmons Date: Wed, 6 May 2026 12:23:27 -0400 Subject: [PATCH 04/31] refactor readOutsidePackets (#1642) * refactor readOutsidePackets They layout of this method is confusing and relys on certain parts to return early for things to work correctly. Change the ordering of the logic so that we do this: - Handle unencrypted packets - Decrypt packet - Handle encrypted packets This way, nothing can sneak through unencrypted to where it shouldn't be. * fix comment * code review comments * check for expected type/subtype * check header version * log header * need to handle TestReply * clean roaming / connectionManager * dont need to roam here now, we do it earlier * cleanup metrics and errors * rxInvalid * debug logger checks * ErrOutOfWindow --- header/header.go | 14 ++ message_metrics.go | 8 + outside.go | 413 +++++++++++++++++++++------------------------ 3 files changed, 210 insertions(+), 225 deletions(-) diff --git a/header/header.go b/header/header.go index f22509b8..b973141f 100644 --- a/header/header.go +++ b/header/header.go @@ -174,6 +174,10 @@ func (h *H) SubTypeName() string { return SubTypeName(h.Type, h.Subtype) } +func (h *H) IsValidSubType() bool { + return IsValidSubType(h.Type, h.Subtype) +} + // SubTypeName will transform a nebula message sub type into a human string func SubTypeName(t MessageType, s MessageSubType) string { if n, ok := subTypeMap[t]; ok { @@ -185,6 +189,16 @@ func SubTypeName(t MessageType, s MessageSubType) string { return "unknown" } +func IsValidSubType(t MessageType, s MessageSubType) bool { + if n, ok := subTypeMap[t]; ok { + if _, ok := (*n)[s]; ok { + return true + } + } + + return false +} + // NewHeader turns bytes into a header func NewHeader(b []byte) (*H, error) { h := new(H) diff --git a/message_metrics.go b/message_metrics.go index 10e8472c..45de9a5c 100644 --- a/message_metrics.go +++ b/message_metrics.go @@ -13,6 +13,8 @@ type MessageMetrics struct { rxUnknown metrics.Counter txUnknown metrics.Counter + + rxInvalid metrics.Counter } func (m *MessageMetrics) Rx(t header.MessageType, s header.MessageSubType, i int64) { @@ -33,6 +35,11 @@ func (m *MessageMetrics) Tx(t header.MessageType, s header.MessageSubType, i int } } } +func (m *MessageMetrics) RxInvalid(i int64) { + if m != nil && m.rxInvalid != nil { + m.rxInvalid.Inc(i) + } +} func newMessageMetrics() *MessageMetrics { gen := func(t string) [][]metrics.Counter { @@ -56,6 +63,7 @@ func newMessageMetrics() *MessageMetrics { rxUnknown: metrics.GetOrRegisterCounter("messages.rx.other", nil), txUnknown: metrics.GetOrRegisterCounter("messages.tx.other", nil), + rxInvalid: metrics.GetOrRegisterCounter("messages.rx.invalid", nil), } } diff --git a/outside.go b/outside.go index 1e00a0a9..17013ed3 100644 --- a/outside.go +++ b/outside.go @@ -20,23 +20,46 @@ const ( minFwPacketLen = 4 ) +var ErrOutOfWindow = errors.New("out of window packet") + func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { err := h.Parse(packet) if err != nil { // Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors + // TODO: record metrics for rx holepunch/punchy packets? if len(packet) > 1 { - f.l.Info("Error while parsing inbound packet", - "from", via, - "error", err, - "packet", packet, - ) + f.messageMetrics.RxInvalid(1) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Error while parsing inbound packet", + "from", via, + "error", err, + "packet", packet, + ) + } + } + return + } + + if h.Version != header.Version { + f.messageMetrics.RxInvalid(1) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Unexpected header version received", "from", via) + } + return + } + + // Check before processing to see if this is a expected type/subtype + if !h.IsValidSubType() { + f.messageMetrics.RxInvalid(1) + if f.l.Enabled(context.Background(), slog.LevelDebug) { + f.l.Debug("Unexpected packet received", "from", via) } return } - //l.Error("in packet ", header, packet[HeaderLen:]) if !via.IsRelayed { if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) { + f.messageMetrics.RxInvalid(1) if f.l.Enabled(context.Background(), slog.LevelDebug) { f.l.Debug("Refusing to process double encrypted packet", "from", via) } @@ -44,215 +67,192 @@ func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, } } + // don't keep Rx metrics for message type, since you can see those in the tun metrics + if h.Type != header.Message { + f.messageMetrics.Rx(h.Type, h.Subtype, 1) + } + + // Unencrypted packets + switch h.Type { + case header.Handshake: + f.handshakeManager.HandleIncoming(via, packet, h) + return + + case header.RecvError: + f.handleRecvError(via.UdpAddr, h) + return + } + + // Relay packets are special + isMessageRelay := (h.Type == header.Message && h.Subtype == header.MessageRelay) + var hostinfo *HostInfo - // verify if we've seen this index before, otherwise respond to the handshake initiation - if h.Type == header.Message && h.Subtype == header.MessageRelay { + if isMessageRelay { hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex) } else { hostinfo = f.hostMap.QueryIndex(h.RemoteIndex) } - var ci *ConnectionState - if hostinfo != nil { - ci = hostinfo.ConnectionState + // At this point we should have a valid existing tunnel, verify and send + // recvError if necessary + if hostinfo == nil || hostinfo.ConnectionState == nil { + if !via.IsRelayed { + f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex) + } + return } + // All remaining packets are encrypted + ci := hostinfo.ConnectionState + if !ci.window.Check(f.l, h.MessageCounter) { + return + } + + // Relay packets are special + if isMessageRelay { + f.handleOutsideRelayPacket(hostinfo, via, out, packet, h, fwPacket, lhf, nb, q, localCache) + + return + } + + out, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) + if err != nil { + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Failed to decrypt packet", + "error", err, + "from", via, + "header", h, + ) + } + return + } + + // Roam before we respond + f.handleHostRoaming(hostinfo, via) + f.connectionManager.In(hostinfo) + switch h.Type { case header.Message: - if !f.handleEncrypted(ci, via, h) { - return - } - switch h.Subtype { case header.MessageNone: - if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) { - return - } - case header.MessageRelay: - // The entire body is sent as AD, not encrypted. - // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. - // The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's - // otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice - // which will gracefully fail in the DecryptDanger call. - signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()] - signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():] - out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, signedPayload, signatureValue, h.MessageCounter, nb) - if err != nil { - return - } - // Successfully validated the thing. Get rid of the Relay header. - signedPayload = signedPayload[header.Len:] - // Pull the Roaming parts up here, and return in all call paths. - f.handleHostRoaming(hostinfo, via) - // Track usage of both the HostInfo and the Relay for the received & authenticated packet - f.connectionManager.In(hostinfo) - f.connectionManager.RelayUsed(h.RemoteIndex) - - relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex) - if !ok { - // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing - // its internal mapping. This should never happen. - hostinfo.logger(f.l).Error("HostInfo missing remote relay index", - "vpnAddrs", hostinfo.vpnAddrs, - "remoteIndex", h.RemoteIndex, - ) - return - } - - switch relay.Type { - case TerminalType: - // If I am the target of this relay, process the unwrapped packet - // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. - via = ViaSender{ - UdpAddr: via.UdpAddr, - relayHI: hostinfo, - remoteIdx: relay.RemoteIndex, - relay: relay, - IsRelayed: true, - } - f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) - return - case ForwardingType: - // Find the target HostInfo relay object - targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) - if err != nil { - hostinfo.logger(f.l).Info("Failed to find target host info by ip", - "relayTo", relay.PeerAddr, - "error", err, - "hostinfo.vpnAddrs", hostinfo.vpnAddrs, - ) - return - } - - // If that relay is Established, forward the payload through it - if targetRelay.State == Established { - switch targetRelay.Type { - case ForwardingType: - // Forward this packet through the relay tunnel - // Find the target HostInfo - f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false) - return - case TerminalType: - hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") - } - } else { - hostinfo.logger(f.l).Info("Unexpected target relay state", - "relayTo", relay.PeerAddr, - "relayFrom", hostinfo.vpnAddrs[0], - "targetRelayState", targetRelay.State, - ) - return - } - } + f.handleOutsideMessagePacket(hostinfo, out, packet, fwPacket, nb, q, localCache) + default: + hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message subtype seen", "from", via, "header", h) + return } case header.LightHouse: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, via, h) { - return - } - - d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) - if err != nil { - hostinfo.logger(f.l).Error("Failed to decrypt lighthouse packet", - "error", err, - "from", via, - "packet", packet, - ) - return - } - //TODO: assert via is not relayed - lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f) - - // Fallthrough to the bottom to record incoming traffic + lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, out, f) case header.Test: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, via, h) { + switch h.Subtype { + case header.TestReply: + // No-op, useful for the Roaming and connectionManager side-effects above + case header.TestRequest: + f.send(header.Test, header.TestReply, ci, hostinfo, out, nb, out) + default: + hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected test subtype seen", "from", via, "header", h) return } - d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) - if err != nil { - hostinfo.logger(f.l).Error("Failed to decrypt test packet", - "error", err, - "from", via, - "packet", packet, - ) - return - } - - if h.Subtype == header.TestRequest { - // This testRequest might be from TryPromoteBest, so we should roam - // to the new IP address before responding - f.handleHostRoaming(hostinfo, via) - f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out) - } - - // Fallthrough to the bottom to record incoming traffic - - // Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they - // are unauthenticated - - case header.Handshake: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handshakeManager.HandleIncoming(via, packet, h) - return - - case header.RecvError: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - f.handleRecvError(via.UdpAddr, h) - return - case header.CloseTunnel: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if !f.handleEncrypted(ci, via, h) { - return - } - _, err = f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) - if err != nil { - hostinfo.logger(f.l).Error("Failed to decrypt CloseTunnel packet", - "error", err, - "from", via, - "packet", packet, - ) - return - } - hostinfo.logger(f.l).Info("Close tunnel received, tearing down.", "from", via) - f.closeTunnel(hostinfo) - return case header.Control: - if !f.handleEncrypted(ci, via, h) { - return - } - - d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb) - if err != nil { - hostinfo.logger(f.l).Error("Failed to decrypt Control packet", - "error", err, - "from", via, - "packet", packet, - ) - return - } - - f.relayManager.HandleControlMsg(hostinfo, d, f) + f.relayManager.HandleControlMsg(hostinfo, out, f) default: - f.messageMetrics.Rx(h.Type, h.Subtype, 1) - if f.l.Enabled(context.Background(), slog.LevelDebug) { - hostinfo.logger(f.l).Debug("Unexpected packet received", "from", via) - } + hostinfo.logger(f.l).Error("IsValidSubType was true, but unexpected message type seen", "from", via, "header", h) + } +} + +func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { + // The entire body is sent as AD, not encrypted. + // The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value. + // The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's + // otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice + // which will gracefully fail in the DecryptDanger call. + signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()] + signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():] + var err error + out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, signedPayload, signatureValue, h.MessageCounter, nb) + if err != nil { + return + } + // Successfully validated the thing. Get rid of the Relay header. + signedPayload = signedPayload[header.Len:] + // Pull the Roaming parts up here, and return in all call paths. + f.handleHostRoaming(hostinfo, via) + // Track usage of both the HostInfo and the Relay for the received & authenticated packet + f.connectionManager.In(hostinfo) + f.connectionManager.RelayUsed(h.RemoteIndex) + + relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex) + if !ok { + // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing + // its internal mapping. This should never happen. + hostinfo.logger(f.l).Error("HostInfo missing remote relay index", + "vpnAddrs", hostinfo.vpnAddrs, + "remoteIndex", h.RemoteIndex, + ) return } - f.handleHostRoaming(hostinfo, via) + switch relay.Type { + case TerminalType: + // If I am the target of this relay, process the unwrapped packet + // From this recursive point, all these variables are 'burned'. We shouldn't rely on them again. + via = ViaSender{ + UdpAddr: via.UdpAddr, + relayHI: hostinfo, + remoteIdx: relay.RemoteIndex, + relay: relay, + IsRelayed: true, + } + f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache) + case ForwardingType: + // Find the target HostInfo relay object + targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) + if err != nil { + hostinfo.logger(f.l).Info("Failed to find target host info by ip", + "relayTo", relay.PeerAddr, + "error", err, + "hostinfo.vpnAddrs", hostinfo.vpnAddrs, + ) + return + } - f.connectionManager.In(hostinfo) + // If that relay is Established, forward the payload through it + if targetRelay.State == Established { + switch targetRelay.Type { + case ForwardingType: + // Forward this packet through the relay tunnel + // Find the target HostInfo + f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false) + case TerminalType: + hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") + return + default: + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Unexpected targetRelay Type", "from", via, "relayType", targetRelay.Type) + } + return + } + } else { + hostinfo.logger(f.l).Info("Unexpected target relay state", + "relayTo", relay.PeerAddr, + "relayFrom", hostinfo.vpnAddrs[0], + "targetRelayState", targetRelay.State, + ) + return + } + default: + if f.l.Enabled(context.Background(), slog.LevelDebug) { + hostinfo.logger(f.l).Debug("Unexpected relay type", "from", via, "relayType", relay.Type) + } + } } // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote @@ -300,23 +300,6 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) { } -// handleEncrypted returns true if a packet should be processed, false otherwise -func (f *Interface) handleEncrypted(ci *ConnectionState, via ViaSender, h *header.H) bool { - // If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect - if ci == nil { - if !via.IsRelayed { - f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex) - } - return false - } - // If the window check fails, refuse to process the packet, but don't send a recv error - if !ci.window.Check(f.l, h.MessageCounter) { - return false - } - - return true -} - var ( ErrPacketTooShort = errors.New("packet is too short") ErrUnknownIPVersion = errors.New("packet is an unknown ip version") @@ -523,38 +506,20 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet [] } if !hostinfo.ConnectionState.window.Update(f.l, mc) { - if f.l.Enabled(context.Background(), slog.LevelDebug) { - hostinfo.logger(f.l).Debug("dropping out of window packet", "header", h) - } - return nil, errors.New("out of window packet") + return nil, ErrOutOfWindow } return out, nil } -func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool { - var err error - - out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb) - if err != nil { - hostinfo.logger(f.l).Error("Failed to decrypt packet", "error", err) - return false - } - - err = newPacket(out, true, fwPacket) +func (f *Interface) handleOutsideMessagePacket(hostinfo *HostInfo, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) { + err := newPacket(out, true, fwPacket) if err != nil { hostinfo.logger(f.l).Warn("Error while validating inbound packet", "error", err, "packet", out, ) - return false - } - - if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) { - if f.l.Enabled(context.Background(), slog.LevelDebug) { - hostinfo.logger(f.l).Debug("dropping out of window packet", "fwPacket", fwPacket) - } - return false + return } dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache) @@ -568,15 +533,13 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out "reason", dropReason, ) } - return false + return } - f.connectionManager.In(hostinfo) _, err = f.readers[q].Write(out) if err != nil { f.l.Error("Failed to write to tun", "error", err) } - return true } func (f *Interface) maybeSendRecvError(endpoint netip.AddrPort, index uint32) { From 213dd46588d516f0151d0cc54e16a4cd042f9ba4 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Wed, 6 May 2026 16:21:16 -0500 Subject: [PATCH 05/31] Stop leaking goroutines past Control.Stop, consolidate punching in Punchy (#1708) --- connection_manager.go | 54 +++-------- connection_manager_test.go | 8 +- e2e/leak_test.go | 10 +- examples/config.yml | 4 + lighthouse.go | 52 ++-------- main.go | 6 +- punchy.go | 193 ++++++++++++++++++++++++++++++------- punchy_test.go | 81 ++++++++-------- scheduler.go | 84 ++++++++++++++++ scheduler_test.go | 79 +++++++++++++++ sshd/server.go | 25 ++++- timeout.go | 17 ++++ 12 files changed, 434 insertions(+), 179 deletions(-) create mode 100644 scheduler.go create mode 100644 scheduler_test.go diff --git a/connection_manager.go b/connection_manager.go index e7fc04cd..ee6d1eaf 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -11,7 +11,6 @@ import ( "sync/atomic" "time" - "github.com/rcrowley/go-metrics" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" @@ -45,19 +44,16 @@ type connectionManager struct { inactivityTimeout atomic.Int64 dropInactive atomic.Bool - metricsTxPunchy metrics.Counter - l *slog.Logger } func newConnectionManagerFromConfig(l *slog.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager { cm := &connectionManager{ - hostMap: hm, - l: l, - punchy: p, - relayUsed: make(map[uint32]struct{}), - relayUsedLock: &sync.RWMutex{}, - metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil), + hostMap: hm, + l: l, + punchy: p, + relayUsed: make(map[uint32]struct{}), + relayUsedLock: &sync.RWMutex{}, } cm.reload(c, true) @@ -369,7 +365,7 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim if !outTraffic { // Send a punch packet to keep the NAT state alive - cm.sendPunch(hostinfo) + cm.punchy.SendPunch(hostinfo) } return decision, hostinfo, primary @@ -400,17 +396,16 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim // If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel. // Just maintain NAT state if configured to do so. - cm.sendPunch(hostinfo) + cm.punchy.SendPunch(hostinfo) cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval) return doNothing, nil, nil } - if cm.punchy.GetTargetEverything() { - // This is similar to the old punchy behavior with a slight optimization. - // We aren't receiving traffic but we are sending it, punch on all known - // ips in case we need to re-prime NAT state - cm.sendPunch(hostinfo) - } + // We aren't receiving traffic but we are sending it. The outbound + // traffic itself refreshes the primary remote's NAT state; this + // fans out to non-primary remotes, but only if target_all_remotes + // is configured. + cm.punchy.SendPunchToAll(hostinfo) if cm.l.Enabled(context.Background(), slog.LevelDebug) { hostinfo.logger(cm.l).Debug("Tunnel status", @@ -512,31 +507,6 @@ func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostI } } -func (cm *connectionManager) sendPunch(hostinfo *HostInfo) { - if !cm.punchy.GetPunch() { - // Punching is disabled - return - } - - if cm.intf.lightHouse.IsAnyLighthouseAddr(hostinfo.vpnAddrs) { - // Do not punch to lighthouses, we assume our lighthouse update interval is good enough. - // In the event the update interval is not sufficient to maintain NAT state then a publicly available lighthouse - // would lose the ability to notify us and punchy.respond would become unreliable. - return - } - - if cm.punchy.GetTargetEverything() { - hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) { - cm.metricsTxPunchy.Inc(1) - cm.intf.outside.WriteTo([]byte{1}, addr) - }) - - } else if hostinfo.remote.IsValid() { - cm.metricsTxPunchy.Inc(1) - cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote) - } -} - func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) { cs := cm.intf.pki.getCertState() curCrt := hostinfo.ConnectionState.myCert diff --git a/connection_manager_test.go b/connection_manager_test.go index 7dc08a45..e167e5f2 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -64,7 +64,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { // Create manager conf := config.NewC(test.NewLogger()) - punchy := NewPunchyFromConfig(test.NewLogger(), conf) + punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil) nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc.intf = ifce p := []byte("") @@ -146,7 +146,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { // Create manager conf := config.NewC(test.NewLogger()) - punchy := NewPunchyFromConfig(test.NewLogger(), conf) + punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil) nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc.intf = ifce p := []byte("") @@ -233,7 +233,7 @@ func Test_NewConnectionManager_DisconnectInactive(t *testing.T) { conf.Settings["tunnels"] = map[string]any{ "drop_inactive": true, } - punchy := NewPunchyFromConfig(test.NewLogger(), conf) + punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil) nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) assert.True(t, nc.dropInactive.Load()) nc.intf = ifce @@ -358,7 +358,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { // Create manager conf := config.NewC(test.NewLogger()) - punchy := NewPunchyFromConfig(test.NewLogger(), conf) + punchy := NewPunchyFromConfig(test.NewLogger(), conf, nil) nc := newConnectionManagerFromConfig(test.NewLogger(), conf, hostMap, punchy) nc.intf = ifce ifce.connectionManager = nc diff --git a/e2e/leak_test.go b/e2e/leak_test.go index ffb024fe..576d67a8 100644 --- a/e2e/leak_test.go +++ b/e2e/leak_test.go @@ -18,14 +18,10 @@ import ( // retry mechanism gives the wg.Wait()-driven goroutines a moment to drain // before failing the assertion. // -// IgnoreCurrent is necessary in the parallelized suite: other tests can -// leave goroutines mid-shutdown when this one runs (Stop is async, the -// wg.Wait() drain is not blocking on test return). We're checking that -// *this* test's setup tears down cleanly, not that the whole suite is -// idle at this moment. Intentionally NOT t.Parallel()'d for the same -// reason — concurrent test goroutines would always show up. +// Intentionally NOT t.Parallel()'d: concurrent tests would have their own +// goroutines running and trip the assertion. func TestNoGoroutineLeaks(t *testing.T) { - defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + defer goleak.VerifyNone(t) ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) diff --git a/examples/config.yml b/examples/config.yml index f5752ae4..ac4810e6 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -163,17 +163,21 @@ listen: punchy: # Continues to punch inbound/outbound at a regular interval to avoid expiration of firewall nat mappings + # This setting is reloadable. punch: true # respond means that a node you are trying to reach will connect back out to you if your hole punching fails # this is extremely useful if one node is behind a difficult nat, such as a symmetric NAT # Default is false + # This setting is reloadable. #respond: true # delays a punch response for misbehaving NATs, default is 1 second. + # This setting is reloadable. #delay: 1s # set the delay before attempting punchy.respond. Default is 5 seconds. respond must be true to take effect. + # This setting is reloadable. #respond_delay: 5s # Cipher allows you to choose between the available ciphers for your network. Options are chachapoly or aes diff --git a/lighthouse.go b/lighthouse.go index 6034e68c..1a136a1b 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -15,7 +15,6 @@ import ( "time" "github.com/gaissmai/bart" - "github.com/rcrowley/go-metrics" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" @@ -35,7 +34,6 @@ type LightHouse struct { myVpnNetworks []netip.Prefix myVpnNetworksTable *bart.Lite - punchConn udp.Conn punchy *Punchy // Local cache of answers from light houses @@ -75,9 +73,8 @@ type LightHouse struct { calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote - metrics *MessageMetrics - metricHolepunchTx metrics.Counter - l *slog.Logger + metrics *MessageMetrics + l *slog.Logger } // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object @@ -105,7 +102,6 @@ func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, c myVpnNetworksTable: cs.myVpnNetworksTable, addrMap: make(map[netip.Addr]*RemoteList), nebulaPort: nebulaPort, - punchConn: pc, punchy: p, updateTrigger: make(chan struct{}, 1), queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), @@ -118,9 +114,6 @@ func NewLightHouseFromConfig(ctx context.Context, l *slog.Logger, c *config.C, c if c.GetBool("stats.lighthouse_metrics", false) { h.metrics = newLighthouseMetrics() - h.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil) - } else { - h.metricHolepunchTx = metrics.NilCounter{} } err := h.reload(c, true) @@ -1406,58 +1399,25 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpn return } - empty := []byte{0} - punch := func(vpnPeer netip.AddrPort, logVpnAddr netip.Addr) { - if !vpnPeer.IsValid() { - return - } - - go func() { - time.Sleep(lhh.lh.punchy.GetDelay()) - lhh.lh.metricHolepunchTx.Inc(1) - lhh.lh.punchConn.WriteTo(empty, vpnPeer) - }() - - if lhh.l.Enabled(context.Background(), slog.LevelDebug) { - lhh.l.Debug("Punching", - "vpnPeer", vpnPeer, - "logVpnAddr", logVpnAddr, - ) - } - } - remoteAllowList := lhh.lh.GetRemoteAllowList() for _, a := range n.Details.V4AddrPorts { b := protoV4AddrPortToNetAddrPort(a) if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) { - punch(b, detailsVpnAddr) + lhh.lh.punchy.Schedule(b, detailsVpnAddr) } } for _, a := range n.Details.V6AddrPorts { b := protoV6AddrPortToNetAddrPort(a) if remoteAllowList.Allow(detailsVpnAddr, b.Addr()) { - punch(b, detailsVpnAddr) + lhh.lh.punchy.Schedule(b, detailsVpnAddr) } } // This sends a nebula test packet to the host trying to contact us. In the case // of a double nat or other difficult scenario, this may help establish - // a tunnel. - if lhh.lh.punchy.GetRespond() { - go func() { - time.Sleep(lhh.lh.punchy.GetRespondDelay()) - if lhh.l.Enabled(context.Background(), slog.LevelDebug) { - lhh.l.Debug("Sending a nebula test packet", - "vpnAddr", detailsVpnAddr, - ) - } - //NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine - // for each punchBack packet. We should move this into a timerwheel or a single goroutine - // managed by a channel. - w.SendMessageToVpnAddr(header.Test, header.TestRequest, detailsVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) - }() - } + // a tunnel. ScheduleRespond is a no-op when punchy.respond is disabled. + lhh.lh.punchy.ScheduleRespond(detailsVpnAddr) } func protoAddrToNetAddr(addr *Addr) netip.Addr { diff --git a/main.go b/main.go index d5e5dcc8..37aa24d1 100644 --- a/main.go +++ b/main.go @@ -55,7 +55,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev } l.Info("Firewall started", "firewallHashes", fw.GetRuleHashes()) - ssh, err := sshd.NewSSHServer(l.With("subsystem", "sshd")) + ssh, err := sshd.NewSSHServer(ctx, l.With("subsystem", "sshd")) if err != nil { return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err) } @@ -170,7 +170,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev } hostMap := NewHostMapFromConfig(l, c) - punchy := NewPunchyFromConfig(l, c) + punchy := NewPunchyFromConfig(l, c, udpConns[0]) connManager := newConnectionManagerFromConfig(l, c, hostMap, punchy) lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy) if err != nil { @@ -240,6 +240,8 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev handshakeManager.f = ifce go handshakeManager.Run(ctx) + + punchy.Start(ctx, ifce, hostMap, lightHouse) } stats, err := newStatsServerFromConfig(ctx, l, c, buildVersion, configTest) diff --git a/punchy.go b/punchy.go index 6ecf4f85..38a0e1ca 100644 --- a/punchy.go +++ b/punchy.go @@ -1,24 +1,70 @@ package nebula import ( + "context" "log/slog" + "net/netip" "sync/atomic" "time" + "github.com/rcrowley/go-metrics" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/header" + "github.com/slackhq/nebula/udp" ) +// holepunchQueueSize buffers the channel that pending holepunchJobs land on after their delay timer fires. +const holepunchQueueSize = 64 + +// holepunchJob is one scheduled item delivered to the worker goroutine. +// - target valid -> send a UDP punch to target. vpnAddr, if set, is the peer's vpn addr carried for log context. +// - target invalid, vpnAddr valid -> send an encrypted test packet to vpnAddr (a "punchback"). +type holepunchJob struct { + target netip.AddrPort + vpnAddr netip.Addr +} + +// lighthouseChecker is the slice of LightHouse that Punchy actually needs. +// Defined here so Punchy doesn't take a *LightHouse dependency (LightHouse +// already holds a *Punchy, and the bidirectional pointer reference is awkward +// even within the same package). Tests can also substitute a fake. +type lighthouseChecker interface { + IsAnyLighthouseAddr(vpnAddrs []netip.Addr) bool +} + type Punchy struct { punch atomic.Bool respond atomic.Bool delay atomic.Int64 respondDelay atomic.Int64 punchEverything atomic.Bool - l *slog.Logger + + sched *Scheduler[holepunchJob] + punchConn udp.Conn + metricHolepunchTx metrics.Counter + metricPunchyTx metrics.Counter + + ctx context.Context + ifce EncWriter + hm *HostMap + lh lighthouseChecker + + l *slog.Logger } -func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy { - p := &Punchy{l: l} +func NewPunchyFromConfig(l *slog.Logger, c *config.C, punchConn udp.Conn) *Punchy { + p := &Punchy{ + l: l, + punchConn: punchConn, + sched: NewScheduler[holepunchJob](holepunchQueueSize), + metricPunchyTx: metrics.GetOrRegisterCounter("messages.tx.punchy", nil), + } + + if c.GetBool("stats.lighthouse_metrics", false) { + p.metricHolepunchTx = metrics.GetOrRegisterCounter("messages.tx.holepunch", nil) + } else { + p.metricHolepunchTx = metrics.NilCounter{} + } p.reload(c, true) c.RegisterReloadCallback(func(c *config.C) { @@ -29,7 +75,7 @@ func NewPunchyFromConfig(l *slog.Logger, c *config.C) *Punchy { } func (p *Punchy) reload(c *config.C, initial bool) { - if initial { + if initial || c.HasChanged("punchy.punch") || c.HasChanged("punchy") { var yes bool if c.IsSet("punchy.punch") { yes = c.GetBool("punchy.punch", false) @@ -38,16 +84,15 @@ func (p *Punchy) reload(c *config.C, initial bool) { yes = c.GetBool("punchy", false) } - p.punch.Store(yes) - if yes { + old := p.punch.Swap(yes) + switch { + case initial && yes: p.l.Info("punchy enabled") - } else { + case initial: p.l.Info("punchy disabled") + case old != yes: + p.l.Info("punchy.punch changed", "punch", yes) } - - } else if c.HasChanged("punchy.punch") || c.HasChanged("punchy") { - //TODO: it should be relatively easy to support this, just need to be able to cancel the goroutine and boot it up from here - p.l.Warn("Changing punchy.punch with reload is not supported, ignoring.") } if initial || c.HasChanged("punchy.respond") || c.HasChanged("punch_back") { @@ -59,52 +104,132 @@ func (p *Punchy) reload(c *config.C, initial bool) { yes = c.GetBool("punch_back", false) } - p.respond.Store(yes) - - if !initial { - p.l.Info("punchy.respond changed", "respond", p.GetRespond()) + old := p.respond.Swap(yes) + if !initial && old != yes { + p.l.Info("punchy.respond changed", "respond", yes) } } //NOTE: this will not apply to any in progress operations, only the next one if initial || c.HasChanged("punchy.delay") { - p.delay.Store((int64)(c.GetDuration("punchy.delay", time.Second))) - if !initial { - p.l.Info("punchy.delay changed", "delay", p.GetDelay()) + newDelay := int64(c.GetDuration("punchy.delay", time.Second)) + old := p.delay.Swap(newDelay) + if !initial && old != newDelay { + p.l.Info("punchy.delay changed", "delay", time.Duration(newDelay)) } } if initial || c.HasChanged("punchy.target_all_remotes") { - p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false)) - if !initial { - p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", p.GetTargetEverything()) + yes := c.GetBool("punchy.target_all_remotes", false) + old := p.punchEverything.Swap(yes) + if !initial && old != yes { + p.l.Info("punchy.target_all_remotes changed", "target_all_remotes", yes) } } if initial || c.HasChanged("punchy.respond_delay") { - p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second))) - if !initial { - p.l.Info("punchy.respond_delay changed", "respond_delay", p.GetRespondDelay()) + newDelay := int64(c.GetDuration("punchy.respond_delay", 5*time.Second)) + old := p.respondDelay.Swap(newDelay) + if !initial && old != newDelay { + p.l.Info("punchy.respond_delay changed", "respond_delay", time.Duration(newDelay)) } } } -func (p *Punchy) GetPunch() bool { - return p.punch.Load() +// Schedule queues a punch packet to target, to be sent after the configured delay. +// vpnAddr is the peer's vpn addr, used for log context when the packet actually fires. +// No-op if target is not a valid AddrPort or if Start has not yet been called. Safe to call from any goroutine. +func (p *Punchy) Schedule(target netip.AddrPort, vpnAddr netip.Addr) { + if !target.IsValid() || p.ctx == nil { + return + } + p.scheduleJob(holepunchJob{target: target, vpnAddr: vpnAddr}, time.Duration(p.delay.Load())) } -func (p *Punchy) GetRespond() bool { - return p.respond.Load() +// ScheduleRespond queues a punchback test packet to vpnAddr after the configured respond delay, +// gated on punchy.respond. No-op when respond is disabled or before Start has been called. +func (p *Punchy) ScheduleRespond(vpnAddr netip.Addr) { + if !p.respond.Load() || p.ctx == nil { + return + } + p.scheduleJob(holepunchJob{vpnAddr: vpnAddr}, time.Duration(p.respondDelay.Load())) } -func (p *Punchy) GetDelay() time.Duration { - return (time.Duration)(p.delay.Load()) +// scheduleJob delegates to the pooled Scheduler. +// The callback observes p.ctx so a job that becomes due after Stop is dropped instead of queued. +func (p *Punchy) scheduleJob(job holepunchJob, delay time.Duration) { + p.sched.Schedule(p.ctx, job, delay) } -func (p *Punchy) GetRespondDelay() time.Duration { - return (time.Duration)(p.respondDelay.Load()) +// SendPunch sends an immediate keepalive punch for an idle hostinfo. +// The configured punchy.target_all_remotes mode picks the targets. Gated on punchy.punch and the lighthouse-skip rule +// (lighthouses don't get keepalive punches because the regular update interval keeps their NAT state warm). +func (p *Punchy) SendPunch(hostinfo *HostInfo) { + if !p.punch.Load() { + return + } + if p.lh.IsAnyLighthouseAddr(hostinfo.vpnAddrs) { + return + } + + if p.punchEverything.Load() { + p.sendPunchToAllRemotes(hostinfo) + } else if hostinfo.remote.IsValid() { + p.metricPunchyTx.Inc(1) + p.punchConn.WriteTo([]byte{1}, hostinfo.remote) + } } -func (p *Punchy) GetTargetEverything() bool { - return p.punchEverything.Load() +// SendPunchToAll punches every known remote for hostinfo, but only when punchy.target_all_remotes is enabled. +// The connection manager calls this during outbound-only traffic: the outbound traffic itself keeps the primary's +// NAT state warm, but non-primary remotes need separate refresh, so we fan out to all of them (the redundant +// primary punch is harmless). Gated on punchy.punch and the lighthouse-skip rule. +func (p *Punchy) SendPunchToAll(hostinfo *HostInfo) { + if !p.punchEverything.Load() { + return + } + if !p.punch.Load() { + return + } + if p.lh.IsAnyLighthouseAddr(hostinfo.vpnAddrs) { + return + } + p.sendPunchToAllRemotes(hostinfo) +} + +func (p *Punchy) sendPunchToAllRemotes(hostinfo *HostInfo) { + hostinfo.remotes.ForEach(p.hm.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) { + p.metricPunchyTx.Inc(1) + p.punchConn.WriteTo([]byte{1}, addr) + }) +} + +// Start wires the runtime dependencies and spawns the scheduler worker. +func (p *Punchy) Start(ctx context.Context, ifce EncWriter, hm *HostMap, lh lighthouseChecker) { + p.ctx = ctx + p.ifce = ifce + p.hm = hm + p.lh = lh + + nb := make([]byte, 12, 12) + out := make([]byte, mtu) + empty := []byte{0} + + go p.sched.Run(ctx, func(job holepunchJob) { + switch { + case job.target.IsValid(): + if p.l.Enabled(context.Background(), slog.LevelDebug) { + p.l.Debug("Punching", "target", job.target, "vpnAddr", job.vpnAddr) + } + p.metricHolepunchTx.Inc(1) + p.punchConn.WriteTo(empty, job.target) + case job.vpnAddr.IsValid(): + // A nebula test packet to the host trying to contact us. + // In the case of a double nat or other difficult scenario, this may help establish a tunnel. + if p.l.Enabled(context.Background(), slog.LevelDebug) { + p.l.Debug("Sending a nebula test packet", "vpnAddr", job.vpnAddr) + } + p.ifce.SendMessageToVpnAddr(header.Test, header.TestRequest, job.vpnAddr, []byte(""), nb, out) + } + }) } diff --git a/punchy_test.go b/punchy_test.go index cbf9b17b..e56f3eff 100644 --- a/punchy_test.go +++ b/punchy_test.go @@ -17,42 +17,42 @@ func TestNewPunchyFromConfig(t *testing.T) { c := config.NewC(l) // Test defaults - p := NewPunchyFromConfig(test.NewLogger(), c) - assert.False(t, p.GetPunch()) - assert.False(t, p.GetRespond()) - assert.Equal(t, time.Second, p.GetDelay()) - assert.Equal(t, 5*time.Second, p.GetRespondDelay()) + p := NewPunchyFromConfig(test.NewLogger(), c, nil) + assert.False(t, p.punch.Load()) + assert.False(t, p.respond.Load()) + assert.Equal(t, time.Second, time.Duration(p.delay.Load())) + assert.Equal(t, 5*time.Second, time.Duration(p.respondDelay.Load())) // punchy deprecation c.Settings["punchy"] = true - p = NewPunchyFromConfig(test.NewLogger(), c) - assert.True(t, p.GetPunch()) + p = NewPunchyFromConfig(test.NewLogger(), c, nil) + assert.True(t, p.punch.Load()) // punchy.punch c.Settings["punchy"] = map[string]any{"punch": true} - p = NewPunchyFromConfig(test.NewLogger(), c) - assert.True(t, p.GetPunch()) + p = NewPunchyFromConfig(test.NewLogger(), c, nil) + assert.True(t, p.punch.Load()) // punch_back deprecation c.Settings["punch_back"] = true - p = NewPunchyFromConfig(test.NewLogger(), c) - assert.True(t, p.GetRespond()) + p = NewPunchyFromConfig(test.NewLogger(), c, nil) + assert.True(t, p.respond.Load()) // punchy.respond c.Settings["punchy"] = map[string]any{"respond": true} c.Settings["punch_back"] = false - p = NewPunchyFromConfig(test.NewLogger(), c) - assert.True(t, p.GetRespond()) + p = NewPunchyFromConfig(test.NewLogger(), c, nil) + assert.True(t, p.respond.Load()) // punchy.delay c.Settings["punchy"] = map[string]any{"delay": "1m"} - p = NewPunchyFromConfig(test.NewLogger(), c) - assert.Equal(t, time.Minute, p.GetDelay()) + p = NewPunchyFromConfig(test.NewLogger(), c, nil) + assert.Equal(t, time.Minute, time.Duration(p.delay.Load())) // punchy.respond_delay c.Settings["punchy"] = map[string]any{"respond_delay": "1m"} - p = NewPunchyFromConfig(test.NewLogger(), c) - assert.Equal(t, time.Minute, p.GetRespondDelay()) + p = NewPunchyFromConfig(test.NewLogger(), c, nil) + assert.Equal(t, time.Minute, time.Duration(p.respondDelay.Load())) } func TestPunchy_reload(t *testing.T) { @@ -61,35 +61,34 @@ func TestPunchy_reload(t *testing.T) { delay, _ := time.ParseDuration("1m") require.NoError(t, c.LoadString(` punchy: + punch: false delay: 1m respond: false `)) - p := NewPunchyFromConfig(test.NewLogger(), c) - assert.Equal(t, delay, p.GetDelay()) - assert.False(t, p.GetRespond()) + p := NewPunchyFromConfig(test.NewLogger(), c, nil) + assert.False(t, p.punch.Load()) + assert.Equal(t, delay, time.Duration(p.delay.Load())) + assert.False(t, p.respond.Load()) newDelay, _ := time.ParseDuration("10m") require.NoError(t, c.ReloadConfigString(` punchy: + punch: true delay: 10m respond: true `)) p.reload(c, false) - assert.Equal(t, newDelay, p.GetDelay()) - assert.True(t, p.GetRespond()) + assert.True(t, p.punch.Load()) + assert.Equal(t, newDelay, time.Duration(p.delay.Load())) + assert.True(t, p.respond.Load()) } // The tests below pin the shape of each log line Punchy produces so changes // cannot silently break whatever operators are grepping for. The assertions // are on the structured message + attrs (e.g. "punchy.respond changed" with -// a respond=true field) rather than a formatted string. -// -// Punchy.reload also emits a spurious "Changing punchy.punch with reload is -// not supported" warning whenever any key under punchy changes, because of -// the c.HasChanged("punchy") fallback kept for the deprecated top-level -// punchy form. The tests filter by message rather than asserting total -// entry counts so that warning is tolerated without being locked into -// the format. +// a respond=true field) rather than a formatted string. Tests filter by +// message rather than asserting total entry counts so unrelated info lines +// are tolerated without being locked into the format. type capturedEntry struct { Level slog.Level @@ -145,7 +144,7 @@ func TestPunchy_LogFormat_InitialEnabled(t *testing.T) { c := config.NewC(test.NewLogger()) require.NoError(t, c.LoadString(`punchy: {punch: true}`)) - NewPunchyFromConfig(l, c) + NewPunchyFromConfig(l, c, nil) entry := findEntry(t, hook.entries, "punchy enabled") assert.Equal(t, slog.LevelInfo, entry.Level) @@ -157,32 +156,32 @@ func TestPunchy_LogFormat_InitialDisabled(t *testing.T) { c := config.NewC(test.NewLogger()) require.NoError(t, c.LoadString(`punchy: {punch: false}`)) - NewPunchyFromConfig(l, c) + NewPunchyFromConfig(l, c, nil) entry := findEntry(t, hook.entries, "punchy disabled") assert.Equal(t, slog.LevelInfo, entry.Level) assert.Empty(t, entry.Attrs) } -func TestPunchy_LogFormat_ReloadPunchUnsupported(t *testing.T) { +func TestPunchy_LogFormat_ReloadPunch(t *testing.T) { l, hook := newCapturingPunchyLogger(t) c := config.NewC(test.NewLogger()) require.NoError(t, c.LoadString(`punchy: {punch: false}`)) - NewPunchyFromConfig(l, c) + NewPunchyFromConfig(l, c, nil) hook.entries = nil require.NoError(t, c.ReloadConfigString(`punchy: {punch: true}`)) - entry := findEntry(t, hook.entries, "Changing punchy.punch with reload is not supported, ignoring.") - assert.Equal(t, slog.LevelWarn, entry.Level) - assert.Empty(t, entry.Attrs) + entry := findEntry(t, hook.entries, "punchy.punch changed") + assert.Equal(t, slog.LevelInfo, entry.Level) + assert.Equal(t, map[string]any{"punch": true}, entry.Attrs) } func TestPunchy_LogFormat_ReloadRespond(t *testing.T) { l, hook := newCapturingPunchyLogger(t) c := config.NewC(test.NewLogger()) require.NoError(t, c.LoadString(`punchy: {respond: false}`)) - NewPunchyFromConfig(l, c) + NewPunchyFromConfig(l, c, nil) hook.entries = nil require.NoError(t, c.ReloadConfigString(`punchy: {respond: true}`)) @@ -196,7 +195,7 @@ func TestPunchy_LogFormat_ReloadDelay(t *testing.T) { l, hook := newCapturingPunchyLogger(t) c := config.NewC(test.NewLogger()) require.NoError(t, c.LoadString(`punchy: {delay: 1s}`)) - NewPunchyFromConfig(l, c) + NewPunchyFromConfig(l, c, nil) hook.entries = nil require.NoError(t, c.ReloadConfigString(`punchy: {delay: 10s}`)) @@ -210,7 +209,7 @@ func TestPunchy_LogFormat_ReloadTargetAllRemotes(t *testing.T) { l, hook := newCapturingPunchyLogger(t) c := config.NewC(test.NewLogger()) require.NoError(t, c.LoadString(`punchy: {target_all_remotes: false}`)) - NewPunchyFromConfig(l, c) + NewPunchyFromConfig(l, c, nil) hook.entries = nil require.NoError(t, c.ReloadConfigString(`punchy: {target_all_remotes: true}`)) @@ -224,7 +223,7 @@ func TestPunchy_LogFormat_ReloadRespondDelay(t *testing.T) { l, hook := newCapturingPunchyLogger(t) c := config.NewC(test.NewLogger()) require.NoError(t, c.LoadString(`punchy: {respond_delay: 5s}`)) - NewPunchyFromConfig(l, c) + NewPunchyFromConfig(l, c, nil) hook.entries = nil require.NoError(t, c.ReloadConfigString(`punchy: {respond_delay: 15s}`)) diff --git a/scheduler.go b/scheduler.go new file mode 100644 index 00000000..7733204a --- /dev/null +++ b/scheduler.go @@ -0,0 +1,84 @@ +package nebula + +import ( + "context" + "sync" + "time" +) + +// Scheduler is an allocation-conscious dispatch primitive for delayed work. +// Pending items are handed to time.AfterFunc, and ready items land on a worker +// channel for centralized dispatch in fire-time order. +// +// Pick a Scheduler when fire timing matters (exact deadlines, no bucketing) or when the scheduling +// rate is uneven enough that idle CPU matters. Each fire is a runtime-spawned goroutine running the callback before +// delivering to the worker, which is fine at sparse rates but adds up at line rate. +// +// Pick a TimerWheel when scheduling is high-rate and uniform: its O(1) insert, internal item cache, +// and bucket-batched dispatch are cheaper at scale. +// The caller drives the tick loop (Advance/Purge) and pays for fires at bucket boundaries rather than exact deadlines. +type Scheduler[T any] struct { + queue chan T + pool sync.Pool +} + +type schedItem[T any] struct { + val T + ctx context.Context + s *Scheduler[T] + timer *time.Timer + fire func() +} + +// NewScheduler builds a Scheduler whose worker channel is sized to queueSize. +// The buffer absorbs bursts of timers firing close together without +// blocking the runtime's callback goroutines on the worker. +func NewScheduler[T any](queueSize int) *Scheduler[T] { + s := &Scheduler[T]{ + queue: make(chan T, queueSize), + } + s.pool.New = func() any { + si := &schedItem[T]{s: s} + // fire is allocated exactly once per pool-resident item. + // The closure captures only `si`, which stays stable for the item's lifetime. + si.fire = func() { + select { + case si.s.queue <- si.val: + case <-si.ctx.Done(): + } + var zero T + si.val = zero + si.ctx = nil + si.s.pool.Put(si) + } + return si + } + return s +} + +// Schedule arranges item to be delivered to the worker after delay. +// The runtime's timer heap handles the wait, so the scheduler itself burns no CPU while idle. +// The callback observes ctx: if ctx is cancelled before the timer fires, the item is dropped instead of queued. +func (s *Scheduler[T]) Schedule(ctx context.Context, item T, delay time.Duration) { + si := s.pool.Get().(*schedItem[T]) + si.val = item + si.ctx = ctx + if si.timer == nil { + si.timer = time.AfterFunc(delay, si.fire) + } else { + si.timer.Reset(delay) + } +} + +// Run drains the worker queue, calling fn for each item. Returns when ctx is cancelled. +// Tests that want deterministic timing should drive the queue directly rather than going through Schedule + Run. +func (s *Scheduler[T]) Run(ctx context.Context, fn func(T)) { + for { + select { + case <-ctx.Done(): + return + case item := <-s.queue: + fn(item) + } + } +} diff --git a/scheduler_test.go b/scheduler_test.go new file mode 100644 index 00000000..085d523c --- /dev/null +++ b/scheduler_test.go @@ -0,0 +1,79 @@ +package nebula + +import ( + "context" + "testing" + "time" +) + +func TestScheduler_PooledReuse(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := NewScheduler[int](16) + delivered := make(chan int, 256) + go s.Run(ctx, func(item int) { delivered <- item }) + + const N = 100 + for i := 0; i < N; i++ { + s.Schedule(ctx, i, time.Millisecond) + } + + deadline := time.After(2 * time.Second) + got := 0 + for got < N { + select { + case <-delivered: + got++ + case <-deadline: + t.Fatalf("only %d/%d items delivered", got, N) + } + } +} + +// BenchmarkScheduler_Schedule reports allocations per Schedule call. +// In steady state the Scheduler's sync.Pool means we should see zero allocs per op once the pool warms up. +func BenchmarkScheduler_Schedule(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := NewScheduler[int](b.N) + go s.Run(ctx, func(int) {}) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Schedule(ctx, i, time.Microsecond) + } +} + +// BenchmarkBareAfterFunc is the comparison baseline. +// What we'd pay per Schedule if Punchy called time.AfterFunc directly without the pooled Scheduler. +// Allocates a *time.Timer plus a closure each call. +func BenchmarkBareAfterFunc(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + queue := make(chan int, b.N) + go func() { + for { + select { + case <-ctx.Done(): + return + case <-queue: + } + } + }() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + i := i + time.AfterFunc(time.Microsecond, func() { + select { + case queue <- i: + case <-ctx.Done(): + } + }) + } +} diff --git a/sshd/server.go b/sshd/server.go index 38886e53..ff954bf5 100644 --- a/sshd/server.go +++ b/sshd/server.go @@ -32,10 +32,12 @@ type SSHServer struct { cancel func() } -// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen -func NewSSHServer(l *slog.Logger) (*SSHServer, error) { +// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen. +// The ssh server's context is parented off the supplied ctx so cancelling it +// (e.g. on Control.Stop) tears down active sessions and closes the listener. +func NewSSHServer(ctx context.Context, l *slog.Logger) (*SSHServer, error) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(ctx) s := &SSHServer{ trustedKeys: make(map[string]map[string]bool), l: l, @@ -153,6 +155,10 @@ func (s *SSHServer) RegisterCommand(c *Command) { // Run begins listening and accepting connections func (s *SSHServer) Run(addr string) error { + if s.ctx.Err() != nil { + return s.ctx.Err() + } + var err error s.listener, err = net.Listen("tcp", addr) if err != nil { @@ -161,8 +167,21 @@ func (s *SSHServer) Run(addr string) error { s.l.Info("SSH server is listening", "sshListener", addr) + // Per-invocation watcher: cancellation of the parent context (e.g. + // Control.Stop) closes the listener so Accept unblocks and run returns. + // Closing `done` on exit keeps the watcher from outliving this Run call. + done := make(chan struct{}) + go func() { + select { + case <-s.ctx.Done(): + s.Stop() + case <-done: + } + }() + // Run loops until there is an error s.run() + close(done) s.closeSessions() s.l.Info("SSH server stopped listening") diff --git a/timeout.go b/timeout.go index c1b4c398..96bf688b 100644 --- a/timeout.go +++ b/timeout.go @@ -8,6 +8,23 @@ import ( // How many timer objects should be cached const timerCacheMax = 50000 +// TimerWheel is a hashed timing wheel: a fixed slot array indexed by (now + delay) % wheelLen, +// with each slot a singly linked list of items due in that bucket. +// Adds are O(1), Purges return items in arrival-within-slot order, and an internal cache of TimeoutItems +// keeps steady-state inserts allocation-free. +// +// The TimerWheel does not handle concurrency or lifecycle on its own. +// Callers drive Advance/Purge from their own ticker loop, take their own locks (or use LockingTimerWheel), +// and decide whether to keep ticking when the wheel is empty. +// +// Pick a TimerWheel when scheduling is high-rate and uniform: line-rate conntrack inserts, +// per-tunnel traffic checks at fixed intervals. O(1) insert plus the item cache means the hot path doesn't allocate. +// Items added in the same tick are dispatched together when that slot rotates current, +// which amortizes the cost of waking the worker. +// +// Pick a Scheduler when delay precision matters or scheduling is sparse or uneven. +// The wheel rounds requested timeouts up to its tick resolution and clamps anything beyond its wheel duration; +// both are silent in this implementation. type TimerWheel[T any] struct { // Current tick current int From a82a8dc547dca7e0f4e30c4d6f6adaaa124babbc Mon Sep 17 00:00:00 2001 From: Jack Doan Date: Wed, 6 May 2026 17:00:07 -0500 Subject: [PATCH 06/31] don't panic on bad ed25519 key lengths (#1601) * don't panic on bad ed25519 key lengths * don't allow mismatched curves * add test --- cert/ca_pool.go | 4 ++++ cert/ca_pool_test.go | 28 ++++++++++++++++++++++++++++ cert/cert_v1.go | 3 +++ cert/cert_v2.go | 3 +++ cert/errors.go | 1 + 5 files changed, 39 insertions(+) diff --git a/cert/ca_pool.go b/cert/ca_pool.go index 792f8e66..966f78e3 100644 --- a/cert/ca_pool.go +++ b/cert/ca_pool.go @@ -217,6 +217,10 @@ func (ncp *CAPool) verify(c Certificate, now time.Time, certFp string, signerFp return nil, err } + if signer.Certificate.Curve() != c.Curve() { + return nil, ErrCurveMismatch + } + if signer.Certificate.Expired(now) { return nil, ErrRootExpired } diff --git a/cert/ca_pool_test.go b/cert/ca_pool_test.go index ab173228..c246e770 100644 --- a/cert/ca_pool_test.go +++ b/cert/ca_pool_test.go @@ -654,3 +654,31 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) { _, err = caPool.VerifyCertificate(time.Now(), c) require.NoError(t, err) } + +func TestCertificateV2_CurveMismatch(t *testing.T) { + caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") + caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) + + caPem, err := ca.MarshalPEM() + require.NoError(t, err) + + caPool := NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + require.NoError(t, err) + assert.Empty(t, b) + + // ip is outside the network + cIp1 := mustParsePrefixUnmapped("10.0.0.1/24") + c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1}, nil, []string{"test"}) + + fp, _ := c.Fingerprint() + _, err = caPool.verify(c, time.Now(), fp, c.Issuer()) + require.NoError(t, err) + // + c2 := c.(*certificateV2) + c2.curve = Curve_CURVE25519 + fp, _ = c.Fingerprint() + _, err = caPool.verify(c, time.Now(), fp, c.Issuer()) + require.Error(t, err) +} diff --git a/cert/cert_v1.go b/cert/cert_v1.go index c32f409a..4df30032 100644 --- a/cert/cert_v1.go +++ b/cert/cert_v1.go @@ -112,6 +112,9 @@ func (c *certificateV1) CheckSignature(key []byte) bool { } switch c.details.curve { case Curve_CURVE25519: + if len(key) != ed25519.PublicKeySize { + return false //avoids a panic internal to ed25519 + } return ed25519.Verify(key, b, c.signature) case Curve_P256: pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key) diff --git a/cert/cert_v2.go b/cert/cert_v2.go index 4648c496..c2b43a69 100644 --- a/cert/cert_v2.go +++ b/cert/cert_v2.go @@ -151,6 +151,9 @@ func (c *certificateV2) CheckSignature(key []byte) bool { switch c.curve { case Curve_CURVE25519: + if len(key) != ed25519.PublicKeySize { + return false //avoids a panic internal to ed25519 + } return ed25519.Verify(key, b, c.signature) case Curve_P256: pubKey, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), key) diff --git a/cert/errors.go b/cert/errors.go index 8c480a14..596cfe19 100644 --- a/cert/errors.go +++ b/cert/errors.go @@ -22,6 +22,7 @@ var ( ErrCaNotFound = errors.New("could not find ca for the certificate") ErrUnknownVersion = errors.New("certificate version unrecognized") ErrCertPubkeyPresent = errors.New("certificate has unexpected pubkey present") + ErrCurveMismatch = errors.New("certificate curve does not match CA") ErrInvalidPEMBlock = errors.New("input did not contain a valid PEM encoded block") ErrInvalidPEMCertificateBanner = errors.New("bytes did not contain a proper certificate banner") From eaf756ea6c90d97790f29503fb5e687a251ca8fb Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 6 May 2026 17:31:48 -0500 Subject: [PATCH 07/31] Bump Apple-Actions/import-codesign-certs from 6 to 7 (#1697) --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index a5e8d397..b911bd52 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -75,7 +75,7 @@ jobs: - name: Import certificates if: env.HAS_SIGNING_CREDS == 'true' - uses: Apple-Actions/import-codesign-certs@v6 + uses: Apple-Actions/import-codesign-certs@v7 with: p12-file-base64: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_P12_BASE64 }} p12-password: ${{ secrets.APPLE_DEVELOPER_CERTIFICATE_PASSWORD }} From 76e82a5256f55f47107327cd6710536169483e16 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 6 May 2026 17:32:21 -0500 Subject: [PATCH 08/31] Bump golang.org/x/net (#1664) Bumps the golang-x-dependencies group with 1 update in the / directory: [golang.org/x/net](https://github.com/golang/net). Updates `golang.org/x/net` from 0.52.0 to 0.53.0 - [Commits](https://github.com/golang/net/compare/v0.52.0...v0.53.0) --- updated-dependencies: - dependency-name: golang.org/x/net dependency-version: 0.53.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 24d901c5..bfbc987f 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,7 @@ require ( go.yaml.in/yaml/v3 v3.0.4 golang.org/x/crypto v0.50.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 - golang.org/x/net v0.52.0 + golang.org/x/net v0.53.0 golang.org/x/sync v0.20.0 golang.org/x/sys v0.43.0 golang.org/x/term v0.42.0 diff --git a/go.sum b/go.sum index aad164c7..10116c5b 100644 --- a/go.sum +++ b/go.sum @@ -182,8 +182,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= -golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= +golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= From dd2ac5d6550a37745f9047d8a230482c1bc8ad18 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 6 May 2026 17:32:45 -0500 Subject: [PATCH 09/31] Bump docker/login-action from 3 to 4 (#1628) Bumps [docker/login-action](https://github.com/docker/login-action) from 3 to 4. - [Release notes](https://github.com/docker/login-action/releases) - [Commits](https://github.com/docker/login-action/compare/v3...v4) --- updated-dependencies: - dependency-name: docker/login-action dependency-version: '4' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index b911bd52..8d4b62bc 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -135,7 +135,7 @@ jobs: - name: Login to Docker Hub if: ${{ env.HAS_DOCKER_CREDS == 'true' }} - uses: docker/login-action@v3 + uses: docker/login-action@v4 with: username: ${{ vars.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} From dd3a7ad03c488860e39060a728959817f162112a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 6 May 2026 17:33:16 -0500 Subject: [PATCH 10/31] Bump docker/setup-buildx-action from 3 to 4 (#1627) Bumps [docker/setup-buildx-action](https://github.com/docker/setup-buildx-action) from 3 to 4. - [Release notes](https://github.com/docker/setup-buildx-action/releases) - [Commits](https://github.com/docker/setup-buildx-action/compare/v3...v4) --- updated-dependencies: - dependency-name: docker/setup-buildx-action dependency-version: '4' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8d4b62bc..e323a2ca 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -142,7 +142,7 @@ jobs: - name: Set up Docker Buildx if: ${{ env.HAS_DOCKER_CREDS == 'true' }} - uses: docker/setup-buildx-action@v3 + uses: docker/setup-buildx-action@v4 - name: Build and push images if: ${{ env.HAS_DOCKER_CREDS == 'true' }} From 23c67bd8d820d48f16892a94f77f747bb5b358c7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 6 May 2026 17:33:47 -0500 Subject: [PATCH 11/31] Bump actions/upload-artifact from 6 to 7 (#1618) Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 6 to 7. - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](https://github.com/actions/upload-artifact/compare/v6...v7) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-version: '7' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/release.yml | 6 +++--- .github/workflows/test.yml | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e323a2ca..e934d436 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -24,7 +24,7 @@ jobs: mv build/*.tar.gz release - name: Upload artifacts - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@v7 with: name: linux-latest path: release @@ -55,7 +55,7 @@ jobs: mv dist\windows\wintun build\dist\windows\ - name: Upload artifacts - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@v7 with: name: windows-latest path: build @@ -104,7 +104,7 @@ jobs: fi - name: Upload artifacts - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@v7 with: name: darwin-latest path: ./release/* diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index aeaea294..009c22a9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -45,7 +45,7 @@ jobs: - name: Build test mobile run: make build-test-mobile - - uses: actions/upload-artifact@v6 + - uses: actions/upload-artifact@v7 with: name: e2e packet flow linux-latest path: e2e/mermaid/linux-latest @@ -125,7 +125,7 @@ jobs: - name: End 2 end run: make e2evv - - uses: actions/upload-artifact@v6 + - uses: actions/upload-artifact@v7 with: name: e2e packet flow ${{ matrix.os }} path: e2e/mermaid/${{ matrix.os }} From 83809a599a1414b57e715fe241d7204487eb9a9f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 6 May 2026 17:34:06 -0500 Subject: [PATCH 12/31] Bump actions/download-artifact from 7 to 8 (#1617) Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 7 to 8. - [Release notes](https://github.com/actions/download-artifact/releases) - [Commits](https://github.com/actions/download-artifact/compare/v7...v8) --- updated-dependencies: - dependency-name: actions/download-artifact dependency-version: '8' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/release.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e934d436..356ae363 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -128,7 +128,7 @@ jobs: - name: Download artifacts if: ${{ env.HAS_DOCKER_CREDS == 'true' }} - uses: actions/download-artifact@v7 + uses: actions/download-artifact@v8 with: name: linux-latest path: artifacts @@ -163,7 +163,7 @@ jobs: - uses: actions/checkout@v6 - name: Download artifacts - uses: actions/download-artifact@v7 + uses: actions/download-artifact@v8 with: path: artifacts From cba9ea5b1fb10fd7a7a00ce6a2adb7cf2f14fbc2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 6 May 2026 17:36:07 -0500 Subject: [PATCH 13/31] Bump github.com/gaissmai/bart from 0.26.0 to 0.26.1 (#1604) Bumps [github.com/gaissmai/bart](https://github.com/gaissmai/bart) from 0.26.0 to 0.26.1. - [Release notes](https://github.com/gaissmai/bart/releases) - [Commits](https://github.com/gaissmai/bart/compare/v0.26.0...v0.26.1) --- updated-dependencies: - dependency-name: github.com/gaissmai/bart dependency-version: 0.26.1 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index bfbc987f..84728201 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/armon/go-radix v1.0.0 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/flynn/noise v1.1.0 - github.com/gaissmai/bart v0.26.0 + github.com/gaissmai/bart v0.26.1 github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.4 diff --git a/go.sum b/go.sum index 10116c5b..3b0b87df 100644 --- a/go.sum +++ b/go.sum @@ -26,8 +26,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= -github.com/gaissmai/bart v0.26.0 h1:xOZ57E9hJLBiQaSyeZa9wgWhGuzfGACgqp4BE77OkO0= -github.com/gaissmai/bart v0.26.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c= +github.com/gaissmai/bart v0.26.1 h1:+w4rnLGNlA2GDVn382Tfe3jOsK5vOr5n4KmigJ9lbTo= +github.com/gaissmai/bart v0.26.1/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= From 5f920fdd7d5af2510516ef3e6dbd9543de8019ae Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Wed, 6 May 2026 17:37:03 -0500 Subject: [PATCH 14/31] Remove the global noiseEndianness var (#1707) --- connection_state.go | 9 +- handshake/machine.go | 2 + noise.go | 73 --------------- noiseutil/aesgcm.go | 53 +++++++++++ noiseutil/chachapoly.go | 52 +++++++++++ noiseutil/cipher_state.go | 40 ++++++++ noiseutil/cipher_state_test.go | 166 +++++++++++++++++++++++++++++++++ pki.go | 8 +- 8 files changed, 321 insertions(+), 82 deletions(-) delete mode 100644 noise.go create mode 100644 noiseutil/aesgcm.go create mode 100644 noiseutil/chachapoly.go create mode 100644 noiseutil/cipher_state.go create mode 100644 noiseutil/cipher_state_test.go diff --git a/connection_state.go b/connection_state.go index 47e23b5a..0ae2d9be 100644 --- a/connection_state.go +++ b/connection_state.go @@ -7,13 +7,14 @@ import ( "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/handshake" + "github.com/slackhq/nebula/noiseutil" ) const ReplayWindow = 1024 type ConnectionState struct { - eKey *NebulaCipherState - dKey *NebulaCipherState + eKey noiseutil.CipherState + dKey noiseutil.CipherState myCert cert.Certificate peerCert *cert.CachedCertificate initiator bool @@ -31,8 +32,8 @@ func newConnectionStateFromResult(r *handshake.Result) *ConnectionState { myCert: r.MyCert, initiator: r.Initiator, peerCert: r.RemoteCert, - eKey: NewNebulaCipherState(r.EKey), - dKey: NewNebulaCipherState(r.DKey), + eKey: noiseutil.NewCipherState(r.EKey, r.Cipher), + dKey: noiseutil.NewCipherState(r.DKey, r.Cipher), window: NewBits(ReplayWindow), } ci.messageCounter.Add(r.MessageIndex) diff --git a/handshake/machine.go b/handshake/machine.go index 25ed3a5a..737358dc 100644 --- a/handshake/machine.go +++ b/handshake/machine.go @@ -31,6 +31,7 @@ type CertVerifier func(cert.Certificate) (*cert.CachedCertificate, error) type Result struct { EKey *noise.CipherState DKey *noise.CipherState + Cipher noise.CipherFunc // identifies which post-handshake CipherState the data plane should wrap EKey/DKey in MyCert cert.Certificate RemoteCert *cert.CachedCertificate RemoteIndex uint32 @@ -105,6 +106,7 @@ func NewMachine( myVersion: version, result: &Result{ Initiator: initiator, + Cipher: cred.cipherSuite, }, }, nil } diff --git a/noise.go b/noise.go deleted file mode 100644 index 0491da17..00000000 --- a/noise.go +++ /dev/null @@ -1,73 +0,0 @@ -package nebula - -import ( - "crypto/cipher" - "encoding/binary" - "errors" - - "github.com/flynn/noise" -) - -type endianness interface { - PutUint64(b []byte, v uint64) -} - -var noiseEndianness endianness = binary.BigEndian - -type NebulaCipherState struct { - c cipher.AEAD -} - -func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState { - x := s.Cipher() - return &NebulaCipherState{c: x.(cipher.AEAD)} -} - -// EncryptDanger encrypts and authenticates a given payload. -// -// out is a destination slice to hold the output of the EncryptDanger operation. -// - ad is additional data, which will be authenticated and appended to out, but not encrypted. -// - plaintext is encrypted, authenticated and appended to out. -// - n is a nonce value which must never be re-used with this key. -// - nb is a buffer used for temporary storage in the implementation of this call, which should -// be re-used by callers to minimize garbage collection. -func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) { - if s != nil { - // TODO: Is this okay now that we have made messageCounter atomic? - // Alternative may be to split the counter space into ranges - //if n <= s.n { - // return nil, errors.New("CRITICAL: a duplicate counter value was used") - //} - //s.n = n - nb[0] = 0 - nb[1] = 0 - nb[2] = 0 - nb[3] = 0 - noiseEndianness.PutUint64(nb[4:], n) - out = s.c.Seal(out, nb, plaintext, ad) - //l.Debugf("Encryption: outlen: %d, nonce: %d, ad: %s, plainlen %d", len(out), n, ad, len(plaintext)) - return out, nil - } else { - return nil, errors.New("no cipher state available to encrypt") - } -} - -func (s *NebulaCipherState) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) { - if s != nil { - nb[0] = 0 - nb[1] = 0 - nb[2] = 0 - nb[3] = 0 - noiseEndianness.PutUint64(nb[4:], n) - return s.c.Open(out, nb, ciphertext, ad) - } else { - return []byte{}, nil - } -} - -func (s *NebulaCipherState) Overhead() int { - if s != nil { - return s.c.Overhead() - } - return 0 -} diff --git a/noiseutil/aesgcm.go b/noiseutil/aesgcm.go new file mode 100644 index 00000000..dcbd5693 --- /dev/null +++ b/noiseutil/aesgcm.go @@ -0,0 +1,53 @@ +package noiseutil + +import ( + "crypto/cipher" + "encoding/binary" + "errors" + + "github.com/flynn/noise" +) + +// CipherStateAESGCM is the data-plane wrapper for the AES-GCM AEAD cipher. +// AES-GCM uses big-endian nonce encoding per the Noise spec. +type CipherStateAESGCM struct { + c cipher.AEAD +} + +// NewCipherStateAESGCM extracts the underlying AEAD from the post-handshake noise.CipherState. +// The caller is responsible for ensuring the noise cipher is actually AES-GCM, +// otherwise the type assertion still succeeds but the nonce endianness will be wrong on the wire. +func NewCipherStateAESGCM(s *noise.CipherState) *CipherStateAESGCM { + return &CipherStateAESGCM{c: s.Cipher().(cipher.AEAD)} +} + +func (s *CipherStateAESGCM) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) { + if s == nil { + return nil, errors.New("no cipher state available to encrypt") + } + nb[0] = 0 + nb[1] = 0 + nb[2] = 0 + nb[3] = 0 + binary.BigEndian.PutUint64(nb[4:], n) + return s.c.Seal(out, nb, plaintext, ad), nil +} + +func (s *CipherStateAESGCM) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) { + if s == nil { + return []byte{}, nil + } + nb[0] = 0 + nb[1] = 0 + nb[2] = 0 + nb[3] = 0 + binary.BigEndian.PutUint64(nb[4:], n) + return s.c.Open(out, nb, ciphertext, ad) +} + +func (s *CipherStateAESGCM) Overhead() int { + if s == nil { + return 0 + } + return s.c.Overhead() +} diff --git a/noiseutil/chachapoly.go b/noiseutil/chachapoly.go new file mode 100644 index 00000000..31ab3bfe --- /dev/null +++ b/noiseutil/chachapoly.go @@ -0,0 +1,52 @@ +package noiseutil + +import ( + "crypto/cipher" + "encoding/binary" + "errors" + + "github.com/flynn/noise" +) + +// CipherStateChaChaPoly is the data-plane wrapper for the ChaCha20-Poly1305 AEAD cipher. +// ChaCha20-Poly1305 uses little-endian nonce encoding per the Noise spec. +type CipherStateChaChaPoly struct { + c cipher.AEAD +} + +// NewCipherStateChaChaPoly extracts the underlying AEAD from the post-handshake noise.CipherState. +// The caller is responsible for ensuring the noise cipher is actually ChaCha20-Poly1305. +func NewCipherStateChaChaPoly(s *noise.CipherState) *CipherStateChaChaPoly { + return &CipherStateChaChaPoly{c: s.Cipher().(cipher.AEAD)} +} + +func (s *CipherStateChaChaPoly) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) { + if s == nil { + return nil, errors.New("no cipher state available to encrypt") + } + nb[0] = 0 + nb[1] = 0 + nb[2] = 0 + nb[3] = 0 + binary.LittleEndian.PutUint64(nb[4:], n) + return s.c.Seal(out, nb, plaintext, ad), nil +} + +func (s *CipherStateChaChaPoly) DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) { + if s == nil { + return []byte{}, nil + } + nb[0] = 0 + nb[1] = 0 + nb[2] = 0 + nb[3] = 0 + binary.LittleEndian.PutUint64(nb[4:], n) + return s.c.Open(out, nb, ciphertext, ad) +} + +func (s *CipherStateChaChaPoly) Overhead() int { + if s == nil { + return 0 + } + return s.c.Overhead() +} diff --git a/noiseutil/cipher_state.go b/noiseutil/cipher_state.go new file mode 100644 index 00000000..bb316385 --- /dev/null +++ b/noiseutil/cipher_state.go @@ -0,0 +1,40 @@ +package noiseutil + +import ( + "fmt" + + "github.com/flynn/noise" +) + +// CipherState is the post-handshake AEAD cipher used for the data plane. +// Each supported cipher has its own concrete implementation in this package with the nonce endianness hardcoded, +// so the encrypt/decrypt fast path avoids interface dispatch on the byte order. +type CipherState interface { + // EncryptDanger encrypts and authenticates a given payload. + // + // out is a destination slice to hold the output of the EncryptDanger operation. + // - ad is additional data, which will be authenticated and appended to out, but not encrypted. + // - plaintext is encrypted, authenticated and appended to out. + // - n is a nonce value which must never be re-used with this key. + // - nb is a scratch buffer used to assemble the nonce. + EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) + + // DecryptDanger authenticates and decrypts a given payload, with the same argument shape as EncryptDanger. + DecryptDanger(out, ad, ciphertext []byte, n uint64, nb []byte) ([]byte, error) + + // Overhead returns the AEAD tag size, or 0 if the receiver is nil. + Overhead() int +} + +// NewCipherState wraps the post-handshake noise.CipherState in the per-cipher type that matches cipherFunc. +// cipherFunc must be the same cipher used to build the noise CipherSuite that produced s. +func NewCipherState(s *noise.CipherState, cipherFunc noise.CipherFunc) CipherState { + switch cipherFunc.CipherName() { + case CipherAESGCM.CipherName(): + return NewCipherStateAESGCM(s) + case noise.CipherChaChaPoly.CipherName(): + return NewCipherStateChaChaPoly(s) + default: + panic(fmt.Sprintf("noiseutil: unsupported cipher %q", cipherFunc.CipherName())) + } +} diff --git a/noiseutil/cipher_state_test.go b/noiseutil/cipher_state_test.go new file mode 100644 index 00000000..a4df01e9 --- /dev/null +++ b/noiseutil/cipher_state_test.go @@ -0,0 +1,166 @@ +package noiseutil + +import ( + "testing" + + "github.com/flynn/noise" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCipherStateAESGCMRoundtrip(t *testing.T) { + enc, dec := buildCipherStates(t, CipherAESGCM) + roundtrip(t, NewCipherStateAESGCM(enc), NewCipherStateAESGCM(dec)) +} + +func TestCipherStateChaChaPolyRoundtrip(t *testing.T) { + enc, dec := buildCipherStates(t, noise.CipherChaChaPoly) + roundtrip(t, NewCipherStateChaChaPoly(enc), NewCipherStateChaChaPoly(dec)) +} + +func TestNewCipherStateDispatch(t *testing.T) { + encA, _ := buildCipherStates(t, CipherAESGCM) + encC, _ := buildCipherStates(t, noise.CipherChaChaPoly) + + assert.IsType(t, &CipherStateAESGCM{}, NewCipherState(encA, CipherAESGCM)) + assert.IsType(t, &CipherStateChaChaPoly{}, NewCipherState(encC, noise.CipherChaChaPoly)) +} + +func TestNewCipherStateUnsupportedPanics(t *testing.T) { + enc, _ := buildCipherStates(t, CipherAESGCM) + assert.Panics(t, func() { + NewCipherState(enc, fakeCipher{}) + }) +} + +type fakeCipher struct{} + +func (fakeCipher) Cipher(k [32]byte) noise.Cipher { return nil } +func (fakeCipher) CipherName() string { return "Fake" } + +// buildCipherStates runs an in-memory NN handshake with the requested cipher +// to produce a pair of post-handshake CipherStates that share keys. +func buildCipherStates(t *testing.T, c noise.CipherFunc) (*noise.CipherState, *noise.CipherState) { + t.Helper() + suite := noise.NewCipherSuite(noise.DH25519, c, noise.HashSHA256) + cfg := noise.Config{CipherSuite: suite, Pattern: noise.HandshakeNN} + cfg.Initiator = true + hsI, err := noise.NewHandshakeState(cfg) + require.NoError(t, err) + cfg.Initiator = false + hsR, err := noise.NewHandshakeState(cfg) + require.NoError(t, err) + + msg, _, _, err := hsI.WriteMessage(nil, nil) + require.NoError(t, err) + _, _, _, err = hsR.ReadMessage(nil, msg) + require.NoError(t, err) + + msg, dR, _, err := hsR.WriteMessage(nil, nil) + require.NoError(t, err) + _, eI, _, err := hsI.ReadMessage(nil, msg) + require.NoError(t, err) + require.NotNil(t, eI) + require.NotNil(t, dR) + + // noise returns (cs1, cs2) where cs1 is the initiator->responder cipher. + return eI, dR +} + +func roundtrip(t *testing.T, enc, dec CipherState) { + t.Helper() + plaintext := []byte("nebula cipher state roundtrip") + ad := []byte("aad") + nb := make([]byte, 12) + + ct, err := enc.EncryptDanger(nil, ad, plaintext, 1, nb) + require.NoError(t, err) + assert.NotEqual(t, plaintext, ct) + + pt, err := dec.DecryptDanger(nil, ad, ct, 1, nb) + require.NoError(t, err) + assert.Equal(t, plaintext, pt) + + // Wrong nonce must fail authentication. + _, err = dec.DecryptDanger(nil, ad, ct, 2, nb) + require.Error(t, err) + + assert.Equal(t, enc.Overhead(), dec.Overhead()) + assert.Equal(t, 16, enc.Overhead()) +} + +func BenchmarkCipherStateEncryptAESGCM(b *testing.B) { + enc, _ := buildCipherStatesB(b, CipherAESGCM) + benchEncryptCipherState(b, NewCipherState(enc, CipherAESGCM)) +} + +func BenchmarkCipherStateEncryptChaChaPoly(b *testing.B) { + enc, _ := buildCipherStatesB(b, noise.CipherChaChaPoly) + benchEncryptCipherState(b, NewCipherState(enc, noise.CipherChaChaPoly)) +} + +func benchEncryptCipherState(b *testing.B, cs CipherState) { + plaintext := make([]byte, 1280) + ad := make([]byte, 16) + nb := make([]byte, 12) + out := make([]byte, 0, len(plaintext)+cs.Overhead()) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + var err error + out, err = cs.EncryptDanger(out[:0], ad, plaintext, uint64(i+1), nb) + if err != nil { + b.Fatal(err) + } + } +} + +func buildCipherStatesB(b *testing.B, c noise.CipherFunc) (*noise.CipherState, *noise.CipherState) { + b.Helper() + suite := noise.NewCipherSuite(noise.DH25519, c, noise.HashSHA256) + cfg := noise.Config{CipherSuite: suite, Pattern: noise.HandshakeNN} + cfg.Initiator = true + hsI, err := noise.NewHandshakeState(cfg) + if err != nil { + b.Fatal(err) + } + cfg.Initiator = false + hsR, err := noise.NewHandshakeState(cfg) + if err != nil { + b.Fatal(err) + } + msg, _, _, err := hsI.WriteMessage(nil, nil) + if err != nil { + b.Fatal(err) + } + if _, _, _, err := hsR.ReadMessage(nil, msg); err != nil { + b.Fatal(err) + } + msg, dR, _, err := hsR.WriteMessage(nil, nil) + if err != nil { + b.Fatal(err) + } + _, eI, _, err := hsI.ReadMessage(nil, msg) + if err != nil { + b.Fatal(err) + } + return eI, dR +} + +func TestCipherStateNilSafety(t *testing.T) { + var aes *CipherStateAESGCM + _, err := aes.EncryptDanger(nil, nil, nil, 0, make([]byte, 12)) + require.Error(t, err) + out, err := aes.DecryptDanger(nil, nil, nil, 0, make([]byte, 12)) + require.NoError(t, err) + assert.Empty(t, out) + assert.Equal(t, 0, aes.Overhead()) + + var cc *CipherStateChaChaPoly + _, err = cc.EncryptDanger(nil, nil, nil, 0, make([]byte, 12)) + require.Error(t, err) + out, err = cc.DecryptDanger(nil, nil, nil, 0, make([]byte, 12)) + require.NoError(t, err) + assert.Empty(t, out) + assert.Equal(t, 0, cc.Overhead()) +} diff --git a/pki.go b/pki.go index acc80486..1bef5106 100644 --- a/pki.go +++ b/pki.go @@ -99,12 +99,10 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { var currentState *CertState if initial { cipher = c.GetString("cipher", "aes") - //TODO: this sucks and we should make it not a global switch cipher { - case "aes": - noiseEndianness = binary.BigEndian - case "chachapoly": - noiseEndianness = binary.LittleEndian + case "aes", "chachapoly": + // Each post-handshake CipherState in noiseutil hardcodes its own + // nonce endianness now, so there's nothing to set up here. default: return util.NewContextualError( "unknown cipher", From 1ada3d4dd98659a425fb0196b2e34eead36f9914 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Thu, 7 May 2026 10:30:29 -0500 Subject: [PATCH 15/31] Use DefinedNets fancy new netbsd10 vagrant box for smokes (#1711) --- .github/workflows/smoke-extra.yml | 48 ++++++++++++------- .../smoke/vagrant-netbsd-amd64/Vagrantfile | 2 +- 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/.github/workflows/smoke-extra.yml b/.github/workflows/smoke-extra.yml index 3734db75..cca7678b 100644 --- a/.github/workflows/smoke-extra.yml +++ b/.github/workflows/smoke-extra.yml @@ -14,10 +14,18 @@ on: - 'go.sum' jobs: - smoke-extra: + smoke-extra-libvirt: if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra') - name: Run extra smoke tests + name: ${{ matrix.target }} runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + target: + - freebsd-amd64 + - openbsd-amd64 + - netbsd-amd64 + - linux-amd64-ipv6disable env: VAGRANT_DEFAULT_PROVIDER: libvirt steps: @@ -40,28 +48,36 @@ jobs: sudo chmod 666 /var/run/libvirt/libvirt-sock vagrant plugin install vagrant-libvirt - - name: freebsd-amd64 - run: make smoke-vagrant/freebsd-amd64 + - name: ${{ matrix.target }} + run: make smoke-vagrant/${{ matrix.target }} - - name: openbsd-amd64 - run: make smoke-vagrant/openbsd-amd64 + timeout-minutes: 30 - - name: netbsd-amd64 - run: make smoke-vagrant/netbsd-amd64 + # linux-386 needs VirtualBox, which conflicts with KVM/libvirt -- isolated job. + smoke-extra-virtualbox: + if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra') + name: linux-386 + runs-on: ubuntu-latest + env: + VAGRANT_DEFAULT_PROVIDER: virtualbox + steps: - - name: linux-amd64-ipv6disable - run: make smoke-vagrant/linux-amd64-ipv6disable + - uses: actions/checkout@v6 - # linux-386 runs last because it requires disabling KVM to use VirtualBox, - # which prevents libvirt (used by the other tests) from working after this point. - - name: install virtualbox for i386 test + - uses: actions/setup-go@v6 + with: + go-version: '1.25' + check-latest: true + + - name: add hashicorp source + run: wget -O- https://apt.releases.hashicorp.com/gpg | gpg --dearmor | sudo tee /usr/share/keyrings/hashicorp-archive-keyring.gpg && echo "deb [signed-by=/usr/share/keyrings/hashicorp-archive-keyring.gpg] https://apt.releases.hashicorp.com $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/hashicorp.list + + - name: install vagrant and virtualbox run: | - sudo apt-get install -y virtualbox + sudo apt-get update && sudo apt-get install -y vagrant virtualbox sudo rmmod kvm_amd kvm_intel kvm 2>/dev/null || true - name: linux-386 - env: - VAGRANT_DEFAULT_PROVIDER: virtualbox run: make smoke-vagrant/linux-386 timeout-minutes: 30 diff --git a/.github/workflows/smoke/vagrant-netbsd-amd64/Vagrantfile b/.github/workflows/smoke/vagrant-netbsd-amd64/Vagrantfile index 14ba2ce1..a3fa7ec2 100644 --- a/.github/workflows/smoke/vagrant-netbsd-amd64/Vagrantfile +++ b/.github/workflows/smoke/vagrant-netbsd-amd64/Vagrantfile @@ -1,7 +1,7 @@ # -*- mode: ruby -*- # vi: set ft=ruby : Vagrant.configure("2") do |config| - config.vm.box = "generic/netbsd9" + config.vm.box = "DefinedNet/netbsd10" config.vm.synced_folder "../build", "/nebula", type: "rsync" end From c82db210ef7a31940412044b4cad0e372ea23658 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Thu, 7 May 2026 11:30:26 -0500 Subject: [PATCH 16/31] Change windows unsafe routes to link routes, fix sshd reload bug (#1709) --- e2e/sshd_test.go | 125 +++++++++++++++++++++++++++++++++++++++++ overlay/tun_windows.go | 16 ++++-- sshd/server.go | 56 +++++++++--------- 3 files changed, 162 insertions(+), 35 deletions(-) create mode 100644 e2e/sshd_test.go diff --git a/e2e/sshd_test.go b/e2e/sshd_test.go new file mode 100644 index 00000000..e91f1bd0 --- /dev/null +++ b/e2e/sshd_test.go @@ -0,0 +1,125 @@ +//go:build e2e_testing +// +build e2e_testing + +package e2e + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/pem" + "net" + "strings" + "testing" + "time" + + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" +) + +func TestSSHDLifecycle(t *testing.T) { + // TestSSHDLifecycle exercises the in-process sshd through several config reloads and a Control.Stop. + ca, _, caKey, _ := cert_test.NewTestCaCert( + cert.Version1, cert.Curve_CURVE25519, + time.Now(), time.Now().Add(10*time.Minute), + nil, nil, []string{}, + ) + + hostKeyPEM := generateSSHHostKey(t) + clientSigner, clientAuthKey := generateSSHClientKey(t) + sshdAddr := allocLoopbackPort(t) + + overrides := m{ + "sshd": m{ + "enabled": true, + "listen": sshdAddr, + "host_key": hostKeyPEM, + "authorized_users": []m{{ + "user": "tester", + "keys": []string{clientAuthKey}, + }}, + }, + } + control, _, _, _ := newSimpleServer(cert.Version1, ca, caKey, "sshd-test", "10.222.0.1/24", overrides) + control.Start() + t.Cleanup(func() { control.Stop() }) + + // sshd binds in a goroutine after Start returns; wait for it. + require.Eventually(t, func() bool { return canDial(sshdAddr) }, 2*time.Second, 25*time.Millisecond, + "sshd never started listening") + + for i := 1; i <= 3; i++ { + out := sshExecReload(t, sshdAddr, clientSigner) + assert.Contains(t, out, "Reloading config", "reload cycle %d", i) + require.Eventually(t, func() bool { return canDial(sshdAddr) }, 2*time.Second, 25*time.Millisecond, + "sshd not listening after reload cycle %d", i) + } + + control.Stop() + require.Eventually(t, func() bool { return !canDial(sshdAddr) }, 2*time.Second, 25*time.Millisecond, + "sshd still listening after Control.Stop") +} + +func canDial(addr string) bool { + c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond) + if err != nil { + return false + } + _ = c.Close() + return true +} + +// allocLoopbackPort grabs an unused TCP port on 127.0.0.1, closes it, and returns the address. There +// is a small race between releasing the port and the sshd reclaiming it; in practice the OS keeps the +// port available long enough for the test to bind it. +func allocLoopbackPort(t *testing.T) string { + t.Helper() + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + addr := l.Addr().String() + require.NoError(t, l.Close()) + return addr +} + +func generateSSHHostKey(t *testing.T) string { + t.Helper() + _, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + block, err := ssh.MarshalPrivateKey(priv, "nebula-e2e-host") + require.NoError(t, err) + return string(pem.EncodeToMemory(block)) +} + +func generateSSHClientKey(t *testing.T) (ssh.Signer, string) { + t.Helper() + _, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + signer, err := ssh.NewSignerFromKey(priv) + require.NoError(t, err) + auth := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(signer.PublicKey()))) + return signer, auth +} + +func sshExecReload(t *testing.T, addr string, signer ssh.Signer) string { + t.Helper() + cfg := &ssh.ClientConfig{ + User: "tester", + Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 2 * time.Second, + } + client, err := ssh.Dial("tcp", addr, cfg) + require.NoError(t, err) + defer client.Close() + + sess, err := client.NewSession() + require.NoError(t, err) + defer sess.Close() + + // reload tears the channel down before sending exit-status, so Output returns an error on the + // channel close. The output buffer still has whatever the reload callback wrote before that. + out, _ := sess.Output("reload") + return string(out) +} diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 680dddb3..14c8d499 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -156,11 +156,8 @@ func (t *winTun) addRoutes(logErrors bool) error { continue } - // Add our unsafe route - // Windows does not support multipath routes natively, so we install only a single route. - // This is not a problem as traffic will always be sent to Nebula which handles the multipath routing internally. - // In effect this provides multipath routing support to windows supporting loadbalancing and redundancy. - err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric)) + // Add our unsafe route as an on-link route to the nebula tun device. + err := luid.AddRoute(r.Cidr, unspecifiedNextHop(r.Cidr), uint32(r.Metric)) if err != nil { retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) if logErrors { @@ -206,7 +203,7 @@ func (t *winTun) removeRoutes(routes []Route) error { } // See comment on luid.AddRoute - err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr()) + err := luid.DeleteRoute(r.Cidr, unspecifiedNextHop(r.Cidr)) if err != nil { t.l.Error("Failed to remove route", "error", err, "route", r) } else { @@ -261,6 +258,13 @@ func (t *winTun) Close() error { return t.tun.Close() } +func unspecifiedNextHop(p netip.Prefix) netip.Addr { + if p.Addr().Is4() { + return netip.IPv4Unspecified() + } + return netip.IPv6Unspecified() +} + func generateGUIDByDeviceName(name string) (*windows.GUID, error) { // GUID is 128 bit hash := crypto.MD5.New() diff --git a/sshd/server.go b/sshd/server.go index ff954bf5..86c52961 100644 --- a/sshd/server.go +++ b/sshd/server.go @@ -27,23 +27,20 @@ type SSHServer struct { commands *radix.Tree listener net.Listener - // Call the cancel() function to stop all active sessions - ctx context.Context - cancel func() + // ctx parents per-Run contexts. Cancelling it (e.g. via Control.Stop) tears the server down even + // across reloads, since each Run derives a fresh child rather than reusing this one directly. + ctx context.Context } // NewSSHServer creates a new ssh server rigged with default commands and prepares to listen. // The ssh server's context is parented off the supplied ctx so cancelling it // (e.g. on Control.Stop) tears down active sessions and closes the listener. func NewSSHServer(ctx context.Context, l *slog.Logger) (*SSHServer, error) { - - ctx, cancel := context.WithCancel(ctx) s := &SSHServer{ trustedKeys: make(map[string]map[string]bool), l: l, commands: radix.New(), ctx: ctx, - cancel: cancel, } cc := ssh.CertChecker{ @@ -153,45 +150,51 @@ func (s *SSHServer) RegisterCommand(c *Command) { s.commands.Insert(c.Name, c) } -// Run begins listening and accepting connections +// Run begins listening and accepting connections. Each invocation derives a fresh per-Run context +// from the constructor-supplied ctx so a Stop+Run sequence (used by config reload) starts clean +// rather than carrying a permanently-cancelled context across runs. func (s *SSHServer) Run(addr string) error { if s.ctx.Err() != nil { return s.ctx.Err() } - var err error - s.listener, err = net.Listen("tcp", addr) + listener, err := net.Listen("tcp", addr) if err != nil { return err } + // s.listener is the public handle Stop uses to interrupt the active run; listener (the local) is what + // this run owns. They start equal but a fast reload may overwrite s.listener with the next run's + // listener before this run's watcher fires, so each run must close its own listener via the local + // reference. + s.listener = listener - s.l.Info("SSH server is listening", "sshListener", addr) + runCtx, cancel := context.WithCancel(s.ctx) + defer cancel() - // Per-invocation watcher: cancellation of the parent context (e.g. - // Control.Stop) closes the listener so Accept unblocks and run returns. - // Closing `done` on exit keeps the watcher from outliving this Run call. - done := make(chan struct{}) + // Close the listener when this run's context is cancelled. That can come from the parent + // (Control.Stop), from Run returning normally (defer cancel above), or transitively when a sibling + // run cancels through Stop closing the listener. net.Listener.Close is idempotent so a duplicate + // close from Stop is benign. go func() { - select { - case <-s.ctx.Done(): - s.Stop() - case <-done: + <-runCtx.Done() + if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + s.l.Warn("Failed to close the sshd listener", "error", err) } }() + s.l.Info("SSH server is listening", "sshListener", addr) + // Run loops until there is an error - s.run() - close(done) - s.closeSessions() + s.run(runCtx, listener) s.l.Info("SSH server stopped listening") // We don't return an error because run logs for us return nil } -func (s *SSHServer) run() { +func (s *SSHServer) run(ctx context.Context, listener net.Listener) { for { - c, err := s.listener.Accept() + c, err := listener.Accept() if err != nil { if !errors.Is(err, net.ErrClosed) { s.l.Warn("Error in listener, shutting down", "error", err) @@ -203,7 +206,7 @@ func (s *SSHServer) run() { // Ensure that a bad client doesn't hurt us by checking for the parent context // cancellation before calling NewServerConn, and forcing the socket to close when // the context is cancelled. - sessionContext, sessionCancel := context.WithCancel(s.ctx) + sessionContext, sessionCancel := context.WithCancel(ctx) go func() { <-sessionContext.Done() c.Close() @@ -246,14 +249,9 @@ func (s *SSHServer) run() { } func (s *SSHServer) Stop() { - // Close the listener, this will cause all session to terminate as well, see SSHServer.Run if s.listener != nil { if err := s.listener.Close(); err != nil { s.l.Warn("Failed to close the sshd listener", "error", err) } } } - -func (s *SSHServer) closeSessions() { - s.cancel() -} From 696903d6d91be3751a576779916cb7d5701140f2 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Thu, 7 May 2026 20:17:38 -0500 Subject: [PATCH 17/31] Add a way to set the network type on windows + tests (#1710) --- .github/workflows/smoke-extra.yml | 49 +++ .github/workflows/smoke/smoke-windows.ps1 | 272 ++++++++++++++++ examples/config.yml | 26 ++ overlay/network_category_windows.go | 358 ++++++++++++++++++++ overlay/network_category_windows_test.go | 109 +++++++ overlay/tun_bypass_windows.go | 23 ++ overlay/tun_bypass_windows_386.go | 11 + overlay/tun_windows.go | 54 +++- udp/udp_android.go | 3 +- udp/udp_bsd.go | 3 +- udp/udp_bypass_windows.go | 57 ++++ udp/udp_bypass_windows_386.go | 11 + udp/udp_netbsd.go | 3 +- udp/udp_windows.go | 13 +- wfp/wfp_windows.go | 377 ++++++++++++++++++++++ 15 files changed, 1349 insertions(+), 20 deletions(-) create mode 100644 .github/workflows/smoke/smoke-windows.ps1 create mode 100644 overlay/network_category_windows.go create mode 100644 overlay/network_category_windows_test.go create mode 100644 overlay/tun_bypass_windows.go create mode 100644 overlay/tun_bypass_windows_386.go create mode 100644 udp/udp_bypass_windows.go create mode 100644 udp/udp_bypass_windows_386.go create mode 100644 wfp/wfp_windows.go diff --git a/.github/workflows/smoke-extra.yml b/.github/workflows/smoke-extra.yml index cca7678b..e0428e9c 100644 --- a/.github/workflows/smoke-extra.yml +++ b/.github/workflows/smoke-extra.yml @@ -81,3 +81,52 @@ jobs: run: make smoke-vagrant/linux-386 timeout-minutes: 30 + + smoke-windows: + if: github.ref == 'refs/heads/master' || contains(github.event.pull_request.labels.*.name, 'smoke-test-extra') + name: Run windows smoke test + runs-on: windows-latest + steps: + + - uses: actions/checkout@v6 + + - uses: actions/setup-go@v6 + with: + go-version: '1.25' + check-latest: true + + # WSL2 + Ubuntu so the smoke can run a real linux peer with its own + # netns. iputils-ping is needed for the in-WSL ping check. WSL1 has no + # real kernel and would lack /dev/net/tun, so we have to force WSL2. + - uses: Vampire/setup-wsl@v3 + with: + distribution: Ubuntu-24.04 + additional-packages: iputils-ping iproute2 + + # Vampire/setup-wsl provisions WSL1 even when the WSL2 platform is present. + # Convert the distro to WSL2 explicitly before we try to use /dev/net/tun. + - name: convert distro to WSL2 + shell: pwsh + run: | + wsl --set-version Ubuntu-24.04 2 + wsl --shutdown + wsl --list --verbose + + - name: build windows nebula + run: make bin-windows + + - name: build linux nebula for WSL + shell: bash + env: + GOOS: linux + GOARCH: amd64 + run: | + mkdir -p build/linux-amd64 + go build -o build/linux-amd64/nebula ./cmd/nebula + + - name: run smoke-windows + shell: pwsh + working-directory: ./.github/workflows/smoke + run: ./smoke-windows.ps1 + + timeout-minutes: 15 diff --git a/.github/workflows/smoke/smoke-windows.ps1 b/.github/workflows/smoke/smoke-windows.ps1 new file mode 100644 index 00000000..0436598d --- /dev/null +++ b/.github/workflows/smoke/smoke-windows.ps1 @@ -0,0 +1,272 @@ +#!/usr/bin/env pwsh +# Windows smoke test for the nebula tun + UDP + NLM code paths. +# +# Topology: +# - lighthouse runs natively on the Windows host (wintun + windows UDP) +# - peer runs inside WSL2 (Linux build of nebula, /dev/net/tun) +# +# WSL2 gives us a real netns boundary so the loopback fast-path on Windows +# does not short-circuit the overlay -- when WSL pings the lighthouse VPN IP, +# Linux has no idea that IP is local to the Windows host, so the packet is +# forced through nebula. Same in reverse. + +$ErrorActionPreference = 'Stop' + +# wsl.exe emits UTF-16 LE by default which PowerShell reads as bytes, mangling +# every captured string. WSL_UTF8 makes wsl.exe emit UTF-8 instead. +$env:WSL_UTF8 = '1' + +$RepoRoot = Resolve-Path "$PSScriptRoot\..\..\.." +$Nebula = Join-Path $RepoRoot 'nebula.exe' +$NebulaCert = Join-Path $RepoRoot 'nebula-cert.exe' +$NebulaLinux = Join-Path $RepoRoot 'build\linux-amd64\nebula' + +if (-not (Test-Path $Nebula)) { throw "missing $Nebula; run 'make bin-windows' first" } +if (-not (Test-Path $NebulaCert)) { throw "missing $NebulaCert; run 'make bin-windows' first" } +if (-not (Test-Path $NebulaLinux)) { throw "missing $NebulaLinux; build the linux nebula first" } + +# Matches the distro installed by Vampire/setup-wsl in smoke-extra.yml. +$Distro = 'Ubuntu-24.04' +$listed = (wsl --list --quiet 2>$null) -join "`n" +if ($listed -notmatch [regex]::Escape($Distro)) { + throw "WSL distro $Distro not registered. Got: $listed" +} +Write-Host "Using WSL distro: $Distro" + +# Windows host as seen from inside WSL: WSL's default-route gateway. We extract +# it with a regex rather than awk fields so PowerShell does not eat any '$N' +# tokens, and tabs/double-spaces in `ip route` output do not confuse a cut. +$ipCmd = 'ip route show default | grep -oE "([0-9]+\.){3}[0-9]+" | head -1' +$WindowsIp = (wsl -d $Distro -- bash -c $ipCmd).Trim() +if (-not $WindowsIp) { throw "could not determine Windows host IP from WSL" } +Write-Host "Windows host IP from WSL: $WindowsIp" + +$WorkDir = Join-Path $env:TEMP 'nebula-smoke-windows' +if (Test-Path $WorkDir) { Remove-Item -Recurse -Force $WorkDir } +New-Item -ItemType Directory -Path $WorkDir | Out-Null + +$WslDir = '/tmp/nebula-smoke' +wsl -d $Distro -- bash -c "rm -rf $WslDir && mkdir -p $WslDir" | Out-Null + +$DevName = 'nebula-smoke' +$Ip1 = '192.168.241.1' +$Ip2 = '192.168.241.2' +$Port = 4242 + +& $NebulaCert ca -name 'smoke-ca' -out-crt "$WorkDir\ca.crt" -out-key "$WorkDir\ca.key" +if ($LASTEXITCODE -ne 0) { throw "nebula-cert ca failed (exit $LASTEXITCODE)" } + +& $NebulaCert sign -name 'lighthouse' -networks "$Ip1/24" -ca-crt "$WorkDir\ca.crt" -ca-key "$WorkDir\ca.key" -out-crt "$WorkDir\lighthouse.crt" -out-key "$WorkDir\lighthouse.key" +if ($LASTEXITCODE -ne 0) { throw "nebula-cert sign lighthouse failed (exit $LASTEXITCODE)" } + +& $NebulaCert sign -name 'peer' -networks "$Ip2/24" -ca-crt "$WorkDir\ca.crt" -ca-key "$WorkDir\ca.key" -out-crt "$WorkDir\peer.crt" -out-key "$WorkDir\peer.key" +if ($LASTEXITCODE -ne 0) { throw "nebula-cert sign peer failed (exit $LASTEXITCODE)" } + +# Windows lighthouse config. +@" +pki: + ca: $WorkDir\ca.crt + cert: $WorkDir\lighthouse.crt + key: $WorkDir\lighthouse.key +static_host_map: {} +lighthouse: + am_lighthouse: true + interval: 60 + hosts: [] +listen: + host: 0.0.0.0 + port: $Port +tun: + disabled: false + dev: $DevName + drop_local_broadcast: false + drop_multicast: false + tx_queue: 500 + mtu: 1300 + network_category: private +logging: + level: info + format: text +firewall: + outbound_action: drop + inbound_action: drop + conntrack: + tcp_timeout: 12m + udp_timeout: 3m + default_timeout: 10m + outbound: + - port: any + proto: any + host: any + inbound: + - port: any + proto: any + host: any +"@ | Out-File -FilePath "$WorkDir\lighthouse.yml" -Encoding utf8 + +# WSL peer config (paths are POSIX, deliberately). +@" +pki: + ca: $WslDir/ca.crt + cert: $WslDir/peer.crt + key: $WslDir/peer.key +static_host_map: + "${Ip1}": ["${WindowsIp}:$Port"] +lighthouse: + am_lighthouse: false + interval: 60 + hosts: + - "${Ip1}" +listen: + host: 0.0.0.0 + port: 0 +tun: + disabled: false + dev: nebula1 + drop_local_broadcast: false + drop_multicast: false + tx_queue: 500 + mtu: 1300 +logging: + level: info + format: text +firewall: + outbound_action: drop + inbound_action: drop + conntrack: + tcp_timeout: 12m + udp_timeout: 3m + default_timeout: 10m + outbound: + - port: any + proto: any + host: any + inbound: + - port: any + proto: any + host: any +"@ | Out-File -FilePath "$WorkDir\peer.yml" -Encoding utf8 + +# Stage WSL artifacts. Convert Windows paths to WSL paths ourselves rather than +# calling `wslpath`, because PowerShell's argument-passing to external EXEs +# strips backslashes from path arguments in ways that are hard to escape around. +function ConvertTo-WslPath { + param([string]$WindowsPath) + if ($WindowsPath -notmatch '^([A-Za-z]):\\(.*)$') { + throw "cannot convert path to WSL: $WindowsPath" + } + return "/mnt/$($matches[1].ToLower())/$($matches[2].Replace('\','/'))" +} + +$WslWorkDir = ConvertTo-WslPath $WorkDir +$WslNebulaPath = ConvertTo-WslPath $NebulaLinux +wsl -d $Distro -- bash -c "cp '$WslWorkDir/ca.crt' '$WslWorkDir/peer.crt' '$WslWorkDir/peer.key' '$WslWorkDir/peer.yml' $WslDir/ && cp '$WslNebulaPath' $WslDir/nebula && chmod +x $WslDir/nebula" + +# Make sure WSL has tun support and /dev/net/tun is usable before starting +# nebula. Diagnostics first so a fail here points at the real problem (e.g. +# WSL1 distros do not have a real kernel and will not have tun). +Write-Host '=== WSL diagnostic ===' +wsl --version 2>&1 | Out-Host +wsl --list --verbose 2>&1 | Out-Host +wsl -d $Distro -u root -- uname -a | Out-Host +wsl -d $Distro -u root -- bash -c "modprobe tun 2>&1 || true; mkdir -p /dev/net; [ -c /dev/net/tun ] || mknod /dev/net/tun c 10 200; chmod 600 /dev/net/tun; ls -l /dev/net/tun" +if ($LASTEXITCODE -ne 0) { throw "failed to prepare /dev/net/tun in WSL (TUN support missing?)" } + +# Deliberately no New-NetFirewallRule calls here -- nebula's windows_bypass_wdf +# feature is supposed to install WFP permit filters that let inbound traffic +# through Windows Defender Firewall on its own. If this smoke regresses, that +# feature regressed. + +$lhOut = Join-Path $WorkDir 'lighthouse.out.log' +$lhErr = Join-Path $WorkDir 'lighthouse.err.log' +$lhProc = Start-Process -FilePath $Nebula -ArgumentList @('-config', "$WorkDir\lighthouse.yml") ` + -PassThru -NoNewWindow ` + -RedirectStandardOutput $lhOut ` + -RedirectStandardError $lhErr + +# Run nebula in WSL as root with no sudo + no shell wrapper. PowerShell's +# Start-Process arg quoting mangles `bash -c "..."` strings that contain +# spaces/redirections, so we skip bash entirely and let Start-Process do the +# stdout/stderr capture itself. +$peerOut = Join-Path $WorkDir 'peer.out.log' +$peerErr = Join-Path $WorkDir 'peer.err.log' +$peerProc = Start-Process -FilePath 'wsl' ` + -ArgumentList @('-d', $Distro, '-u', 'root', '--', "$WslDir/nebula", '-config', "$WslDir/peer.yml") ` + -PassThru -NoNewWindow ` + -RedirectStandardOutput $peerOut ` + -RedirectStandardError $peerErr + +function Wait-Until { + param([scriptblock]$Predicate, [int]$TimeoutSec, [string]$What) + $deadline = (Get-Date).AddSeconds($TimeoutSec) + while ((Get-Date) -lt $deadline) { + if (& $Predicate) { return } + Start-Sleep -Milliseconds 500 + } + throw "timed out waiting for: $What" +} + +try { + Wait-Until -TimeoutSec 30 -What "windows wintun adapter $DevName with NetworkCategory=Private" -Predicate { + if ($lhProc.HasExited) { throw "lighthouse exited (code $($lhProc.ExitCode)) before tun was ready" } + $p = Get-NetConnectionProfile -InterfaceAlias $DevName -ErrorAction SilentlyContinue + $p -and ("$($p.NetworkCategory)" -ieq 'Private') + } + Write-Host "OK: $DevName NetworkCategory=Private" + + Wait-Until -TimeoutSec 30 -What "WSL nebula1 with $Ip2" -Predicate { + if ($peerProc.HasExited) { throw "peer exited (code $($peerProc.ExitCode)) before tun was ready" } + $r = wsl -d $Distro -u root -- bash -c "ip -o addr show nebula1 2>/dev/null | grep -q 'inet $Ip2' && echo yes" + ("$r").Trim() -eq 'yes' + } + Write-Host "OK: WSL nebula1 has $Ip2" + + Wait-Until -TimeoutSec 30 -What "ping from WSL peer to windows lighthouse ($Ip1)" -Predicate { + if ($peerProc.HasExited) { throw "peer exited (code $($peerProc.ExitCode)) before ping succeeded" } + $r = wsl -d $Distro -u root -- bash -c "ping -c1 -W1 $Ip1 >/dev/null 2>&1 && echo OK" + ("$r").Trim() -eq 'OK' + } + Write-Host "OK: WSL peer -> windows lighthouse" + + Wait-Until -TimeoutSec 30 -What "ping from windows lighthouse to WSL peer ($Ip2)" -Predicate { + $null = & ping.exe -n 1 -w 1000 $Ip2 + $LASTEXITCODE -eq 0 + } + Write-Host "OK: windows lighthouse -> WSL peer" + + Write-Host '' + Write-Host 'All smoke checks passed.' +} +catch { + Write-Host '' + Write-Host '=== lighthouse stdout ===' + Get-Content $lhOut -ErrorAction SilentlyContinue | Out-Host + Write-Host '=== lighthouse stderr ===' + Get-Content $lhErr -ErrorAction SilentlyContinue | Out-Host + Write-Host '=== peer stdout ===' + Get-Content $peerOut -ErrorAction SilentlyContinue | Out-Host + Write-Host '=== peer stderr ===' + Get-Content $peerErr -ErrorAction SilentlyContinue | Out-Host + Write-Host '=== nebula WFP filters ===' + # Dump nebula-installed filters so we can verify they got registered with + # the conditions we expect. + $wfpDump = Join-Path $WorkDir 'wfp.xml' + netsh wfp show filters file=$wfpDump 2>&1 | Out-Null + if (Test-Path $wfpDump) { + Select-String -Path $wfpDump -Pattern 'Nebula' -Context 0,80 -ErrorAction SilentlyContinue | Out-Host + } + throw +} +finally { + if (-not $lhProc.HasExited) { + Stop-Process -Id $lhProc.Id -Force -ErrorAction SilentlyContinue + $lhProc.WaitForExit(5000) | Out-Null + } + wsl -d $Distro -u root -- bash -c "pkill -f $WslDir/nebula 2>/dev/null; true" | Out-Null + # pkill returns 1 when no match and wsl propagates that; the smoke is done + # so we don't want it to leak into the script's exit code. + $global:LASTEXITCODE = 0 + if ($peerProc -and -not $peerProc.HasExited) { + Stop-Process -Id $peerProc.Id -Force -ErrorAction SilentlyContinue + } +} diff --git a/examples/config.yml b/examples/config.yml index ac4810e6..6c7fb489 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -138,6 +138,14 @@ listen: # max, net.core.rmem_max and net.core.wmem_max #read_buffer: 10485760 #write_buffer: 10485760 + + # On Windows only + # When true, Nebula installs a WFP (Windows Filtering Platform) PERMIT filter scoped to UDP at the listener port. + # WFP sits below Windows Defender Firewall, so this lets peer handshakes reach Nebula's outside socket regardless + # of WDF's inbound rules. + # Default true; set to false to leave WDF in charge of inbound decisions on the listener port. Not reloadable. + #windows_bypass_wdf: true + # By default, Nebula replies to packets it has no tunnel for with a "recv_error" packet. This packet helps speed up reconnection # in the case that Nebula on either side did not shut down cleanly. This response can be abused as a way to discover if Nebula is running # on a host though. This option lets you configure if you want to send "recv_error" packets always, never, or only to private network remotes. @@ -286,6 +294,24 @@ tun: # metric: 100 # install: true + # On Windows only, sets the network category of the nebula interface. Without this, Windows often + # leaves the network as "Unidentified" and treats it as Public, which makes the host firewall more + # restrictive than you usually want for an overlay between trusted peers. Valid values: + # private - treat the nebula network as a private/trusted network (default) + # public - treat it as a public/untrusted network + # domain - treat it as a domain-authenticated network + # unset - leave whatever Windows decided alone + # Not reloadable. + #network_category: private + + # On Windows only + # When true, Nebula installs a WFP (Windows Filtering Platform) PERMIT filter scoped to the nebula adapter LUID. + # WFP sits below Windows Defender Firewall, so this lets inbound traffic through regardless of WDF rules. + # Filters are auto-removed when the adapter goes away. + # See listen.windows_bypass_wdf for the matching control over inbound to nebula's outside UDP listener. + # Default true; set to false to leave WDF in charge of inbound decisions on the nebula interface. Not reloadable. + #windows_bypass_wdf: true + # On linux only, set to true to manage unsafe routes directly on the system route table with gateway routes instead of # in nebula configuration files. Default false, not reloadable. #use_system_route_table: false diff --git a/overlay/network_category_windows.go b/overlay/network_category_windows.go new file mode 100644 index 00000000..cbf87f00 --- /dev/null +++ b/overlay/network_category_windows.go @@ -0,0 +1,358 @@ +//go:build !e2e_testing +// +build !e2e_testing + +package overlay + +import ( + "errors" + "fmt" + "log/slog" + "runtime" + "strings" + "syscall" + "time" + "unsafe" + + "golang.org/x/sys/windows" +) + +// networkCategory mirrors NLM_NETWORK_CATEGORY from netlistmgr.h. +type networkCategory int32 + +const ( + networkCategoryPublic networkCategory = 0 + networkCategoryPrivate networkCategory = 1 + networkCategoryDomainAuthenticated networkCategory = 2 +) + +func (c networkCategory) String() string { + switch c { + case networkCategoryPublic: + return "public" + case networkCategoryPrivate: + return "private" + case networkCategoryDomainAuthenticated: + return "domain" + } + return fmt.Sprintf("unknown(%d)", c) +} + +// parseNetworkCategory accepts the user-supplied tun.network_category. A +// second return of false means "leave the category alone". +func parseNetworkCategory(s string) (networkCategory, bool, error) { + switch strings.ToLower(strings.TrimSpace(s)) { + case "", "unset": + return 0, false, nil + case "public": + return networkCategoryPublic, true, nil + case "private": + return networkCategoryPrivate, true, nil + case "domain", "domainauthenticated": + return networkCategoryDomainAuthenticated, true, nil + } + return 0, false, fmt.Errorf("unknown tun.network_category %q (expected public, private, domain, or unset)", s) +} + +// CLSID_NetworkListManager {DCB00C01-570F-4A9B-8D69-199FDBA5723B} +var clsidNetworkListManager = windows.GUID{ + Data1: 0xDCB00C01, Data2: 0x570F, Data3: 0x4A9B, + Data4: [8]byte{0x8D, 0x69, 0x19, 0x9F, 0xDB, 0xA5, 0x72, 0x3B}, +} + +// IID_INetworkListManager {DCB00000-570F-4A9B-8D69-199FDBA5723B} +var iidINetworkListManager = windows.GUID{ + Data1: 0xDCB00000, Data2: 0x570F, Data3: 0x4A9B, + Data4: [8]byte{0x8D, 0x69, 0x19, 0x9F, 0xDB, 0xA5, 0x72, 0x3B}, +} + +// x/sys/windows doesn't expose CoCreateInstance, so we bind it ourselves. +var procCoCreateInstance = windows.NewLazySystemDLL("ole32.dll").NewProc("CoCreateInstance") + +const clsCtxAll = windows.CLSCTX_INPROC_SERVER | windows.CLSCTX_INPROC_HANDLER | + windows.CLSCTX_LOCAL_SERVER | windows.CLSCTX_REMOTE_SERVER + +const ( + hrSFALSE = 0x00000001 + hrRPCEChangedMode = 0x80010106 +) + +type hresult uint32 + +func (h hresult) failed() bool { return int32(h) < 0 } +func (h hresult) String() string { + return fmt.Sprintf("HRESULT 0x%08x", uint32(h)) +} + +var errAdapterNotFound = errors.New("adapter not present in network connections enumeration") + +// Vtable layouts. Slot order must match the declaration order in netlistmgr.h. +// All NLM interfaces here derive from IDispatch, which derives from IUnknown. + +type iUnknownVtbl struct { + QueryInterface uintptr + AddRef uintptr + Release uintptr +} + +type iDispatchVtbl struct { + iUnknownVtbl + GetTypeInfoCount uintptr + GetTypeInfo uintptr + GetIDsOfNames uintptr + Invoke uintptr +} + +type iNetworkListManagerVtbl struct { + iDispatchVtbl + GetNetworks uintptr + GetNetwork uintptr + GetNetworkConnections uintptr + GetNetworkConnection uintptr + IsConnectedToInternet uintptr + IsConnected uintptr + GetConnectivity uintptr +} + +type iNetworkListManager struct{ Vtbl *iNetworkListManagerVtbl } + +func (n *iNetworkListManager) Release() { + syscall.SyscallN(n.Vtbl.Release, uintptr(unsafe.Pointer(n))) +} + +func (n *iNetworkListManager) GetNetworkConnections() (*iEnumNetworkConnections, error) { + var enum *iEnumNetworkConnections + r1, _, _ := syscall.SyscallN(n.Vtbl.GetNetworkConnections, + uintptr(unsafe.Pointer(n)), uintptr(unsafe.Pointer(&enum)), + ) + if hr := hresult(r1); hr.failed() { + return nil, fmt.Errorf("INetworkListManager.GetNetworkConnections: %s", hr) + } + return enum, nil +} + +type iEnumNetworkConnectionsVtbl struct { + iDispatchVtbl + NewEnum uintptr + Next uintptr + Skip uintptr + Reset uintptr + Clone uintptr +} + +type iEnumNetworkConnections struct{ Vtbl *iEnumNetworkConnectionsVtbl } + +func (e *iEnumNetworkConnections) Release() { + syscall.SyscallN(e.Vtbl.Release, uintptr(unsafe.Pointer(e))) +} + +// Next returns the next connection, or (nil, nil) at the end of the enumeration. +func (e *iEnumNetworkConnections) Next() (*iNetworkConnection, error) { + var conn *iNetworkConnection + var fetched uint32 + r1, _, _ := syscall.SyscallN(e.Vtbl.Next, + uintptr(unsafe.Pointer(e)), 1, + uintptr(unsafe.Pointer(&conn)), uintptr(unsafe.Pointer(&fetched)), + ) + if hr := hresult(r1); hr.failed() { + return nil, fmt.Errorf("IEnumNetworkConnections.Next: %s", hr) + } + if fetched == 0 { + return nil, nil + } + return conn, nil +} + +type iNetworkConnectionVtbl struct { + iDispatchVtbl + GetNetwork uintptr + IsConnectedToInternet uintptr + IsConnected uintptr + GetConnectivity uintptr + GetConnectionId uintptr + GetAdapterId uintptr + GetDomainType uintptr +} + +type iNetworkConnection struct{ Vtbl *iNetworkConnectionVtbl } + +func (c *iNetworkConnection) Release() { + syscall.SyscallN(c.Vtbl.Release, uintptr(unsafe.Pointer(c))) +} + +func (c *iNetworkConnection) GetAdapterId() (windows.GUID, error) { + var g windows.GUID + r1, _, _ := syscall.SyscallN(c.Vtbl.GetAdapterId, + uintptr(unsafe.Pointer(c)), uintptr(unsafe.Pointer(&g)), + ) + if hr := hresult(r1); hr.failed() { + return windows.GUID{}, fmt.Errorf("INetworkConnection.GetAdapterId: %s", hr) + } + return g, nil +} + +func (c *iNetworkConnection) GetNetwork() (*iNetwork, error) { + var net *iNetwork + r1, _, _ := syscall.SyscallN(c.Vtbl.GetNetwork, + uintptr(unsafe.Pointer(c)), uintptr(unsafe.Pointer(&net)), + ) + if hr := hresult(r1); hr.failed() { + return nil, fmt.Errorf("INetworkConnection.GetNetwork: %s", hr) + } + return net, nil +} + +type iNetworkVtbl struct { + iDispatchVtbl + GetName uintptr + SetName uintptr + GetDescription uintptr + SetDescription uintptr + GetNetworkId uintptr + GetDomainType uintptr + GetNetworkConnections uintptr + GetTimeCreatedAndConnected uintptr + IsConnectedToInternet uintptr + IsConnected uintptr + GetConnectivity uintptr + GetCategory uintptr + SetCategory uintptr +} + +type iNetwork struct{ Vtbl *iNetworkVtbl } + +func (n *iNetwork) Release() { + syscall.SyscallN(n.Vtbl.Release, uintptr(unsafe.Pointer(n))) +} + +func (n *iNetwork) GetCategory() (networkCategory, error) { + var c networkCategory + r1, _, _ := syscall.SyscallN(n.Vtbl.GetCategory, + uintptr(unsafe.Pointer(n)), uintptr(unsafe.Pointer(&c)), + ) + if hr := hresult(r1); hr.failed() { + return 0, fmt.Errorf("INetwork.GetCategory: %s", hr) + } + return c, nil +} + +func (n *iNetwork) SetCategory(c networkCategory) error { + r1, _, _ := syscall.SyscallN(n.Vtbl.SetCategory, + uintptr(unsafe.Pointer(n)), uintptr(int32(c)), + ) + if hr := hresult(r1); hr.failed() { + return fmt.Errorf("INetwork.SetCategory: %s", hr) + } + return nil +} + +// coInit initializes COM for the current OS thread. The returned function must +// be deferred to balance a successful init. RPC_E_CHANGED_MODE means COM is +// already initialized in a different mode on this thread, which is still fine +// for our calls but we must not Uninitialize in that case. +func coInit() (func(), error) { + err := windows.CoInitializeEx(0, windows.COINIT_MULTITHREADED) + if err == nil { + return windows.CoUninitialize, nil + } + if e, ok := err.(syscall.Errno); ok { + switch uint32(e) { + case hrSFALSE: + return windows.CoUninitialize, nil + case hrRPCEChangedMode: + return func() {}, nil + } + } + return nil, fmt.Errorf("CoInitializeEx: %w", err) +} + +func createNetworkListManager() (*iNetworkListManager, error) { + var nlm *iNetworkListManager + r1, _, _ := procCoCreateInstance.Call( + uintptr(unsafe.Pointer(&clsidNetworkListManager)), + 0, + uintptr(clsCtxAll), + uintptr(unsafe.Pointer(&iidINetworkListManager)), + uintptr(unsafe.Pointer(&nlm)), + ) + if hr := hresult(r1); hr.failed() { + return nil, fmt.Errorf("CoCreateInstance(NetworkListManager): %s", hr) + } + return nlm, nil +} + +// setNetworkCategory locates the network connection bound to adapterGUID and +// sets the category of its parent network. Returns errAdapterNotFound if the +// adapter is not yet visible in the NLM enumeration. +func setNetworkCategory(adapterGUID windows.GUID, cat networkCategory) error { + deinit, err := coInit() + if err != nil { + return err + } + defer deinit() + + nlm, err := createNetworkListManager() + if err != nil { + return err + } + defer nlm.Release() + + enum, err := nlm.GetNetworkConnections() + if err != nil { + return err + } + defer enum.Release() + + for { + conn, err := enum.Next() + if err != nil { + return err + } + if conn == nil { + return errAdapterNotFound + } + + guid, err := conn.GetAdapterId() + if err != nil || guid != adapterGUID { + conn.Release() + continue + } + + net, err := conn.GetNetwork() + conn.Release() + if err != nil { + return err + } + err = net.SetCategory(cat) + net.Release() + return err + } +} + +// applyNetworkCategory polls until the wintun adapter shows up in the NLM +// enumeration, then sets the category. Intended to run in its own goroutine. +func applyNetworkCategory(l *slog.Logger, adapterGUID windows.GUID, cat networkCategory) { + // COM Init/Uninit must be paired on the same OS thread. + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + const ( + attempts = 30 + interval = 500 * time.Millisecond + ) + for i := 0; i < attempts; i++ { + err := setNetworkCategory(adapterGUID, cat) + if err == nil { + l.Info("Set Windows network category", "category", cat.String()) + return + } + if !errors.Is(err, errAdapterNotFound) { + l.Warn("Failed to set Windows network category", "error", err, "category", cat.String()) + return + } + time.Sleep(interval) + } + l.Warn("Gave up waiting for adapter to appear in NLM enumeration; network category not set", + "category", cat.String(), + "waited", time.Duration(attempts)*interval, + ) +} diff --git a/overlay/network_category_windows_test.go b/overlay/network_category_windows_test.go new file mode 100644 index 00000000..c679f8c4 --- /dev/null +++ b/overlay/network_category_windows_test.go @@ -0,0 +1,109 @@ +//go:build !e2e_testing +// +build !e2e_testing + +package overlay + +import ( + "testing" +) + +func Test_parseNetworkCategory(t *testing.T) { + cases := []struct { + in string + wantCat networkCategory + wantApply bool + wantErr bool + }{ + {"", 0, false, false}, + {"unset", 0, false, false}, + {" UNSET ", 0, false, false}, + {"private", networkCategoryPrivate, true, false}, + {"Private", networkCategoryPrivate, true, false}, + {" PRIVATE ", networkCategoryPrivate, true, false}, + {"public", networkCategoryPublic, true, false}, + {"PUBLIC", networkCategoryPublic, true, false}, + {"domain", networkCategoryDomainAuthenticated, true, false}, + {"DomainAuthenticated", networkCategoryDomainAuthenticated, true, false}, + {"garbage", 0, false, true}, + {"privates", 0, false, true}, + } + for _, tc := range cases { + cat, apply, err := parseNetworkCategory(tc.in) + if (err != nil) != tc.wantErr { + t.Errorf("parseNetworkCategory(%q) err=%v, wantErr=%v", tc.in, err, tc.wantErr) + continue + } + if cat != tc.wantCat || apply != tc.wantApply { + t.Errorf("parseNetworkCategory(%q) = (%v, %v), want (%v, %v)", tc.in, cat, apply, tc.wantCat, tc.wantApply) + } + } +} + +// Test_NLM_round_trip exercises every COM call path used by setNetworkCategory +// without mutating the host's network state. It validates the CLSID/IID +// constants and every vtable index by enumerating connections, fetching the +// adapter id and parent network, reading the current category, and writing it +// back unchanged. +// +// Requires Windows but does not require admin or the wintun driver. Skips if +// no network connections are available (unlikely outside of an isolated +// container). +func Test_NLM_round_trip(t *testing.T) { + deinit, err := coInit() + if err != nil { + t.Fatalf("coInit: %v", err) + } + defer deinit() + + nlm, err := createNetworkListManager() + if err != nil { + t.Fatalf("createNetworkListManager: %v", err) + } + defer nlm.Release() + + enum, err := nlm.GetNetworkConnections() + if err != nil { + t.Fatalf("GetNetworkConnections: %v", err) + } + defer enum.Release() + + saw := 0 + for { + conn, err := enum.Next() + if err != nil { + t.Fatalf("EnumNetworkConnections.Next: %v", err) + } + if conn == nil { + break + } + saw++ + + if _, err := conn.GetAdapterId(); err != nil { + conn.Release() + t.Fatalf("INetworkConnection.GetAdapterId: %v", err) + } + + net, err := conn.GetNetwork() + conn.Release() + if err != nil { + t.Fatalf("INetworkConnection.GetNetwork: %v", err) + } + + cat, err := net.GetCategory() + if err != nil { + net.Release() + t.Fatalf("INetwork.GetCategory: %v", err) + } + // Set to the current value so the host's NLM state is unchanged but + // SetCategory's vtable slot is still validated end-to-end. + if err := net.SetCategory(cat); err != nil { + net.Release() + t.Fatalf("INetwork.SetCategory(%v): %v", cat, err) + } + net.Release() + } + + if saw == 0 { + t.Skip("no NLM network connections available; skipping round-trip") + } +} diff --git a/overlay/tun_bypass_windows.go b/overlay/tun_bypass_windows.go new file mode 100644 index 00000000..1f62373c --- /dev/null +++ b/overlay/tun_bypass_windows.go @@ -0,0 +1,23 @@ +//go:build (amd64 || arm64) && !e2e_testing +// +build amd64 arm64 +// +build !e2e_testing + +package overlay + +import ( + "log/slog" + + "github.com/slackhq/nebula/wfp" +) + +// installInterfaceBypass installs a WFP PERMIT filter scoped to the wintun interface LUID so inbound traffic on the +// nebula adapter bypasses Windows Defender Firewall. +func installInterfaceBypass(l *slog.Logger, luid uint64) closer { + s, err := wfp.PermitInterface(luid) + if err != nil { + l.Warn("Failed to install WFP bypass filters on nebula interface", "error", err) + return nil + } + l.Info("Installed WFP filters bypassing Windows Defender Firewall on nebula interface") + return s +} diff --git a/overlay/tun_bypass_windows_386.go b/overlay/tun_bypass_windows_386.go new file mode 100644 index 00000000..366430b0 --- /dev/null +++ b/overlay/tun_bypass_windows_386.go @@ -0,0 +1,11 @@ +//go:build !e2e_testing +// +build !e2e_testing + +package overlay + +import "log/slog" + +// installInterfaceBypass is a no-op on windows-386 because we don't currently build for it. +func installInterfaceBypass(_ *slog.Logger, _ uint64) closer { + return nil +} diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 14c8d499..cf01615f 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -25,15 +25,24 @@ import ( "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) +type closer interface { + Close() +} + const tunGUIDLabel = "Fixed Nebula Windows GUID v1" type winTun struct { - Device string - vpnNetworks []netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[routing.Gateways]] - l *slog.Logger + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[routing.Gateways]] + guid windows.GUID + networkCategory networkCategory + setCategory bool + bypassWDF bool + wdfBypass closer + l *slog.Logger tun *wintun.NativeTun } @@ -54,11 +63,20 @@ func newTun(c *config.C, l *slog.Logger, vpnNetworks []netip.Prefix, _ bool) (*w return nil, fmt.Errorf("generate GUID failed: %w", err) } + cat, setCat, err := parseNetworkCategory(c.GetString("tun.network_category", "private")) + if err != nil { + return nil, err + } + t := &winTun{ - Device: deviceName, - vpnNetworks: vpnNetworks, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, + Device: deviceName, + vpnNetworks: vpnNetworks, + MTU: c.GetInt("tun.mtu", DefaultMTU), + guid: *guid, + networkCategory: cat, + setCategory: setCat, + bypassWDF: c.GetBool("tun.windows_bypass_wdf", true), + l: l, } err = t.reload(c, true) @@ -142,6 +160,17 @@ func (t *winTun) Activate() error { return err } + if t.setCategory { + // The wintun adapter takes a moment to register with the Network List + // Manager, so we apply the category in the background and retry until + // it shows up. + go applyNetworkCategory(t.l, t.guid, t.networkCategory) + } + + if t.bypassWDF { + t.wdfBypass = installInterfaceBypass(t.l, uint64(t.tun.LUID())) + } + return nil } @@ -255,6 +284,11 @@ func (t *winTun) Close() error { _ = luid.FlushDNS(windows.AF_INET) _ = luid.FlushDNS(windows.AF_INET6) + if t.wdfBypass != nil { + t.wdfBypass.Close() + t.wdfBypass = nil + } + return t.tun.Close() } diff --git a/udp/udp_android.go b/udp/udp_android.go index 3fc68003..213ab422 100644 --- a/udp/udp_android.go +++ b/udp/udp_android.go @@ -5,12 +5,11 @@ package udp import ( "fmt" + "log/slog" "net" "net/netip" "syscall" - "log/slog" - "golang.org/x/sys/unix" ) diff --git a/udp/udp_bsd.go b/udp/udp_bsd.go index c42a3c18..31ae9c5a 100644 --- a/udp/udp_bsd.go +++ b/udp/udp_bsd.go @@ -8,12 +8,11 @@ package udp import ( "fmt" + "log/slog" "net" "net/netip" "syscall" - "log/slog" - "golang.org/x/sys/unix" ) diff --git a/udp/udp_bypass_windows.go b/udp/udp_bypass_windows.go new file mode 100644 index 00000000..b8b06b1e --- /dev/null +++ b/udp/udp_bypass_windows.go @@ -0,0 +1,57 @@ +//go:build (amd64 || arm64) && !e2e_testing +// +build amd64 arm64 +// +build !e2e_testing + +package udp + +import ( + "log/slog" + "sync" + + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/wfp" +) + +// wrapWithWDFBypass wraps a Conn so that the first ReloadConfig consults listen.windows_bypass_wdf +// and installs a WFP PERMIT filter for the listener's bound UDP port. The session is released when Close runs. +func wrapWithWDFBypass(l *slog.Logger, conn Conn) Conn { + return &bypassConn{Conn: conn, l: l} +} + +type bypassConn struct { + Conn + + l *slog.Logger + installOnce sync.Once + session *wfp.Session +} + +func (b *bypassConn) ReloadConfig(c *config.C) { + b.installOnce.Do(func() { + if !c.GetBool("listen.windows_bypass_wdf", true) { + return + } + addr, err := b.Conn.LocalAddr() + if err != nil { + b.l.Warn("Failed to query listener port for WFP bypass", "error", err) + return + } + s, err := wfp.PermitUDPPort(addr.Port()) + if err != nil { + b.l.Warn("Failed to install WFP bypass filters for listener", "error", err) + return + } + b.l.Info("Installed WFP filters bypassing Windows Defender Firewall on UDP listener port", + "port", addr.Port()) + b.session = s + }) + b.Conn.ReloadConfig(c) +} + +func (b *bypassConn) Close() error { + if b.session != nil { + b.session.Close() + b.session = nil + } + return b.Conn.Close() +} diff --git a/udp/udp_bypass_windows_386.go b/udp/udp_bypass_windows_386.go new file mode 100644 index 00000000..fa5a6eec --- /dev/null +++ b/udp/udp_bypass_windows_386.go @@ -0,0 +1,11 @@ +//go:build !e2e_testing +// +build !e2e_testing + +package udp + +import "log/slog" + +// wrapWithWDFBypass is a no-op on windows-386 since we don't currently build for it. +func wrapWithWDFBypass(_ *slog.Logger, conn Conn) Conn { + return conn +} diff --git a/udp/udp_netbsd.go b/udp/udp_netbsd.go index 4b2de75a..b0c81393 100644 --- a/udp/udp_netbsd.go +++ b/udp/udp_netbsd.go @@ -7,12 +7,11 @@ package udp import ( "fmt" + "log/slog" "net" "net/netip" "syscall" - "log/slog" - "golang.org/x/sys/unix" ) diff --git a/udp/udp_windows.go b/udp/udp_windows.go index 7969f7e8..1f34f0bc 100644 --- a/udp/udp_windows.go +++ b/udp/udp_windows.go @@ -19,13 +19,18 @@ func NewListener(l *slog.Logger, ip netip.Addr, port int, multi bool, batch int) return nil, fmt.Errorf("multiple udp listeners not supported on windows") } + var conn Conn rc, err := NewRIOListener(l, ip, port) if err == nil { - return rc, nil + conn = rc + } else { + l.Error("Falling back to standard udp sockets", "error", err) + conn, err = NewGenericListener(l, ip, port, multi, batch) + if err != nil { + return nil, err + } } - - l.Error("Falling back to standard udp sockets", "error", err) - return NewGenericListener(l, ip, port, multi, batch) + return wrapWithWDFBypass(l, conn), nil } func NewListenConfig(multi bool) net.ListenConfig { diff --git a/wfp/wfp_windows.go b/wfp/wfp_windows.go new file mode 100644 index 00000000..22aa0565 --- /dev/null +++ b/wfp/wfp_windows.go @@ -0,0 +1,377 @@ +//go:build (amd64 || arm64) && !e2e_testing +// +build amd64 arm64 +// +build !e2e_testing + +// Package wfp installs Windows Filtering Platform (WFP) PERMIT filters in a dynamic, session-scoped sublayer. +// Because WFP sits below Windows Defender Firewall, a high-weight permit at FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4/V6 lets +// the matching inbound traffic through regardless of WDF rules. +// +// Each Session owns its own engine handle. When the handle closes, every dynamic object added during the session +// is auto-deleted by Windows, so there are no orphaned filters. +// +// Type definitions and constants are derived from the wireguard-windows firewall package (MIT). +// Only the subset we exercise is reproduced. +package wfp + +import ( + "fmt" + "unsafe" + + "golang.org/x/sys/windows" +) + +// FWPM layer GUIDs (fwpmu.h). +// +// FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 = e1cd9fe7-f4b5-4273-96c0-592e487b8650 +// FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V6 = a3b42c97-9f04-4672-b87e-cee9c483257f +var ( + fwpmLayerAleAuthRecvAcceptV4 = windows.GUID{ + Data1: 0xe1cd9fe7, Data2: 0xf4b5, Data3: 0x4273, + Data4: [8]byte{0x96, 0xc0, 0x59, 0x2e, 0x48, 0x7b, 0x86, 0x50}, + } + fwpmLayerAleAuthRecvAcceptV6 = windows.GUID{ + Data1: 0xa3b42c97, Data2: 0x9f04, Data3: 0x4672, + Data4: [8]byte{0xb8, 0x7e, 0xce, 0xe9, 0xc4, 0x83, 0x25, 0x7f}, + } +) + +// FWPM_CONDITION_IP_LOCAL_INTERFACE = 4cd62a49-59c3-4969-b7f3-bda5d32890a4 +var fwpmConditionIPLocalInterface = windows.GUID{ + Data1: 0x4cd62a49, Data2: 0x59c3, Data3: 0x4969, + Data4: [8]byte{0xb7, 0xf3, 0xbd, 0xa5, 0xd3, 0x28, 0x90, 0xa4}, +} + +// FWPM_CONDITION_IP_PROTOCOL = 3971ef2b-623e-4f9a-8cb1-6e79b806b9a7 +var fwpmConditionIPProtocol = windows.GUID{ + Data1: 0x3971ef2b, Data2: 0x623e, Data3: 0x4f9a, + Data4: [8]byte{0x8c, 0xb1, 0x6e, 0x79, 0xb8, 0x06, 0xb9, 0xa7}, +} + +// FWPM_CONDITION_IP_LOCAL_PORT = 0c1ba1af-5765-453f-af22-a8f791ac775b +var fwpmConditionIPLocalPort = windows.GUID{ + Data1: 0x0c1ba1af, Data2: 0x5765, Data3: 0x453f, + Data4: [8]byte{0xaf, 0x22, 0xa8, 0xf7, 0x91, 0xac, 0x77, 0x5b}, +} + +// IPPROTO_UDP from in.h. +const ipprotoUDP uint8 = 17 + +// FWP_ACTION_TYPE values (fwptypes.h). PERMIT is terminating. +const fwpActionPermit uint32 = 0x00001002 // 0x2 | FWP_ACTION_FLAG_TERMINATING(0x1000) + +// FWP_DATA_TYPE values we use. +const ( + fwpEmpty uint32 = 0 + fwpUint8 uint32 = 1 + fwpUint16 uint32 = 2 + fwpUint64 uint32 = 4 +) + +// FWP_MATCH_TYPE values. +const fwpMatchEqual uint32 = 0 + +// FWPM_SESSION flags. +const fwpmSessionFlagDynamic uint32 = 0x1 + +// FWPM_FILTER_FLAG_CLEAR_ACTION_RIGHT prevents lower-priority filters in other sublayers, +// notably Windows Defender Firewall's MPSSVC_WF sublayer, which shares our 0xFFFF weight from overriding this PERMIT. +// Without it, a default WDF block at the same sublayer weight can still win arbitration. +const fwpmFilterFlagClearActionRight uint32 = 0x8 + +// RPC authentication. +// RPC_C_AUTHN_WINNT works on workgroup machines with no domain context +// RPC_C_AUTHN_DEFAULT falls back through a chain that can land on something WFP doesn't accept on a fresh box. +const rpcCAuthnWinNT uint32 = 10 + +// fwpByteBlob (FWP_BYTE_BLOB). 16 bytes on 64-bit. +type fwpByteBlob struct { + size uint32 + _ uint32 // padding + data *uint8 +} + +// fwpValue0 / FWP_CONDITION_VALUE0 layout. 16 bytes on 64-bit. +// The union is pointer-sized; types <= 32 bits (UINT8/16/32, INT8/16/32, float) live inline in the low bytes +// of `value`, while UINT64/INT64/double and aggregate types are stored *by pointer*, even on 64-bit, where the +// union member is declared as UINT64*. So when populating an FWP_UINT64 condition, pass +// uintptr(unsafe.Pointer(&luidVar)) instead of the LUID inline. +type fwpValue0 struct { + type_ uint32 + _ uint32 // padding before union to 8-byte alignment + value uintptr +} + +// fwpmDisplayData0 / FWPM_DISPLAY_DATA0. 16 bytes on 64-bit. +type fwpmDisplayData0 struct { + name *uint16 + description *uint16 +} + +// fwpmAction0 / FWPM_ACTION0. 20 bytes; no leading padding because actionType +// is uint32 and GUID's first field is uint32. +type fwpmAction0 struct { + actionType uint32 + filterType windows.GUID +} + +// fwpmFilterCondition0. 40 bytes on 64-bit. +type fwpmFilterCondition0 struct { + fieldKey windows.GUID // 16 + matchType uint32 // 4 + _ uint32 // 4 padding + conditionValue fwpValue0 // 16 +} + +// fwpmFilter0. 200 bytes on 64-bit. +type fwpmFilter0 struct { + filterKey windows.GUID + displayData fwpmDisplayData0 + flags uint32 + _ uint32 // padding before *GUID + providerKey *windows.GUID + providerData fwpByteBlob + layerKey windows.GUID + subLayerKey windows.GUID + weight fwpValue0 + numFilterConditions uint32 + _ uint32 // padding before pointer + filterCondition *fwpmFilterCondition0 + action fwpmAction0 + _ [4]byte // layout correction + providerContextKey windows.GUID + reserved *windows.GUID + filterID uint64 + effectiveWeight fwpValue0 +} + +// fwpmSublayer0. 72 bytes on 64-bit. +type fwpmSublayer0 struct { + subLayerKey windows.GUID + displayData fwpmDisplayData0 + flags uint32 + _ uint32 // padding before *GUID + providerKey *windows.GUID + providerData fwpByteBlob + weight uint16 + _ [6]byte // padding to 72 bytes +} + +// fwpmSession0. 72 bytes on 64-bit. +type fwpmSession0 struct { + sessionKey windows.GUID + displayData fwpmDisplayData0 + flags uint32 + txnWaitTimeoutInMSec uint32 + processId uint32 + _ uint32 // padding before *SID + sid *windows.SID + username *uint16 + kernelMode uint8 + _ [7]byte // tail padding +} + +// fwpuclnt.dll bindings. Only the calls we use. +var ( + modFwpuclnt = windows.NewLazySystemDLL("fwpuclnt.dll") + procFwpmEngineOpen0 = modFwpuclnt.NewProc("FwpmEngineOpen0") + procFwpmEngineClose0 = modFwpuclnt.NewProc("FwpmEngineClose0") + procFwpmSubLayerAdd0 = modFwpuclnt.NewProc("FwpmSubLayerAdd0") + procFwpmFilterAdd0 = modFwpuclnt.NewProc("FwpmFilterAdd0") +) + +// Session holds the WFP engine handle for a single bypass operation. The handle owns a dynamic session: +// when it is closed, every WFP object added during the session (sublayer + filters) is automatically deleted by +// Windows. That gives us correct cleanup even if the host process is killed hard between Permit* and Close. +type Session struct { + engine uintptr +} + +// Close releases the engine handle. Windows deletes every dynamic object (sublayer + filters) the session installed. +// Safe to call on a nil receiver. +func (s *Session) Close() { + if s == nil || s.engine == 0 { + return + } + procFwpmEngineClose0.Call(s.engine) + s.engine = 0 +} + +// PermitInterface installs PERMIT filters at FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 and _V6 scoped to the given network +// interface LUID. Inbound traffic on that interface bypasses Windows Defender Firewall. +func PermitInterface(luid uint64) (*Session, error) { + s, sublayerKey, err := newSession() + if err != nil { + return nil, err + } + + if err := addInterfaceFilter(s.engine, sublayerKey, fwpmLayerAleAuthRecvAcceptV4, luid); err != nil { + s.Close() + return nil, fmt.Errorf("add v4 filter: %w", err) + } + if err := addInterfaceFilter(s.engine, sublayerKey, fwpmLayerAleAuthRecvAcceptV6, luid); err != nil { + s.Close() + return nil, fmt.Errorf("add v6 filter: %w", err) + } + return s, nil +} + +// PermitUDPPort installs PERMIT filters at FWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 and _V6 scoped to UDP traffic with the +// given local port. Inbound UDP to that port on any interface bypasses Windows Defender Firewall. +func PermitUDPPort(port uint16) (*Session, error) { + s, sublayerKey, err := newSession() + if err != nil { + return nil, err + } + + if err := addUDPPortFilter(s.engine, sublayerKey, fwpmLayerAleAuthRecvAcceptV4, port); err != nil { + s.Close() + return nil, fmt.Errorf("add v4 filter: %w", err) + } + if err := addUDPPortFilter(s.engine, sublayerKey, fwpmLayerAleAuthRecvAcceptV6, port); err != nil { + s.Close() + return nil, fmt.Errorf("add v6 filter: %w", err) + } + return s, nil +} + +func newSession() (*Session, windows.GUID, error) { + engine, err := openDynamicEngine() + if err != nil { + return nil, windows.GUID{}, err + } + sublayerKey, err := registerSublayer(engine) + if err != nil { + procFwpmEngineClose0.Call(engine) + return nil, windows.GUID{}, err + } + return &Session{engine: engine}, sublayerKey, nil +} + +func openDynamicEngine() (uintptr, error) { + session := fwpmSession0{flags: fwpmSessionFlagDynamic} + var engine uintptr + r1, _, _ := procFwpmEngineOpen0.Call( + 0, // serverName == NULL (local) + uintptr(rpcCAuthnWinNT), + 0, // authIdentity == NULL + uintptr(unsafe.Pointer(&session)), + uintptr(unsafe.Pointer(&engine)), + ) + if r1 != 0 { + return 0, fmt.Errorf("FwpmEngineOpen0: 0x%x", r1) + } + return engine, nil +} + +// registerSublayer adds a session-scoped sublayer with a freshly generated GUID, weight 0xFFFF so its filters arbitrate +// above WDF's default sublayer. The sublayer is dynamic (no PERSISTENT flag) and goes away when the engine handle closes. +func registerSublayer(engine uintptr) (windows.GUID, error) { + key, err := windows.GenerateGUID() + if err != nil { + return windows.GUID{}, fmt.Errorf("GenerateGUID for sublayer: %w", err) + } + + name, _ := windows.UTF16PtrFromString("Nebula WDF bypass sublayer") + desc, _ := windows.UTF16PtrFromString("Permit filters bypassing Windows Defender Firewall") + sl := fwpmSublayer0{ + subLayerKey: key, + displayData: fwpmDisplayData0{name: name, description: desc}, + weight: 0xFFFF, + } + r1, _, _ := procFwpmSubLayerAdd0.Call( + engine, + uintptr(unsafe.Pointer(&sl)), + 0, // sd == NULL + ) + if r1 != 0 { + return windows.GUID{}, fmt.Errorf("FwpmSubLayerAdd0: 0x%x", r1) + } + return key, nil +} + +func addInterfaceFilter(engine uintptr, sublayerKey, layer windows.GUID, luid uint64) error { + name, _ := windows.UTF16PtrFromString("Nebula allow interface inbound") + desc, _ := windows.UTF16PtrFromString("Permits inbound traffic on a nebula interface") + + // luid must remain addressable through the syscall -- FWP_UINT64 is stored + // by pointer in the FWP_VALUE0 union. + cond := fwpmFilterCondition0{ + fieldKey: fwpmConditionIPLocalInterface, + matchType: fwpMatchEqual, + conditionValue: fwpValue0{ + type_: fwpUint64, + value: uintptr(unsafe.Pointer(&luid)), + }, + } + + filter := fwpmFilter0{ + // filterKey left zero: WFP assigns one when the filter is added. + displayData: fwpmDisplayData0{name: name, description: desc}, + flags: fwpmFilterFlagClearActionRight, + layerKey: layer, + subLayerKey: sublayerKey, + weight: fwpValue0{type_: fwpUint8, value: uintptr(15)}, + numFilterConditions: 1, + filterCondition: &cond, + action: fwpmAction0{actionType: fwpActionPermit}, + } + + r1, _, _ := procFwpmFilterAdd0.Call( + engine, + uintptr(unsafe.Pointer(&filter)), + 0, // sd == NULL + 0, // id == NULL + ) + if r1 != 0 { + return fmt.Errorf("FwpmFilterAdd0: 0x%x", r1) + } + return nil +} + +// addUDPPortFilter installs a PERMIT filter that matches (IP_PROTOCOL == UDP) AND (IP_LOCAL_PORT == port). +// FWP_UINT8 and FWP_UINT16 are <= 32 bits so they live inline in the FWP_VALUE0 union. +func addUDPPortFilter(engine uintptr, sublayerKey, layer windows.GUID, port uint16) error { + name, _ := windows.UTF16PtrFromString("Nebula allow UDP port inbound") + desc, _ := windows.UTF16PtrFromString("Permits inbound UDP to a nebula listener port") + + conds := [2]fwpmFilterCondition0{ + { + fieldKey: fwpmConditionIPProtocol, + matchType: fwpMatchEqual, + conditionValue: fwpValue0{ + type_: fwpUint8, + value: uintptr(ipprotoUDP), + }, + }, + { + fieldKey: fwpmConditionIPLocalPort, + matchType: fwpMatchEqual, + conditionValue: fwpValue0{ + type_: fwpUint16, + value: uintptr(port), + }, + }, + } + + filter := fwpmFilter0{ + displayData: fwpmDisplayData0{name: name, description: desc}, + flags: fwpmFilterFlagClearActionRight, + layerKey: layer, + subLayerKey: sublayerKey, + weight: fwpValue0{type_: fwpUint8, value: uintptr(15)}, + numFilterConditions: 2, + filterCondition: &conds[0], + action: fwpmAction0{actionType: fwpActionPermit}, + } + + r1, _, _ := procFwpmFilterAdd0.Call( + engine, + uintptr(unsafe.Pointer(&filter)), + 0, + 0, + ) + if r1 != 0 { + return fmt.Errorf("FwpmFilterAdd0: 0x%x", r1) + } + return nil +} From 398d67e2da34573801545c9403d86d8460a2c8a5 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 8 May 2026 14:43:19 -0500 Subject: [PATCH 18/31] Windows code signing (#1718) --- .github/actions/code-sign/action.yml | 113 +++++++++++++++++++++++++++ .github/workflows/release.yml | 10 +++ 2 files changed, 123 insertions(+) create mode 100644 .github/actions/code-sign/action.yml diff --git a/.github/actions/code-sign/action.yml b/.github/actions/code-sign/action.yml new file mode 100644 index 00000000..bfa1a9ec --- /dev/null +++ b/.github/actions/code-sign/action.yml @@ -0,0 +1,113 @@ +name: Code-sign Windows binaries +description: > + Sign every .exe under a given path in place via the DefinedNet code-signer + Lambda. If `role` or `bucket` is empty, logs a notice and skips signing so + forks and dev branches without AWS access still produce usable builds. + +inputs: + path: + description: "Directory whose .exe files should be signed in place" + required: true + role: + description: "IAM role ARN to assume via OIDC; empty disables signing" + required: false + default: "" + bucket: + description: "S3 staging bucket the code-signer Lambda reads from; empty disables signing" + required: false + default: "" + region: + description: "AWS region for the role and Lambda" + required: false + default: "us-east-2" + function-name: + description: "Code-signer Lambda function name" + required: false + default: "code-signer" + key-prefix: + description: "S3 key prefix the caller is authorized to write under" + required: false + default: "code-signing/slackhq/nebula" + +runs: + using: composite + steps: + - name: Skip notice + if: inputs.role == '' || inputs.bucket == '' + shell: sh + run: echo "::notice::code-signer role or bucket not set; skipping code signing." + + - name: Configure AWS credentials + if: inputs.role != '' && inputs.bucket != '' + uses: aws-actions/configure-aws-credentials@v6 + with: + role-to-assume: ${{ inputs.role }} + aws-region: ${{ inputs.region }} + # Default is 12 retries to ride out IAM trust-policy propagation; once + # the role is stable we want a real misconfiguration to fail fast. + retry-max-attempts: 5 + + - name: Sign .exe files + if: inputs.role != '' && inputs.bucket != '' + shell: sh + env: + SIGN_PATH: ${{ inputs.path }} + BUCKET: ${{ inputs.bucket }} + FUNCTION_NAME: ${{ inputs.function-name }} + KEY_PREFIX: ${{ inputs.key-prefix }} + run: | + set -eu + RUN="${GITHUB_RUN_ID}-${GITHUB_RUN_ATTEMPT}" + + find "$SIGN_PATH" -name '*.exe' -print | while read -r path + do + rel=${path#"$SIGN_PATH"/} + file=$(basename "$path") + name=${file%.exe} + prefix="${KEY_PREFIX}/${RUN}" + src="${prefix}/unsigned/${rel}" + dst="${prefix}/signed/${rel}" + + echo "::group::Sign ${rel}" + echo "Uploading unsigned to s3://${BUCKET}/${src}" + aws s3 cp --no-progress "$path" "s3://${BUCKET}/${src}" >/dev/null + + echo "Invoking ${FUNCTION_NAME} Lambda" + payload=$(jq -nc \ + --arg s "$src" \ + --arg d "$dst" \ + --arg p "$name" \ + '{source_key: $s, dest_key: $d, program_name: $p}') + meta=$(aws lambda invoke \ + --function-name "$FUNCTION_NAME" \ + --cli-binary-format raw-in-base64-out \ + --payload "$payload" \ + --output json \ + /tmp/sign-resp.json) + if echo "$meta" | jq -e '.FunctionError != null' >/dev/null + then + echo "::endgroup::" + echo "::error::code-signer Lambda failed for ${rel}" + cat /tmp/sign-resp.json >&2 + exit 1 + fi + + echo "Downloading signed back to ${path}" + aws s3 cp --no-progress "s3://${BUCKET}/${dst}" "$path" >/dev/null + + aws s3 rm "s3://${BUCKET}/${src}" >/dev/null 2>&1 || true + aws s3 rm "s3://${BUCKET}/${dst}" >/dev/null 2>&1 || true + + # Sanity-check the bytes we got back actually carry an Authenticode + # signature that this machine can validate end to end. + status=$(powershell -NoProfile -Command "(Get-AuthenticodeSignature -FilePath '$path').Status" | tr -d '\r') + if [ "$status" != "Valid" ] + then + echo "::endgroup::" + echo "::error::${rel} signature status: ${status} (expected Valid)" + exit 1 + fi + + echo "Signed ${rel} (sha256=$(jq -r '.sha256' /tmp/sign-resp.json), status=${status})" + echo "::endgroup::" + done diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 356ae363..e4ca2933 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -32,6 +32,9 @@ jobs: build-windows: name: Build Windows runs-on: windows-latest + permissions: + id-token: write + contents: read steps: - uses: actions/checkout@v6 @@ -54,6 +57,13 @@ jobs: mkdir build\dist\windows mv dist\windows\wintun build\dist\windows\ + - name: Code-sign + uses: ./.github/actions/code-sign + with: + path: build + role: ${{ secrets.DEFINED_CODE_SIGNER_ROLE }} + bucket: ${{ secrets.DEFINED_CODE_SIGNER_BUCKET }} + - name: Upload artifacts uses: actions/upload-artifact@v7 with: From 110ea8f45c11d50ef2954db6f4fa9cadf1331f69 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 15 May 2026 14:14:32 -0400 Subject: [PATCH 19/31] Bump the golang-x-dependencies group with 4 updates (#1721) Bumps the golang-x-dependencies group with 4 updates: [golang.org/x/crypto](https://github.com/golang/crypto), [golang.org/x/net](https://github.com/golang/net), [golang.org/x/sys](https://github.com/golang/sys) and [golang.org/x/term](https://github.com/golang/term). Updates `golang.org/x/crypto` from 0.50.0 to 0.51.0 - [Commits](https://github.com/golang/crypto/compare/v0.50.0...v0.51.0) Updates `golang.org/x/net` from 0.53.0 to 0.54.0 - [Commits](https://github.com/golang/net/compare/v0.53.0...v0.54.0) Updates `golang.org/x/sys` from 0.43.0 to 0.44.0 - [Commits](https://github.com/golang/sys/compare/v0.43.0...v0.44.0) Updates `golang.org/x/term` from 0.42.0 to 0.43.0 - [Commits](https://github.com/golang/term/compare/v0.42.0...v0.43.0) --- updated-dependencies: - dependency-name: golang.org/x/crypto dependency-version: 0.51.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/net dependency-version: 0.54.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/sys dependency-version: 0.44.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies - dependency-name: golang.org/x/term dependency-version: 0.43.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: golang-x-dependencies ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 8 ++++---- go.sum | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/go.mod b/go.mod index 84728201..ee51151f 100644 --- a/go.mod +++ b/go.mod @@ -24,12 +24,12 @@ require ( github.com/vishvananda/netlink v1.3.1 go.uber.org/goleak v1.3.0 go.yaml.in/yaml/v3 v3.0.4 - golang.org/x/crypto v0.50.0 + golang.org/x/crypto v0.51.0 golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 - golang.org/x/net v0.53.0 + golang.org/x/net v0.54.0 golang.org/x/sync v0.20.0 - golang.org/x/sys v0.43.0 - golang.org/x/term v0.42.0 + golang.org/x/sys v0.44.0 + golang.org/x/term v0.43.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b golang.zx2c4.com/wireguard/windows v0.6.1 diff --git a/go.sum b/go.sum index 3b0b87df..5640bd46 100644 --- a/go.sum +++ b/go.sum @@ -162,8 +162,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= -golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= +golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI= +golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= @@ -182,8 +182,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= -golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= +golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w= +golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -208,11 +208,11 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= -golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ= +golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY= -golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY= +golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4= +golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= From 6c7ebb08759ddfea0c628bcc7c3069d379edee1b Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 15 May 2026 15:35:49 -0500 Subject: [PATCH 20/31] Reset static host list addresses on change (#1713) --- lighthouse.go | 25 +++++++-- lighthouse_test.go | 126 ++++++++++++++++++++++++++++++++++++++++++++ remote_list.go | 25 +++++++++ remote_list_test.go | 89 +++++++++++++++++++++++++++++++ 4 files changed, 261 insertions(+), 4 deletions(-) diff --git a/lighthouse.go b/lighthouse.go index 1a136a1b..d23e84b8 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -272,16 +272,18 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { //NOTE: many things will get much simpler when we combine static_host_map and lighthouse.hosts in config if initial || c.HasChanged("static_host_map") || c.HasChanged("static_map.cadence") || c.HasChanged("static_map.network") || c.HasChanged("static_map.lookup_timeout") { // Clean up. Entries still in the static_host_map will be re-built. - // Entries no longer present must have their (possible) background DNS goroutines stopped. - if existingStaticList := lh.staticList.Load(); existingStaticList != nil { + ourselves := lh.myVpnNetworks[0].Addr() + oldStaticList := lh.staticList.Load() + if oldStaticList != nil { lh.RLock() - for staticVpnAddr := range *existingStaticList { + for staticVpnAddr := range *oldStaticList { if am, ok := lh.addrMap[staticVpnAddr]; ok && am != nil { - am.hr.Cancel() + am.ResetForOwner(ourselves) } } lh.RUnlock() } + // Build a new list based on current config. staticList := make(map[netip.Addr]struct{}) err := lh.loadStaticMap(c, staticList) @@ -289,6 +291,21 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { return err } + // For entries removed from static_host_map, stop the DNS goroutine and drop the cached addrs. + // All addrs must come from the lighthouses now that it's no longer a static host. + if oldStaticList != nil { + lh.RLock() + for staticVpnAddr := range *oldStaticList { + if _, stillStatic := staticList[staticVpnAddr]; stillStatic { + continue + } + if am, ok := lh.addrMap[staticVpnAddr]; ok && am != nil { + am.ClearHostnameResults() + } + } + lh.RUnlock() + } + lh.staticList.Store(&staticList) if !initial { if c.HasChanged("static_host_map") { diff --git a/lighthouse_test.go b/lighthouse_test.go index c57c44ec..81c883ff 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -303,6 +303,132 @@ func TestLighthouse_reload(t *testing.T) { require.NoError(t, err) } +// TestLighthouse_reloadStaticHostMap verifies that reloading static_host_map applies the new +// config rather than appending to it. See issue #718. +func TestLighthouse_reloadStaticHostMap(t *testing.T) { + l := test.NewLogger() + c := config.NewC(l) + c.Settings["lighthouse"] = map[string]any{"am_lighthouse": true} + c.Settings["listen"] = map[string]any{"port": 4242} + c.Settings["static_host_map"] = map[string]any{ + "10.128.0.2": []any{"1.1.1.1:4242"}, + } + + myVpnNet := netip.MustParsePrefix("10.128.0.1/24") + nt := new(bart.Lite) + nt.Insert(myVpnNet) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } + + lh, err := NewLightHouseFromConfig(t.Context(), l, c, cs, nil, nil) + require.NoError(t, err) + + staticHost := netip.MustParseAddr("10.128.0.2") + otherHost := netip.MustParseAddr("10.128.0.3") + + // Capture the RemoteList pointer up front; an in-flight handshake would hold the same one + // on hostinfo.remotes, so it must reflect every reload below. + pinned := lh.Query(staticHost) + require.NotNil(t, pinned) + assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:4242")}, pinned.CopyAddrs([]netip.Prefix{})) + + // Replace the remote address. The new address should be the only entry. + nc := map[string]any{ + "static_host_map": map[string]any{ + "10.128.0.2": []any{"2.2.2.2:4242"}, + }, + } + rc, err := yaml.Marshal(nc) + require.NoError(t, err) + require.NoError(t, c.ReloadConfigString(string(rc))) + + rl := lh.Query(staticHost) + require.NotNil(t, rl) + assert.Same(t, pinned, rl, "RemoteList pointer must stay stable so in-flight handshakes pick up the change") + assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("2.2.2.2:4242")}, rl.CopyAddrs([]netip.Prefix{})) + + // Reload back to the original IP. Mirrors the round-trip in issue #718 step 6-8 where + // the buggy reload produced [1.1.1.1, 2.2.2.2, 1.1.1.1] instead of [1.1.1.1]. + nc = map[string]any{ + "static_host_map": map[string]any{ + "10.128.0.2": []any{"1.1.1.1:4242"}, + }, + } + rc, err = yaml.Marshal(nc) + require.NoError(t, err) + require.NoError(t, c.ReloadConfigString(string(rc))) + + rl = lh.Query(staticHost) + require.NotNil(t, rl) + assert.Same(t, pinned, rl) + assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:4242")}, rl.CopyAddrs([]netip.Prefix{})) + + // Reload with the same config. An unchanged entry must not duplicate. + require.NoError(t, c.ReloadConfigString(string(rc))) + + rl = lh.Query(staticHost) + require.NotNil(t, rl) + assert.Same(t, pinned, rl) + assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:4242")}, rl.CopyAddrs([]netip.Prefix{})) + + // Switch back to 2.2.2.2 so the rest of the test continues against a known address. + nc = map[string]any{ + "static_host_map": map[string]any{ + "10.128.0.2": []any{"2.2.2.2:4242"}, + }, + } + rc, err = yaml.Marshal(nc) + require.NoError(t, err) + require.NoError(t, c.ReloadConfigString(string(rc))) + + // Add a second host alongside the first. Both should be present, neither duplicated. + nc = map[string]any{ + "static_host_map": map[string]any{ + "10.128.0.2": []any{"2.2.2.2:4242"}, + "10.128.0.3": []any{"3.3.3.3:4242"}, + }, + } + rc, err = yaml.Marshal(nc) + require.NoError(t, err) + require.NoError(t, c.ReloadConfigString(string(rc))) + + rl = lh.Query(staticHost) + require.NotNil(t, rl) + assert.Same(t, pinned, rl, "adding a sibling entry must not displace the existing RemoteList") + assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("2.2.2.2:4242")}, rl.CopyAddrs([]netip.Prefix{})) + + rl = lh.Query(otherHost) + require.NotNil(t, rl) + assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("3.3.3.3:4242")}, rl.CopyAddrs([]netip.Prefix{})) + + // Drop the first host entirely. The vpnAddr is no longer marked static, our owner + // contribution is cleared, but the addrMap entry stays in place so non-static cache + // data (from lighthouse queries) on the same RemoteList isn't lost. In-flight handshakes + // that already had the pointer see an empty address list rather than retrying stale ones. + nc = map[string]any{ + "static_host_map": map[string]any{ + "10.128.0.3": []any{"3.3.3.3:4242"}, + }, + } + rc, err = yaml.Marshal(nc) + require.NoError(t, err) + require.NoError(t, c.ReloadConfigString(string(rc))) + + _, isStatic := lh.GetStaticHostList()[staticHost] + assert.False(t, isStatic) + + rl = lh.Query(staticHost) + require.NotNil(t, rl) + assert.Same(t, pinned, rl) + assert.Empty(t, rl.CopyAddrs([]netip.Prefix{})) + + rl = lh.Query(otherHost) + require.NotNil(t, rl) + assert.Equal(t, []netip.AddrPort{netip.MustParseAddrPort("3.3.3.3:4242")}, rl.CopyAddrs([]netip.Prefix{})) +} + func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply { req := &NebulaMeta{ Type: NebulaMeta_HostQuery, diff --git a/remote_list.go b/remote_list.go index 7b95de87..ef6eb794 100644 --- a/remote_list.go +++ b/remote_list.go @@ -239,6 +239,31 @@ func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) { r.hr = hr } +// ResetForOwner zeros the reported address slices for the given owner and marks the addrs list dirty. +// Any pending hostname resolution will be canceled. +func (r *RemoteList) ResetForOwner(ownerVpnAddr netip.Addr) { + r.Lock() + defer r.Unlock() + r.hr.Cancel() + if c, ok := r.cache[ownerVpnAddr]; ok { + if c.v4 != nil { + c.v4.reported = c.v4.reported[:0] + } + if c.v6 != nil { + c.v6.reported = c.v6.reported[:0] + } + } + r.shouldRebuild = true +} + +// ClearHostnameResults cancels the in-flight DNS resolver goroutine (if any) and drops the resolved IP cache. +func (r *RemoteList) ClearHostnameResults() { + r.Lock() + defer r.Unlock() + r.unlockedSetHostnamesResults(nil) + r.shouldRebuild = true +} + // Len locks and reports the size of the deduplicated address list // The deduplication work may need to occur here, so you must pass preferredRanges func (r *RemoteList) Len(preferredRanges []netip.Prefix) int { diff --git a/remote_list_test.go b/remote_list_test.go index 0caf86a4..0b5b7d5d 100644 --- a/remote_list_test.go +++ b/remote_list_test.go @@ -6,8 +6,22 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +// trackedHostnameResults builds a *hostnamesResults with a known cancel function and a +// pre-populated ips map so tests can assert cancellation and verify previously-resolved +// IPs survive a cancel without spinning up a real DNS resolver. +func trackedHostnameResults(cancelFn func(), addrs ...string) *hostnamesResults { + hr := &hostnamesResults{cancelFn: cancelFn} + ips := map[netip.AddrPort]struct{}{} + for _, a := range addrs { + ips[netip.MustParseAddrPort(a)] = struct{}{} + } + hr.ips.Store(&ips) + return hr +} + func TestRemoteList_Rebuild(t *testing.T) { rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil) rl.unlockedSetV4( @@ -112,6 +126,81 @@ func TestRemoteList_Rebuild(t *testing.T) { assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String()) } +func TestRemoteList_ResetForOwner(t *testing.T) { + ourselves := netip.MustParseAddr("10.0.0.1") + otherOwner := netip.MustParseAddr("10.0.0.2") + vpnAddr := netip.MustParseAddr("10.0.0.99") + + rl := NewRemoteList([]netip.Addr{vpnAddr}, nil) + rl.unlockedSetV4(ourselves, vpnAddr, + []*V4AddrPort{newIp4AndPortFromString("1.1.1.1:4242")}, + func(netip.Addr, *V4AddrPort) bool { return true }, + ) + rl.unlockedSetV6(ourselves, vpnAddr, + []*V6AddrPort{newIp6AndPortFromString("[1::1]:4242")}, + func(netip.Addr, *V6AddrPort) bool { return true }, + ) + rl.unlockedSetV4(otherOwner, vpnAddr, + []*V4AddrPort{newIp4AndPortFromString("2.2.2.2:4242")}, + func(netip.Addr, *V4AddrPort) bool { return true }, + ) + + canceled := 0 + hr := trackedHostnameResults(func() { canceled++ }, "3.3.3.3:4242") + rl.Lock() + rl.unlockedSetHostnamesResults(hr) + rl.Unlock() + + rl.ResetForOwner(ourselves) + + rl.RLock() + defer rl.RUnlock() + assert.Empty(t, rl.cache[ourselves].v4.reported, "our v4 reported should be cleared") + assert.Empty(t, rl.cache[ourselves].v6.reported, "our v6 reported should be cleared") + assert.Len(t, rl.cache[otherOwner].v4.reported, 1, "other owner's contribution must be preserved") + assert.Equal(t, "2.2.2.2:4242", protoV4AddrPortToNetAddrPort(rl.cache[otherOwner].v4.reported[0]).String()) + assert.Equal(t, 1, canceled, "DNS resolution goroutine should be canceled") + assert.Same(t, hr, rl.hr, "hostnamesResults must be preserved so DNS-resolved IPs keep feeding addrs until replaced") + assert.NotEmpty(t, rl.hr.GetAddrs(), "previously-resolved IPs should still be readable after cancel") + assert.True(t, rl.shouldRebuild, "shouldRebuild must be set so the next Rebuild recomputes addrs") +} + +func TestRemoteList_ResetForOwner_NoEntry(t *testing.T) { + // An owner with no cache entry must not panic; shouldRebuild is still set and any + // existing hostnamesResults is canceled. + rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("10.0.0.99")}, nil) + canceled := 0 + rl.Lock() + rl.unlockedSetHostnamesResults(trackedHostnameResults(func() { canceled++ }, "3.3.3.3:4242")) + rl.Unlock() + + rl.ResetForOwner(netip.MustParseAddr("10.0.0.1")) + + rl.RLock() + defer rl.RUnlock() + assert.Equal(t, 1, canceled) + assert.True(t, rl.shouldRebuild) +} + +func TestRemoteList_ClearHostnameResults(t *testing.T) { + rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("10.0.0.99")}, nil) + + canceled := 0 + hr := trackedHostnameResults(func() { canceled++ }, "3.3.3.3:4242") + rl.Lock() + rl.unlockedSetHostnamesResults(hr) + rl.Unlock() + require.NotEmpty(t, hr.GetAddrs(), "hostnamesResults should have its fastrack IPs populated") + + rl.ClearHostnameResults() + + rl.RLock() + defer rl.RUnlock() + assert.Equal(t, 1, canceled, "DNS resolution goroutine should be canceled") + assert.Nil(t, rl.hr, "hostnamesResults should be dropped") + assert.True(t, rl.shouldRebuild) +} + func BenchmarkFullRebuild(b *testing.B) { rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil) rl.unlockedSetV4( From 3c121e7ab1b9f0369d72a78c60edacd8a7cf6b2f Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 15 May 2026 15:36:08 -0500 Subject: [PATCH 21/31] Allow for `-` to stand in for stdin/out (#1714) --- cmd/nebula-cert/ca.go | 39 ++++++-- cmd/nebula-cert/ca_test.go | 79 ++++++++++++++-- cmd/nebula-cert/keygen.go | 15 ++- cmd/nebula-cert/keygen_test.go | 41 ++++++++ cmd/nebula-cert/passwords.go | 4 +- cmd/nebula-cert/print.go | 31 ++++-- cmd/nebula-cert/print_test.go | 39 ++++++++ cmd/nebula-cert/sign.go | 62 ++++++++---- cmd/nebula-cert/sign_test.go | 120 +++++++++++++++++++++-- cmd/nebula-cert/stdio.go | 117 +++++++++++++++++++++++ cmd/nebula-cert/stdio_test.go | 167 +++++++++++++++++++++++++++++++++ cmd/nebula-cert/verify.go | 17 +++- cmd/nebula-cert/verify_test.go | 44 +++++++++ 13 files changed, 718 insertions(+), 57 deletions(-) create mode 100644 cmd/nebula-cert/stdio.go create mode 100644 cmd/nebula-cert/stdio_test.go diff --git a/cmd/nebula-cert/ca.go b/cmd/nebula-cert/ca.go index cd9b82f9..3145f445 100644 --- a/cmd/nebula-cert/ca.go +++ b/cmd/nebula-cert/ca.go @@ -97,6 +97,19 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error if err = mustFlagString("out-key", cf.outKeyPath); err != nil { return err } + } else { + // out-key is meaningless under PKCS#11 because the private key never + // leaves the HSM; reject it so we never silently accept or claim a + // stdout slot for it. + outKeySet := false + cf.set.Visit(func(f *flag.Flag) { + if f.Name == "out-key" { + outKeySet = true + } + }) + if outKeySet { + return newHelpErrorf("cannot set -out-key with -pkcs11") + } } if err := mustFlagString("out-crt", cf.outCertPath); err != nil { return err @@ -171,12 +184,21 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error } } + var claims ioClaims + if err := reserveOutputs(&claims, + "out-key", *cf.outKeyPath, + "out-crt", *cf.outCertPath, + "out-qr", *cf.outQRPath, + ); err != nil { + return err + } + var passphrase []byte if !isP11 && *cf.encryption { passphrase = []byte(os.Getenv("NEBULA_CA_PASSPHRASE")) if len(passphrase) == 0 { for i := 0; i < 5; i++ { - out.Write([]byte("Enter passphrase: ")) + errOut.Write([]byte("Enter passphrase: ")) passphrase, err = pr.ReadPassword() if err == ErrNoTerminal { @@ -261,14 +283,16 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error Curve: curve, } - if !isP11 { + if !isP11 && !isStdio(*cf.outKeyPath) { if _, err := os.Stat(*cf.outKeyPath); err == nil { return fmt.Errorf("refusing to overwrite existing CA key: %s", *cf.outKeyPath) } } - if _, err := os.Stat(*cf.outCertPath); err == nil { - return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath) + if !isStdio(*cf.outCertPath) { + if _, err := os.Stat(*cf.outCertPath); err == nil { + return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath) + } } var c cert.Certificate @@ -294,7 +318,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error b = cert.MarshalSigningPrivateKeyToPEM(curve, rawPriv) } - err = os.WriteFile(*cf.outKeyPath, b, 0600) + err = writeOutput(*cf.outKeyPath, b, 0600, out) if err != nil { return fmt.Errorf("error while writing out-key: %s", err) } @@ -305,7 +329,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error return fmt.Errorf("error while marshalling certificate: %s", err) } - err = os.WriteFile(*cf.outCertPath, b, 0600) + err = writeOutput(*cf.outCertPath, b, 0600, out) if err != nil { return fmt.Errorf("error while writing out-crt: %s", err) } @@ -316,7 +340,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error return fmt.Errorf("error while generating qr code: %s", err) } - err = os.WriteFile(*cf.outQRPath, b, 0600) + err = writeOutput(*cf.outQRPath, b, 0600, out) if err != nil { return fmt.Errorf("error while writing out-qr: %s", err) } @@ -332,6 +356,7 @@ func caSummary() string { func caHelp(out io.Writer) { cf := newCaFlags() out.Write([]byte("Usage of " + os.Args[0] + " " + caSummary() + "\n")) + out.Write([]byte(stdioHelpText)) cf.set.SetOutput(out) cf.set.PrintDefaults() } diff --git a/cmd/nebula-cert/ca_test.go b/cmd/nebula-cert/ca_test.go index 779d3a2d..ce0113b6 100644 --- a/cmd/nebula-cert/ca_test.go +++ b/cmd/nebula-cert/ca_test.go @@ -27,6 +27,7 @@ func Test_caHelp(t *testing.T) { assert.Equal( t, "Usage of "+os.Args[0]+" ca : create a self signed certificate authority\n"+ + " Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+ " -argon-iterations uint\n"+ " \tOptional: Argon2 iterations parameter used for encrypted private key passphrase (default 1)\n"+ " -argon-memory uint\n"+ @@ -84,7 +85,7 @@ func Test_ca(t *testing.T) { err: nil, } - pwPromptOb := "Enter passphrase: " + pwPromptEB := "Enter passphrase: " // required args assertHelpError(t, ca( @@ -168,8 +169,8 @@ func Test_ca(t *testing.T) { eb.Reset() args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} require.NoError(t, ca(args, ob, eb, testpw)) - assert.Equal(t, pwPromptOb, ob.String()) - assert.Empty(t, eb.String()) + assert.Empty(t, ob.String()) + assert.Equal(t, pwPromptEB, eb.String()) // test encrypted key with passphrase environment variable os.Remove(keyF.Name()) @@ -207,8 +208,8 @@ func Test_ca(t *testing.T) { eb.Reset() args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} require.Error(t, ca(args, ob, eb, errpw)) - assert.Equal(t, pwPromptOb, ob.String()) - assert.Empty(t, eb.String()) + assert.Empty(t, ob.String()) + assert.Equal(t, pwPromptEB, eb.String()) // test when user fails to enter a password os.Remove(keyF.Name()) @@ -217,8 +218,8 @@ func Test_ca(t *testing.T) { eb.Reset() args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext") - assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up - assert.Empty(t, eb.String()) + assert.Empty(t, ob.String()) + assert.Equal(t, strings.Repeat(pwPromptEB, 5), eb.String()) // prompts 5 times before giving up // create valid cert/key for overwrite tests os.Remove(keyF.Name()) @@ -247,3 +248,67 @@ func Test_ca(t *testing.T) { os.Remove(keyF.Name()) } + +func Test_ca_stdio(t *testing.T) { + nopw := &StubPasswordReader{} + + keyF, err := os.CreateTemp("", "ca.key") + require.NoError(t, err) + os.Remove(keyF.Name()) + defer os.Remove(keyF.Name()) + + crtF, err := os.CreateTemp("", "ca.crt") + require.NoError(t, err) + os.Remove(crtF.Name()) + defer os.Remove(crtF.Name()) + + // out-crt on stdout, out-key on disk + ob := &bytes.Buffer{} + eb := &bytes.Buffer{} + require.NoError(t, ca([]string{"-name", "test-ca", "-duration", "1h", "-out-crt", "-", "-out-key", keyF.Name()}, ob, eb, nopw)) + assert.Empty(t, eb.String()) + c, _, err := cert.UnmarshalCertificateFromPEM(ob.Bytes()) + require.NoError(t, err) + assert.True(t, c.IsCA()) + assert.Equal(t, "test-ca", c.Name()) + + // out-key on stdout, out-crt on disk + os.Remove(keyF.Name()) + ob.Reset() + eb.Reset() + require.NoError(t, ca([]string{"-name", "test-ca", "-duration", "1h", "-out-crt", crtF.Name(), "-out-key", "-"}, ob, eb, nopw)) + assert.Empty(t, eb.String()) + _, _, curve, err := cert.UnmarshalSigningPrivateKeyFromPEM(ob.Bytes()) + require.NoError(t, err) + assert.Equal(t, cert.Curve_CURVE25519, curve) + + // dual stdout is rejected up front + os.Remove(crtF.Name()) + ob.Reset() + eb.Reset() + require.EqualError(t, + ca([]string{"-name", "test-ca", "-duration", "1h", "-out-crt", "-", "-out-key", "-"}, ob, eb, nopw), + `-out-key and -out-crt both set to "-", only one output may write to stdout`) + assert.Empty(t, ob.String()) + + // an output conflict combined with -encrypt must error BEFORE prompting + // for a passphrase; pr would record any read attempt + tracker := &trackingPasswordReader{} + ob.Reset() + eb.Reset() + require.EqualError(t, + ca([]string{"-name", "test-ca", "-duration", "1h", "-encrypt", "-out-crt", "-", "-out-key", "-"}, ob, eb, tracker), + `-out-key and -out-crt both set to "-", only one output may write to stdout`) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + assert.Zero(t, tracker.calls, "passphrase prompt should not have been called") +} + +type trackingPasswordReader struct { + calls int +} + +func (pr *trackingPasswordReader) ReadPassword() ([]byte, error) { + pr.calls++ + return []byte(""), nil +} diff --git a/cmd/nebula-cert/keygen.go b/cmd/nebula-cert/keygen.go index 496f84c2..dea6c4af 100644 --- a/cmd/nebula-cert/keygen.go +++ b/cmd/nebula-cert/keygen.go @@ -42,6 +42,8 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error { if err = mustFlagString("out-key", cf.outKeyPath); err != nil { return err } + } else if *cf.outKeyPath != "" { + return newHelpErrorf("cannot set -out-key with -pkcs11") } if err = mustFlagString("out-pub", cf.outPubPath); err != nil { return err @@ -69,6 +71,14 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error { } } + var claims ioClaims + if err := reserveOutputs(&claims, + "out-key", *cf.outKeyPath, + "out-pub", *cf.outPubPath, + ); err != nil { + return err + } + if isP11 { p11Client, err := pkclient.FromUrl(*cf.p11url) if err != nil { @@ -82,12 +92,12 @@ func keygen(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("error while getting public key: %w", err) } } else { - err = os.WriteFile(*cf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600) + err = writeOutput(*cf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600, out) if err != nil { return fmt.Errorf("error while writing out-key: %s", err) } } - err = os.WriteFile(*cf.outPubPath, cert.MarshalPublicKeyToPEM(curve, pub), 0600) + err = writeOutput(*cf.outPubPath, cert.MarshalPublicKeyToPEM(curve, pub), 0600, out) if err != nil { return fmt.Errorf("error while writing out-pub: %s", err) } @@ -102,6 +112,7 @@ func keygenSummary() string { func keygenHelp(out io.Writer) { cf := newKeygenFlags() _, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + keygenSummary() + "\n")) + _, _ = out.Write([]byte(stdioHelpText)) cf.set.SetOutput(out) cf.set.PrintDefaults() } diff --git a/cmd/nebula-cert/keygen_test.go b/cmd/nebula-cert/keygen_test.go index 95d9893e..98c4c456 100644 --- a/cmd/nebula-cert/keygen_test.go +++ b/cmd/nebula-cert/keygen_test.go @@ -20,6 +20,7 @@ func Test_keygenHelp(t *testing.T) { assert.Equal( t, "Usage of "+os.Args[0]+" keygen : create a public/private key pair. the public key can be passed to `nebula-cert sign`\n"+ + " Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+ " -curve string\n"+ " \tECDH Curve (25519, P256) (default \"25519\")\n"+ " -out-key string\n"+ @@ -93,3 +94,43 @@ func Test_keygen(t *testing.T) { require.NoError(t, err) assert.Len(t, lPub, 32) } + +func Test_keygen_stdio(t *testing.T) { + keyF, err := os.CreateTemp("", "test.key") + require.NoError(t, err) + os.Remove(keyF.Name()) + defer os.Remove(keyF.Name()) + + pubF, err := os.CreateTemp("", "test.pub") + require.NoError(t, err) + os.Remove(pubF.Name()) + defer os.Remove(pubF.Name()) + + // out-pub on stdout, out-key on disk + ob := &bytes.Buffer{} + eb := &bytes.Buffer{} + require.NoError(t, keygen([]string{"-out-pub", "-", "-out-key", keyF.Name()}, ob, eb)) + assert.Empty(t, eb.String()) + lPub, _, curve, err := cert.UnmarshalPublicKeyFromPEM(ob.Bytes()) + require.NoError(t, err) + assert.Equal(t, cert.Curve_CURVE25519, curve) + assert.Len(t, lPub, 32) + + // out-key on stdout, out-pub on disk + os.Remove(keyF.Name()) + ob.Reset() + eb.Reset() + require.NoError(t, keygen([]string{"-out-pub", pubF.Name(), "-out-key", "-"}, ob, eb)) + assert.Empty(t, eb.String()) + lKey, _, curve, err := cert.UnmarshalPrivateKeyFromPEM(ob.Bytes()) + require.NoError(t, err) + assert.Equal(t, cert.Curve_CURVE25519, curve) + assert.Len(t, lKey, 32) + + // both on stdout is a conflict caught up front + ob.Reset() + eb.Reset() + require.EqualError(t, keygen([]string{"-out-pub", "-", "-out-key", "-"}, ob, eb), + `-out-key and -out-pub both set to "-", only one output may write to stdout`) + assert.Empty(t, ob.String()) +} diff --git a/cmd/nebula-cert/passwords.go b/cmd/nebula-cert/passwords.go index 8129560e..0aa2115d 100644 --- a/cmd/nebula-cert/passwords.go +++ b/cmd/nebula-cert/passwords.go @@ -22,7 +22,9 @@ func (pr StdinPasswordReader) ReadPassword() ([]byte, error) { } password, err := term.ReadPassword(int(os.Stdin.Fd())) - fmt.Println() + // Terminal echo is off while reading, so the user's Enter key does not + // produce a visible newline. Emit one on stderr to match the prompt. + fmt.Fprintln(os.Stderr) return password, err } diff --git a/cmd/nebula-cert/print.go b/cmd/nebula-cert/print.go index 30e0965b..3ba0571e 100644 --- a/cmd/nebula-cert/print.go +++ b/cmd/nebula-cert/print.go @@ -40,11 +40,23 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { return err } - rawCert, err := os.ReadFile(*pf.path) + var claims ioClaims + if err := reserveInputs(&claims, "path", *pf.path); err != nil { + return err + } + if err := reserveOutputs(&claims, "out-qr", *pf.outQRPath); err != nil { + return err + } + + rawCert, err := readInput("path", *pf.path, &claims) if err != nil { return fmt.Errorf("unable to read cert; %s", err) } + // When the QR is going to stdout, suppress the human-readable text/json + // output so the binary stream is not contaminated. + qrToStdout := isStdio(*pf.outQRPath) + var c cert.Certificate var qrBytes []byte part := 0 @@ -57,11 +69,13 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("error while unmarshaling cert: %s", err) } - if *pf.json { - jsonCerts = append(jsonCerts, c) - } else { - _, _ = out.Write([]byte(c.String())) - _, _ = out.Write([]byte("\n")) + if !qrToStdout { + if *pf.json { + jsonCerts = append(jsonCerts, c) + } else { + _, _ = out.Write([]byte(c.String())) + _, _ = out.Write([]byte("\n")) + } } if *pf.outQRPath != "" { @@ -79,7 +93,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { part++ } - if *pf.json { + if *pf.json && !qrToStdout { b, _ := json.Marshal(jsonCerts) _, _ = out.Write(b) _, _ = out.Write([]byte("\n")) @@ -91,7 +105,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("error while generating qr code: %s", err) } - err = os.WriteFile(*pf.outQRPath, b, 0600) + err = writeOutput(*pf.outQRPath, b, 0600, out) if err != nil { return fmt.Errorf("error while writing out-qr: %s", err) } @@ -107,6 +121,7 @@ func printSummary() string { func printHelp(out io.Writer) { pf := newPrintFlags() out.Write([]byte("Usage of " + os.Args[0] + " " + printSummary() + "\n")) + out.Write([]byte(stdioHelpText)) pf.set.SetOutput(out) pf.set.PrintDefaults() } diff --git a/cmd/nebula-cert/print_test.go b/cmd/nebula-cert/print_test.go index 221ab778..8d5d31be 100644 --- a/cmd/nebula-cert/print_test.go +++ b/cmd/nebula-cert/print_test.go @@ -25,6 +25,7 @@ func Test_printHelp(t *testing.T) { assert.Equal( t, "Usage of "+os.Args[0]+" print : prints details about a certificate\n"+ + " Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+ " -json\n"+ " \tOptional: outputs certificates in json format\n"+ " -out-qr string\n"+ @@ -178,6 +179,44 @@ func Test_printCert(t *testing.T) { ob.String(), ) assert.Empty(t, eb.String()) + + // read cert from stdin + ob.Reset() + eb.Reset() + withStdin(t, bytes.NewReader(p)) + err = printCert([]string{"-json", "-path", "-"}, ob, eb) + require.NoError(t, err) + assert.Equal( + t, + `[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}] +`, + ob.String(), + ) + assert.Empty(t, eb.String()) + + // -out-qr - sends only the PNG to stdout, suppressing the cert dump + ob.Reset() + eb.Reset() + withStdin(t, bytes.NewReader(p)) + err = printCert([]string{"-path", "-", "-out-qr", "-"}, ob, eb) + require.NoError(t, err) + assert.Empty(t, eb.String()) + stdout := ob.Bytes() + require.NotEmpty(t, stdout) + // PNG magic, no PEM/JSON noise prepended + assert.Equal(t, []byte{0x89, 'P', 'N', 'G', 0x0d, 0x0a, 0x1a, 0x0a}, stdout[:8]) + assert.NotContains(t, string(stdout), "NebulaCertificate") + assert.NotContains(t, string(stdout), `"details"`) + + // json + out-qr - still suppresses json + ob.Reset() + eb.Reset() + withStdin(t, bytes.NewReader(p)) + err = printCert([]string{"-json", "-path", "-", "-out-qr", "-"}, ob, eb) + require.NoError(t, err) + assert.Empty(t, eb.String()) + assert.Equal(t, []byte{0x89, 'P', 'N', 'G'}, ob.Bytes()[:4]) + assert.NotContains(t, ob.String(), `"details"`) } // NewTestCaCert will generate a CA cert diff --git a/cmd/nebula-cert/sign.go b/cmd/nebula-cert/sign.go index 561138ca..9b57c4fe 100644 --- a/cmd/nebula-cert/sign.go +++ b/cmd/nebula-cert/sign.go @@ -85,6 +85,9 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" { return newHelpErrorf("cannot set both -in-pub and -out-key") } + if isP11 && *sf.outKeyPath != "" { + return newHelpErrorf("cannot set -out-key with -pkcs11") + } var v4Networks []netip.Prefix var v6Networks []netip.Prefix @@ -102,13 +105,35 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2) } + if *sf.outKeyPath == "" { + *sf.outKeyPath = *sf.name + ".key" + } + if *sf.outCertPath == "" { + *sf.outCertPath = *sf.name + ".crt" + } + + var claims ioClaims + if err := reserveInputs(&claims, + "ca-key", *sf.caKeyPath, + "ca-crt", *sf.caCertPath, + "in-pub", *sf.inPubPath, + ); err != nil { + return err + } + if err := reserveOutputs(&claims, + "out-key", *sf.outKeyPath, + "out-crt", *sf.outCertPath, + "out-qr", *sf.outQRPath, + ); err != nil { + return err + } + var curve cert.Curve var caKey []byte if !isP11 { var rawCAKey []byte - rawCAKey, err := os.ReadFile(*sf.caKeyPath) - + rawCAKey, err = readInput("ca-key", *sf.caKeyPath, &claims) if err != nil { return fmt.Errorf("error while reading ca-key: %s", err) } @@ -121,7 +146,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) if len(passphrase) == 0 { // ask for a passphrase until we get one for i := 0; i < 5; i++ { - out.Write([]byte("Enter passphrase: ")) + errOut.Write([]byte("Enter passphrase: ")) passphrase, err = pr.ReadPassword() if errors.Is(err, ErrNoTerminal) { @@ -147,7 +172,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) } } - rawCACert, err := os.ReadFile(*sf.caCertPath) + rawCACert, err := readInput("ca-crt", *sf.caCertPath, &claims) if err != nil { return fmt.Errorf("error while reading ca-crt: %s", err) } @@ -245,7 +270,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) if *sf.inPubPath != "" { var pubCurve cert.Curve - rawPub, err := os.ReadFile(*sf.inPubPath) + rawPub, err := readInput("in-pub", *sf.inPubPath, &claims) if err != nil { return fmt.Errorf("error while reading in-pub: %s", err) } @@ -266,16 +291,10 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) pub, rawPriv = newKeypair(curve) } - if *sf.outKeyPath == "" { - *sf.outKeyPath = *sf.name + ".key" - } - - if *sf.outCertPath == "" { - *sf.outCertPath = *sf.name + ".crt" - } - - if _, err := os.Stat(*sf.outCertPath); err == nil { - return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath) + if !isStdio(*sf.outCertPath) { + if _, err := os.Stat(*sf.outCertPath); err == nil { + return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath) + } } var crts []cert.Certificate @@ -360,11 +379,13 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) } if !isP11 && *sf.inPubPath == "" { - if _, err := os.Stat(*sf.outKeyPath); err == nil { - return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath) + if !isStdio(*sf.outKeyPath) { + if _, err := os.Stat(*sf.outKeyPath); err == nil { + return fmt.Errorf("refusing to overwrite existing key: %s", *sf.outKeyPath) + } } - err = os.WriteFile(*sf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600) + err = writeOutput(*sf.outKeyPath, cert.MarshalPrivateKeyToPEM(curve, rawPriv), 0600, out) if err != nil { return fmt.Errorf("error while writing out-key: %s", err) } @@ -379,7 +400,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) b = append(b, sb...) } - err = os.WriteFile(*sf.outCertPath, b, 0600) + err = writeOutput(*sf.outCertPath, b, 0600, out) if err != nil { return fmt.Errorf("error while writing out-crt: %s", err) } @@ -390,7 +411,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return fmt.Errorf("error while generating qr code: %s", err) } - err = os.WriteFile(*sf.outQRPath, b, 0600) + err = writeOutput(*sf.outQRPath, b, 0600, out) if err != nil { return fmt.Errorf("error while writing out-qr: %s", err) } @@ -440,6 +461,7 @@ func signSummary() string { func signHelp(out io.Writer) { sf := newSignFlags() out.Write([]byte("Usage of " + os.Args[0] + " " + signSummary() + "\n")) + out.Write([]byte(stdioHelpText)) sf.set.SetOutput(out) sf.set.PrintDefaults() } diff --git a/cmd/nebula-cert/sign_test.go b/cmd/nebula-cert/sign_test.go index f5f8cbb0..64d5c7d9 100644 --- a/cmd/nebula-cert/sign_test.go +++ b/cmd/nebula-cert/sign_test.go @@ -27,6 +27,7 @@ func Test_signHelp(t *testing.T) { assert.Equal( t, "Usage of "+os.Args[0]+" sign : create and sign a certificate\n"+ + " Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+ " -ca-crt string\n"+ " \tOptional: path to the signing CA cert (default \"ca.crt\")\n"+ " -ca-key string\n"+ @@ -376,15 +377,18 @@ func Test_signCert(t *testing.T) { // test with the proper password args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} require.NoError(t, signCert(args, ob, eb, testpw)) - assert.Equal(t, "Enter passphrase: ", ob.String()) - assert.Empty(t, eb.String()) + assert.Empty(t, ob.String()) + assert.Equal(t, "Enter passphrase: ", eb.String()) // test with the proper password in the environment os.Remove(crtF.Name()) os.Remove(keyF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} os.Setenv("NEBULA_CA_PASSPHRASE", string(passphrase)) + ob.Reset() + eb.Reset() require.NoError(t, signCert(args, ob, eb, testpw)) + assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) os.Setenv("NEBULA_CA_PASSPHRASE", "") @@ -395,8 +399,8 @@ func Test_signCert(t *testing.T) { testpw.password = []byte("invalid password") args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} require.Error(t, signCert(args, ob, eb, testpw)) - assert.Equal(t, "Enter passphrase: ", ob.String()) - assert.Empty(t, eb.String()) + assert.Empty(t, ob.String()) + assert.Equal(t, "Enter passphrase: ", eb.String()) // test with the wrong password in environment ob.Reset() @@ -416,8 +420,8 @@ func Test_signCert(t *testing.T) { args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} require.Error(t, signCert(args, ob, eb, nopw)) // normally the user hitting enter on the prompt would add newlines between these - assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String()) - assert.Empty(t, eb.String()) + assert.Empty(t, ob.String()) + assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", eb.String()) // test an error condition ob.Reset() @@ -425,6 +429,106 @@ func Test_signCert(t *testing.T) { args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} require.Error(t, signCert(args, ob, eb, errpw)) - assert.Equal(t, "Enter passphrase: ", ob.String()) - assert.Empty(t, eb.String()) + assert.Empty(t, ob.String()) + assert.Equal(t, "Enter passphrase: ", eb.String()) +} + +func Test_signCert_stdio(t *testing.T) { + nopw := &StubPasswordReader{ + password: []byte(""), + err: nil, + } + + caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader) + rawCAKey := cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv) + + ca, _ := NewTestCaCert("ca", caPub, caPriv, time.Now(), time.Now().Add(time.Minute*200), nil, nil, nil) + rawCACrt, _ := ca.MarshalPEM() + + caCrtF, err := os.CreateTemp("", "sign-cert.crt") + require.NoError(t, err) + defer os.Remove(caCrtF.Name()) + caCrtF.Write(rawCACrt) + + caKeyF, err := os.CreateTemp("", "sign-cert.key") + require.NoError(t, err) + defer os.Remove(caKeyF.Name()) + caKeyF.Write(rawCAKey) + + keyF, err := os.CreateTemp("", "sign.key") + require.NoError(t, err) + os.Remove(keyF.Name()) + defer os.Remove(keyF.Name()) + + // ca-key on stdin, cert to stdout + withStdin(t, bytes.NewReader(rawCAKey)) + ob := &bytes.Buffer{} + eb := &bytes.Buffer{} + args := []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", "-", "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", "-", "-out-key", keyF.Name(), "-duration", "100m"} + require.NoError(t, signCert(args, ob, eb, nopw)) + assert.Empty(t, eb.String()) + + lCrt, _, err := cert.UnmarshalCertificateFromPEM(ob.Bytes()) + require.NoError(t, err) + assert.Equal(t, "stdin-test", lCrt.Name()) + assert.True(t, lCrt.CheckSignature(caPub)) + + // two flags reading from stdin should error before any read attempt; + // otherwise an interactive shell would hang on io.ReadAll + stdinIn := bytes.NewReader(rawCAKey) + withStdin(t, stdinIn) + ob.Reset() + eb.Reset() + args = []string{"-version", "1", "-ca-crt", "-", "-ca-key", "-", "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + require.EqualError(t, signCert(args, ob, eb, nopw), + `-ca-key and -ca-crt both set to "-", only one input may read from stdin`) + assert.Equal(t, len(rawCAKey), stdinIn.Len(), "stdin should be untouched when conflict is caught up front") + + // two flags writing to stdout should error before any output is written + // AND before stdin is consumed + stdinR := bytes.NewReader(rawCAKey) + withStdin(t, stdinR) + ob.Reset() + eb.Reset() + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", "-", "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", "-", "-out-key", "-", "-duration", "100m"} + require.EqualError(t, signCert(args, ob, eb, nopw), + `-out-key and -out-crt both set to "-", only one output may write to stdout`) + assert.Empty(t, ob.String()) + // stdin should be untouched because the conflict was caught up front + assert.Equal(t, len(rawCAKey), stdinR.Len()) + + // out-key on stdout, cert on disk + keyF2, err := os.CreateTemp("", "sign.key") + require.NoError(t, err) + os.Remove(keyF2.Name()) + defer os.Remove(keyF2.Name()) + crtF, err := os.CreateTemp("", "sign.crt") + require.NoError(t, err) + os.Remove(crtF.Name()) + defer os.Remove(crtF.Name()) + + ob.Reset() + eb.Reset() + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "stdin-test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", "-", "-duration", "100m"} + require.NoError(t, signCert(args, ob, eb, nopw)) + assert.Empty(t, eb.String()) + _, _, curve, err := cert.UnmarshalPrivateKeyFromPEM(ob.Bytes()) + require.NoError(t, err) + assert.Equal(t, cert.Curve_CURVE25519, curve) + + // in-pub on stdin (caller already has a keypair, only the cert is generated) + inPub, _ := x25519Keypair() + rawInPub := cert.MarshalPublicKeyToPEM(cert.Curve_CURVE25519, inPub) + + withStdin(t, bytes.NewReader(rawInPub)) + os.Remove(crtF.Name()) + ob.Reset() + eb.Reset() + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "in-pub-test", "-ip", "1.1.1.1/24", "-in-pub", "-", "-out-crt", "-", "-duration", "100m"} + require.NoError(t, signCert(args, ob, eb, nopw)) + assert.Empty(t, eb.String()) + stdinCrt, _, err := cert.UnmarshalCertificateFromPEM(ob.Bytes()) + require.NoError(t, err) + assert.Equal(t, "in-pub-test", stdinCrt.Name()) + assert.Equal(t, inPub, stdinCrt.PublicKey()) } diff --git a/cmd/nebula-cert/stdio.go b/cmd/nebula-cert/stdio.go new file mode 100644 index 00000000..3f71d52f --- /dev/null +++ b/cmd/nebula-cert/stdio.go @@ -0,0 +1,117 @@ +package main + +import ( + "fmt" + "io" + "os" +) + +// stdioPath is the special path value that selects stdin (for inputs) or +// stdout (for outputs) instead of a file on disk. +const stdioPath = "-" + +// stdioHelpText is rendered just under the Usage line of each subcommand +// help so the - convention is documented once instead of on every flag. +const stdioHelpText = " Pass \"-\" to any path flag to read from stdin or write to stdout.\n" + +// stdinReader is the source used when an input flag is set to "-". +// It is a package level var so tests can swap in a deterministic reader. +// Tests that mutate stdinReader cannot run with t.Parallel(). +var stdinReader io.Reader = os.Stdin + +// ioClaims tracks which flags have claimed stdin and stdout during a single +// command invocation so we can refuse a second flag asking for the same +// stream. +type ioClaims struct { + in string + out string +} + +func (c *ioClaims) claimIn(flagName string) error { + if c.in != "" && c.in != flagName { + return fmt.Errorf("-%s and -%s both set to %q, only one input may read from stdin", c.in, flagName, stdioPath) + } + c.in = flagName + return nil +} + +func (c *ioClaims) claimOut(flagName string) error { + if c.out != "" && c.out != flagName { + return fmt.Errorf("-%s and -%s both set to %q, only one output may write to stdout", c.out, flagName, stdioPath) + } + c.out = flagName + return nil +} + +// reserveInputs walks alternating (flagName, path) pairs and claims stdin +// for any path equal to stdioPath. It must be called before any input is +// read so a conflict can be reported immediately instead of blocking on +// io.ReadAll while waiting for input that will never arrive. +func reserveInputs(claims *ioClaims, pairs ...string) error { + return reserveStdio(claims, "reserveInputs", (*ioClaims).claimIn, pairs) +} + +// reserveOutputs walks alternating (flagName, path) pairs and claims stdout +// for any path equal to stdioPath. It must be called before any output is +// written so a conflict cannot leave one stream half written before the +// second flag fails. +func reserveOutputs(claims *ioClaims, pairs ...string) error { + return reserveStdio(claims, "reserveOutputs", (*ioClaims).claimOut, pairs) +} + +func reserveStdio(claims *ioClaims, who string, claim func(*ioClaims, string) error, pairs []string) error { + if len(pairs)%2 != 0 { + panic(who + " requires alternating name, path pairs") + } + for i := 0; i < len(pairs); i += 2 { + name, path := pairs[i], pairs[i+1] + if path != stdioPath { + continue + } + if err := claim(claims, name); err != nil { + return err + } + } + return nil +} + +// readInput returns the bytes referenced by path, reading from stdin when +// path is stdioPath. +func readInput(flagName, path string, claims *ioClaims) ([]byte, error) { + if path == stdioPath { + if err := claims.claimIn(flagName); err != nil { + return nil, err + } + return io.ReadAll(stdinReader) + } + return os.ReadFile(path) +} + +// openInput returns a reader for path. When path is stdioPath the returned +// reader wraps stdin and Close is a no-op. +func openInput(flagName, path string, claims *ioClaims) (io.ReadCloser, error) { + if path == stdioPath { + if err := claims.claimIn(flagName); err != nil { + return nil, err + } + return io.NopCloser(stdinReader), nil + } + return os.Open(path) +} + +// writeOutput writes data to path, or to stdout when path is stdioPath. perm +// is only used for file output. The caller must have already claimed stdout +// via reserveOutputs before invoking with stdioPath. +func writeOutput(path string, data []byte, perm os.FileMode, stdout io.Writer) error { + if path == stdioPath { + _, err := stdout.Write(data) + return err + } + return os.WriteFile(path, data, perm) +} + +// isStdio reports whether path is the stdio sentinel and so should skip +// existence checks like "refuse to overwrite". +func isStdio(path string) bool { + return path == stdioPath +} diff --git a/cmd/nebula-cert/stdio_test.go b/cmd/nebula-cert/stdio_test.go new file mode 100644 index 00000000..dc87a597 --- /dev/null +++ b/cmd/nebula-cert/stdio_test.go @@ -0,0 +1,167 @@ +package main + +import ( + "bytes" + "io" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// withStdin temporarily replaces stdinReader for the duration of t. +func withStdin(t *testing.T, r io.Reader) { + t.Helper() + prev := stdinReader + stdinReader = r + t.Cleanup(func() { stdinReader = prev }) +} + +func Test_readInput_stdin(t *testing.T) { + withStdin(t, bytes.NewBufferString("hello")) + var claims ioClaims + + got, err := readInput("path", "-", &claims) + require.NoError(t, err) + assert.Equal(t, []byte("hello"), got) + assert.Equal(t, "path", claims.in) +} + +func Test_readInput_file(t *testing.T) { + dir := t.TempDir() + p := filepath.Join(dir, "f") + require.NoError(t, os.WriteFile(p, []byte("file"), 0600)) + var claims ioClaims + + got, err := readInput("path", p, &claims) + require.NoError(t, err) + assert.Equal(t, []byte("file"), got) + assert.Empty(t, claims.in) +} + +func Test_readInput_doubleStdinErrors(t *testing.T) { + withStdin(t, bytes.NewBufferString("hello")) + var claims ioClaims + + _, err := readInput("ca-key", "-", &claims) + require.NoError(t, err) + + _, err = readInput("ca-crt", "-", &claims) + require.EqualError(t, err, `-ca-key and -ca-crt both set to "-", only one input may read from stdin`) +} + +func Test_openInput_stdin(t *testing.T) { + withStdin(t, bytes.NewBufferString("hi")) + var claims ioClaims + + r, err := openInput("ca", "-", &claims) + require.NoError(t, err) + defer r.Close() + b, err := io.ReadAll(r) + require.NoError(t, err) + assert.Equal(t, []byte("hi"), b) +} + +func Test_openInput_doubleStdinErrors(t *testing.T) { + withStdin(t, bytes.NewBufferString("hi")) + var claims ioClaims + + r, err := openInput("ca", "-", &claims) + require.NoError(t, err) + r.Close() + + _, err = openInput("crt", "-", &claims) + require.EqualError(t, err, `-ca and -crt both set to "-", only one input may read from stdin`) +} + +func Test_writeOutput_stdout(t *testing.T) { + out := &bytes.Buffer{} + + err := writeOutput("-", []byte("payload"), 0600, out) + require.NoError(t, err) + assert.Equal(t, "payload", out.String()) +} + +func Test_writeOutput_file(t *testing.T) { + dir := t.TempDir() + p := filepath.Join(dir, "f") + out := &bytes.Buffer{} + + err := writeOutput(p, []byte("payload"), 0600, out) + require.NoError(t, err) + assert.Empty(t, out.String()) + got, err := os.ReadFile(p) + require.NoError(t, err) + assert.Equal(t, []byte("payload"), got) +} + +func Test_reserveOutputs_noConflict(t *testing.T) { + var claims ioClaims + require.NoError(t, reserveOutputs(&claims, + "out-key", "/tmp/key", + "out-crt", "-", + "out-qr", "", + )) + assert.Equal(t, "out-crt", claims.out) +} + +func Test_reserveOutputs_conflict(t *testing.T) { + var claims ioClaims + err := reserveOutputs(&claims, + "out-key", "-", + "out-crt", "-", + ) + require.EqualError(t, err, `-out-key and -out-crt both set to "-", only one output may write to stdout`) +} + +func Test_reserveOutputs_panicsOnOddPairs(t *testing.T) { + defer func() { + r := recover() + require.NotNil(t, r) + }() + var claims ioClaims + _ = reserveOutputs(&claims, "out-key") +} + +func Test_reserveInputs_noConflict(t *testing.T) { + var claims ioClaims + require.NoError(t, reserveInputs(&claims, + "ca-key", "/tmp/ca.key", + "ca-crt", "-", + "in-pub", "", + )) + assert.Equal(t, "ca-crt", claims.in) +} + +func Test_reserveInputs_conflict(t *testing.T) { + var claims ioClaims + err := reserveInputs(&claims, + "ca-key", "-", + "ca-crt", "-", + ) + require.EqualError(t, err, `-ca-key and -ca-crt both set to "-", only one input may read from stdin`) +} + +func Test_claimIn_idempotent(t *testing.T) { + // pre-claim then a lazy re-claim of the same flag should be a no-op + var claims ioClaims + require.NoError(t, claims.claimIn("ca-key")) + require.NoError(t, claims.claimIn("ca-key")) + assert.Equal(t, "ca-key", claims.in) +} + +func Test_claimOut_idempotent(t *testing.T) { + var claims ioClaims + require.NoError(t, claims.claimOut("out-crt")) + require.NoError(t, claims.claimOut("out-crt")) + assert.Equal(t, "out-crt", claims.out) +} + +func Test_isStdio(t *testing.T) { + assert.True(t, isStdio("-")) + assert.False(t, isStdio("")) + assert.False(t, isStdio("./-")) + assert.False(t, isStdio("foo")) +} diff --git a/cmd/nebula-cert/verify.go b/cmd/nebula-cert/verify.go index 36258dd8..76d3dbe6 100644 --- a/cmd/nebula-cert/verify.go +++ b/cmd/nebula-cert/verify.go @@ -39,18 +39,26 @@ func verify(args []string, out io.Writer, errOut io.Writer) error { return err } - caFile, err := os.Open(*vf.caPath) + var claims ioClaims + if err := reserveInputs(&claims, + "ca", *vf.caPath, + "crt", *vf.certPath, + ); err != nil { + return err + } + + caReader, err := openInput("ca", *vf.caPath, &claims) if err != nil { return fmt.Errorf("error while reading ca: %w", err) } - defer caFile.Close() + defer caReader.Close() - caPool, err := cert.NewCAPoolFromPEMReader(caFile) + caPool, err := cert.NewCAPoolFromPEMReader(caReader) if err != nil && !errors.Is(err, cert.ErrExpired) { return fmt.Errorf("error while adding ca cert to pool: %w", err) } - rawCert, err := os.ReadFile(*vf.certPath) + rawCert, err := readInput("crt", *vf.certPath, &claims) if err != nil { return fmt.Errorf("unable to read crt: %w", err) } @@ -85,6 +93,7 @@ func verifySummary() string { func verifyHelp(out io.Writer) { vf := newVerifyFlags() _, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n")) + _, _ = out.Write([]byte(stdioHelpText)) vf.set.SetOutput(out) vf.set.PrintDefaults() } diff --git a/cmd/nebula-cert/verify_test.go b/cmd/nebula-cert/verify_test.go index 1aa5e8e6..aa089d0e 100644 --- a/cmd/nebula-cert/verify_test.go +++ b/cmd/nebula-cert/verify_test.go @@ -23,6 +23,7 @@ func Test_verifyHelp(t *testing.T) { assert.Equal( t, "Usage of "+os.Args[0]+" verify : verifies a certificate isn't expired and was signed by a trusted authority.\n"+ + " Pass \"-\" to any path flag to read from stdin or write to stdout.\n"+ " -ca string\n"+ " \tRequired: path to a file containing one or more ca certificates\n"+ " -crt string\n"+ @@ -122,3 +123,46 @@ func Test_verify(t *testing.T) { assert.Empty(t, eb.String()) require.NoError(t, err) } + +func Test_verify_stdio(t *testing.T) { + ob := &bytes.Buffer{} + eb := &bytes.Buffer{} + + caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader) + ca, _ := NewTestCaCert("test-ca", caPub, caPriv, time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour*2), nil, nil, nil) + caPEM, _ := ca.MarshalPEM() + + crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) + crtPEM, _ := crt.MarshalPEM() + + caFile, err := os.CreateTemp("", "verify-ca") + require.NoError(t, err) + defer os.Remove(caFile.Name()) + caFile.Write(caPEM) + + // crt on stdin, ca on disk + withStdin(t, bytes.NewReader(crtPEM)) + require.NoError(t, verify([]string{"-ca", caFile.Name(), "-crt", "-"}, ob, eb)) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + + // ca on stdin, crt on disk + certFile, err := os.CreateTemp("", "verify-cert") + require.NoError(t, err) + defer os.Remove(certFile.Name()) + certFile.Write(crtPEM) + + withStdin(t, bytes.NewReader(caPEM)) + ob.Reset() + eb.Reset() + require.NoError(t, verify([]string{"-ca", "-", "-crt", certFile.Name()}, ob, eb)) + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + + // both flags on stdin should error + withStdin(t, bytes.NewReader(caPEM)) + ob.Reset() + eb.Reset() + require.EqualError(t, verify([]string{"-ca", "-", "-crt", "-"}, ob, eb), + `-ca and -crt both set to "-", only one input may read from stdin`) +} From 99c5854e5c87525a1ebfca3be233e9e25ef2b573 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 15 May 2026 15:36:26 -0500 Subject: [PATCH 22/31] Prime some critical stats before the first scrape (#1715) --- interface.go | 38 ++++++++++++++---------- interface_test.go | 73 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 15 deletions(-) create mode 100644 interface_test.go diff --git a/interface.go b/interface.go index 5fedcdd3..32f5c2a6 100644 --- a/interface.go +++ b/interface.go @@ -491,26 +491,34 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { certInitiatingVersion := metrics.GetOrRegisterGauge("certificate.initiating_version", nil) certMaxVersion := metrics.GetOrRegisterGauge("certificate.max_version", nil) + emit := func() { + f.firewall.EmitStats() + f.handshakeManager.EmitStats() + udpStats() + + certState := f.pki.getCertState() + defaultCrt := certState.GetDefaultCertificate() + certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second)) + certInitiatingVersion.Update(int64(defaultCrt.Version())) + + // Report the max certificate version we are capable of using + if certState.v2Cert != nil { + certMaxVersion.Update(int64(certState.v2Cert.Version())) + } else { + certMaxVersion.Update(int64(certState.v1Cert.Version())) + } + } + + // Prime gauges so a Prometheus scrape that lands before the first tick + // sees real values instead of the zero defaults (issue #907). + emit() + for { select { case <-ctx.Done(): return case <-ticker.C: - f.firewall.EmitStats() - f.handshakeManager.EmitStats() - udpStats() - - certState := f.pki.getCertState() - defaultCrt := certState.GetDefaultCertificate() - certExpirationGauge.Update(int64(defaultCrt.NotAfter().Sub(time.Now()) / time.Second)) - certInitiatingVersion.Update(int64(defaultCrt.Version())) - - // Report the max certificate version we are capable of using - if certState.v2Cert != nil { - certMaxVersion.Update(int64(certState.v2Cert.Version())) - } else { - certMaxVersion.Update(int64(certState.v1Cert.Version())) - } + emit() } } } diff --git a/interface_test.go b/interface_test.go new file mode 100644 index 00000000..b0a9d025 --- /dev/null +++ b/interface_test.go @@ -0,0 +1,73 @@ +//go:build linux || darwin + +package nebula + +import ( + "context" + "net/netip" + "testing" + "time" + + "github.com/rcrowley/go-metrics" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/firewall" + "github.com/slackhq/nebula/overlay/overlaytest" + "github.com/slackhq/nebula/test" + "github.com/slackhq/nebula/udp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test_emitStats_primesGauges covers issue #907: a Prometheus scrape that +// landed before the first ticker fire used to read 0 for the cert gauges. +// emitStats now primes the gauges before entering the ticker loop. We assert +// the gauge is zero before the first call and non-zero after. +func Test_emitStats_primesGauges(t *testing.T) { + defer metrics.DefaultRegistry.UnregisterAll() + + l := test.NewLogger() + hostMap := newHostMap(l) + preferredRanges := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")} + hostMap.preferredRanges.Store(&preferredRanges) + + notAfter := time.Now().Add(time.Hour) + cs := &CertState{ + initiatingVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1, notAfter: notAfter}, + v1Credential: nil, + } + + lh := newTestLighthouse() + ifce := &Interface{ + hostMap: hostMap, + inside: &overlaytest.NoopTun{}, + outside: &udp.NoopConn{}, + firewall: &Firewall{Conntrack: &FirewallConntrack{Conns: map[firewall.Packet]*conn{}}}, + lightHouse: lh, + pki: &PKI{}, + handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), + l: l, + // On linux, udp.NewUDPStatsEmitter indexes writers[0] and asserts to + // *udp.StdConn. A zero value works: getMemInfo sees a nil rawConn, + // returns an error, and the emitter falls through to a no-op. + writers: []udp.Conn{&udp.StdConn{}}, + } + ifce.pki.cs.Store(cs) + + ttlGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil) + require.Zero(t, ttlGauge.Value(), "gauge should be zero before emitStats runs") + + // Pre-cancel the context so emitStats returns after priming the gauges + // without ever reading from ticker.C. The one hour interval is just a + // belt-and-suspenders, the test does not expect the ticker to fire. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + ifce.emitStats(ctx, time.Hour) + + ttl := ttlGauge.Value() + assert.Positive(t, ttl, "ttl gauge should be primed by emitStats before its first tick") + assert.LessOrEqual(t, ttl, int64(3600)) + assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.initiating_version", nil).Value()) + assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.max_version", nil).Value()) +} From 625f58b84adc778895b20a3dd74b2e2190c83132 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 15 May 2026 15:36:44 -0500 Subject: [PATCH 23/31] Record my local details in the dns server if enabled (#1716) --- dns_server.go | 118 +++++++++++++++++++++++++++++++++++---------- dns_server_test.go | 89 ++++++++++++++++++++++++++++++++++ main.go | 2 +- 3 files changed, 182 insertions(+), 27 deletions(-) diff --git a/dns_server.go b/dns_server.go index ff1369ab..a80630b5 100644 --- a/dns_server.go +++ b/dns_server.go @@ -11,19 +11,21 @@ import ( "sync" "sync/atomic" - "github.com/gaissmai/bart" "github.com/miekg/dns" "github.com/slackhq/nebula/config" ) type dnsServer struct { sync.RWMutex - l *slog.Logger - ctx context.Context - dnsMap4 map[string]netip.Addr - dnsMap6 map[string]netip.Addr - hostMap *HostMap - myVpnAddrsTable *bart.Lite + l *slog.Logger + ctx context.Context + dnsMap4 map[string]netip.Addr + dnsMap6 map[string]netip.Addr + hostMap *HostMap + pki *PKI + + // selfHost is the cached FQDN we last seeded for ourselves + selfHost string mux *dns.ServeMux @@ -55,14 +57,14 @@ type dnsServer struct { // they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel // watcher that tears the listener down on nebula shutdown. The returned // pointer is always non-nil, even on error. -func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) { +func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, pki *PKI, hostMap *HostMap, c *config.C) (*dnsServer, error) { ds := &dnsServer{ - l: l, - ctx: ctx, - dnsMap4: make(map[string]netip.Addr), - dnsMap6: make(map[string]netip.Addr), - hostMap: hostMap, - myVpnAddrsTable: cs.myVpnAddrsTable, + l: l, + ctx: ctx, + dnsMap4: make(map[string]netip.Addr), + dnsMap6: make(map[string]netip.Addr), + hostMap: hostMap, + pki: pki, } ds.mux = dns.NewServeMux() ds.mux.HandleFunc(".", ds.handleDnsRequest) @@ -76,6 +78,7 @@ func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState, if err := ds.reload(c, true); err != nil { return ds, err } + ds.seedSelf() return ds, nil } @@ -113,7 +116,7 @@ func (d *dnsServer) reload(c *config.C, initial bool) error { d.Stop() } // Drop any records that accumulated while enabled; a later re-enable - // will repopulate from fresh handshakes. + // will repopulate from fresh handshakes and a fresh seedSelf. d.clearRecords() return nil } @@ -121,17 +124,14 @@ func (d *dnsServer) reload(c *config.C, initial bool) error { if running == nil { // Was disabled (or never started); bring it up now. go d.Start() - return nil + } else if !sameAddr { + d.shutdownServer(running, runningStarted, "reload") + // Old Start goroutine has now exited; bring up a fresh listener on the new address. + go d.Start() } - if sameAddr { - return nil - } - - d.shutdownServer(running, runningStarted, "reload") - // Old Start goroutine has now exited; bring up a fresh listener on the - // new address. - go d.Start() + // Refresh the self entry every enabled reload so cert renewals that change our name or VPN addresses are picked up. + d.seedSelf() return nil } @@ -249,6 +249,20 @@ func (d *dnsServer) QueryCert(data string) string { return "" } + // The hostmap only ever contains peers we have handshaked with, so it never carries an entry for ourselves. + // Answer self lookups straight from the local cert state. + if cs := d.certState(); cs != nil && cs.myVpnAddrsTable != nil && cs.myVpnAddrsTable.Contains(ip) { + c := cs.GetDefaultCertificate() + if c == nil { + return "" + } + b, err := c.MarshalJSON() + if err != nil { + return "" + } + return string(b) + } + hostinfo := d.hostMap.QueryVpnAddr(ip) if hostinfo == nil { return "" @@ -266,12 +280,60 @@ func (d *dnsServer) QueryCert(data string) string { return string(b) } -// clearRecords drops all DNS records. +// clearRecords drops all DNS records, including the self entry. func (d *dnsServer) clearRecords() { d.Lock() defer d.Unlock() clear(d.dnsMap4) clear(d.dnsMap6) + d.selfHost = "" +} + +// seedSelf inserts (or refreshes) a record for our own cert name pointing at our VPN addresses, +// so a single-lighthouse network can resolve the lighthouse's own hostname without the two-process workaround. +func (d *dnsServer) seedSelf() { + if !d.enabled.Load() { + return + } + cs := d.certState() + if cs == nil { + return + } + c := cs.GetDefaultCertificate() + if c == nil { + return + } + newHost := strings.ToLower(c.Name()) + "." + + d.Lock() + defer d.Unlock() + if d.selfHost != "" && d.selfHost != newHost { + delete(d.dnsMap4, d.selfHost) + delete(d.dnsMap6, d.selfHost) + } + d.selfHost = newHost + delete(d.dnsMap4, newHost) + delete(d.dnsMap6, newHost) + haveV4, haveV6 := false, false + for _, addr := range cs.myVpnAddrs { + if addr.Is4() && !haveV4 { + d.dnsMap4[newHost] = addr + haveV4 = true + } else if addr.Is6() && !haveV6 { + d.dnsMap6[newHost] = addr + haveV6 = true + } + if haveV4 && haveV6 { + break + } + } +} + +func (d *dnsServer) certState() *CertState { + if d.pki == nil { + return nil + } + return d.pki.getCertState() } // Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host` @@ -309,8 +371,12 @@ func (d *dnsServer) isSelfNebulaOrLocalhost(addr string) bool { return true } + cs := d.certState() + if cs == nil || cs.myVpnAddrsTable == nil { + return false + } //if we found it in this table, it's good - return d.myVpnAddrsTable.Contains(b) + return cs.myVpnAddrsTable.Contains(b) } func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) { diff --git a/dns_server_test.go b/dns_server_test.go index dcea046c..58646937 100644 --- a/dns_server_test.go +++ b/dns_server_test.go @@ -9,7 +9,10 @@ import ( "testing" "time" + "github.com/gaissmai/bart" "github.com/miekg/dns" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -276,6 +279,92 @@ func TestDnsServer_Stop_beforeBind_doesNotHang(t *testing.T) { } } +// newTestPKI builds a minimal *PKI with a single v1 cert whose name and +// VPN addresses are caller-provided, suitable for exercising seedSelf and +// QueryCert self handling. +func newTestPKI(t *testing.T, name string, addrs []netip.Addr) *PKI { + t.Helper() + networks := make([]netip.Prefix, 0, len(addrs)) + for _, a := range addrs { + bits := 32 + if a.Is6() { + bits = 128 + } + networks = append(networks, netip.PrefixFrom(a, bits)) + } + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) + c, _, _, _ := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, ca, caKey, name, time.Time{}, time.Time{}, networks, nil, nil) + + addrsTable := new(bart.Lite) + for _, a := range addrs { + addrsTable.Insert(netip.PrefixFrom(a, a.BitLen())) + } + + cs := &CertState{ + v2Cert: c, + initiatingVersion: cert.Version2, + myVpnAddrs: addrs, + myVpnAddrsTable: addrsTable, + } + pki := &PKI{} + pki.cs.Store(cs) + return pki +} + +func TestDnsServer_seedSelf_addsOwnRecord(t *testing.T) { + ds, c := newTestDnsServer(t) + myV4 := netip.MustParseAddr("10.0.0.1") + myV6 := netip.MustParseAddr("fd00::1") + ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{myV4, myV6}) + setDnsConfig(c, "127.0.0.1", "0", true, true) + require.NoError(t, ds.reload(c, true)) + + ds.seedSelf() + got4, exists := ds.Query(dns.TypeA, "lighthouse.") + assert.True(t, exists) + assert.Equal(t, myV4, got4) + got6, exists := ds.Query(dns.TypeAAAA, "lighthouse.") + assert.True(t, exists) + assert.Equal(t, myV6, got6) +} + +func TestDnsServer_seedSelf_disabled_noOp(t *testing.T) { + ds, c := newTestDnsServer(t) + ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{netip.MustParseAddr("10.0.0.1")}) + setDnsConfig(c, "127.0.0.1", "0", true, false) + require.NoError(t, ds.reload(c, true)) + + ds.seedSelf() + _, exists := ds.Query(dns.TypeA, "lighthouse.") + assert.False(t, exists) +} + +func TestDnsServer_clearRecords_dropsSelfHost(t *testing.T) { + ds, c := newTestDnsServer(t) + ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{netip.MustParseAddr("10.0.0.1")}) + setDnsConfig(c, "127.0.0.1", "0", true, true) + require.NoError(t, ds.reload(c, true)) + ds.seedSelf() + require.NotEmpty(t, ds.selfHost) + + ds.clearRecords() + assert.Empty(t, ds.selfHost) + _, exists := ds.Query(dns.TypeA, "lighthouse.") + assert.False(t, exists) +} + +func TestDnsServer_QueryCert_returnsOwnCert(t *testing.T) { + ds, _ := newTestDnsServer(t) + myV4 := netip.MustParseAddr("10.0.0.1") + ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{myV4}) + + got := ds.QueryCert(myV4.String() + ".") + assert.NotEmpty(t, got, "TXT lookup of our own VPN address should return our cert") + + other := netip.MustParseAddr("10.0.0.99") + assert.Empty(t, ds.QueryCert(other.String()+"."), "unknown peer IP should return nothing") +} + func TestDnsServer_reload_disable_stopsRunningServer(t *testing.T) { port := freeUDPPort(t) ds, c := newTestDnsServer(t) diff --git a/main.go b/main.go index 37aa24d1..7d7a0f72 100644 --- a/main.go +++ b/main.go @@ -194,7 +194,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig) lightHouse.handshakeTrigger = handshakeManager.trigger - ds, err := newDnsServerFromConfig(ctx, l, pki.getCertState(), hostMap, c) + ds, err := newDnsServerFromConfig(ctx, l, pki, hostMap, c) if err != nil { l.Warn("Failed to start DNS responder", "error", err) } From ffd5249cf522a1dd582c707888776f5f54264d32 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 15 May 2026 15:37:01 -0500 Subject: [PATCH 24/31] Search for config.yaml/yml in both service and cli mode (#1717) --- cmd/nebula-service/main.go | 9 +++-- cmd/nebula-service/service.go | 17 ++------- cmd/nebula/main.go | 9 +++-- config/default.go | 29 +++++++++++++++ config/default_test.go | 67 +++++++++++++++++++++++++++++++++++ 5 files changed, 110 insertions(+), 21 deletions(-) create mode 100644 config/default.go create mode 100644 config/default_test.go diff --git a/cmd/nebula-service/main.go b/cmd/nebula-service/main.go index 19fb3a9f..724c0c6a 100644 --- a/cmd/nebula-service/main.go +++ b/cmd/nebula-service/main.go @@ -61,9 +61,12 @@ func main() { } if *configPath == "" { - fmt.Println("-config flag must be set") - flag.Usage() - os.Exit(1) + p, err := config.DefaultPath() + if err != nil { + fmt.Println(err) + os.Exit(1) + } + *configPath = p } c := config.NewC(l) diff --git a/cmd/nebula-service/service.go b/cmd/nebula-service/service.go index 6551ceb4..7c2b39c8 100644 --- a/cmd/nebula-service/service.go +++ b/cmd/nebula-service/service.go @@ -3,8 +3,6 @@ package main import ( "fmt" "log" - "os" - "path/filepath" "github.com/kardianos/service" "github.com/slackhq/nebula" @@ -57,24 +55,13 @@ func (p *program) Stop(s service.Service) error { return nil } -func fileExists(filename string) bool { - _, err := os.Stat(filename) - if os.IsNotExist(err) { - return false - } - return true -} - func doService(configPath *string, configTest *bool, build string, serviceFlag *string) error { if *configPath == "" { - ex, err := os.Executable() + p, err := config.DefaultPath() if err != nil { return err } - *configPath = filepath.Dir(ex) + "/config.yaml" - if !fileExists(*configPath) { - *configPath = filepath.Dir(ex) + "/config.yml" - } + *configPath = p } svcConfig := &service.Config{ diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index d7f0de93..219519c2 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -50,9 +50,12 @@ func main() { } if *configPath == "" { - fmt.Println("-config flag must be set") - flag.Usage() - os.Exit(1) + p, err := config.DefaultPath() + if err != nil { + fmt.Println(err) + os.Exit(1) + } + *configPath = p } l := logging.NewLogger(os.Stdout) diff --git a/config/default.go b/config/default.go new file mode 100644 index 00000000..9494c655 --- /dev/null +++ b/config/default.go @@ -0,0 +1,29 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" +) + +// DefaultPath returns a path to a config file alongside the running executable, preferring config.yaml over config.yml. +// If neither file exists an error is returned that names both paths checked. +func DefaultPath() (string, error) { + ex, err := os.Executable() + if err != nil { + return "", err + } + return defaultPathInDir(filepath.Dir(ex)) +} + +func defaultPathInDir(dir string) (string, error) { + yamlPath := filepath.Join(dir, "config.yaml") + if _, err := os.Stat(yamlPath); err == nil { + return yamlPath, nil + } + ymlPath := filepath.Join(dir, "config.yml") + if _, err := os.Stat(ymlPath); err == nil { + return ymlPath, nil + } + return "", fmt.Errorf("no default config found at %s or %s", yamlPath, ymlPath) +} diff --git a/config/default_test.go b/config/default_test.go new file mode 100644 index 00000000..a4d56f59 --- /dev/null +++ b/config/default_test.go @@ -0,0 +1,67 @@ +package config + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDefaultPathInDir(t *testing.T) { + t.Run("prefers config.yaml when both exist", func(t *testing.T) { + dir := t.TempDir() + want := filepath.Join(dir, "config.yaml") + other := filepath.Join(dir, "config.yml") + require.NoError(t, os.WriteFile(want, []byte("a: 1"), 0644)) + require.NoError(t, os.WriteFile(other, []byte("a: 2"), 0644)) + + got, err := defaultPathInDir(dir) + require.NoError(t, err) + assert.Equal(t, want, got) + }) + + t.Run("returns config.yaml when only it exists", func(t *testing.T) { + dir := t.TempDir() + want := filepath.Join(dir, "config.yaml") + require.NoError(t, os.WriteFile(want, []byte("a: 1"), 0644)) + + got, err := defaultPathInDir(dir) + require.NoError(t, err) + assert.Equal(t, want, got) + }) + + t.Run("falls back to config.yml when only it exists", func(t *testing.T) { + dir := t.TempDir() + want := filepath.Join(dir, "config.yml") + require.NoError(t, os.WriteFile(want, []byte("a: 1"), 0644)) + + got, err := defaultPathInDir(dir) + require.NoError(t, err) + assert.Equal(t, want, got) + }) + + t.Run("errors when neither exists and names both paths", func(t *testing.T) { + dir := t.TempDir() + got, err := defaultPathInDir(dir) + assert.Empty(t, got) + require.Error(t, err) + assert.Contains(t, err.Error(), filepath.Join(dir, "config.yaml")) + assert.Contains(t, err.Error(), filepath.Join(dir, "config.yml")) + }) +} + +func TestDefaultPath(t *testing.T) { + got, err := DefaultPath() + if err != nil { + ex, exErr := os.Executable() + require.NoError(t, exErr) + assert.Contains(t, err.Error(), filepath.Dir(ex)) + return + } + ex, err := os.Executable() + require.NoError(t, err) + assert.Equal(t, filepath.Dir(ex), filepath.Dir(got)) + assert.Contains(t, []string{"config.yaml", "config.yml"}, filepath.Base(got)) +} From 0d23377c6575bd716448920269f8142a789097ca Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Mon, 18 May 2026 11:10:30 -0500 Subject: [PATCH 25/31] Fix flakey cert tests (#1728) --- cert/helper_test.go | 14 ++++++++++---- cert_test/cert.go | 14 ++++++++++---- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/cert/helper_test.go b/cert/helper_test.go index 1b72a0ff..9becfa5c 100644 --- a/cert/helper_test.go +++ b/cert/helper_test.go @@ -13,6 +13,12 @@ import ( "golang.org/x/crypto/ed25519" ) +// testCertNow is the reference "now" used to derive default before/after times +// in NewTestCaCert and NewTestCert. Holding it fixed for the lifetime of the +// test binary keeps CA and leaf defaults aligned at the same second, so a leaf +// signed with default times can never expire after its CA on a rounding race. +var testCertNow = time.Now().Round(time.Second) + // NewTestCaCert will create a new ca certificate func NewTestCaCert(version Version, curve Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) { var err error @@ -34,10 +40,10 @@ func NewTestCaCert(version Version, curve Curve, before, after time.Time, networ } if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) + before = testCertNow.Add(time.Second * -60) } if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) + after = testCertNow.Add(time.Second * 60) } t := &TBSCertificate{ @@ -70,11 +76,11 @@ func NewTestCaCert(version Version, curve Curve, before, after time.Time, networ // Expiry times are defaulted if you do not pass them in func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) { if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) + before = testCertNow.Add(time.Second * -60) } if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) + after = testCertNow.Add(time.Second * 60) } if len(networks) == 0 { diff --git a/cert_test/cert.go b/cert_test/cert.go index c3759f12..4c440aff 100644 --- a/cert_test/cert.go +++ b/cert_test/cert.go @@ -14,6 +14,12 @@ import ( "golang.org/x/crypto/ed25519" ) +// testCertNow is the reference "now" used to derive default before/after times +// in NewTestCaCert and NewTestCert. Holding it fixed for the lifetime of the +// test binary keeps CA and leaf defaults aligned at the same second, so a leaf +// signed with default times can never expire after its CA on a rounding race. +var testCertNow = time.Now().Round(time.Second) + // NewTestCaCert will create a new ca certificate func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { var err error @@ -35,10 +41,10 @@ func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Ti } if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) + before = testCertNow.Add(time.Second * -60) } if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) + after = testCertNow.Add(time.Second * 60) } t := &cert.TBSCertificate{ @@ -71,11 +77,11 @@ func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Ti // Expiry times are defaulted if you do not pass them in func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) + before = testCertNow.Add(time.Second * -60) } if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) + after = testCertNow.Add(time.Second * 60) } var pub, priv []byte From 04dea41f7495d09c9ee3d7c03b1bae00adb25ba4 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Mon, 18 May 2026 11:25:34 -0500 Subject: [PATCH 26/31] Make firewall reload when unsafe networks in the cert changes (#1719) --- firewall.go | 30 ++++---- interface.go | 17 ++++- interface_emit_test.go | 73 ++++++++++++++++++++ interface_test.go | 151 +++++++++++++++++++++++++++-------------- 4 files changed, 201 insertions(+), 70 deletions(-) create mode 100644 interface_emit_test.go diff --git a/firewall.go b/firewall.go index adecbe81..904c71b2 100644 --- a/firewall.go +++ b/firewall.go @@ -58,8 +58,9 @@ type Firewall struct { routableNetworks *bart.Lite // assignedNetworks is a list of vpn networks assigned to us in the certificate. - assignedNetworks []netip.Prefix - hasUnsafeNetworks bool + assignedNetworks []netip.Prefix + // unsafeNetworks is the list of unsafe networks issued to us in the certificate + unsafeNetworks []netip.Prefix rules string rulesVersion uint16 @@ -158,10 +159,9 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur assignedNetworks = append(assignedNetworks, network) } - hasUnsafeNetworks := false - for _, n := range c.UnsafeNetworks() { + unsafeNetworks := c.UnsafeNetworks() + for _, n := range unsafeNetworks { routableNetworks.Insert(n) - hasUnsafeNetworks = true } return &Firewall{ @@ -169,15 +169,15 @@ func NewFirewall(l *slog.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Dur Conns: make(map[firewall.Packet]*conn), TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax), }, - InRules: newFirewallTable(), - OutRules: newFirewallTable(), - TCPTimeout: tcpTimeout, - UDPTimeout: UDPTimeout, - DefaultTimeout: defaultTimeout, - routableNetworks: routableNetworks, - assignedNetworks: assignedNetworks, - hasUnsafeNetworks: hasUnsafeNetworks, - l: l, + InRules: newFirewallTable(), + OutRules: newFirewallTable(), + TCPTimeout: tcpTimeout, + UDPTimeout: UDPTimeout, + DefaultTimeout: defaultTimeout, + routableNetworks: routableNetworks, + assignedNetworks: assignedNetworks, + unsafeNetworks: unsafeNetworks, + l: l, incomingMetrics: firewallMetrics{ droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil), @@ -897,7 +897,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localCidr string) error { } if localCidr == "" { - if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny { + if len(f.unsafeNetworks) == 0 || f.defaultLocalCIDRAny { flc.Any = true return nil } diff --git a/interface.go b/interface.go index 32f5c2a6..f96e431a 100644 --- a/interface.go +++ b/interface.go @@ -7,6 +7,7 @@ import ( "io" "log/slog" "net/netip" + "slices" "sync" "sync/atomic" "time" @@ -14,6 +15,7 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -375,13 +377,22 @@ func (f *Interface) reloadDisconnectInvalid(c *config.C) { } func (f *Interface) reloadFirewall(c *config.C) { - //TODO: need to trigger/detect if the certificate changed too - if c.HasChanged("firewall") == false { + cs := f.pki.getCertState() + curCert := cs.getCertificate(cert.Version2) + if curCert == nil { + curCert = cs.getCertificate(cert.Version1) + } + + // The firewall builds its routableNetworks set from the certificate's UnsafeNetworks at construction. + // Check to see if that set has changed, and if so, rebuild the firewall. + certUnsafeChanged := curCert != nil && !slices.Equal(curCert.UnsafeNetworks(), f.firewall.unsafeNetworks) + + if !c.HasChanged("firewall") && !certUnsafeChanged { f.l.Debug("No firewall config change detected") return } - fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c) + fw, err := NewFirewallFromConfig(f.l, cs, c) if err != nil { f.l.Error("Error while creating firewall during reload", "error", err) return diff --git a/interface_emit_test.go b/interface_emit_test.go new file mode 100644 index 00000000..b0a9d025 --- /dev/null +++ b/interface_emit_test.go @@ -0,0 +1,73 @@ +//go:build linux || darwin + +package nebula + +import ( + "context" + "net/netip" + "testing" + "time" + + "github.com/rcrowley/go-metrics" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/firewall" + "github.com/slackhq/nebula/overlay/overlaytest" + "github.com/slackhq/nebula/test" + "github.com/slackhq/nebula/udp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Test_emitStats_primesGauges covers issue #907: a Prometheus scrape that +// landed before the first ticker fire used to read 0 for the cert gauges. +// emitStats now primes the gauges before entering the ticker loop. We assert +// the gauge is zero before the first call and non-zero after. +func Test_emitStats_primesGauges(t *testing.T) { + defer metrics.DefaultRegistry.UnregisterAll() + + l := test.NewLogger() + hostMap := newHostMap(l) + preferredRanges := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")} + hostMap.preferredRanges.Store(&preferredRanges) + + notAfter := time.Now().Add(time.Hour) + cs := &CertState{ + initiatingVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1, notAfter: notAfter}, + v1Credential: nil, + } + + lh := newTestLighthouse() + ifce := &Interface{ + hostMap: hostMap, + inside: &overlaytest.NoopTun{}, + outside: &udp.NoopConn{}, + firewall: &Firewall{Conntrack: &FirewallConntrack{Conns: map[firewall.Packet]*conn{}}}, + lightHouse: lh, + pki: &PKI{}, + handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), + l: l, + // On linux, udp.NewUDPStatsEmitter indexes writers[0] and asserts to + // *udp.StdConn. A zero value works: getMemInfo sees a nil rawConn, + // returns an error, and the emitter falls through to a no-op. + writers: []udp.Conn{&udp.StdConn{}}, + } + ifce.pki.cs.Store(cs) + + ttlGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil) + require.Zero(t, ttlGauge.Value(), "gauge should be zero before emitStats runs") + + // Pre-cancel the context so emitStats returns after priming the gauges + // without ever reading from ticker.C. The one hour interval is just a + // belt-and-suspenders, the test does not expect the ticker to fire. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + ifce.emitStats(ctx, time.Hour) + + ttl := ttlGauge.Value() + assert.Positive(t, ttl, "ttl gauge should be primed by emitStats before its first tick") + assert.LessOrEqual(t, ttl, int64(3600)) + assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.initiating_version", nil).Value()) + assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.max_version", nil).Value()) +} diff --git a/interface_test.go b/interface_test.go index b0a9d025..1b912bbb 100644 --- a/interface_test.go +++ b/interface_test.go @@ -1,73 +1,120 @@ -//go:build linux || darwin - package nebula import ( - "context" "net/netip" "testing" - "time" - "github.com/rcrowley/go-metrics" "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/overlay/overlaytest" + "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" - "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -// Test_emitStats_primesGauges covers issue #907: a Prometheus scrape that -// landed before the first ticker fire used to read 0 for the cert gauges. -// emitStats now primes the gauges before entering the ticker loop. We assert -// the gauge is zero before the first call and non-zero after. -func Test_emitStats_primesGauges(t *testing.T) { - defer metrics.DefaultRegistry.UnregisterAll() - +// TestReloadFirewall_CertUnsafeNetworksChanged verifies that reloadFirewall +// rebuilds the firewall when only the certificate's UnsafeNetworks have changed, +// even if the firewall section of the YAML has not. +func TestReloadFirewall_CertUnsafeNetworksChanged(t *testing.T) { l := test.NewLogger() - hostMap := newHostMap(l) - preferredRanges := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")} - hostMap.preferredRanges.Store(&preferredRanges) - notAfter := time.Now().Add(time.Hour) - cs := &CertState{ - initiatingVersion: cert.Version1, - privateKey: []byte{}, - v1Cert: &dummyCert{version: cert.Version1, notAfter: notAfter}, - v1Credential: nil, + vpnNet := netip.MustParsePrefix("10.0.0.1/24") + initialUnsafe := []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")} + + // dummyCert avoids dragging the real signing pipeline into a unit test. + c1 := &dummyCert{ + version: cert.Version2, + networks: []netip.Prefix{vpnNet}, + unsafeNetworks: initialUnsafe, + } + pki := &PKI{} + pki.cs.Store(&CertState{v2Cert: c1, initiatingVersion: cert.Version2}) + + rawYAML := `firewall: + outbound: + - port: any + proto: any + host: any + inbound: + - port: any + proto: any + host: any +` + cfg := config.NewC(l) + require.NoError(t, cfg.LoadString(rawYAML)) + + fw, err := NewFirewallFromConfig(l, pki.getCertState(), cfg) + require.NoError(t, err) + require.Equal(t, initialUnsafe, fw.unsafeNetworks) + + f := &Interface{ + pki: pki, + firewall: fw, + l: l, } - lh := newTestLighthouse() - ifce := &Interface{ - hostMap: hostMap, - inside: &overlaytest.NoopTun{}, - outside: &udp.NoopConn{}, - firewall: &Firewall{Conntrack: &FirewallConntrack{Conns: map[firewall.Packet]*conn{}}}, - lightHouse: lh, - pki: &PKI{}, - handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig), - l: l, - // On linux, udp.NewUDPStatsEmitter indexes writers[0] and asserts to - // *udp.StdConn. A zero value works: getMemInfo sees a nil rawConn, - // returns an error, and the emitter falls through to a no-op. - writers: []udp.Conn{&udp.StdConn{}}, + // Swap the cert with a different UnsafeNetworks set. + newUnsafe := []netip.Prefix{ + netip.MustParsePrefix("198.51.100.0/24"), + netip.MustParsePrefix("203.0.113.0/24"), } - ifce.pki.cs.Store(cs) + c2 := &dummyCert{ + version: cert.Version2, + networks: []netip.Prefix{vpnNet}, + unsafeNetworks: newUnsafe, + } + pki.cs.Store(&CertState{v2Cert: c2, initiatingVersion: cert.Version2}) - ttlGauge := metrics.GetOrRegisterGauge("certificate.ttl_seconds", nil) - require.Zero(t, ttlGauge.Value(), "gauge should be zero before emitStats runs") + // Reload with the same YAML so HasChanged("firewall") reports false. + require.NoError(t, cfg.ReloadConfigString(rawYAML)) + require.False(t, cfg.HasChanged("firewall")) - // Pre-cancel the context so emitStats returns after priming the gauges - // without ever reading from ticker.C. The one hour interval is just a - // belt-and-suspenders, the test does not expect the ticker to fire. - ctx, cancel := context.WithCancel(context.Background()) - cancel() - ifce.emitStats(ctx, time.Hour) + f.reloadFirewall(cfg) - ttl := ttlGauge.Value() - assert.Positive(t, ttl, "ttl gauge should be primed by emitStats before its first tick") - assert.LessOrEqual(t, ttl, int64(3600)) - assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.initiating_version", nil).Value()) - assert.Equal(t, int64(cert.Version1), metrics.GetOrRegisterGauge("certificate.max_version", nil).Value()) + assert.NotSame(t, fw, f.firewall, "firewall pointer should have been replaced") + assert.Equal(t, newUnsafe, f.firewall.unsafeNetworks) + assert.True(t, f.firewall.routableNetworks.Contains(netip.MustParseAddr("203.0.113.5"))) +} + +// TestReloadFirewall_NoChange verifies that reloadFirewall is a no-op when +// neither the firewall config nor the cert's UnsafeNetworks have changed. +func TestReloadFirewall_NoChange(t *testing.T) { + l := test.NewLogger() + + vpnNet := netip.MustParsePrefix("10.0.0.1/24") + unsafe := []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")} + + c1 := &dummyCert{ + version: cert.Version2, + networks: []netip.Prefix{vpnNet}, + unsafeNetworks: unsafe, + } + pki := &PKI{} + pki.cs.Store(&CertState{v2Cert: c1, initiatingVersion: cert.Version2}) + + rawYAML := `firewall: + outbound: + - port: any + proto: any + host: any + inbound: + - port: any + proto: any + host: any +` + cfg := config.NewC(l) + require.NoError(t, cfg.LoadString(rawYAML)) + + fw, err := NewFirewallFromConfig(l, pki.getCertState(), cfg) + require.NoError(t, err) + + f := &Interface{ + pki: pki, + firewall: fw, + l: l, + } + + require.NoError(t, cfg.ReloadConfigString(rawYAML)) + f.reloadFirewall(cfg) + + assert.Same(t, fw, f.firewall, "firewall should not have been replaced") } From 074a123a4bb51e6dba649f309c713eaab0af96c2 Mon Sep 17 00:00:00 2001 From: randomizedcoder <64496590+randomizedcoder@users.noreply.github.com> Date: Mon, 18 May 2026 10:23:10 -0700 Subject: [PATCH 27/31] Reject port numbers outside [0, 65535] in firewall rule parsing (#1724) --- firewall.go | 39 +++++++++++++++++++-------- firewall_test.go | 69 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 11 deletions(-) diff --git a/firewall.go b/firewall.go index 904c71b2..eb120fa6 100644 --- a/firewall.go +++ b/firewall.go @@ -1055,7 +1055,6 @@ func (r *rule) sanity() error { } func parsePort(s string) (int32, int32, error) { - var err error const notAPort int32 = -2 if s == "any" { return firewall.PortAny, firewall.PortAny, nil @@ -1064,11 +1063,11 @@ func parsePort(s string) (int32, int32, error) { return firewall.PortFragment, firewall.PortFragment, nil } if !strings.Contains(s, `-`) { - rPort, err := strconv.Atoi(s) + rPort, err := parsePortValue("", s) if err != nil { - return notAPort, notAPort, fmt.Errorf("was not a number; `%s`", s) + return notAPort, notAPort, err } - return int32(rPort), int32(rPort), nil + return rPort, rPort, nil } sPorts := strings.SplitN(s, `-`, 2) @@ -1079,22 +1078,40 @@ func parsePort(s string) (int32, int32, error) { return notAPort, notAPort, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s) } - rStartPort, err := strconv.Atoi(sPorts[0]) + startPort, err := parsePortValue("beginning range ", sPorts[0]) if err != nil { - return notAPort, notAPort, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0]) + return notAPort, notAPort, err } - rEndPort, err := strconv.Atoi(sPorts[1]) + endPort, err := parsePortValue("ending range ", sPorts[1]) if err != nil { - return notAPort, notAPort, fmt.Errorf("ending range was not a number; `%s`", sPorts[1]) + return notAPort, notAPort, err } - startPort := int32(rStartPort) - endPort := int32(rEndPort) - if startPort == firewall.PortAny { endPort = firewall.PortAny } return startPort, endPort, nil } + +// parsePortValue accepts a base-10 decimal in [0, 65535] and returns it +// widened to int32. Using strconv.ParseUint with bitSize 16 rejects +// negative input, out-of-range input (>65535), and any non-decimal byte +// by construction, so the int32 widening that follows is provably safe +// and cannot collide with firewall.PortAny (0) or firewall.PortFragment +// (-1) via integer truncation. +// +// prefix is prepended to both error messages so callers can disambiguate +// the single-port path (prefix="") from the range bounds (prefix="beginning +// range " / "ending range "), preserving the historical error strings. +func parsePortValue(prefix, s string) (int32, error) { + n, err := strconv.ParseUint(s, 10, 16) + if err == nil { + return int32(n), nil + } + if errors.Is(err, strconv.ErrRange) { + return 0, fmt.Errorf("%sout of range [0,65535]; `%s`", prefix, s) + } + return 0, fmt.Errorf("%swas not a number; `%s`", prefix, s) +} diff --git a/firewall_test.go b/firewall_test.go index 40b57477..9373f1fd 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -1029,6 +1029,75 @@ func Test_parsePort(t *testing.T) { require.NoError(t, err) } +// Test_parsePort_invalid covers inputs that must error. The named bug is +// that int32(strconv.Atoi("4294967296")) truncates to 0 == firewall.PortAny, +// silently turning a typo into a match-all-ports rule; the rest are +// representative syntax/range probes. +func Test_parsePort_invalid(t *testing.T) { + tests := []struct { + name string + input string + wantErrContains string + }{ + // Numeric overflow (the named bug + boundary). + {"named bug: 2^32 truncates to PortAny", "4294967296", "out of range"}, + {"just above max real port", "65536", "out of range"}, + + // Negatives route through the range branch and hit the empty-half + // guard; included as defense in depth so a future refactor cannot + // accidentally reach the int32 cast. + {"negative", "-1", "could not be parsed"}, + + // Syntax probes. + {"NUL between digits", "4\x002", "was not a number"}, + {"hex notation", "0x10", "was not a number"}, + {"scientific notation", "1e3", "was not a number"}, + {"leading whitespace", " 42", "was not a number"}, + {"fullwidth digits", "42", "was not a number"}, + + // Range branch. + {"range upper out of range", "1-65536", "ending range out of range"}, + {"range lower out of range", "65536-65537", "beginning range out of range"}, + {"range with negative upper", "1--1", "ending range was not a number"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, _, err := parsePort(tc.input) + require.Error(t, err, "input %q must error", tc.input) + require.ErrorContains(t, err, tc.wantErrContains) + }) + } +} + +// Test_parsePort_valid_boundaries locks in success cases at 0, 1, and 65535 +// so a future refactor cannot regress the boundaries. +func Test_parsePort_valid_boundaries(t *testing.T) { + tests := []struct { + name string + input string + wantStart int32 + wantEnd int32 + }{ + {"zero is PortAny", "0", 0, 0}, + {"min real port", "1", 1, 1}, + {"max real port", "65535", 65535, 65535}, + {"range zero to max forces end to zero", "0-65535", 0, 0}, + {"range max to max", "65535-65535", 65535, 65535}, + {"range one to max", "1-65535", 1, 65535}, + {"range with whitespace inside", " 1 - 2 ", 1, 2}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + s, e, err := parsePort(tc.input) + require.NoError(t, err) + assert.Equal(t, tc.wantStart, s, "start port") + assert.Equal(t, tc.wantEnd, e, "end port") + }) + } +} + func TestNewFirewallFromConfig(t *testing.T) { l := test.NewLogger() // Test a bad rule definition From 0c1ad9bb48e8e1c289d92299b75ee3e7ebeb5805 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Tue, 19 May 2026 08:35:04 -0500 Subject: [PATCH 28/31] Parallelize the tests a bit more (#1730) --- .github/workflows/gofmt.yml | 34 -------- .github/workflows/test.yml | 170 ++++++++++++++++++++---------------- Makefile | 43 ++++++++- 3 files changed, 136 insertions(+), 111 deletions(-) delete mode 100644 .github/workflows/gofmt.yml diff --git a/.github/workflows/gofmt.yml b/.github/workflows/gofmt.yml deleted file mode 100644 index 4d57c7b2..00000000 --- a/.github/workflows/gofmt.yml +++ /dev/null @@ -1,34 +0,0 @@ -name: gofmt -on: - push: - branches: - - master - pull_request: - paths: - - '.github/workflows/gofmt.yml' - - '**.go' -jobs: - - gofmt: - name: Run gofmt - runs-on: ubuntu-latest - steps: - - - uses: actions/checkout@v6 - - - uses: actions/setup-go@v6 - with: - go-version: '1.25' - check-latest: true - - - name: Install goimports - run: | - go install golang.org/x/tools/cmd/goimports@latest - - - name: gofmt - run: | - if [ "$(find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -l)" ] - then - find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -d - exit 1 - fi diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 009c22a9..2abb3740 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -13,8 +13,8 @@ on: - 'go.sum' jobs: - test-linux: - name: Build all and test on ubuntu-linux + static: + name: Static checks runs-on: ubuntu-latest steps: @@ -25,8 +25,16 @@ jobs: go-version: '1.25' check-latest: true - - name: Build - run: make all + - name: Install goimports + run: go install golang.org/x/tools/cmd/goimports@latest + + - name: gofmt + run: | + if [ "$(find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -l)" ] + then + find . -iname '*.go' | grep -v '\.pb\.go$' | xargs goimports -d + exit 1 + fi - name: Vet run: make vet @@ -36,66 +44,38 @@ jobs: with: version: v2.5 - - name: Test - run: make test - - - name: End 2 end - run: make e2evv - - - name: Build test mobile - run: make build-test-mobile - - - uses: actions/upload-artifact@v7 - with: - name: e2e packet flow linux-latest - path: e2e/mermaid/linux-latest - if-no-files-found: warn - - test-linux-boringcrypto: - name: Build and test on linux with boringcrypto - runs-on: ubuntu-latest - steps: - - - uses: actions/checkout@v6 - - - uses: actions/setup-go@v6 - with: - go-version: '1.25' - check-latest: true - - - name: Build - run: make bin-boringcrypto - - - name: Test - run: make test-boringcrypto - - - name: End 2 end - run: make e2e GOEXPERIMENT=boringcrypto CGO_ENABLED=1 TEST_ENV="TEST_LOGS=1" TEST_FLAGS="-v -ldflags -checklinkname=0" - - test-linux-pkcs11: - name: Build and test on linux with pkcs11 - runs-on: ubuntu-latest - steps: - - - uses: actions/checkout@v6 - - - uses: actions/setup-go@v6 - with: - go-version: '1.25' - check-latest: true - - - name: Build - run: make bin-pkcs11 - - - name: Test - run: make test-pkcs11 - test: - name: Build and test on ${{ matrix.os }} + name: Test ${{ matrix.name }} runs-on: ${{ matrix.os }} strategy: + fail-fast: false matrix: - os: [windows-latest, macos-latest] + include: + - name: linux + os: ubuntu-latest + build-cmd: go build ./cmd/nebula ./cmd/nebula-cert + test-cmd: make test + e2e-cmd: make e2evv + - name: linux-boringcrypto + os: ubuntu-latest + build-cmd: make bin-boringcrypto + test-cmd: make test-boringcrypto + e2e-cmd: make e2e GOEXPERIMENT=boringcrypto CGO_ENABLED=1 TEST_ENV="TEST_LOGS=1" TEST_FLAGS="-v -ldflags -checklinkname=0" + - name: linux-pkcs11 + os: ubuntu-latest + build-cmd: make bin-pkcs11 + test-cmd: make test-pkcs11 + e2e-cmd: '' + - name: macos + os: macos-latest + build-cmd: go build ./cmd/nebula ./cmd/nebula-cert + test-cmd: make test + e2e-cmd: make e2evv + - name: windows + os: windows-latest + build-cmd: go build ./cmd/nebula ./cmd/nebula-cert + test-cmd: make test + e2e-cmd: make e2evv steps: - uses: actions/checkout@v6 @@ -105,28 +85,66 @@ jobs: go-version: '1.25' check-latest: true - - name: Build nebula - run: go build ./cmd/nebula + - name: Build + run: ${{ matrix.build-cmd }} - - name: Build nebula-cert - run: go build ./cmd/nebula-cert - - - name: Vet - run: make vet - - - name: golangci-lint - uses: golangci/golangci-lint-action@v9 - with: - version: v2.5 + - name: Cross-build darwin-amd64 + if: matrix.name == 'macos' + run: GOARCH=amd64 go build -o /tmp/nebula-amd64 ./cmd/nebula && GOARCH=amd64 go build -o /tmp/nebula-cert-amd64 ./cmd/nebula-cert - name: Test - run: make test + run: ${{ matrix.test-cmd }} - name: End 2 end - run: make e2evv + if: matrix.e2e-cmd != '' + run: ${{ matrix.e2e-cmd }} - uses: actions/upload-artifact@v7 + if: matrix.e2e-cmd != '' && always() with: - name: e2e packet flow ${{ matrix.os }} - path: e2e/mermaid/${{ matrix.os }} + name: e2e packet flow ${{ matrix.name }} + path: e2e/mermaid/ if-no-files-found: warn + + cross-build: + name: Cross-build ${{ matrix.name }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - {name: linux-arm, make-target: all-cross-linux-arm} + - {name: linux-mips, make-target: all-cross-linux-mips} + - {name: linux-other, make-target: all-cross-linux-other} + - {name: freebsd, make-target: all-freebsd} + - {name: openbsd, make-target: all-openbsd} + - {name: netbsd, make-target: all-netbsd} + - {name: windows, make-target: all-cross-windows} + - {name: mobile, make-target: build-test-mobile} + steps: + + - uses: actions/checkout@v6 + + - uses: actions/setup-go@v6 + with: + go-version: '1.25' + check-latest: true + + - name: Build ${{ matrix.name }} + run: make -j"$(nproc)" ${{ matrix.make-target }} + + finish: + name: CI status + if: always() + needs: [static, test, cross-build] + runs-on: ubuntu-latest + steps: + + - name: Fail if any upstream job failed + if: contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') + run: | + echo "upstream results: ${{ toJSON(needs) }}" + exit 1 + + - name: All upstream jobs passed + run: echo "ok" diff --git a/Makefile b/Makefile index 0b199a5a..892c8eb0 100644 --- a/Makefile +++ b/Makefile @@ -60,6 +60,18 @@ ALL = $(ALL_LINUX) \ windows-amd64 \ windows-arm64 +# Cross-build shards used by .github/workflows/test.yml — same as ALL_* +# but with the arch that has a native CI runner removed, so the cross-build +# job is not duplicating coverage the native test jobs already give. +ALL_CROSS_LINUX = $(filter-out linux-amd64,$(ALL_LINUX)) + +# ALL_CROSS_LINUX further split into family sub-shards so each can run on +# its own CI runner in parallel. Union of the three must equal +# ALL_CROSS_LINUX; adding a new linux arch goes into the matching family. +ALL_CROSS_LINUX_ARM = linux-arm-5 linux-arm-6 linux-arm-7 linux-arm64 +ALL_CROSS_LINUX_MIPS = linux-mips linux-mipsle linux-mips64 linux-mips64le linux-mips-softfloat +ALL_CROSS_LINUX_OTHER = linux-386 linux-ppc64le linux-riscv64 linux-loong64 + e2e: $(TEST_ENV) go test -tags=e2e_testing -count=1 $(TEST_FLAGS) ./e2e @@ -82,6 +94,35 @@ DOCKER_BIN = build/linux-amd64/nebula build/linux-amd64/nebula-cert all: $(ALL:%=build/%/nebula) $(ALL:%=build/%/nebula-cert) +all-linux: $(ALL_LINUX:%=build/%/nebula) $(ALL_LINUX:%=build/%/nebula-cert) + +all-freebsd: $(ALL_FREEBSD:%=build/%/nebula) $(ALL_FREEBSD:%=build/%/nebula-cert) + +all-openbsd: $(ALL_OPENBSD:%=build/%/nebula) $(ALL_OPENBSD:%=build/%/nebula-cert) + +all-netbsd: $(ALL_NETBSD:%=build/%/nebula) $(ALL_NETBSD:%=build/%/nebula-cert) + +all-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert build/darwin-arm64/nebula build/darwin-arm64/nebula-cert + +all-windows: build/windows-amd64/nebula.exe build/windows-amd64/nebula-cert.exe build/windows-arm64/nebula.exe build/windows-arm64/nebula-cert.exe + +# CI cross-build shards. darwin-arm64 is covered by the native macos-latest +# job; windows-amd64 is covered by the native windows-latest job; both are +# omitted here to avoid building them a second time. darwin-amd64 stays in +# all-cross-darwin because intel mac is only a labeled/master-time native +# job, so PRs still need cross-build coverage for it. +all-cross-linux: $(ALL_CROSS_LINUX:%=build/%/nebula) $(ALL_CROSS_LINUX:%=build/%/nebula-cert) + +all-cross-linux-arm: $(ALL_CROSS_LINUX_ARM:%=build/%/nebula) $(ALL_CROSS_LINUX_ARM:%=build/%/nebula-cert) + +all-cross-linux-mips: $(ALL_CROSS_LINUX_MIPS:%=build/%/nebula) $(ALL_CROSS_LINUX_MIPS:%=build/%/nebula-cert) + +all-cross-linux-other: $(ALL_CROSS_LINUX_OTHER:%=build/%/nebula) $(ALL_CROSS_LINUX_OTHER:%=build/%/nebula-cert) + +all-cross-darwin: build/darwin-amd64/nebula build/darwin-amd64/nebula-cert + +all-cross-windows: build/windows-arm64/nebula.exe build/windows-arm64/nebula-cert.exe + docker: docker/linux-$(shell go env GOARCH) release: $(ALL:%=build/nebula-%.tar.gz) @@ -236,5 +277,5 @@ smoke-vagrant/%: bin-docker build/%/nebula cd .github/workflows/smoke/ && ./smoke-vagrant.sh $* .FORCE: -.PHONY: bench bench-cpu bench-cpu-long bin build-test-mobile e2e e2ev e2evv e2evvv e2evvvv proto release service smoke-docker smoke-docker-race test test-cov-html smoke-vagrant/% +.PHONY: all all-linux all-freebsd all-openbsd all-netbsd all-darwin all-windows all-cross-linux all-cross-linux-arm all-cross-linux-mips all-cross-linux-other all-cross-darwin all-cross-windows bench bench-cpu bench-cpu-long bin build-test-mobile e2e e2ev e2evv e2evvv e2evvvv proto release service smoke-docker smoke-docker-race test test-cov-html smoke-vagrant/% .DEFAULT_GOAL := bin From 72bad1603a92373e1ae8da7b8fd95feb1efc9561 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 22 May 2026 08:53:50 -0500 Subject: [PATCH 29/31] Bump github.com/gaissmai/bart from 0.26.1 to 0.27.1 (#1732) Bumps [github.com/gaissmai/bart](https://github.com/gaissmai/bart) from 0.26.1 to 0.27.1. - [Release notes](https://github.com/gaissmai/bart/releases) - [Commits](https://github.com/gaissmai/bart/compare/v0.26.1...v0.27.1) --- updated-dependencies: - dependency-name: github.com/gaissmai/bart dependency-version: 0.27.1 dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index ee51151f..bd1c0c57 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/armon/go-radix v1.0.0 github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/flynn/noise v1.1.0 - github.com/gaissmai/bart v0.26.1 + github.com/gaissmai/bart v0.27.1 github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 github.com/kardianos/service v1.2.4 diff --git a/go.sum b/go.sum index 5640bd46..8ab36d34 100644 --- a/go.sum +++ b/go.sum @@ -26,8 +26,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= -github.com/gaissmai/bart v0.26.1 h1:+w4rnLGNlA2GDVn382Tfe3jOsK5vOr5n4KmigJ9lbTo= -github.com/gaissmai/bart v0.26.1/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c= +github.com/gaissmai/bart v0.27.1 h1:FysPzqETMJa8q9rNkLW5peT1hq25nLOz8ksHbSVoiAk= +github.com/gaissmai/bart v0.27.1/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= From 873f94f4655098e3df133ba8b9eb2633bb594fc9 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 22 May 2026 10:19:06 -0500 Subject: [PATCH 30/31] Reduce relay log spam (#1733) --- handshake_manager.go | 3 +- relay_manager.go | 55 ++++++++++++++++-------- relay_manager_test.go | 97 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 136 insertions(+), 19 deletions(-) create mode 100644 relay_manager_test.go diff --git a/handshake_manager.go b/handshake_manager.go index 87257028..d03814da 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -83,6 +83,7 @@ type HandshakeHostInfo struct { initiatingVersionOverride cert.Version // Should we use a non-default cert version for this handshake? counter int64 // How many attempts have we made so far lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt + lastRelays []netip.Addr // Relays we attempted to use during the previous attempt packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes hostinfo *HostInfo @@ -323,7 +324,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered ) } - hm.f.relayManager.StartRelays(hm.f, vpnIp, hostinfo, stage0) + hm.f.relayManager.StartRelays(hm.f, vpnIp, hh, stage0) // If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add if !lighthouseTriggered { diff --git a/relay_manager.go b/relay_manager.go index 25e65871..1fd98963 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -7,6 +7,7 @@ import ( "fmt" "log/slog" "net/netip" + "slices" "sync/atomic" "github.com/slackhq/nebula/cert" @@ -57,14 +58,25 @@ func (rm *relayManager) GetUseRelays() bool { // For each candidate relay it either kicks off a handshake to the relay, sends a CreateRelayRequest, retransmits // one that may have been lost, or, once the relay is Established, forwards the in-progress // stage 0 handshake packet for vpnIp through it. -func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *HostInfo, stage0 []byte) { +func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hh *HandshakeHostInfo, stage0 []byte) { + hostinfo := hh.hostinfo if !rm.GetUseRelays() || len(hostinfo.remotes.relays) == 0 { + hh.lastRelays = nil return } - hostinfo.logger(rm.l).Info("Attempt to relay through hosts", "relays", hostinfo.remotes.relays) + relays := hostinfo.remotes.relays + listLevel := slog.LevelDebug + prior := hh.lastRelays + if !slices.Equal(relays, prior) { + listLevel = slog.LevelInfo + hh.lastRelays = slices.Clone(relays) + } + hl := hostinfo.logger(rm.l) + hl.Log(context.Background(), listLevel, "Attempt to relay through hosts", "relays", relays) + // Send a RelayRequest to all known Relay IP's - for _, relay := range hostinfo.remotes.relays { + for _, relay := range relays { // Don't relay through the host I'm trying to connect to if relay == vpnIp { continue @@ -75,12 +87,19 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho continue } + // Each relay's per-attempt log fires at Info on the first time we hit it and Debug after that. + level := slog.LevelInfo + if slices.Contains(prior, relay) { + level = slog.LevelDebug + } + relayHostInfo := rm.hostmap.QueryVpnAddr(relay) if relayHostInfo == nil || !relayHostInfo.remote.IsValid() { - hostinfo.logger(rm.l).Info("Establish tunnel to relay target", "relay", relay.String()) + hl.Log(context.Background(), level, "Establish tunnel to relay target", "relay", relay.String()) f.Handshake(relay) continue } + // Check the relay HostInfo to see if we already established a relay through existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp) if !ok { @@ -88,7 +107,7 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho if relayHostInfo.remote.IsValid() { idx, err := AddRelay(rm.l, relayHostInfo, rm.hostmap, vpnIp, nil, TerminalType, Requested) if err != nil { - hostinfo.logger(rm.l).Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err) + hl.Info("Failed to add relay to hostmap", "relay", relay.String(), "error", err) } m := NebulaControl{ @@ -99,12 +118,12 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho switch relayHostInfo.GetCert().Certificate.Version() { case cert.Version1: if !f.myVpnAddrs[0].Is4() { - hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") + hl.Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") continue } if !vpnIp.Is4() { - hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") + hl.Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") continue } @@ -116,16 +135,16 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho m.RelayFromAddr = netAddrToProtoAddr(f.myVpnAddrs[0]) m.RelayToAddr = netAddrToProtoAddr(vpnIp) default: - hostinfo.logger(rm.l).Error("Unknown certificate version found while creating relay") + hl.Error("Unknown certificate version found while creating relay") continue } msg, err := m.Marshal() if err != nil { - hostinfo.logger(rm.l).Error("Failed to marshal Control message to create relay", "error", err) + hl.Error("Failed to marshal Control message to create relay", "error", err) } else { f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - rm.l.Info("send CreateRelayRequest", + rm.l.Log(context.Background(), level, "send CreateRelayRequest", "relayFrom", f.myVpnAddrs[0], "relayTo", vpnIp, "initiatorRelayIndex", idx, @@ -138,14 +157,14 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho switch existingRelay.State { case Established: - hostinfo.logger(rm.l).Info("Send handshake via relay", "relay", relay.String()) + hl.Log(context.Background(), level, "Send handshake via relay", "relay", relay.String()) f.SendVia(relayHostInfo, existingRelay, stage0, make([]byte, 12), make([]byte, mtu), false) case Disestablished: // Mark this relay as 'requested' relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested) fallthrough case Requested: - hostinfo.logger(rm.l).Info("Re-send CreateRelay request", "relay", relay.String()) + hl.Log(context.Background(), level, "Re-send CreateRelay request", "relay", relay.String()) // Re-send the CreateRelay request, in case the previous one was lost. m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, @@ -155,12 +174,12 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho switch relayHostInfo.GetCert().Certificate.Version() { case cert.Version1: if !f.myVpnAddrs[0].Is4() { - hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") + hl.Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") continue } if !vpnIp.Is4() { - hostinfo.logger(rm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") + hl.Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") continue } @@ -172,16 +191,16 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho m.RelayFromAddr = netAddrToProtoAddr(f.myVpnAddrs[0]) m.RelayToAddr = netAddrToProtoAddr(vpnIp) default: - hostinfo.logger(rm.l).Error("Unknown certificate version found while creating relay") + hl.Error("Unknown certificate version found while creating relay") continue } msg, err := m.Marshal() if err != nil { - hostinfo.logger(rm.l).Error("Failed to marshal Control message to create relay", "error", err) + hl.Error("Failed to marshal Control message to create relay", "error", err) } else { // This must send over the hostinfo, not over hm.Hosts[ip] f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - rm.l.Info("send CreateRelayRequest", + rm.l.Log(context.Background(), level, "send CreateRelayRequest", "relayFrom", f.myVpnAddrs[0], "relayTo", vpnIp, "initiatorRelayIndex", existingRelay.LocalIndex, @@ -192,7 +211,7 @@ func (rm *relayManager) StartRelays(f *Interface, vpnIp netip.Addr, hostinfo *Ho // PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case. fallthrough default: - hostinfo.logger(rm.l).Error("Relay unexpected state", + hl.Error("Relay unexpected state", "vpnIp", vpnIp, "state", existingRelay.State, "relay", relay, diff --git a/relay_manager_test.go b/relay_manager_test.go new file mode 100644 index 00000000..8da38940 --- /dev/null +++ b/relay_manager_test.go @@ -0,0 +1,97 @@ +package nebula + +import ( + "bytes" + "log/slog" + "net/netip" + "testing" + + "github.com/gaissmai/bart" + "github.com/slackhq/nebula/test" + "github.com/stretchr/testify/assert" +) + +// TestStartRelaysLogDedupe verifies that repeated attempts with the same relay set drop the log +// chatter to Debug, mirroring how the normal handshake retry loop quiets down once it's already +// announced its targets. +func TestStartRelaysLogDedupe(t *testing.T) { + vpnIp := netip.MustParseAddr("100.64.99.4") + otherRelay := netip.MustParseAddr("100.64.99.5") + + newHH := func() *HandshakeHostInfo { + // Use the target's own vpnIp as the "relay" so the loop body skips it without + // touching any sender-side state. That isolates the test to the level-selection + // behavior of the top-level "Attempt to relay through hosts" log. + hostinfo := &HostInfo{ + vpnAddrs: []netip.Addr{vpnIp}, + localIndexId: 1, + remotes: NewRemoteList([]netip.Addr{vpnIp}, nil), + } + hostinfo.remotes.relays = []netip.Addr{vpnIp} + return &HandshakeHostInfo{hostinfo: hostinfo} + } + + // Park any extra relay addresses we'll introduce mid-test in myVpnAddrsTable so the loop + // body always skips before touching f.Handshake (which would need a real handshakeManager). + addrTable := new(bart.Lite) + addrTable.Insert(netip.PrefixFrom(otherRelay, otherRelay.BitLen())) + f := &Interface{myVpnAddrsTable: addrTable} + + newRM := func(buf *bytes.Buffer) *relayManager { + l := test.NewLoggerWithOutputAndLevel(buf, slog.LevelDebug) + rm := &relayManager{l: l, hostmap: newHostMap(l)} + rm.useRelays.Store(true) + return rm + } + + const msg = `msg="Attempt to relay through hosts"` + + t.Run("first attempt logs at Info", func(t *testing.T) { + var buf bytes.Buffer + rm := newRM(&buf) + hh := newHH() + rm.StartRelays(f, vpnIp, hh, nil) + assert.Equal(t, []netip.Addr{vpnIp}, hh.lastRelays, "lastRelays should record the relay set we just attempted") + assert.Contains(t, buf.String(), "level=INFO "+msg, "expected Info level on first attempt") + }) + + t.Run("repeat attempt with same relays drops to Debug", func(t *testing.T) { + var buf bytes.Buffer + rm := newRM(&buf) + hh := newHH() + rm.StartRelays(f, vpnIp, hh, nil) + first := append([]netip.Addr(nil), hh.lastRelays...) + buf.Reset() + rm.StartRelays(f, vpnIp, hh, nil) + assert.Equal(t, first, hh.lastRelays) + assert.Contains(t, buf.String(), "level=DEBUG "+msg, "expected Debug level on identical retry") + assert.NotContains(t, buf.String(), "level=INFO "+msg, "Info should not fire on identical retry") + }) + + t.Run("changed relay list bumps back to Info", func(t *testing.T) { + var buf bytes.Buffer + rm := newRM(&buf) + hh := newHH() + rm.StartRelays(f, vpnIp, hh, nil) + buf.Reset() + + // The lighthouse handed us a new set this round. + hh.hostinfo.remotes.relays = []netip.Addr{vpnIp, otherRelay} + + rm.StartRelays(f, vpnIp, hh, nil) + assert.Equal(t, []netip.Addr{vpnIp, otherRelay}, hh.lastRelays) + assert.Contains(t, buf.String(), "level=INFO "+msg, "expected Info when the relay list changes") + }) + + t.Run("disabled relays clears lastRelays and emits no Attempt log", func(t *testing.T) { + var buf bytes.Buffer + rm := newRM(&buf) + rm.useRelays.Store(false) + hh := newHH() + hh.lastRelays = []netip.Addr{vpnIp} + + rm.StartRelays(f, vpnIp, hh, nil) + assert.Nil(t, hh.lastRelays, "with relays disabled lastRelays should be cleared") + assert.NotContains(t, buf.String(), msg, "should not log when we shortcut out") + }) +} From 3a95495c6355dffeb83607eeedcf5a96eb5d484f Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 22 May 2026 10:19:53 -0500 Subject: [PATCH 31/31] Fix duplicate log fields which slog duplicates (#1734) --- handshake_manager.go | 3 --- inside.go | 1 - outside.go | 5 ++--- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/handshake_manager.go b/handshake_manager.go index d03814da..e04886b5 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -218,7 +218,6 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered fields := []any{ "udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()), "initiatorIndex", hh.hostinfo.localIndexId, - "remoteIndex", hh.hostinfo.remoteIndexId, "durationNs", time.Since(hh.startTime).Nanoseconds(), } // hh.machine can be nil here if buildStage0Packet never succeeded @@ -466,7 +465,6 @@ func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex", - "remoteIndex", hostinfo.remoteIndexId, "collision", existingRemoteIndex.vpnAddrs, ) } @@ -489,7 +487,6 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. hostinfo.logger(hm.l).Info("New host shadows existing host remoteIndex", - "remoteIndex", hostinfo.remoteIndexId, "collision", existingRemoteIndex.vpnAddrs, ) } diff --git a/inside.go b/inside.go index 68cb38ec..27a6f758 100644 --- a/inside.go +++ b/inside.go @@ -391,7 +391,6 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType "error", err, "udpAddr", remote, "counter", c, - "attemptedCounter", c, ) return } diff --git a/outside.go b/outside.go index 17013ed3..4c0c935e 100644 --- a/outside.go +++ b/outside.go @@ -194,8 +194,7 @@ func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing // its internal mapping. This should never happen. hostinfo.logger(f.l).Error("HostInfo missing remote relay index", - "vpnAddrs", hostinfo.vpnAddrs, - "remoteIndex", h.RemoteIndex, + "relayRemoteIndex", h.RemoteIndex, ) return } @@ -218,8 +217,8 @@ func (f *Interface) handleOutsideRelayPacket(hostinfo *HostInfo, via ViaSender, if err != nil { hostinfo.logger(f.l).Info("Failed to find target host info by ip", "relayTo", relay.PeerAddr, + "relayFrom", hostinfo.vpnAddrs[0], "error", err, - "hostinfo.vpnAddrs", hostinfo.vpnAddrs, ) return }