diff --git a/dns_server.go b/dns_server.go index 7357654..936996b 100644 --- a/dns_server.go +++ b/dns_server.go @@ -1,6 +1,7 @@ package nebula import ( + "context" "fmt" "net" "net/netip" @@ -39,7 +40,7 @@ func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecord } } -func (d *dnsRecords) Query(q uint16, data string) netip.Addr { +func (d *dnsRecords) query(q uint16, data string) netip.Addr { data = strings.ToLower(data) d.RLock() defer d.RUnlock() @@ -57,7 +58,7 @@ func (d *dnsRecords) Query(q uint16, data string) netip.Addr { return netip.Addr{} } -func (d *dnsRecords) QueryCert(data string) string { +func (d *dnsRecords) queryCert(data string) string { ip, err := netip.ParseAddr(data[:len(data)-1]) if err != nil { return "" @@ -122,7 +123,7 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) { case dns.TypeA, dns.TypeAAAA: qType := dns.TypeToString[q.Qtype] d.l.Debugf("Query for %s %s", qType, q.Name) - ip := d.Query(q.Qtype, q.Name) + ip := d.query(q.Qtype, q.Name) if ip.IsValid() { rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip)) if err == nil { @@ -135,7 +136,7 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) { return } d.l.Debugf("Query for TXT %s", q.Name) - ip := d.QueryCert(q.Name) + ip := d.queryCert(q.Name) if ip != "" { rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip)) if err == nil { @@ -163,18 +164,18 @@ func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { w.WriteMsg(m) } -func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() { +func dnsMain(ctx context.Context, l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() { dnsR = newDnsRecords(l, cs, hostMap) // attach request handler func dns.HandleFunc(".", dnsR.handleDnsRequest) c.RegisterReloadCallback(func(c *config.C) { - reloadDns(l, c) + reloadDns(ctx, l, c) }) return func() { - startDns(l, c) + startDns(ctx, l, c) } } @@ -187,24 +188,24 @@ func getDnsServerAddr(c *config.C) string { return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))) } -func startDns(l *logrus.Logger, c *config.C) { +func startDns(ctx context.Context, l *logrus.Logger, c *config.C) { dnsAddr = getDnsServerAddr(c) dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"} l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder") err := dnsServer.ListenAndServe() - defer dnsServer.Shutdown() + defer dnsServer.ShutdownContext(ctx) if err != nil { l.Errorf("Failed to start server: %s\n ", err.Error()) } } -func reloadDns(l *logrus.Logger, c *config.C) { +func reloadDns(ctx context.Context, l *logrus.Logger, c *config.C) { if dnsAddr == getDnsServerAddr(c) { l.Debug("No DNS server config change detected") return } l.Debug("Restarting DNS server") - dnsServer.Shutdown() - go startDns(l, c) + dnsServer.ShutdownContext(ctx) + go startDns(ctx, l, c) } diff --git a/main.go b/main.go index b278fa6..52270d9 100644 --- a/main.go +++ b/main.go @@ -284,7 +284,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg var dnsStart func() if lightHouse.amLighthouse && serveDns { l.Debugln("Starting dns server") - dnsStart = dnsMain(l, pki.getCertState(), hostMap, c) + dnsStart = dnsMain(ctx, l, pki.getCertState(), hostMap, c) } return &Control{