add a little context to dns

This commit is contained in:
Jay Wren 2025-04-18 17:09:10 -04:00
parent b8ea55eb90
commit 5ceac2b078
2 changed files with 14 additions and 13 deletions

View File

@ -1,6 +1,7 @@
package nebula
import (
"context"
"fmt"
"net"
"net/netip"
@ -39,7 +40,7 @@ func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecord
}
}
func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
func (d *dnsRecords) query(q uint16, data string) netip.Addr {
data = strings.ToLower(data)
d.RLock()
defer d.RUnlock()
@ -57,7 +58,7 @@ func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
return netip.Addr{}
}
func (d *dnsRecords) QueryCert(data string) string {
func (d *dnsRecords) queryCert(data string) string {
ip, err := netip.ParseAddr(data[:len(data)-1])
if err != nil {
return ""
@ -122,7 +123,7 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
case dns.TypeA, dns.TypeAAAA:
qType := dns.TypeToString[q.Qtype]
d.l.Debugf("Query for %s %s", qType, q.Name)
ip := d.Query(q.Qtype, q.Name)
ip := d.query(q.Qtype, q.Name)
if ip.IsValid() {
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip))
if err == nil {
@ -135,7 +136,7 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
return
}
d.l.Debugf("Query for TXT %s", q.Name)
ip := d.QueryCert(q.Name)
ip := d.queryCert(q.Name)
if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
if err == nil {
@ -163,18 +164,18 @@ func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
w.WriteMsg(m)
}
func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
func dnsMain(ctx context.Context, l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
dnsR = newDnsRecords(l, cs, hostMap)
// attach request handler func
dns.HandleFunc(".", dnsR.handleDnsRequest)
c.RegisterReloadCallback(func(c *config.C) {
reloadDns(l, c)
reloadDns(ctx, l, c)
})
return func() {
startDns(l, c)
startDns(ctx, l, c)
}
}
@ -187,24 +188,24 @@ func getDnsServerAddr(c *config.C) string {
return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
}
func startDns(l *logrus.Logger, c *config.C) {
func startDns(ctx context.Context, l *logrus.Logger, c *config.C) {
dnsAddr = getDnsServerAddr(c)
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder")
err := dnsServer.ListenAndServe()
defer dnsServer.Shutdown()
defer dnsServer.ShutdownContext(ctx)
if err != nil {
l.Errorf("Failed to start server: %s\n ", err.Error())
}
}
func reloadDns(l *logrus.Logger, c *config.C) {
func reloadDns(ctx context.Context, l *logrus.Logger, c *config.C) {
if dnsAddr == getDnsServerAddr(c) {
l.Debug("No DNS server config change detected")
return
}
l.Debug("Restarting DNS server")
dnsServer.Shutdown()
go startDns(l, c)
dnsServer.ShutdownContext(ctx)
go startDns(ctx, l, c)
}

View File

@ -284,7 +284,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
var dnsStart func()
if lightHouse.amLighthouse && serveDns {
l.Debugln("Starting dns server")
dnsStart = dnsMain(l, pki.getCertState(), hostMap, c)
dnsStart = dnsMain(ctx, l, pki.getCertState(), hostMap, c)
}
return &Control{