From da99cb8987144e6fa621326d4806e9e3ba787bcd Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Thu, 7 May 2026 23:24:25 -0500 Subject: [PATCH] Record my local details in the dns server if enabled --- dns_server.go | 118 +++++++++++++++++++++++++++++++++++---------- dns_server_test.go | 89 ++++++++++++++++++++++++++++++++++ main.go | 2 +- 3 files changed, 182 insertions(+), 27 deletions(-) diff --git a/dns_server.go b/dns_server.go index ff1369ab..a80630b5 100644 --- a/dns_server.go +++ b/dns_server.go @@ -11,19 +11,21 @@ import ( "sync" "sync/atomic" - "github.com/gaissmai/bart" "github.com/miekg/dns" "github.com/slackhq/nebula/config" ) type dnsServer struct { sync.RWMutex - l *slog.Logger - ctx context.Context - dnsMap4 map[string]netip.Addr - dnsMap6 map[string]netip.Addr - hostMap *HostMap - myVpnAddrsTable *bart.Lite + l *slog.Logger + ctx context.Context + dnsMap4 map[string]netip.Addr + dnsMap6 map[string]netip.Addr + hostMap *HostMap + pki *PKI + + // selfHost is the cached FQDN we last seeded for ourselves + selfHost string mux *dns.ServeMux @@ -55,14 +57,14 @@ type dnsServer struct { // they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel // watcher that tears the listener down on nebula shutdown. The returned // pointer is always non-nil, even on error. -func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) { +func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, pki *PKI, hostMap *HostMap, c *config.C) (*dnsServer, error) { ds := &dnsServer{ - l: l, - ctx: ctx, - dnsMap4: make(map[string]netip.Addr), - dnsMap6: make(map[string]netip.Addr), - hostMap: hostMap, - myVpnAddrsTable: cs.myVpnAddrsTable, + l: l, + ctx: ctx, + dnsMap4: make(map[string]netip.Addr), + dnsMap6: make(map[string]netip.Addr), + hostMap: hostMap, + pki: pki, } ds.mux = dns.NewServeMux() ds.mux.HandleFunc(".", ds.handleDnsRequest) @@ -76,6 +78,7 @@ func newDnsServerFromConfig(ctx context.Context, l *slog.Logger, cs *CertState, if err := ds.reload(c, true); err != nil { return ds, err } + ds.seedSelf() return ds, nil } @@ -113,7 +116,7 @@ func (d *dnsServer) reload(c *config.C, initial bool) error { d.Stop() } // Drop any records that accumulated while enabled; a later re-enable - // will repopulate from fresh handshakes. + // will repopulate from fresh handshakes and a fresh seedSelf. d.clearRecords() return nil } @@ -121,17 +124,14 @@ func (d *dnsServer) reload(c *config.C, initial bool) error { if running == nil { // Was disabled (or never started); bring it up now. go d.Start() - return nil + } else if !sameAddr { + d.shutdownServer(running, runningStarted, "reload") + // Old Start goroutine has now exited; bring up a fresh listener on the new address. + go d.Start() } - if sameAddr { - return nil - } - - d.shutdownServer(running, runningStarted, "reload") - // Old Start goroutine has now exited; bring up a fresh listener on the - // new address. - go d.Start() + // Refresh the self entry every enabled reload so cert renewals that change our name or VPN addresses are picked up. + d.seedSelf() return nil } @@ -249,6 +249,20 @@ func (d *dnsServer) QueryCert(data string) string { return "" } + // The hostmap only ever contains peers we have handshaked with, so it never carries an entry for ourselves. + // Answer self lookups straight from the local cert state. + if cs := d.certState(); cs != nil && cs.myVpnAddrsTable != nil && cs.myVpnAddrsTable.Contains(ip) { + c := cs.GetDefaultCertificate() + if c == nil { + return "" + } + b, err := c.MarshalJSON() + if err != nil { + return "" + } + return string(b) + } + hostinfo := d.hostMap.QueryVpnAddr(ip) if hostinfo == nil { return "" @@ -266,12 +280,60 @@ func (d *dnsServer) QueryCert(data string) string { return string(b) } -// clearRecords drops all DNS records. +// clearRecords drops all DNS records, including the self entry. func (d *dnsServer) clearRecords() { d.Lock() defer d.Unlock() clear(d.dnsMap4) clear(d.dnsMap6) + d.selfHost = "" +} + +// seedSelf inserts (or refreshes) a record for our own cert name pointing at our VPN addresses, +// so a single-lighthouse network can resolve the lighthouse's own hostname without the two-process workaround. +func (d *dnsServer) seedSelf() { + if !d.enabled.Load() { + return + } + cs := d.certState() + if cs == nil { + return + } + c := cs.GetDefaultCertificate() + if c == nil { + return + } + newHost := strings.ToLower(c.Name()) + "." + + d.Lock() + defer d.Unlock() + if d.selfHost != "" && d.selfHost != newHost { + delete(d.dnsMap4, d.selfHost) + delete(d.dnsMap6, d.selfHost) + } + d.selfHost = newHost + delete(d.dnsMap4, newHost) + delete(d.dnsMap6, newHost) + haveV4, haveV6 := false, false + for _, addr := range cs.myVpnAddrs { + if addr.Is4() && !haveV4 { + d.dnsMap4[newHost] = addr + haveV4 = true + } else if addr.Is6() && !haveV6 { + d.dnsMap6[newHost] = addr + haveV6 = true + } + if haveV4 && haveV6 { + break + } + } +} + +func (d *dnsServer) certState() *CertState { + if d.pki == nil { + return nil + } + return d.pki.getCertState() } // Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host` @@ -309,8 +371,12 @@ func (d *dnsServer) isSelfNebulaOrLocalhost(addr string) bool { return true } + cs := d.certState() + if cs == nil || cs.myVpnAddrsTable == nil { + return false + } //if we found it in this table, it's good - return d.myVpnAddrsTable.Contains(b) + return cs.myVpnAddrsTable.Contains(b) } func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) { diff --git a/dns_server_test.go b/dns_server_test.go index dcea046c..58646937 100644 --- a/dns_server_test.go +++ b/dns_server_test.go @@ -9,7 +9,10 @@ import ( "testing" "time" + "github.com/gaissmai/bart" "github.com/miekg/dns" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -276,6 +279,92 @@ func TestDnsServer_Stop_beforeBind_doesNotHang(t *testing.T) { } } +// newTestPKI builds a minimal *PKI with a single v1 cert whose name and +// VPN addresses are caller-provided, suitable for exercising seedSelf and +// QueryCert self handling. +func newTestPKI(t *testing.T, name string, addrs []netip.Addr) *PKI { + t.Helper() + networks := make([]netip.Prefix, 0, len(addrs)) + for _, a := range addrs { + bits := 32 + if a.Is6() { + bits = 128 + } + networks = append(networks, netip.PrefixFrom(a, bits)) + } + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) + c, _, _, _ := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, ca, caKey, name, time.Time{}, time.Time{}, networks, nil, nil) + + addrsTable := new(bart.Lite) + for _, a := range addrs { + addrsTable.Insert(netip.PrefixFrom(a, a.BitLen())) + } + + cs := &CertState{ + v2Cert: c, + initiatingVersion: cert.Version2, + myVpnAddrs: addrs, + myVpnAddrsTable: addrsTable, + } + pki := &PKI{} + pki.cs.Store(cs) + return pki +} + +func TestDnsServer_seedSelf_addsOwnRecord(t *testing.T) { + ds, c := newTestDnsServer(t) + myV4 := netip.MustParseAddr("10.0.0.1") + myV6 := netip.MustParseAddr("fd00::1") + ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{myV4, myV6}) + setDnsConfig(c, "127.0.0.1", "0", true, true) + require.NoError(t, ds.reload(c, true)) + + ds.seedSelf() + got4, exists := ds.Query(dns.TypeA, "lighthouse.") + assert.True(t, exists) + assert.Equal(t, myV4, got4) + got6, exists := ds.Query(dns.TypeAAAA, "lighthouse.") + assert.True(t, exists) + assert.Equal(t, myV6, got6) +} + +func TestDnsServer_seedSelf_disabled_noOp(t *testing.T) { + ds, c := newTestDnsServer(t) + ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{netip.MustParseAddr("10.0.0.1")}) + setDnsConfig(c, "127.0.0.1", "0", true, false) + require.NoError(t, ds.reload(c, true)) + + ds.seedSelf() + _, exists := ds.Query(dns.TypeA, "lighthouse.") + assert.False(t, exists) +} + +func TestDnsServer_clearRecords_dropsSelfHost(t *testing.T) { + ds, c := newTestDnsServer(t) + ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{netip.MustParseAddr("10.0.0.1")}) + setDnsConfig(c, "127.0.0.1", "0", true, true) + require.NoError(t, ds.reload(c, true)) + ds.seedSelf() + require.NotEmpty(t, ds.selfHost) + + ds.clearRecords() + assert.Empty(t, ds.selfHost) + _, exists := ds.Query(dns.TypeA, "lighthouse.") + assert.False(t, exists) +} + +func TestDnsServer_QueryCert_returnsOwnCert(t *testing.T) { + ds, _ := newTestDnsServer(t) + myV4 := netip.MustParseAddr("10.0.0.1") + ds.pki = newTestPKI(t, "lighthouse", []netip.Addr{myV4}) + + got := ds.QueryCert(myV4.String() + ".") + assert.NotEmpty(t, got, "TXT lookup of our own VPN address should return our cert") + + other := netip.MustParseAddr("10.0.0.99") + assert.Empty(t, ds.QueryCert(other.String()+"."), "unknown peer IP should return nothing") +} + func TestDnsServer_reload_disable_stopsRunningServer(t *testing.T) { port := freeUDPPort(t) ds, c := newTestDnsServer(t) diff --git a/main.go b/main.go index 37aa24d1..7d7a0f72 100644 --- a/main.go +++ b/main.go @@ -194,7 +194,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig) lightHouse.handshakeTrigger = handshakeManager.trigger - ds, err := newDnsServerFromConfig(ctx, l, pki.getCertState(), hostMap, c) + ds, err := newDnsServerFromConfig(ctx, l, pki, hostMap, c) if err != nil { l.Warn("Failed to start DNS responder", "error", err) }