From 2f4532f1026f78028c380aeb937d0f5baf41eab4 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Tue, 21 Apr 2026 12:41:10 -0500 Subject: [PATCH] No more dns globals, proper cleanup on shutdown (#1667) --- dns_server.go | 255 ++++++++++++++++++++++++++++++++++++--------- dns_server_test.go | 220 +++++++++++++++++++++++++++++++++++++- hostmap.go | 4 +- interface.go | 6 +- main.go | 21 +--- 5 files changed, 432 insertions(+), 74 deletions(-) diff --git a/dns_server.go b/dns_server.go index 73576546..75c56f0f 100644 --- a/dns_server.go +++ b/dns_server.go @@ -1,12 +1,14 @@ package nebula import ( + "context" "fmt" "net" "net/netip" "strconv" "strings" "sync" + "sync/atomic" "github.com/gaissmai/bart" "github.com/miekg/dns" @@ -14,32 +16,207 @@ import ( "github.com/slackhq/nebula/config" ) -// This whole thing should be rewritten to use context - -var dnsR *dnsRecords -var dnsServer *dns.Server -var dnsAddr string - -type dnsRecords struct { +type dnsServer struct { sync.RWMutex l *logrus.Logger + ctx context.Context dnsMap4 map[string]netip.Addr dnsMap6 map[string]netip.Addr hostMap *HostMap myVpnAddrsTable *bart.Lite + + mux *dns.ServeMux + + // enabled mirrors `lighthouse.serve_dns && lighthouse.am_lighthouse`. + // Start, Add, and reload consult it so callers don't need to know the + // gating rules. When it toggles off via reload, accumulated records are + // cleared so a later re-enable starts with a fresh map populated from + // new handshakes. + enabled atomic.Bool + + serverMu sync.Mutex + server *dns.Server + // started is closed once `server` has finished binding (or after + // ListenAndServe returns on a bind failure). Stop waits on it before + // calling Shutdown to avoid the miekg/dns "server not started" race + // where a Shutdown that arrives before bind completes is silently + // ignored, leaving the listener running forever. + started chan struct{} + addr string } -func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords { - return &dnsRecords{ +// newDnsServerFromConfig builds a dnsServer, applies the initial config, and +// registers a reload callback. The reload callback is registered before the +// initial config is applied, so a SIGHUP can later enable, fix, or disable +// DNS even if the initial application failed. +// +// The dnsServer internally gates on `lighthouse.serve_dns && +// lighthouse.am_lighthouse`. Start and Add are safe to call unconditionally, +// they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel +// watcher that tears the listener down on nebula shutdown. The returned +// pointer is always non-nil, even on error. +func newDnsServerFromConfig(ctx context.Context, l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) { + ds := &dnsServer{ l: l, + ctx: ctx, dnsMap4: make(map[string]netip.Addr), dnsMap6: make(map[string]netip.Addr), hostMap: hostMap, myVpnAddrsTable: cs.myVpnAddrsTable, } + ds.mux = dns.NewServeMux() + ds.mux.HandleFunc(".", ds.handleDnsRequest) + + c.RegisterReloadCallback(func(c *config.C) { + if err := ds.reload(c, false); err != nil { + l.WithError(err).Error("Failed to reload DNS responder from config") + } + }) + + if err := ds.reload(c, true); err != nil { + return ds, err + } + return ds, nil } -func (d *dnsRecords) Query(q uint16, data string) netip.Addr { +// reload applies the latest config and reconciles the running state with it: +// - enabled toggled on -> spawn a runner +// - enabled toggled off -> stop the runner +// - listen address changed (while running) -> restart on the new address +// - everything else -> no-op +// +// On the initial call it only records configuration; Control.Start is what +// launches the first runner via dnsStart. +func (d *dnsServer) reload(c *config.C, initial bool) error { + wantsDns := c.GetBool("lighthouse.serve_dns", false) + amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) + enabled := wantsDns && amLighthouse + newAddr := getDnsServerAddr(c) + + d.serverMu.Lock() + running := d.server + runningStarted := d.started + sameAddr := d.addr == newAddr + d.addr = newAddr + d.enabled.Store(enabled) + d.serverMu.Unlock() + + if initial { + if wantsDns && !amLighthouse { + d.l.Warn("DNS server refusing to run because this host is not a lighthouse.") + } + return nil + } + + if !enabled { + if running != nil { + d.Stop() + } + // Drop any records that accumulated while enabled; a later re-enable + // will repopulate from fresh handshakes. + d.clearRecords() + return nil + } + + if running == nil { + // Was disabled (or never started); bring it up now. + go d.Start() + return nil + } + + if sameAddr { + return nil + } + + d.shutdownServer(running, runningStarted, "reload") + // Old Start goroutine has now exited; bring up a fresh listener on the + // new address. + go d.Start() + return nil +} + +// shutdownServer waits for the server to finish binding (so Shutdown actually +// stops it rather than no-oping) and then shuts it down. +func (d *dnsServer) shutdownServer(srv *dns.Server, started chan struct{}, reason string) { + if srv == nil { + return + } + if started != nil { + <-started + } + if err := srv.Shutdown(); err != nil { + d.l.WithError(err).WithField("reason", reason).Warn("Failed to shut down the DNS responder") + } +} + +// Start binds and serves the DNS responder. Blocks until Stop is called or +// the listener errors. Safe to call when DNS is disabled (returns +// immediately). This is what Control.dnsStart points at. +// +// Must be invoked after the tun device is active so that lighthouse.dns.host +// may bind to a nebula IP. +func (d *dnsServer) Start() { + if !d.enabled.Load() { + return + } + + started := make(chan struct{}) + d.serverMu.Lock() + if d.ctx.Err() != nil { + d.serverMu.Unlock() + return + } + addr := d.addr + server := &dns.Server{ + Addr: addr, + Net: "udp", + Handler: d.mux, + NotifyStartedFunc: func() { close(started) }, + } + d.server = server + d.started = started + d.serverMu.Unlock() + + // Per-invocation ctx watcher. Exits when Start does, so we don't leak a + // watcher per reload-driven restart. + done := make(chan struct{}) + go func() { + select { + case <-d.ctx.Done(): + d.shutdownServer(server, started, "shutdown") + case <-done: + } + }() + + d.l.WithField("dnsListener", addr).Info("Starting DNS responder") + err := server.ListenAndServe() + close(done) + + // If the listener never bound (bind error) NotifyStartedFunc never fires, + // so close started here to release any Stop caller waiting on it. + select { + case <-started: + default: + close(started) + } + + if err != nil { + d.l.WithError(err).Warn("Failed to run the DNS responder") + } +} + +// Stop shuts down the active server, if any. Idempotent. +func (d *dnsServer) Stop() { + d.serverMu.Lock() + srv := d.server + started := d.started + d.server = nil + d.started = nil + d.serverMu.Unlock() + d.shutdownServer(srv, started, "stop") +} + +func (d *dnsServer) Query(q uint16, data string) netip.Addr { data = strings.ToLower(data) d.RLock() defer d.RUnlock() @@ -57,7 +234,7 @@ func (d *dnsRecords) Query(q uint16, data string) netip.Addr { return netip.Addr{} } -func (d *dnsRecords) QueryCert(data string) string { +func (d *dnsServer) QueryCert(data string) string { ip, err := netip.ParseAddr(data[:len(data)-1]) if err != nil { return "" @@ -80,8 +257,19 @@ func (d *dnsRecords) QueryCert(data string) string { return string(b) } +// clearRecords drops all DNS records. +func (d *dnsServer) clearRecords() { + d.Lock() + defer d.Unlock() + clear(d.dnsMap4) + clear(d.dnsMap6) +} + // Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host` -func (d *dnsRecords) Add(host string, addresses []netip.Addr) { +func (d *dnsServer) Add(host string, addresses []netip.Addr) { + if !d.enabled.Load() { + return + } host = strings.ToLower(host) d.Lock() defer d.Unlock() @@ -101,7 +289,7 @@ func (d *dnsRecords) Add(host string, addresses []netip.Addr) { } } -func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool { +func (d *dnsServer) isSelfNebulaOrLocalhost(addr string) bool { a, _, _ := net.SplitHostPort(addr) b, err := netip.ParseAddr(a) if err != nil { @@ -116,7 +304,7 @@ func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool { return d.myVpnAddrsTable.Contains(b) } -func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) { +func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) { for _, q := range m.Question { switch q.Qtype { case dns.TypeA, dns.TypeAAAA: @@ -150,7 +338,7 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) { } } -func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { +func (d *dnsServer) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { m := new(dns.Msg) m.SetReply(r) m.Compress = false @@ -163,21 +351,6 @@ func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { w.WriteMsg(m) } -func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() { - dnsR = newDnsRecords(l, cs, hostMap) - - // attach request handler func - dns.HandleFunc(".", dnsR.handleDnsRequest) - - c.RegisterReloadCallback(func(c *config.C) { - reloadDns(l, c) - }) - - return func() { - startDns(l, c) - } -} - func getDnsServerAddr(c *config.C) string { dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", "")) // Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve. @@ -186,25 +359,3 @@ func getDnsServerAddr(c *config.C) string { } return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))) } - -func startDns(l *logrus.Logger, c *config.C) { - dnsAddr = getDnsServerAddr(c) - dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"} - l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder") - err := dnsServer.ListenAndServe() - defer dnsServer.Shutdown() - if err != nil { - l.Errorf("Failed to start server: %s\n ", err.Error()) - } -} - -func reloadDns(l *logrus.Logger, c *config.C) { - if dnsAddr == getDnsServerAddr(c) { - l.Debug("No DNS server config change detected") - return - } - - l.Debug("Restarting DNS server") - dnsServer.Shutdown() - go startDns(l, c) -} diff --git a/dns_server_test.go b/dns_server_test.go index 356e5890..c33c0480 100644 --- a/dns_server_test.go +++ b/dns_server_test.go @@ -1,19 +1,31 @@ package nebula import ( + "context" + "io" + "net" "net/netip" + "strconv" "testing" + "time" "github.com/miekg/dns" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestParsequery(t *testing.T) { l := logrus.New() hostMap := &HostMap{} - ds := newDnsRecords(l, &CertState{}, hostMap) + ds := &dnsServer{ + l: l, + dnsMap4: make(map[string]netip.Addr), + dnsMap6: make(map[string]netip.Addr), + hostMap: hostMap, + } + ds.enabled.Store(true) addrs := []netip.Addr{ netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("1.2.3.5"), @@ -71,3 +83,209 @@ func Test_getDnsServerAddr(t *testing.T) { } assert.Equal(t, "[::]:1", getDnsServerAddr(c)) } + +func newTestDnsServer(t *testing.T) (*dnsServer, *config.C) { + t.Helper() + l := logrus.New() + l.Out = io.Discard + ds := &dnsServer{ + l: l, + ctx: context.Background(), + dnsMap4: make(map[string]netip.Addr), + dnsMap6: make(map[string]netip.Addr), + hostMap: &HostMap{}, + } + ds.mux = dns.NewServeMux() + ds.mux.HandleFunc(".", ds.handleDnsRequest) + return ds, config.NewC(l) +} + +func setDnsConfig(c *config.C, host string, port string, amLighthouse, serveDns bool) { + c.Settings["lighthouse"] = map[string]any{ + "am_lighthouse": amLighthouse, + "serve_dns": serveDns, + "dns": map[string]any{ + "host": host, + "port": port, + }, + } +} + +func TestDnsServer_reload_initial_disabled(t *testing.T) { + ds, c := newTestDnsServer(t) + setDnsConfig(c, "127.0.0.1", "0", true, false) + + require.NoError(t, ds.reload(c, true)) + assert.False(t, ds.enabled.Load()) + assert.Equal(t, "127.0.0.1:0", ds.addr) + assert.Nil(t, ds.server) +} + +func TestDnsServer_reload_initial_enabled(t *testing.T) { + ds, c := newTestDnsServer(t) + setDnsConfig(c, "127.0.0.1", "0", true, true) + + require.NoError(t, ds.reload(c, true)) + assert.True(t, ds.enabled.Load()) + assert.Equal(t, "127.0.0.1:0", ds.addr) + // initial never starts a runner; that's Control.Start's job + assert.Nil(t, ds.server) +} + +func TestDnsServer_reload_initial_serveDnsWithoutLighthouse(t *testing.T) { + ds, c := newTestDnsServer(t) + setDnsConfig(c, "127.0.0.1", "0", false, true) + + require.NoError(t, ds.reload(c, true)) + // Wants DNS but isn't a lighthouse: gated off, no runner. + assert.False(t, ds.enabled.Load()) +} + +func TestDnsServer_reload_sameAddr_noOp(t *testing.T) { + ds, c := newTestDnsServer(t) + setDnsConfig(c, "127.0.0.1", "0", true, true) + + require.NoError(t, ds.reload(c, true)) + // No server running yet, no addr change. Reload should not spawn anything. + require.NoError(t, ds.reload(c, false)) + assert.True(t, ds.enabled.Load()) + assert.Nil(t, ds.server) +} + +func TestDnsServer_StartStop_lifecycle(t *testing.T) { + // Bind to a real (random) UDP port so we exercise the actual + // ListenAndServe + Shutdown plumbing including the started-chan race fix. + port := freeUDPPort(t) + + ds, c := newTestDnsServer(t) + setDnsConfig(c, "127.0.0.1", port, true, true) + require.NoError(t, ds.reload(c, true)) + + done := make(chan struct{}) + go func() { + ds.Start() + close(done) + }() + + waitFor(t, func() bool { + ds.serverMu.Lock() + started := ds.started + ds.serverMu.Unlock() + if started == nil { + return false + } + select { + case <-started: + return true + default: + return false + } + }) + + ds.Stop() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Start did not return after Stop") + } +} + +func TestDnsServer_Stop_beforeBind_doesNotHang(t *testing.T) { + // Stop called immediately after Start should not deadlock even if bind + // hasn't completed yet. This exercises the started-chan close-on-bind-fail + // path: by binding to an obviously bad port (privileged) we get a fast + // bind error before NotifyStartedFunc fires. + ds, c := newTestDnsServer(t) + // Use a port that should fail to bind (negative would be invalid, use a + // host that won't resolve to ensure listenUDP fails quickly). + setDnsConfig(c, "256.256.256.256", "53", true, true) + require.NoError(t, ds.reload(c, true)) + + done := make(chan struct{}) + go func() { + ds.Start() + close(done) + }() + + // Give Start a moment to attempt the bind and fail. + select { + case <-done: + // Bind failed and Start returned; Stop should be a no-op. + case <-time.After(time.Second): + t.Fatal("Start did not return after a bad bind") + } + + stopped := make(chan struct{}) + go func() { + ds.Stop() + close(stopped) + }() + select { + case <-stopped: + case <-time.After(time.Second): + t.Fatal("Stop hung after a failed bind") + } +} + +func TestDnsServer_reload_disable_stopsRunningServer(t *testing.T) { + port := freeUDPPort(t) + ds, c := newTestDnsServer(t) + setDnsConfig(c, "127.0.0.1", port, true, true) + require.NoError(t, ds.reload(c, true)) + + startReturned := make(chan struct{}) + go func() { + ds.Start() + close(startReturned) + }() + waitForBind(t, ds) + + // Toggle serve_dns off; reload should shut the running server down. + setDnsConfig(c, "127.0.0.1", port, true, false) + require.NoError(t, ds.reload(c, false)) + select { + case <-startReturned: + case <-time.After(5 * time.Second): + t.Fatal("Start did not return after reload disabled DNS") + } + assert.False(t, ds.enabled.Load()) +} + +func freeUDPPort(t *testing.T) string { + t.Helper() + conn, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + port := conn.LocalAddr().(*net.UDPAddr).Port + require.NoError(t, conn.Close()) + return strconv.Itoa(port) +} + +func waitForBind(t *testing.T, ds *dnsServer) { + t.Helper() + waitFor(t, func() bool { + ds.serverMu.Lock() + started := ds.started + ds.serverMu.Unlock() + if started == nil { + return false + } + select { + case <-started: + return true + default: + return false + } + }) +} + +func waitFor(t *testing.T, cond func() bool) { + t.Helper() + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + if cond() { + return + } + time.Sleep(5 * time.Millisecond) + } + t.Fatal("timed out waiting for condition") +} diff --git a/hostmap.go b/hostmap.go index 7e2939e0..25181d83 100644 --- a/hostmap.go +++ b/hostmap.go @@ -604,9 +604,9 @@ func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostI // unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps. // If an entry exists for the Hosts table (vpnIp -> hostinfo) then the provided hostinfo will be made primary func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { - if f.serveDns { + if f.dnsServer != nil { remoteCert := hostinfo.ConnectionState.peerCert - dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs) + f.dnsServer.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs) } for _, addr := range hostinfo.vpnAddrs { hm.unlockedInnerAddHostInfo(addr, hostinfo, f) diff --git a/interface.go b/interface.go index 9e7a98a9..481b1d4d 100644 --- a/interface.go +++ b/interface.go @@ -29,7 +29,7 @@ type InterfaceConfig struct { pki *PKI Cipher string Firewall *Firewall - ServeDns bool + DnsServer *dnsServer HandshakeManager *HandshakeManager lightHouse *LightHouse connectionManager *connectionManager @@ -57,7 +57,7 @@ type Interface struct { firewall *Firewall connectionManager *connectionManager handshakeManager *HandshakeManager - serveDns bool + dnsServer *dnsServer createTime time.Time lightHouse *LightHouse myBroadcastAddrsTable *bart.Lite @@ -175,7 +175,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { outside: c.Outside, inside: c.Inside, firewall: c.Firewall, - serveDns: c.ServeDns, + dnsServer: c.DnsServer, handshakeManager: c.HandshakeManager, createTime: time.Now(), lightHouse: c.lightHouse, diff --git a/main.go b/main.go index 8adc2921..0ac63dfa 100644 --- a/main.go +++ b/main.go @@ -215,13 +215,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig) lightHouse.handshakeTrigger = handshakeManager.trigger - serveDns := false - if c.GetBool("lighthouse.serve_dns", false) { - if c.GetBool("lighthouse.am_lighthouse", false) { - serveDns = true - } else { - l.Warn("DNS server refusing to run because this host is not a lighthouse.") - } + ds, err := newDnsServerFromConfig(ctx, l, pki.getCertState(), hostMap, c) + if err != nil { + l.WithError(err).Warn("Failed to start DNS responder") } ifConfig := &InterfaceConfig{ @@ -230,7 +226,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg Outside: udpConns[0], pki: pki, Firewall: fw, - ServeDns: serveDns, + DnsServer: ds, HandshakeManager: handshakeManager, connectionManager: connManager, lightHouse: lightHouse, @@ -280,13 +276,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg attachCommands(l, c, ssh, ifce) - // Start DNS server last to allow using the nebula IP as lighthouse.dns.host - var dnsStart func() - if lightHouse.amLighthouse && serveDns { - l.Debugln("Starting dns server") - dnsStart = dnsMain(l, pki.getCertState(), hostMap, c) - } - return &Control{ state: StateReady, f: ifce, @@ -295,7 +284,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg cancel: cancel, sshStart: sshStart, statsStart: statsStart, - dnsStart: dnsStart, + dnsStart: ds.Start, lighthouseStart: lightHouse.StartUpdateWorker, connectionManagerStart: connManager.Start, }, nil