From e87ccdc9eac2988a2cd8d0a10e82d42e0fb34137 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 13 Jun 2026 01:46:02 +0000 Subject: [PATCH] Add experimental host_query API for local identity lookups Programs running alongside nebula have no simple way to ask "who is this vpn address?" when making authorization decisions, e.g. a nebula-aware webapp that wants to identify an inbound connection by its source address instead of presenting a login form. The existing surfaces are the sshd admin interface (not scriptable from app code) and the lighthouse-only DNS TXT lookup, which returns raw cert JSON over an awkward transport. This adds an opt-in `host_query` config section that serves a small HTTP+JSON API on a unix socket or tcp address, requiring no client library to consume: GET /v1/host?addr= identity of the host owning the address (an established peer, or this node). addr may include a port so a server can pass a connection's RemoteAddr through unparsed. GET /v1/self this node's own identity. Responses carry the certificate-derived identity only: name, vpn addresses, networks, unsafe networks, groups, fingerprint, issuer, validity window, and cert version. The self-vs-peer lookup logic is shared with the DNS TXT handler via a new findCertificateForVpnAddr helper, which also swaps the panicking GetDefaultCertificate call for the nil-returning accessor so a missing certificate yields an empty answer instead of a crash. The listener follows the statsServer lifecycle: the whole section is reloadable via SIGHUP, including moving between socket paths and tcp addresses. Unix sockets default to mode 0600, stale sockets left by an unclean exit are removed at bind time, and a non-socket file at the configured path is never replaced. https://claude.ai/code/session_01Nibp24Pgk2JMue8VyWHq7o --- CHANGELOG.md | 8 + control.go | 4 + dns_server.go | 25 +-- examples/config.yml | 32 +++ host_query_server.go | 459 ++++++++++++++++++++++++++++++++++++++ host_query_server_test.go | 450 +++++++++++++++++++++++++++++++++++++ main.go | 6 + 7 files changed, 962 insertions(+), 22 deletions(-) create mode 100644 host_query_server.go create mode 100644 host_query_server_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ef7551f..a4dce9b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- New experimental `host_query` config section exposing a local HTTP+JSON API + (`GET /v1/host?addr=...`, `GET /v1/self`) over a unix socket or loopback tcp + listener, so programs on the same host can resolve a vpn address to its + certificate identity (name, groups, networks) for authorization decisions. + Disabled by default. + ## [1.10.3] - 2026-02-06 ### Security diff --git a/control.go b/control.go index ef58988b..94b620d6 100644 --- a/control.go +++ b/control.go @@ -52,6 +52,7 @@ type Control struct { sshStart func() statsStart func() dnsStart func() + hostQueryStart func() lighthouseStart func() connectionManagerStart func(context.Context) } @@ -104,6 +105,9 @@ func (c *Control) Start() (func() error, error) { if c.dnsStart != nil { go c.dnsStart() } + if c.hostQueryStart != nil { + go c.hostQueryStart() + } if c.connectionManagerStart != nil { go c.connectionManagerStart(c.ctx) } diff --git a/dns_server.go b/dns_server.go index a80630b5..06c4c6fc 100644 --- a/dns_server.go +++ b/dns_server.go @@ -249,31 +249,12 @@ func (d *dnsServer) QueryCert(data string) string { return "" } - // The hostmap only ever contains peers we have handshaked with, so it never carries an entry for ourselves. - // Answer self lookups straight from the local cert state. - if cs := d.certState(); cs != nil && cs.myVpnAddrsTable != nil && cs.myVpnAddrsTable.Contains(ip) { - c := cs.GetDefaultCertificate() - if c == nil { - return "" - } - b, err := c.MarshalJSON() - if err != nil { - return "" - } - return string(b) - } - - hostinfo := d.hostMap.QueryVpnAddr(ip) - if hostinfo == nil { + crt := findCertificateForVpnAddr(d.certState(), d.hostMap, ip) + if crt == nil { return "" } - q := hostinfo.GetCert() - if q == nil { - return "" - } - - b, err := q.Certificate.MarshalJSON() + b, err := crt.MarshalJSON() if err != nil { return "" } diff --git a/examples/config.yml b/examples/config.yml index 4f7fd1e7..5bd65221 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -223,6 +223,38 @@ punchy: # Overriding this to "" is the same as "/" and will allow overwriting any path on the host. #sandbox_dir: /var/tmp/nebula-debug +# EXPERIMENTAL: this feature may change or disappear in the future. +# host_query exposes a small local HTTP+JSON API that lets other programs on +# this machine resolve a vpn address to its certificate identity (name, vpn +# addresses, groups, fingerprint, validity), e.g. for making authorization +# decisions about an inbound connection: +# GET /v1/host?addr= - identity of the host owning the address: a +# peer with an active tunnel, or this node itself. `addr` may include a +# port (`192.168.100.7:54321`), which is ignored, so a connection's remote +# address can be passed through as is. Returns 404 when the address is +# unknown or has no active tunnel. +# GET /v1/self - this node's own identity. +# Identity answers can be trusted because nebula drops inbound packets whose +# source vpn address is not contained in the sender's certificate, so the +# source address of a connection arriving over the nebula interface is +# guaranteed to map to the certificate reported here. +# There is no authentication in this API; restrict access with unix socket +# file permissions or by listening on a loopback address. Do NOT use a +# non-loopback tcp address unless you intend to let everything that can reach +# it query host identities. +# This whole section is reloadable. +#host_query: + # Toggles the feature + #enabled: false + # listen accepts a unix socket path or a tcp host:port: + #listen: unix:///var/run/nebula-host-query.sock + #listen: 127.0.0.1:8085 + # File mode for the unix socket, as an octal string. Ignored for tcp. + # The socket is created by nebula's user; to grant a group of local services + # access, place the socket in a directory with appropriate permissions + # (e.g. a systemd RuntimeDirectory) and relax this to "0660". + #socket_mode: "0600" + # EXPERIMENTAL: relay support for networks that can't establish direct connections. relay: # Relays are a list of Nebula IP's that peers can use to relay packets to me. diff --git a/host_query_server.go b/host_query_server.go new file mode 100644 index 00000000..9211d3d9 --- /dev/null +++ b/host_query_server.go @@ -0,0 +1,459 @@ +package nebula + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io/fs" + "log/slog" + "net" + "net/http" + "net/netip" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/config" +) + +// hostQueryServer owns the local host query API: a small HTTP+JSON listener +// on a unix socket or tcp address that lets other programs on this machine +// resolve a vpn address to its certificate identity (name, groups, networks) +// for making authorization decisions. It mirrors the lifecycle shape of +// statsServer: constructor wires the reload callback, reload records config, +// Start builds and runs the runtime, Stop tears it down. +type hostQueryServer struct { + l *slog.Logger + ctx context.Context + hostMap *HostMap + pki *PKI + + // enabled mirrors `host_query.enabled`. Start consults it so callers + // don't need to know the gating rules. + enabled atomic.Bool + + runMu sync.Mutex + runCfg *hostQueryConfig + run *hostQueryRuntime // non-nil while a runtime is live +} + +// hostQueryRuntime is the live state owned by a single Start invocation. +// Start stashes a pointer under runMu; Stop and Start's own exit path use +// pointer equality to tell "my runtime" apart from one that replaced it +// after a reload. +type hostQueryRuntime struct { + server *http.Server + listener net.Listener +} + +// hostQueryConfig is the snapshot of host_query config that drives the +// runtime. It is comparable with == so reload can detect "no change" cheaply. +type hostQueryConfig struct { + enabled bool + listen string // raw config value, for error messages + network string // "unix" or "tcp" + addr string // socket path or host:port + // socketMode is the file mode applied to the unix socket after bind. + socketMode fs.FileMode +} + +// newHostQueryServerFromConfig builds a hostQueryServer, applies the initial +// config, and registers a reload callback. The reload callback is registered +// before the initial config is applied so a SIGHUP can later enable, fix, or +// disable the listener even if the initial application failed. +// +// Construction never binds the listener; that happens in Start, so config +// tests are side effect free. Start is safe to call unconditionally: it +// no-ops when the host query API is disabled. The returned pointer is always +// non-nil, even on error. +func newHostQueryServerFromConfig(ctx context.Context, l *slog.Logger, pki *PKI, hostMap *HostMap, c *config.C) (*hostQueryServer, error) { + h := &hostQueryServer{ + l: l, + ctx: ctx, + hostMap: hostMap, + pki: pki, + } + + c.RegisterReloadCallback(func(c *config.C) { + if err := h.reload(c, false); err != nil { + h.l.Error("Failed to reload host query API from config", "error", err) + } + }) + + if err := h.reload(c, true); err != nil { + return h, err + } + return h, nil +} + +// reload records the latest config. On the initial call it only records it; +// Control.Start is what launches the first runtime via hostQueryStart. On +// later calls it reconciles the running runtime with the new config: +// +// - newly enabled -> spawn Start +// - newly disabled -> Stop the runtime +// - config changed (still enabled) -> Stop the old, Start the new +// - no change -> no-op +func (h *hostQueryServer) reload(c *config.C, initial bool) error { + newCfg, err := loadHostQueryConfig(c) + if err != nil { + return err + } + + h.runMu.Lock() + sameCfg := h.runCfg != nil && *h.runCfg == newCfg + h.runCfg = &newCfg + running := h.run != nil + h.runMu.Unlock() + + h.enabled.Store(newCfg.enabled) + + if initial || sameCfg { + return nil + } + + if running { + h.Stop() + } + if newCfg.enabled { + go h.Start() + } + return nil +} + +// Start binds the listener from the latest config and serves until Stop is +// called or ctx fires. Safe to call when the host query API is disabled or +// already running (both no-op). +func (h *hostQueryServer) Start() { + if !h.enabled.Load() { + return + } + + h.runMu.Lock() + if h.ctx.Err() != nil || h.run != nil || h.runCfg == nil { + h.runMu.Unlock() + return + } + cfg := *h.runCfg + ln, err := h.listen(cfg) + if err != nil { + // Drop the cached config so a SIGHUP with the same config re-triggers + // Start once the user fixes the underlying problem. + h.runCfg = nil + h.runMu.Unlock() + h.l.Error("Failed to start host query listener", "listen", cfg.listen, "error", err) + return + } + + mux := http.NewServeMux() + mux.HandleFunc("GET /v1/host", h.handleHost) + mux.HandleFunc("GET /v1/self", h.handleSelf) + srv := &http.Server{Handler: mux, ReadHeaderTimeout: 5 * time.Second} + rt := &hostQueryRuntime{server: srv, listener: ln} + h.run = rt + h.runMu.Unlock() + + h.l.Info("Starting host query listener", "network", cfg.network, "addr", ln.Addr()) + cleanExit := h.serve(srv, ln) + + // A Stop that raced our bind shut the server down before Serve could + // adopt the listener; closing it again is harmless and guarantees a unix + // socket file gets unlinked. + _ = ln.Close() + + // Clear our runtime only if nothing has replaced it. Stop races through + // here too but leaves h.run == nil, so the pointer check skips. + h.runMu.Lock() + if h.run == rt { + h.run = nil + // A listener that exited with an error leaves runCfg cached as if it + // were applied. Drop it so a SIGHUP with the same config re-triggers + // Start once the user fixes the underlying problem. + if !cleanExit { + h.runCfg = nil + } + } + h.runMu.Unlock() +} + +// serve runs srv.Serve and ensures ctx cancellation unblocks it. Returns true +// if the listener exited cleanly (Stop, ctx cancellation, or any other +// http.ErrServerClosed path), false on an unexpected error. +func (h *hostQueryServer) serve(srv *http.Server, ln net.Listener) bool { + // Per-invocation watcher: ctx cancellation triggers a server shutdown + // which in turn unblocks Serve. Closing `done` on exit keeps the watcher + // from outliving this call. + done := make(chan struct{}) + go func() { + select { + case <-h.ctx.Done(): + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := srv.Shutdown(shutdownCtx); err != nil { + h.l.Warn("Failed to shut down host query listener", "error", err) + } + case <-done: + } + }() + defer close(done) + + err := srv.Serve(ln) + if err == nil || errors.Is(err, http.ErrServerClosed) { + return true + } + h.l.Error("Host query listener exited", "error", err) + return false +} + +// Stop tears down the active runtime, if any. Idempotent. +func (h *hostQueryServer) Stop() { + h.runMu.Lock() + rt := h.run + h.run = nil + h.runMu.Unlock() + if rt == nil { + return + } + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := rt.server.Shutdown(shutdownCtx); err != nil { + h.l.Warn("Failed to shut down host query listener", "error", err) + } +} + +// listen binds the configured address. For unix sockets it also clears a +// stale socket file left by an unclean exit and applies the configured file +// mode. +func (h *hostQueryServer) listen(cfg hostQueryConfig) (net.Listener, error) { + if cfg.network == "unix" { + return h.listenUnix(cfg) + } + + if host, _, err := net.SplitHostPort(cfg.addr); err == nil { + ip, ipErr := netip.ParseAddr(host) + if host == "" || (ipErr == nil && !ip.IsLoopback()) { + h.l.Warn("host_query is listening on a non-loopback tcp address; anything that can reach it can query host identities", "addr", cfg.addr) + } + } + return net.Listen("tcp", cfg.addr) +} + +func (h *hostQueryServer) listenUnix(cfg hostQueryConfig) (net.Listener, error) { + if fi, err := os.Stat(cfg.addr); err == nil { + if fi.Mode()&os.ModeSocket == 0 { + return nil, fmt.Errorf("host_query.listen path %s exists and is not a socket, refusing to replace it", cfg.addr) + } + // A normal shutdown unlinks the socket (unlink-on-close), so a file + // here means a previous process exited uncleanly. Remove it so the + // bind below can succeed. + if err = os.Remove(cfg.addr); err != nil { + return nil, fmt.Errorf("failed to remove stale socket %s: %w", cfg.addr, err) + } + } + + ln, err := net.Listen("unix", cfg.addr) + if err != nil { + return nil, err + } + // The socket is briefly live with umask-derived permissions before this + // chmod lands; tolerated because connections accepted in that window + // still only reach this read-only API. + if err = os.Chmod(cfg.addr, cfg.socketMode); err != nil { + _ = ln.Close() + return nil, fmt.Errorf("failed to set mode on socket %s: %w", cfg.addr, err) + } + return ln, nil +} + +func (h *hostQueryServer) certState() *CertState { + if h.pki == nil { + return nil + } + return h.pki.getCertState() +} + +// handleHost serves GET /v1/host?addr=, answering with the identity +// of the host that owns the address: a peer with an active tunnel, or this +// node itself. addr may include a port, which is ignored, so clients can pass +// a connection's remote address through without parsing it. +func (h *hostQueryServer) handleHost(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query().Get("addr") + if q == "" { + writeJSONError(w, http.StatusBadRequest, "missing addr parameter") + return + } + ip, err := parseQueryAddrParam(q) + if err != nil { + writeJSONError(w, http.StatusBadRequest, "invalid address") + return + } + + crt := findCertificateForVpnAddr(h.certState(), h.hostMap, ip) + if crt == nil { + writeJSONError(w, http.StatusNotFound, "no active tunnel for address") + return + } + h.writeHostIdentity(w, crt) +} + +// handleSelf serves GET /v1/self, answering with this node's own identity. +func (h *hostQueryServer) handleSelf(w http.ResponseWriter, r *http.Request) { + var crt cert.Certificate + if cs := h.certState(); cs != nil { + crt = cs.getCertificate(cs.initiatingVersion) + } + if crt == nil { + writeJSONError(w, http.StatusInternalServerError, "no certificate available") + return + } + h.writeHostIdentity(w, crt) +} + +func (h *hostQueryServer) writeHostIdentity(w http.ResponseWriter, crt cert.Certificate) { + id, err := newHostIdentity(crt) + if err != nil { + writeJSONError(w, http.StatusInternalServerError, "failed to fingerprint certificate") + return + } + w.Header().Set("Content-Type", "application/json") + if err = json.NewEncoder(w).Encode(id); err != nil { + h.l.Debug("Failed to write host query response", "error", err) + } +} + +// findCertificateForVpnAddr answers "who owns this vpn address": ourselves +// (from local cert state, since the hostmap never carries an entry for this +// node) or a peer with an active tunnel. Returns nil when the address is +// unknown or the tunnel is mid-teardown. +func findCertificateForVpnAddr(cs *CertState, hostMap *HostMap, ip netip.Addr) cert.Certificate { + if cs != nil && cs.myVpnAddrsTable != nil && cs.myVpnAddrsTable.Contains(ip) { + return cs.getCertificate(cs.initiatingVersion) + } + + hostinfo := hostMap.QueryVpnAddr(ip) + if hostinfo == nil { + return nil + } + cc := hostinfo.GetCert() + if cc == nil { + return nil + } + return cc.Certificate +} + +// hostIdentity is the JSON document served for both /v1/host and /v1/self. +// Every field is derived from the authenticated certificate alone. +type hostIdentity struct { + Name string `json:"name"` + VpnAddrs []netip.Addr `json:"vpnAddrs"` + Networks []netip.Prefix `json:"networks"` + UnsafeNetworks []netip.Prefix `json:"unsafeNetworks"` + Groups []string `json:"groups"` + Fingerprint string `json:"fingerprint"` + Issuer string `json:"issuer"` + NotBefore time.Time `json:"notBefore"` + NotAfter time.Time `json:"notAfter"` + CertVersion int `json:"certVersion"` +} + +func newHostIdentity(crt cert.Certificate) (hostIdentity, error) { + fp, err := crt.Fingerprint() + if err != nil { + return hostIdentity{}, err + } + + // Slices are always allocated so they marshal as [] rather than null; + // consumers iterate groups without a presence check. + networks := crt.Networks() + id := hostIdentity{ + Name: crt.Name(), + VpnAddrs: make([]netip.Addr, 0, len(networks)), + Networks: append(make([]netip.Prefix, 0, len(networks)), networks...), + UnsafeNetworks: append(make([]netip.Prefix, 0, len(crt.UnsafeNetworks())), crt.UnsafeNetworks()...), + Groups: append(make([]string, 0, len(crt.Groups())), crt.Groups()...), + Fingerprint: fp, + Issuer: crt.Issuer(), + NotBefore: crt.NotBefore(), + NotAfter: crt.NotAfter(), + CertVersion: int(crt.Version()), + } + for _, n := range networks { + id.VpnAddrs = append(id.VpnAddrs, n.Addr()) + } + return id, nil +} + +func writeJSONError(w http.ResponseWriter, status int, msg string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(map[string]string{"error": msg}) +} + +// parseQueryAddrParam parses the addr query parameter, accepting a bare +// address or an address with a port (`192.168.100.7:54321`, `[fd00::1]:443`) +// so callers can pass a connection's RemoteAddr straight through. The result +// is unmapped: 4in6 addresses (::ffff:a.b.c.d) are normalized to ipv4. +func parseQueryAddrParam(s string) (netip.Addr, error) { + if ip, err := netip.ParseAddr(s); err == nil { + return ip.Unmap(), nil + } + ap, err := netip.ParseAddrPort(s) + if err != nil { + return netip.Addr{}, err + } + return ap.Addr().Unmap(), nil +} + +func loadHostQueryConfig(c *config.C) (hostQueryConfig, error) { + cfg := hostQueryConfig{ + enabled: c.GetBool("host_query.enabled", false), + listen: c.GetString("host_query.listen", ""), + } + if !cfg.enabled { + return cfg, nil + } + + if cfg.listen == "" { + return cfg, errors.New("host_query.listen can not be empty when host_query is enabled") + } + network, addr, err := parseHostQueryListen(cfg.listen) + if err != nil { + return cfg, err + } + cfg.network = network + cfg.addr = addr + + if network == "unix" { + // Read as a string so YAML can't reinterpret the octal literal. + modeStr := c.GetString("host_query.socket_mode", "0600") + mode, err := strconv.ParseUint(modeStr, 8, 32) + if err != nil || fs.FileMode(mode)&^fs.ModePerm != 0 { + return cfg, fmt.Errorf("host_query.socket_mode was not a valid octal file mode: %s", modeStr) + } + cfg.socketMode = fs.FileMode(mode) + } + return cfg, nil +} + +// parseHostQueryListen splits the host_query.listen config value into a +// network and address for net.Listen: `unix:///abs/path.sock` selects a unix +// socket, anything else must be a tcp host:port. +func parseHostQueryListen(listen string) (network string, addr string, err error) { + if path, ok := strings.CutPrefix(listen, "unix://"); ok { + if !filepath.IsAbs(path) { + return "", "", fmt.Errorf("host_query.listen unix socket path must be absolute: %s", listen) + } + return "unix", path, nil + } + + if _, _, err = net.SplitHostPort(listen); err != nil { + return "", "", fmt.Errorf("host_query.listen must be a unix:// socket path or a host:port address: %s", listen) + } + return "tcp", listen, nil +} diff --git a/host_query_server_test.go b/host_query_server_test.go new file mode 100644 index 00000000..e4323950 --- /dev/null +++ b/host_query_server_test.go @@ -0,0 +1,450 @@ +package nebula + +import ( + "context" + "encoding/json" + "io/fs" + "log/slog" + "net" + "net/http" + "net/http/httptest" + "net/netip" + "net/url" + "os" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_parseHostQueryListen(t *testing.T) { + tests := []struct { + listen string + network string + addr string + wantErr bool + }{ + {listen: "unix:///var/run/nebula.sock", network: "unix", addr: "/var/run/nebula.sock"}, + {listen: "127.0.0.1:8085", network: "tcp", addr: "127.0.0.1:8085"}, + {listen: "[::1]:8085", network: "tcp", addr: "[::1]:8085"}, + {listen: "localhost:8085", network: "tcp", addr: "localhost:8085"}, + {listen: "", wantErr: true}, + {listen: "unix://", wantErr: true}, + {listen: "unix://relative/path.sock", wantErr: true}, + {listen: "not an address", wantErr: true}, + {listen: "127.0.0.1", wantErr: true}, + } + + for _, tt := range tests { + network, addr, err := parseHostQueryListen(tt.listen) + if tt.wantErr { + require.Error(t, err, "listen=%q", tt.listen) + continue + } + require.NoError(t, err, "listen=%q", tt.listen) + assert.Equal(t, tt.network, network, "listen=%q", tt.listen) + assert.Equal(t, tt.addr, addr, "listen=%q", tt.listen) + } +} + +func Test_loadHostQueryConfig(t *testing.T) { + c := config.NewC(nil) + + // Absent section: disabled, no error. + cfg, err := loadHostQueryConfig(c) + require.NoError(t, err) + assert.False(t, cfg.enabled) + + // Enabled without a listen address is an error. + setHostQueryConfig(c, true, "", "") + _, err = loadHostQueryConfig(c) + require.Error(t, err) + + // Unix socket gets the default mode. + setHostQueryConfig(c, true, "unix:///tmp/hq.sock", "") + cfg, err = loadHostQueryConfig(c) + require.NoError(t, err) + assert.Equal(t, "unix", cfg.network) + assert.Equal(t, "/tmp/hq.sock", cfg.addr) + assert.Equal(t, fs.FileMode(0o600), cfg.socketMode) + + setHostQueryConfig(c, true, "unix:///tmp/hq.sock", "0660") + cfg, err = loadHostQueryConfig(c) + require.NoError(t, err) + assert.Equal(t, fs.FileMode(0o660), cfg.socketMode) + + setHostQueryConfig(c, true, "unix:///tmp/hq.sock", "withers") + _, err = loadHostQueryConfig(c) + require.Error(t, err) + + // Mode bits beyond the permission bits are rejected. + setHostQueryConfig(c, true, "unix:///tmp/hq.sock", "10600") + _, err = loadHostQueryConfig(c) + require.Error(t, err) + + setHostQueryConfig(c, true, "127.0.0.1:8085", "") + cfg, err = loadHostQueryConfig(c) + require.NoError(t, err) + assert.Equal(t, "tcp", cfg.network) + assert.Equal(t, "127.0.0.1:8085", cfg.addr) +} + +func setHostQueryConfig(c *config.C, enabled bool, listen, socketMode string) { + settings := map[string]any{ + "enabled": enabled, + "listen": listen, + } + if socketMode != "" { + settings["socket_mode"] = socketMode + } + c.Settings["host_query"] = settings +} + +func newTestHostQueryServer(t *testing.T) (*hostQueryServer, *config.C) { + t.Helper() + h := &hostQueryServer{ + l: slog.New(slog.DiscardHandler), + ctx: context.Background(), + hostMap: newHostMap(slog.New(slog.DiscardHandler)), + } + h.hostMap.preferredRanges.Store(&[]netip.Prefix{}) + return h, config.NewC(nil) +} + +// addTestPeer creates a certificate for a peer owning each addr (as a /24 or +// /64) and inserts it into the hostmap as an established tunnel. +func addTestPeer(t *testing.T, hm *HostMap, name string, addrs []netip.Addr, unsafeNetworks []netip.Prefix, groups []string) cert.Certificate { + t.Helper() + networks := make([]netip.Prefix, 0, len(addrs)) + for _, a := range addrs { + bits := 24 + if a.Is6() { + bits = 64 + } + networks = append(networks, netip.PrefixFrom(a, bits)) + } + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) + crt, _, _, _ := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, ca, caKey, name, time.Time{}, time.Time{}, networks, unsafeNetworks, groups) + fp, err := crt.Fingerprint() + require.NoError(t, err) + + hm.unlockedAddHostInfo(&HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &cert.CachedCertificate{Certificate: crt, Fingerprint: fp}, + }, + vpnAddrs: addrs, + relayState: RelayState{ + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, + }, + }, &Interface{}) + return crt +} + +func getHost(t *testing.T, h *hostQueryServer, addrParam string) (int, map[string]any) { + t.Helper() + r := httptest.NewRequest(http.MethodGet, "/v1/host?addr="+url.QueryEscape(addrParam), nil) + w := httptest.NewRecorder() + h.handleHost(w, r) + return decodeResponse(t, w) +} + +func decodeResponse(t *testing.T, w *httptest.ResponseRecorder) (int, map[string]any) { + t.Helper() + assert.Equal(t, "application/json", w.Header().Get("Content-Type")) + var body map[string]any + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + return w.Code, body +} + +func TestHostQueryServer_handleHost(t *testing.T) { + h, _ := newTestHostQueryServer(t) + h.pki = newTestPKI(t, "self", []netip.Addr{netip.MustParseAddr("10.0.0.1")}) + + peerV4 := netip.MustParseAddr("10.0.0.99") + peerV6 := netip.MustParseAddr("fd00::99") + addTestPeer(t, h.hostMap, "laptop-alice", []netip.Addr{peerV4, peerV6}, + []netip.Prefix{netip.MustParsePrefix("192.168.50.0/24")}, []string{"eng", "ssh"}) + addTestPeer(t, h.hostMap, "groupless", []netip.Addr{netip.MustParseAddr("10.0.0.77")}, nil, nil) + + // An established peer comes back with its full identity. + code, body := getHost(t, h, "10.0.0.99") + require.Equal(t, http.StatusOK, code) + assert.Equal(t, "laptop-alice", body["name"]) + assert.Equal(t, []any{"10.0.0.99", "fd00::99"}, body["vpnAddrs"]) + assert.Equal(t, []any{"10.0.0.99/24", "fd00::99/64"}, body["networks"]) + assert.Equal(t, []any{"192.168.50.0/24"}, body["unsafeNetworks"]) + assert.Equal(t, []any{"eng", "ssh"}, body["groups"]) + assert.NotEmpty(t, body["fingerprint"]) + assert.Equal(t, float64(2), body["certVersion"]) + assert.NotEmpty(t, body["notBefore"]) + assert.NotEmpty(t, body["notAfter"]) + + // Empty cert slices marshal as [] rather than null. + code, body = getHost(t, h, "10.0.0.77") + require.Equal(t, http.StatusOK, code) + require.NotNil(t, body["groups"]) + assert.Empty(t, body["groups"]) + require.NotNil(t, body["unsafeNetworks"]) + assert.Empty(t, body["unsafeNetworks"]) + + // A port in addr is ignored so RemoteAddr can be passed through directly, + // including the bracketed v6 and 4in6 forms. + for _, q := range []string{"10.0.0.99:54321", "[fd00::99]:443", "::ffff:10.0.0.99"} { + code, body = getHost(t, h, q) + require.Equal(t, http.StatusOK, code, "addr=%q", q) + assert.Equal(t, "laptop-alice", body["name"], "addr=%q", q) + } + + // Our own address answers from the local cert state. + code, body = getHost(t, h, "10.0.0.1") + require.Equal(t, http.StatusOK, code) + assert.Equal(t, "self", body["name"]) + + code, body = getHost(t, h, "10.0.0.42") + assert.Equal(t, http.StatusNotFound, code) + assert.NotEmpty(t, body["error"]) + + // A tunnel mid-teardown (no peer cert) is treated as unknown. + h.hostMap.unlockedAddHostInfo(&HostInfo{ + ConnectionState: &ConnectionState{}, + vpnAddrs: []netip.Addr{netip.MustParseAddr("10.0.0.66")}, + relayState: RelayState{ + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, + }, + }, &Interface{}) + code, _ = getHost(t, h, "10.0.0.66") + assert.Equal(t, http.StatusNotFound, code) + + code, body = getHost(t, h, "not-an-address") + assert.Equal(t, http.StatusBadRequest, code) + assert.NotEmpty(t, body["error"]) + + r := httptest.NewRequest(http.MethodGet, "/v1/host", nil) + w := httptest.NewRecorder() + h.handleHost(w, r) + code, body = decodeResponse(t, w) + assert.Equal(t, http.StatusBadRequest, code) + assert.NotEmpty(t, body["error"]) +} + +func TestHostQueryServer_handleSelf(t *testing.T) { + h, _ := newTestHostQueryServer(t) + h.pki = newTestPKI(t, "lighthouse", []netip.Addr{netip.MustParseAddr("10.0.0.1")}) + + r := httptest.NewRequest(http.MethodGet, "/v1/self", nil) + w := httptest.NewRecorder() + h.handleSelf(w, r) + code, body := decodeResponse(t, w) + require.Equal(t, http.StatusOK, code) + assert.Equal(t, "lighthouse", body["name"]) + assert.Equal(t, []any{"10.0.0.1"}, body["vpnAddrs"]) + + // No cert state available should be an error, not a panic. + h.pki = nil + w = httptest.NewRecorder() + h.handleSelf(w, r) + code, body = decodeResponse(t, w) + assert.Equal(t, http.StatusInternalServerError, code) + assert.NotEmpty(t, body["error"]) +} + +func unixHTTPClient(path string) *http.Client { + return &http.Client{ + Timeout: time.Second, + Transport: &http.Transport{ + DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + return (&net.Dialer{}).DialContext(ctx, "unix", path) + }, + }, + } +} + +// waitForServe polls until a GET /v1/self through client succeeds. +func waitForServe(t *testing.T, client *http.Client) { + t.Helper() + waitFor(t, func() bool { + resp, err := client.Get("http://hostquery/v1/self") + if err != nil { + return false + } + resp.Body.Close() + return resp.StatusCode == http.StatusOK + }) +} + +func skipIfNoUnixSockets(t *testing.T) { + t.Helper() + if runtime.GOOS == "windows" { + t.Skip("unix socket tests are not supported on windows CI") + } +} + +func TestHostQueryServer_unixLifecycle(t *testing.T) { + skipIfNoUnixSockets(t) + h, c := newTestHostQueryServer(t) + h.pki = newTestPKI(t, "self", []netip.Addr{netip.MustParseAddr("10.0.0.1")}) + + sock := filepath.Join(t.TempDir(), "hq.sock") + setHostQueryConfig(c, true, "unix://"+sock, "") + require.NoError(t, h.reload(c, true)) + + done := make(chan struct{}) + go func() { + h.Start() + close(done) + }() + + client := unixHTTPClient(sock) + waitForServe(t, client) + + fi, err := os.Stat(sock) + require.NoError(t, err) + assert.Equal(t, fs.FileMode(0o600), fi.Mode().Perm()) + + resp, err := client.Get("http://hostquery/v1/host?addr=10.0.0.1") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + + h.Stop() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Start did not return after Stop") + } + _, err = os.Stat(sock) + assert.True(t, os.IsNotExist(err), "socket file should be unlinked on shutdown") +} + +func TestHostQueryServer_tcpLifecycle(t *testing.T) { + h, c := newTestHostQueryServer(t) + h.pki = newTestPKI(t, "self", []netip.Addr{netip.MustParseAddr("10.0.0.1")}) + + setHostQueryConfig(c, true, "127.0.0.1:0", "") + require.NoError(t, h.reload(c, true)) + + done := make(chan struct{}) + go func() { + h.Start() + close(done) + }() + + var addr string + waitFor(t, func() bool { + h.runMu.Lock() + defer h.runMu.Unlock() + if h.run == nil { + return false + } + addr = h.run.listener.Addr().String() + return true + }) + + client := &http.Client{Timeout: time.Second} + resp, err := client.Get("http://" + addr + "/v1/self") + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + var body map[string]any + require.NoError(t, json.NewDecoder(resp.Body).Decode(&body)) + assert.Equal(t, "self", body["name"]) + + h.Stop() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Start did not return after Stop") + } +} + +func TestHostQueryServer_staleSocket(t *testing.T) { + skipIfNoUnixSockets(t) + h, _ := newTestHostQueryServer(t) + sock := filepath.Join(t.TempDir(), "hq.sock") + + // Simulate an unclean exit: a leftover socket file with no listener. + stale, err := net.ListenUnix("unix", &net.UnixAddr{Name: sock, Net: "unix"}) + require.NoError(t, err) + stale.SetUnlinkOnClose(false) + require.NoError(t, stale.Close()) + _, err = os.Stat(sock) + require.NoError(t, err, "stale socket file should exist") + + cfg := hostQueryConfig{network: "unix", addr: sock, socketMode: 0o600} + ln, err := h.listen(cfg) + require.NoError(t, err, "a stale socket should be removed and rebound") + require.NoError(t, ln.Close()) +} + +func TestHostQueryServer_existingFileNotReplaced(t *testing.T) { + skipIfNoUnixSockets(t) + h, _ := newTestHostQueryServer(t) + path := filepath.Join(t.TempDir(), "hq.sock") + require.NoError(t, os.WriteFile(path, []byte("precious"), 0o600)) + + cfg := hostQueryConfig{network: "unix", addr: path, socketMode: 0o600} + _, err := h.listen(cfg) + require.Error(t, err, "a non-socket file at the listen path must not be replaced") + + content, err := os.ReadFile(path) + require.NoError(t, err) + assert.Equal(t, "precious", string(content)) +} + +func TestHostQueryServer_reload(t *testing.T) { + skipIfNoUnixSockets(t) + h, c := newTestHostQueryServer(t) + h.pki = newTestPKI(t, "self", []netip.Addr{netip.MustParseAddr("10.0.0.1")}) + dir := t.TempDir() + sock1 := filepath.Join(dir, "hq1.sock") + sock2 := filepath.Join(dir, "hq2.sock") + + // Initial reload only records config; Control.Start launches the runtime. + setHostQueryConfig(c, false, "unix://"+sock1, "") + require.NoError(t, h.reload(c, true)) + assert.False(t, h.enabled.Load()) + h.runMu.Lock() + assert.Nil(t, h.run) + h.runMu.Unlock() + + // Enabling via reload spawns the listener. + setHostQueryConfig(c, true, "unix://"+sock1, "") + require.NoError(t, h.reload(c, false)) + waitForServe(t, unixHTTPClient(sock1)) + + // Changing the listen path restarts on the new address. + setHostQueryConfig(c, true, "unix://"+sock2, "") + require.NoError(t, h.reload(c, false)) + waitForServe(t, unixHTTPClient(sock2)) + waitFor(t, func() bool { + _, err := os.Stat(sock1) + return os.IsNotExist(err) + }) + + // Reloading an unchanged config does not restart the runtime. + h.runMu.Lock() + rt := h.run + h.runMu.Unlock() + require.NoError(t, h.reload(c, false)) + h.runMu.Lock() + assert.Same(t, rt, h.run) + h.runMu.Unlock() + + // Disabling stops the listener. + setHostQueryConfig(c, false, "unix://"+sock2, "") + require.NoError(t, h.reload(c, false)) + assert.False(t, h.enabled.Load()) + waitFor(t, func() bool { + h.runMu.Lock() + defer h.runMu.Unlock() + return h.run == nil + }) +} diff --git a/main.go b/main.go index 7d7a0f72..ae6288b4 100644 --- a/main.go +++ b/main.go @@ -249,6 +249,11 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err) } + hostQuery, err := newHostQueryServerFromConfig(ctx, l, pki, hostMap, c) + if err != nil { + return nil, util.ContextualizeIfNeeded("Failed to configure the host query API", err) + } + if configTest { return nil, nil } @@ -266,6 +271,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev sshStart: sshStart, statsStart: stats.Start, dnsStart: ds.Start, + hostQueryStart: hostQuery.Start, lighthouseStart: lightHouse.StartUpdateWorker, connectionManagerStart: connManager.Start, }, nil