mirror of
https://github.com/slackhq/nebula.git
synced 2026-05-16 12:57:38 +02:00
No more dns globals, proper cleanup on shutdown (#1667)
This commit is contained in:
255
dns_server.go
255
dns_server.go
@@ -1,12 +1,14 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/gaissmai/bart"
|
"github.com/gaissmai/bart"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@@ -14,32 +16,207 @@ import (
|
|||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
// This whole thing should be rewritten to use context
|
type dnsServer struct {
|
||||||
|
|
||||||
var dnsR *dnsRecords
|
|
||||||
var dnsServer *dns.Server
|
|
||||||
var dnsAddr string
|
|
||||||
|
|
||||||
type dnsRecords struct {
|
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
l *logrus.Logger
|
l *logrus.Logger
|
||||||
|
ctx context.Context
|
||||||
dnsMap4 map[string]netip.Addr
|
dnsMap4 map[string]netip.Addr
|
||||||
dnsMap6 map[string]netip.Addr
|
dnsMap6 map[string]netip.Addr
|
||||||
hostMap *HostMap
|
hostMap *HostMap
|
||||||
myVpnAddrsTable *bart.Lite
|
myVpnAddrsTable *bart.Lite
|
||||||
|
|
||||||
|
mux *dns.ServeMux
|
||||||
|
|
||||||
|
// enabled mirrors `lighthouse.serve_dns && lighthouse.am_lighthouse`.
|
||||||
|
// Start, Add, and reload consult it so callers don't need to know the
|
||||||
|
// gating rules. When it toggles off via reload, accumulated records are
|
||||||
|
// cleared so a later re-enable starts with a fresh map populated from
|
||||||
|
// new handshakes.
|
||||||
|
enabled atomic.Bool
|
||||||
|
|
||||||
|
serverMu sync.Mutex
|
||||||
|
server *dns.Server
|
||||||
|
// started is closed once `server` has finished binding (or after
|
||||||
|
// ListenAndServe returns on a bind failure). Stop waits on it before
|
||||||
|
// calling Shutdown to avoid the miekg/dns "server not started" race
|
||||||
|
// where a Shutdown that arrives before bind completes is silently
|
||||||
|
// ignored, leaving the listener running forever.
|
||||||
|
started chan struct{}
|
||||||
|
addr string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords {
|
// newDnsServerFromConfig builds a dnsServer, applies the initial config, and
|
||||||
return &dnsRecords{
|
// registers a reload callback. The reload callback is registered before the
|
||||||
|
// initial config is applied, so a SIGHUP can later enable, fix, or disable
|
||||||
|
// DNS even if the initial application failed.
|
||||||
|
//
|
||||||
|
// The dnsServer internally gates on `lighthouse.serve_dns &&
|
||||||
|
// lighthouse.am_lighthouse`. Start and Add are safe to call unconditionally,
|
||||||
|
// they no-op when DNS isn't enabled. Each Start invocation owns a ctx-cancel
|
||||||
|
// watcher that tears the listener down on nebula shutdown. The returned
|
||||||
|
// pointer is always non-nil, even on error.
|
||||||
|
func newDnsServerFromConfig(ctx context.Context, l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) (*dnsServer, error) {
|
||||||
|
ds := &dnsServer{
|
||||||
l: l,
|
l: l,
|
||||||
|
ctx: ctx,
|
||||||
dnsMap4: make(map[string]netip.Addr),
|
dnsMap4: make(map[string]netip.Addr),
|
||||||
dnsMap6: make(map[string]netip.Addr),
|
dnsMap6: make(map[string]netip.Addr),
|
||||||
hostMap: hostMap,
|
hostMap: hostMap,
|
||||||
myVpnAddrsTable: cs.myVpnAddrsTable,
|
myVpnAddrsTable: cs.myVpnAddrsTable,
|
||||||
}
|
}
|
||||||
|
ds.mux = dns.NewServeMux()
|
||||||
|
ds.mux.HandleFunc(".", ds.handleDnsRequest)
|
||||||
|
|
||||||
|
c.RegisterReloadCallback(func(c *config.C) {
|
||||||
|
if err := ds.reload(c, false); err != nil {
|
||||||
|
l.WithError(err).Error("Failed to reload DNS responder from config")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := ds.reload(c, true); err != nil {
|
||||||
|
return ds, err
|
||||||
|
}
|
||||||
|
return ds, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
|
// reload applies the latest config and reconciles the running state with it:
|
||||||
|
// - enabled toggled on -> spawn a runner
|
||||||
|
// - enabled toggled off -> stop the runner
|
||||||
|
// - listen address changed (while running) -> restart on the new address
|
||||||
|
// - everything else -> no-op
|
||||||
|
//
|
||||||
|
// On the initial call it only records configuration; Control.Start is what
|
||||||
|
// launches the first runner via dnsStart.
|
||||||
|
func (d *dnsServer) reload(c *config.C, initial bool) error {
|
||||||
|
wantsDns := c.GetBool("lighthouse.serve_dns", false)
|
||||||
|
amLighthouse := c.GetBool("lighthouse.am_lighthouse", false)
|
||||||
|
enabled := wantsDns && amLighthouse
|
||||||
|
newAddr := getDnsServerAddr(c)
|
||||||
|
|
||||||
|
d.serverMu.Lock()
|
||||||
|
running := d.server
|
||||||
|
runningStarted := d.started
|
||||||
|
sameAddr := d.addr == newAddr
|
||||||
|
d.addr = newAddr
|
||||||
|
d.enabled.Store(enabled)
|
||||||
|
d.serverMu.Unlock()
|
||||||
|
|
||||||
|
if initial {
|
||||||
|
if wantsDns && !amLighthouse {
|
||||||
|
d.l.Warn("DNS server refusing to run because this host is not a lighthouse.")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !enabled {
|
||||||
|
if running != nil {
|
||||||
|
d.Stop()
|
||||||
|
}
|
||||||
|
// Drop any records that accumulated while enabled; a later re-enable
|
||||||
|
// will repopulate from fresh handshakes.
|
||||||
|
d.clearRecords()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if running == nil {
|
||||||
|
// Was disabled (or never started); bring it up now.
|
||||||
|
go d.Start()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if sameAddr {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
d.shutdownServer(running, runningStarted, "reload")
|
||||||
|
// Old Start goroutine has now exited; bring up a fresh listener on the
|
||||||
|
// new address.
|
||||||
|
go d.Start()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// shutdownServer waits for the server to finish binding (so Shutdown actually
|
||||||
|
// stops it rather than no-oping) and then shuts it down.
|
||||||
|
func (d *dnsServer) shutdownServer(srv *dns.Server, started chan struct{}, reason string) {
|
||||||
|
if srv == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if started != nil {
|
||||||
|
<-started
|
||||||
|
}
|
||||||
|
if err := srv.Shutdown(); err != nil {
|
||||||
|
d.l.WithError(err).WithField("reason", reason).Warn("Failed to shut down the DNS responder")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start binds and serves the DNS responder. Blocks until Stop is called or
|
||||||
|
// the listener errors. Safe to call when DNS is disabled (returns
|
||||||
|
// immediately). This is what Control.dnsStart points at.
|
||||||
|
//
|
||||||
|
// Must be invoked after the tun device is active so that lighthouse.dns.host
|
||||||
|
// may bind to a nebula IP.
|
||||||
|
func (d *dnsServer) Start() {
|
||||||
|
if !d.enabled.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
started := make(chan struct{})
|
||||||
|
d.serverMu.Lock()
|
||||||
|
if d.ctx.Err() != nil {
|
||||||
|
d.serverMu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
addr := d.addr
|
||||||
|
server := &dns.Server{
|
||||||
|
Addr: addr,
|
||||||
|
Net: "udp",
|
||||||
|
Handler: d.mux,
|
||||||
|
NotifyStartedFunc: func() { close(started) },
|
||||||
|
}
|
||||||
|
d.server = server
|
||||||
|
d.started = started
|
||||||
|
d.serverMu.Unlock()
|
||||||
|
|
||||||
|
// Per-invocation ctx watcher. Exits when Start does, so we don't leak a
|
||||||
|
// watcher per reload-driven restart.
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-d.ctx.Done():
|
||||||
|
d.shutdownServer(server, started, "shutdown")
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
d.l.WithField("dnsListener", addr).Info("Starting DNS responder")
|
||||||
|
err := server.ListenAndServe()
|
||||||
|
close(done)
|
||||||
|
|
||||||
|
// If the listener never bound (bind error) NotifyStartedFunc never fires,
|
||||||
|
// so close started here to release any Stop caller waiting on it.
|
||||||
|
select {
|
||||||
|
case <-started:
|
||||||
|
default:
|
||||||
|
close(started)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
d.l.WithError(err).Warn("Failed to run the DNS responder")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop shuts down the active server, if any. Idempotent.
|
||||||
|
func (d *dnsServer) Stop() {
|
||||||
|
d.serverMu.Lock()
|
||||||
|
srv := d.server
|
||||||
|
started := d.started
|
||||||
|
d.server = nil
|
||||||
|
d.started = nil
|
||||||
|
d.serverMu.Unlock()
|
||||||
|
d.shutdownServer(srv, started, "stop")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dnsServer) Query(q uint16, data string) netip.Addr {
|
||||||
data = strings.ToLower(data)
|
data = strings.ToLower(data)
|
||||||
d.RLock()
|
d.RLock()
|
||||||
defer d.RUnlock()
|
defer d.RUnlock()
|
||||||
@@ -57,7 +234,7 @@ func (d *dnsRecords) Query(q uint16, data string) netip.Addr {
|
|||||||
return netip.Addr{}
|
return netip.Addr{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dnsRecords) QueryCert(data string) string {
|
func (d *dnsServer) QueryCert(data string) string {
|
||||||
ip, err := netip.ParseAddr(data[:len(data)-1])
|
ip, err := netip.ParseAddr(data[:len(data)-1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ""
|
return ""
|
||||||
@@ -80,8 +257,19 @@ func (d *dnsRecords) QueryCert(data string) string {
|
|||||||
return string(b)
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clearRecords drops all DNS records.
|
||||||
|
func (d *dnsServer) clearRecords() {
|
||||||
|
d.Lock()
|
||||||
|
defer d.Unlock()
|
||||||
|
clear(d.dnsMap4)
|
||||||
|
clear(d.dnsMap6)
|
||||||
|
}
|
||||||
|
|
||||||
// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
|
// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host`
|
||||||
func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
|
func (d *dnsServer) Add(host string, addresses []netip.Addr) {
|
||||||
|
if !d.enabled.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
host = strings.ToLower(host)
|
host = strings.ToLower(host)
|
||||||
d.Lock()
|
d.Lock()
|
||||||
defer d.Unlock()
|
defer d.Unlock()
|
||||||
@@ -101,7 +289,7 @@ func (d *dnsRecords) Add(host string, addresses []netip.Addr) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
|
func (d *dnsServer) isSelfNebulaOrLocalhost(addr string) bool {
|
||||||
a, _, _ := net.SplitHostPort(addr)
|
a, _, _ := net.SplitHostPort(addr)
|
||||||
b, err := netip.ParseAddr(a)
|
b, err := netip.ParseAddr(a)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -116,7 +304,7 @@ func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool {
|
|||||||
return d.myVpnAddrsTable.Contains(b)
|
return d.myVpnAddrsTable.Contains(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
func (d *dnsServer) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
||||||
for _, q := range m.Question {
|
for _, q := range m.Question {
|
||||||
switch q.Qtype {
|
switch q.Qtype {
|
||||||
case dns.TypeA, dns.TypeAAAA:
|
case dns.TypeA, dns.TypeAAAA:
|
||||||
@@ -150,7 +338,7 @@ func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
|
func (d *dnsServer) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetReply(r)
|
m.SetReply(r)
|
||||||
m.Compress = false
|
m.Compress = false
|
||||||
@@ -163,21 +351,6 @@ func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
w.WriteMsg(m)
|
w.WriteMsg(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() {
|
|
||||||
dnsR = newDnsRecords(l, cs, hostMap)
|
|
||||||
|
|
||||||
// attach request handler func
|
|
||||||
dns.HandleFunc(".", dnsR.handleDnsRequest)
|
|
||||||
|
|
||||||
c.RegisterReloadCallback(func(c *config.C) {
|
|
||||||
reloadDns(l, c)
|
|
||||||
})
|
|
||||||
|
|
||||||
return func() {
|
|
||||||
startDns(l, c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getDnsServerAddr(c *config.C) string {
|
func getDnsServerAddr(c *config.C) string {
|
||||||
dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", ""))
|
dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", ""))
|
||||||
// Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve.
|
// Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve.
|
||||||
@@ -186,25 +359,3 @@ func getDnsServerAddr(c *config.C) string {
|
|||||||
}
|
}
|
||||||
return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
|
return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func startDns(l *logrus.Logger, c *config.C) {
|
|
||||||
dnsAddr = getDnsServerAddr(c)
|
|
||||||
dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
|
|
||||||
l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder")
|
|
||||||
err := dnsServer.ListenAndServe()
|
|
||||||
defer dnsServer.Shutdown()
|
|
||||||
if err != nil {
|
|
||||||
l.Errorf("Failed to start server: %s\n ", err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func reloadDns(l *logrus.Logger, c *config.C) {
|
|
||||||
if dnsAddr == getDnsServerAddr(c) {
|
|
||||||
l.Debug("No DNS server config change detected")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
l.Debug("Restarting DNS server")
|
|
||||||
dnsServer.Shutdown()
|
|
||||||
go startDns(l, c)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,19 +1,31 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/slackhq/nebula/config"
|
"github.com/slackhq/nebula/config"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParsequery(t *testing.T) {
|
func TestParsequery(t *testing.T) {
|
||||||
l := logrus.New()
|
l := logrus.New()
|
||||||
hostMap := &HostMap{}
|
hostMap := &HostMap{}
|
||||||
ds := newDnsRecords(l, &CertState{}, hostMap)
|
ds := &dnsServer{
|
||||||
|
l: l,
|
||||||
|
dnsMap4: make(map[string]netip.Addr),
|
||||||
|
dnsMap6: make(map[string]netip.Addr),
|
||||||
|
hostMap: hostMap,
|
||||||
|
}
|
||||||
|
ds.enabled.Store(true)
|
||||||
addrs := []netip.Addr{
|
addrs := []netip.Addr{
|
||||||
netip.MustParseAddr("1.2.3.4"),
|
netip.MustParseAddr("1.2.3.4"),
|
||||||
netip.MustParseAddr("1.2.3.5"),
|
netip.MustParseAddr("1.2.3.5"),
|
||||||
@@ -71,3 +83,209 @@ func Test_getDnsServerAddr(t *testing.T) {
|
|||||||
}
|
}
|
||||||
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
|
assert.Equal(t, "[::]:1", getDnsServerAddr(c))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newTestDnsServer(t *testing.T) (*dnsServer, *config.C) {
|
||||||
|
t.Helper()
|
||||||
|
l := logrus.New()
|
||||||
|
l.Out = io.Discard
|
||||||
|
ds := &dnsServer{
|
||||||
|
l: l,
|
||||||
|
ctx: context.Background(),
|
||||||
|
dnsMap4: make(map[string]netip.Addr),
|
||||||
|
dnsMap6: make(map[string]netip.Addr),
|
||||||
|
hostMap: &HostMap{},
|
||||||
|
}
|
||||||
|
ds.mux = dns.NewServeMux()
|
||||||
|
ds.mux.HandleFunc(".", ds.handleDnsRequest)
|
||||||
|
return ds, config.NewC(l)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setDnsConfig(c *config.C, host string, port string, amLighthouse, serveDns bool) {
|
||||||
|
c.Settings["lighthouse"] = map[string]any{
|
||||||
|
"am_lighthouse": amLighthouse,
|
||||||
|
"serve_dns": serveDns,
|
||||||
|
"dns": map[string]any{
|
||||||
|
"host": host,
|
||||||
|
"port": port,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDnsServer_reload_initial_disabled(t *testing.T) {
|
||||||
|
ds, c := newTestDnsServer(t)
|
||||||
|
setDnsConfig(c, "127.0.0.1", "0", true, false)
|
||||||
|
|
||||||
|
require.NoError(t, ds.reload(c, true))
|
||||||
|
assert.False(t, ds.enabled.Load())
|
||||||
|
assert.Equal(t, "127.0.0.1:0", ds.addr)
|
||||||
|
assert.Nil(t, ds.server)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDnsServer_reload_initial_enabled(t *testing.T) {
|
||||||
|
ds, c := newTestDnsServer(t)
|
||||||
|
setDnsConfig(c, "127.0.0.1", "0", true, true)
|
||||||
|
|
||||||
|
require.NoError(t, ds.reload(c, true))
|
||||||
|
assert.True(t, ds.enabled.Load())
|
||||||
|
assert.Equal(t, "127.0.0.1:0", ds.addr)
|
||||||
|
// initial never starts a runner; that's Control.Start's job
|
||||||
|
assert.Nil(t, ds.server)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDnsServer_reload_initial_serveDnsWithoutLighthouse(t *testing.T) {
|
||||||
|
ds, c := newTestDnsServer(t)
|
||||||
|
setDnsConfig(c, "127.0.0.1", "0", false, true)
|
||||||
|
|
||||||
|
require.NoError(t, ds.reload(c, true))
|
||||||
|
// Wants DNS but isn't a lighthouse: gated off, no runner.
|
||||||
|
assert.False(t, ds.enabled.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDnsServer_reload_sameAddr_noOp(t *testing.T) {
|
||||||
|
ds, c := newTestDnsServer(t)
|
||||||
|
setDnsConfig(c, "127.0.0.1", "0", true, true)
|
||||||
|
|
||||||
|
require.NoError(t, ds.reload(c, true))
|
||||||
|
// No server running yet, no addr change. Reload should not spawn anything.
|
||||||
|
require.NoError(t, ds.reload(c, false))
|
||||||
|
assert.True(t, ds.enabled.Load())
|
||||||
|
assert.Nil(t, ds.server)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDnsServer_StartStop_lifecycle(t *testing.T) {
|
||||||
|
// Bind to a real (random) UDP port so we exercise the actual
|
||||||
|
// ListenAndServe + Shutdown plumbing including the started-chan race fix.
|
||||||
|
port := freeUDPPort(t)
|
||||||
|
|
||||||
|
ds, c := newTestDnsServer(t)
|
||||||
|
setDnsConfig(c, "127.0.0.1", port, true, true)
|
||||||
|
require.NoError(t, ds.reload(c, true))
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
ds.Start()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
waitFor(t, func() bool {
|
||||||
|
ds.serverMu.Lock()
|
||||||
|
started := ds.started
|
||||||
|
ds.serverMu.Unlock()
|
||||||
|
if started == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-started:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
ds.Stop()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("Start did not return after Stop")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDnsServer_Stop_beforeBind_doesNotHang(t *testing.T) {
|
||||||
|
// Stop called immediately after Start should not deadlock even if bind
|
||||||
|
// hasn't completed yet. This exercises the started-chan close-on-bind-fail
|
||||||
|
// path: by binding to an obviously bad port (privileged) we get a fast
|
||||||
|
// bind error before NotifyStartedFunc fires.
|
||||||
|
ds, c := newTestDnsServer(t)
|
||||||
|
// Use a port that should fail to bind (negative would be invalid, use a
|
||||||
|
// host that won't resolve to ensure listenUDP fails quickly).
|
||||||
|
setDnsConfig(c, "256.256.256.256", "53", true, true)
|
||||||
|
require.NoError(t, ds.reload(c, true))
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
ds.Start()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Give Start a moment to attempt the bind and fail.
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// Bind failed and Start returned; Stop should be a no-op.
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("Start did not return after a bad bind")
|
||||||
|
}
|
||||||
|
|
||||||
|
stopped := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
ds.Stop()
|
||||||
|
close(stopped)
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-stopped:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("Stop hung after a failed bind")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDnsServer_reload_disable_stopsRunningServer(t *testing.T) {
|
||||||
|
port := freeUDPPort(t)
|
||||||
|
ds, c := newTestDnsServer(t)
|
||||||
|
setDnsConfig(c, "127.0.0.1", port, true, true)
|
||||||
|
require.NoError(t, ds.reload(c, true))
|
||||||
|
|
||||||
|
startReturned := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
ds.Start()
|
||||||
|
close(startReturned)
|
||||||
|
}()
|
||||||
|
waitForBind(t, ds)
|
||||||
|
|
||||||
|
// Toggle serve_dns off; reload should shut the running server down.
|
||||||
|
setDnsConfig(c, "127.0.0.1", port, true, false)
|
||||||
|
require.NoError(t, ds.reload(c, false))
|
||||||
|
select {
|
||||||
|
case <-startReturned:
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("Start did not return after reload disabled DNS")
|
||||||
|
}
|
||||||
|
assert.False(t, ds.enabled.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func freeUDPPort(t *testing.T) string {
|
||||||
|
t.Helper()
|
||||||
|
conn, err := net.ListenPacket("udp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
port := conn.LocalAddr().(*net.UDPAddr).Port
|
||||||
|
require.NoError(t, conn.Close())
|
||||||
|
return strconv.Itoa(port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForBind(t *testing.T, ds *dnsServer) {
|
||||||
|
t.Helper()
|
||||||
|
waitFor(t, func() bool {
|
||||||
|
ds.serverMu.Lock()
|
||||||
|
started := ds.started
|
||||||
|
ds.serverMu.Unlock()
|
||||||
|
if started == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-started:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitFor(t *testing.T, cond func() bool) {
|
||||||
|
t.Helper()
|
||||||
|
deadline := time.Now().Add(5 * time.Second)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if cond() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
}
|
||||||
|
t.Fatal("timed out waiting for condition")
|
||||||
|
}
|
||||||
|
|||||||
@@ -604,9 +604,9 @@ func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostI
|
|||||||
// unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps.
|
// unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps.
|
||||||
// If an entry exists for the Hosts table (vpnIp -> hostinfo) then the provided hostinfo will be made primary
|
// If an entry exists for the Hosts table (vpnIp -> hostinfo) then the provided hostinfo will be made primary
|
||||||
func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
|
func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
|
||||||
if f.serveDns {
|
if f.dnsServer != nil {
|
||||||
remoteCert := hostinfo.ConnectionState.peerCert
|
remoteCert := hostinfo.ConnectionState.peerCert
|
||||||
dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
|
f.dnsServer.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs)
|
||||||
}
|
}
|
||||||
for _, addr := range hostinfo.vpnAddrs {
|
for _, addr := range hostinfo.vpnAddrs {
|
||||||
hm.unlockedInnerAddHostInfo(addr, hostinfo, f)
|
hm.unlockedInnerAddHostInfo(addr, hostinfo, f)
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ type InterfaceConfig struct {
|
|||||||
pki *PKI
|
pki *PKI
|
||||||
Cipher string
|
Cipher string
|
||||||
Firewall *Firewall
|
Firewall *Firewall
|
||||||
ServeDns bool
|
DnsServer *dnsServer
|
||||||
HandshakeManager *HandshakeManager
|
HandshakeManager *HandshakeManager
|
||||||
lightHouse *LightHouse
|
lightHouse *LightHouse
|
||||||
connectionManager *connectionManager
|
connectionManager *connectionManager
|
||||||
@@ -57,7 +57,7 @@ type Interface struct {
|
|||||||
firewall *Firewall
|
firewall *Firewall
|
||||||
connectionManager *connectionManager
|
connectionManager *connectionManager
|
||||||
handshakeManager *HandshakeManager
|
handshakeManager *HandshakeManager
|
||||||
serveDns bool
|
dnsServer *dnsServer
|
||||||
createTime time.Time
|
createTime time.Time
|
||||||
lightHouse *LightHouse
|
lightHouse *LightHouse
|
||||||
myBroadcastAddrsTable *bart.Lite
|
myBroadcastAddrsTable *bart.Lite
|
||||||
@@ -175,7 +175,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||||||
outside: c.Outside,
|
outside: c.Outside,
|
||||||
inside: c.Inside,
|
inside: c.Inside,
|
||||||
firewall: c.Firewall,
|
firewall: c.Firewall,
|
||||||
serveDns: c.ServeDns,
|
dnsServer: c.DnsServer,
|
||||||
handshakeManager: c.HandshakeManager,
|
handshakeManager: c.HandshakeManager,
|
||||||
createTime: time.Now(),
|
createTime: time.Now(),
|
||||||
lightHouse: c.lightHouse,
|
lightHouse: c.lightHouse,
|
||||||
|
|||||||
21
main.go
21
main.go
@@ -215,13 +215,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig)
|
handshakeManager := NewHandshakeManager(l, hostMap, lightHouse, udpConns[0], handshakeConfig)
|
||||||
lightHouse.handshakeTrigger = handshakeManager.trigger
|
lightHouse.handshakeTrigger = handshakeManager.trigger
|
||||||
|
|
||||||
serveDns := false
|
ds, err := newDnsServerFromConfig(ctx, l, pki.getCertState(), hostMap, c)
|
||||||
if c.GetBool("lighthouse.serve_dns", false) {
|
if err != nil {
|
||||||
if c.GetBool("lighthouse.am_lighthouse", false) {
|
l.WithError(err).Warn("Failed to start DNS responder")
|
||||||
serveDns = true
|
|
||||||
} else {
|
|
||||||
l.Warn("DNS server refusing to run because this host is not a lighthouse.")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ifConfig := &InterfaceConfig{
|
ifConfig := &InterfaceConfig{
|
||||||
@@ -230,7 +226,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
Outside: udpConns[0],
|
Outside: udpConns[0],
|
||||||
pki: pki,
|
pki: pki,
|
||||||
Firewall: fw,
|
Firewall: fw,
|
||||||
ServeDns: serveDns,
|
DnsServer: ds,
|
||||||
HandshakeManager: handshakeManager,
|
HandshakeManager: handshakeManager,
|
||||||
connectionManager: connManager,
|
connectionManager: connManager,
|
||||||
lightHouse: lightHouse,
|
lightHouse: lightHouse,
|
||||||
@@ -280,13 +276,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
|
|
||||||
attachCommands(l, c, ssh, ifce)
|
attachCommands(l, c, ssh, ifce)
|
||||||
|
|
||||||
// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
|
|
||||||
var dnsStart func()
|
|
||||||
if lightHouse.amLighthouse && serveDns {
|
|
||||||
l.Debugln("Starting dns server")
|
|
||||||
dnsStart = dnsMain(l, pki.getCertState(), hostMap, c)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Control{
|
return &Control{
|
||||||
state: StateReady,
|
state: StateReady,
|
||||||
f: ifce,
|
f: ifce,
|
||||||
@@ -295,7 +284,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
sshStart: sshStart,
|
sshStart: sshStart,
|
||||||
statsStart: statsStart,
|
statsStart: statsStart,
|
||||||
dnsStart: dnsStart,
|
dnsStart: ds.Start,
|
||||||
lighthouseStart: lightHouse.StartUpdateWorker,
|
lighthouseStart: lightHouse.StartUpdateWorker,
|
||||||
connectionManagerStart: connManager.Start,
|
connectionManagerStart: connManager.Start,
|
||||||
}, nil
|
}, nil
|
||||||
|
|||||||
Reference in New Issue
Block a user