diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ef7551f..a4dce9b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- New experimental `host_query` config section exposing a local HTTP+JSON API + (`GET /v1/host?addr=...`, `GET /v1/self`) over a unix socket or loopback tcp + listener, so programs on the same host can resolve a vpn address to its + certificate identity (name, groups, networks) for authorization decisions. + Disabled by default. + ## [1.10.3] - 2026-02-06 ### Security diff --git a/control.go b/control.go index ef58988b..94b620d6 100644 --- a/control.go +++ b/control.go @@ -52,6 +52,7 @@ type Control struct { sshStart func() statsStart func() dnsStart func() + hostQueryStart func() lighthouseStart func() connectionManagerStart func(context.Context) } @@ -104,6 +105,9 @@ func (c *Control) Start() (func() error, error) { if c.dnsStart != nil { go c.dnsStart() } + if c.hostQueryStart != nil { + go c.hostQueryStart() + } if c.connectionManagerStart != nil { go c.connectionManagerStart(c.ctx) } diff --git a/dns_server.go b/dns_server.go index a80630b5..06c4c6fc 100644 --- a/dns_server.go +++ b/dns_server.go @@ -249,31 +249,12 @@ func (d *dnsServer) QueryCert(data string) string { return "" } - // The hostmap only ever contains peers we have handshaked with, so it never carries an entry for ourselves. - // Answer self lookups straight from the local cert state. - if cs := d.certState(); cs != nil && cs.myVpnAddrsTable != nil && cs.myVpnAddrsTable.Contains(ip) { - c := cs.GetDefaultCertificate() - if c == nil { - return "" - } - b, err := c.MarshalJSON() - if err != nil { - return "" - } - return string(b) - } - - hostinfo := d.hostMap.QueryVpnAddr(ip) - if hostinfo == nil { + crt := findCertificateForVpnAddr(d.certState(), d.hostMap, ip) + if crt == nil { return "" } - q := hostinfo.GetCert() - if q == nil { - return "" - } - - b, err := q.Certificate.MarshalJSON() + b, err := crt.MarshalJSON() if err != nil { return "" } diff --git a/examples/config.yml b/examples/config.yml index 4f7fd1e7..5bd65221 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -223,6 +223,38 @@ punchy: # Overriding this to "" is the same as "/" and will allow overwriting any path on the host. #sandbox_dir: /var/tmp/nebula-debug +# EXPERIMENTAL: this feature may change or disappear in the future. +# host_query exposes a small local HTTP+JSON API that lets other programs on +# this machine resolve a vpn address to its certificate identity (name, vpn +# addresses, groups, fingerprint, validity), e.g. for making authorization +# decisions about an inbound connection: +# GET /v1/host?addr= - identity of the host owning the address: a +# peer with an active tunnel, or this node itself. `addr` may include a +# port (`192.168.100.7:54321`), which is ignored, so a connection's remote +# address can be passed through as is. Returns 404 when the address is +# unknown or has no active tunnel. +# GET /v1/self - this node's own identity. +# Identity answers can be trusted because nebula drops inbound packets whose +# source vpn address is not contained in the sender's certificate, so the +# source address of a connection arriving over the nebula interface is +# guaranteed to map to the certificate reported here. +# There is no authentication in this API; restrict access with unix socket +# file permissions or by listening on a loopback address. Do NOT use a +# non-loopback tcp address unless you intend to let everything that can reach +# it query host identities. +# This whole section is reloadable. +#host_query: + # Toggles the feature + #enabled: false + # listen accepts a unix socket path or a tcp host:port: + #listen: unix:///var/run/nebula-host-query.sock + #listen: 127.0.0.1:8085 + # File mode for the unix socket, as an octal string. Ignored for tcp. + # The socket is created by nebula's user; to grant a group of local services + # access, place the socket in a directory with appropriate permissions + # (e.g. a systemd RuntimeDirectory) and relax this to "0660". + #socket_mode: "0600" + # EXPERIMENTAL: relay support for networks that can't establish direct connections. relay: # Relays are a list of Nebula IP's that peers can use to relay packets to me. diff --git a/host_query_server.go b/host_query_server.go new file mode 100644 index 00000000..9211d3d9 --- /dev/null +++ b/host_query_server.go @@ -0,0 +1,459 @@ +package nebula + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io/fs" + "log/slog" + "net" + "net/http" + "net/netip" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/config" +) + +// hostQueryServer owns the local host query API: a small HTTP+JSON listener +// on a unix socket or tcp address that lets other programs on this machine +// resolve a vpn address to its certificate identity (name, groups, networks) +// for making authorization decisions. It mirrors the lifecycle shape of +// statsServer: constructor wires the reload callback, reload records config, +// Start builds and runs the runtime, Stop tears it down. +type hostQueryServer struct { + l *slog.Logger + ctx context.Context + hostMap *HostMap + pki *PKI + + // enabled mirrors `host_query.enabled`. Start consults it so callers + // don't need to know the gating rules. + enabled atomic.Bool + + runMu sync.Mutex + runCfg *hostQueryConfig + run *hostQueryRuntime // non-nil while a runtime is live +} + +// hostQueryRuntime is the live state owned by a single Start invocation. +// Start stashes a pointer under runMu; Stop and Start's own exit path use +// pointer equality to tell "my runtime" apart from one that replaced it +// after a reload. +type hostQueryRuntime struct { + server *http.Server + listener net.Listener +} + +// hostQueryConfig is the snapshot of host_query config that drives the +// runtime. It is comparable with == so reload can detect "no change" cheaply. +type hostQueryConfig struct { + enabled bool + listen string // raw config value, for error messages + network string // "unix" or "tcp" + addr string // socket path or host:port + // socketMode is the file mode applied to the unix socket after bind. + socketMode fs.FileMode +} + +// newHostQueryServerFromConfig builds a hostQueryServer, applies the initial +// config, and registers a reload callback. The reload callback is registered +// before the initial config is applied so a SIGHUP can later enable, fix, or +// disable the listener even if the initial application failed. +// +// Construction never binds the listener; that happens in Start, so config +// tests are side effect free. Start is safe to call unconditionally: it +// no-ops when the host query API is disabled. The returned pointer is always +// non-nil, even on error. +func newHostQueryServerFromConfig(ctx context.Context, l *slog.Logger, pki *PKI, hostMap *HostMap, c *config.C) (*hostQueryServer, error) { + h := &hostQueryServer{ + l: l, + ctx: ctx, + hostMap: hostMap, + pki: pki, + } + + c.RegisterReloadCallback(func(c *config.C) { + if err := h.reload(c, false); err != nil { + h.l.Error("Failed to reload host query API from config", "error", err) + } + }) + + if err := h.reload(c, true); err != nil { + return h, err + } + return h, nil +} + +// reload records the latest config. On the initial call it only records it; +// Control.Start is what launches the first runtime via hostQueryStart. On +// later calls it reconciles the running runtime with the new config: +// +// - newly enabled -> spawn Start +// - newly disabled -> Stop the runtime +// - config changed (still enabled) -> Stop the old, Start the new +// - no change -> no-op +func (h *hostQueryServer) reload(c *config.C, initial bool) error { + newCfg, err := loadHostQueryConfig(c) + if err != nil { + return err + } + + h.runMu.Lock() + sameCfg := h.runCfg != nil && *h.runCfg == newCfg + h.runCfg = &newCfg + running := h.run != nil + h.runMu.Unlock() + + h.enabled.Store(newCfg.enabled) + + if initial || sameCfg { + return nil + } + + if running { + h.Stop() + } + if newCfg.enabled { + go h.Start() + } + return nil +} + +// Start binds the listener from the latest config and serves until Stop is +// called or ctx fires. Safe to call when the host query API is disabled or +// already running (both no-op). +func (h *hostQueryServer) Start() { + if !h.enabled.Load() { + return + } + + h.runMu.Lock() + if h.ctx.Err() != nil || h.run != nil || h.runCfg == nil { + h.runMu.Unlock() + return + } + cfg := *h.runCfg + ln, err := h.listen(cfg) + if err != nil { + // Drop the cached config so a SIGHUP with the same config re-triggers + // Start once the user fixes the underlying problem. + h.runCfg = nil + h.runMu.Unlock() + h.l.Error("Failed to start host query listener", "listen", cfg.listen, "error", err) + return + } + + mux := http.NewServeMux() + mux.HandleFunc("GET /v1/host", h.handleHost) + mux.HandleFunc("GET /v1/self", h.handleSelf) + srv := &http.Server{Handler: mux, ReadHeaderTimeout: 5 * time.Second} + rt := &hostQueryRuntime{server: srv, listener: ln} + h.run = rt + h.runMu.Unlock() + + h.l.Info("Starting host query listener", "network", cfg.network, "addr", ln.Addr()) + cleanExit := h.serve(srv, ln) + + // A Stop that raced our bind shut the server down before Serve could + // adopt the listener; closing it again is harmless and guarantees a unix + // socket file gets unlinked. + _ = ln.Close() + + // Clear our runtime only if nothing has replaced it. Stop races through + // here too but leaves h.run == nil, so the pointer check skips. + h.runMu.Lock() + if h.run == rt { + h.run = nil + // A listener that exited with an error leaves runCfg cached as if it + // were applied. Drop it so a SIGHUP with the same config re-triggers + // Start once the user fixes the underlying problem. + if !cleanExit { + h.runCfg = nil + } + } + h.runMu.Unlock() +} + +// serve runs srv.Serve and ensures ctx cancellation unblocks it. Returns true +// if the listener exited cleanly (Stop, ctx cancellation, or any other +// http.ErrServerClosed path), false on an unexpected error. +func (h *hostQueryServer) serve(srv *http.Server, ln net.Listener) bool { + // Per-invocation watcher: ctx cancellation triggers a server shutdown + // which in turn unblocks Serve. Closing `done` on exit keeps the watcher + // from outliving this call. + done := make(chan struct{}) + go func() { + select { + case <-h.ctx.Done(): + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := srv.Shutdown(shutdownCtx); err != nil { + h.l.Warn("Failed to shut down host query listener", "error", err) + } + case <-done: + } + }() + defer close(done) + + err := srv.Serve(ln) + if err == nil || errors.Is(err, http.ErrServerClosed) { + return true + } + h.l.Error("Host query listener exited", "error", err) + return false +} + +// Stop tears down the active runtime, if any. Idempotent. +func (h *hostQueryServer) Stop() { + h.runMu.Lock() + rt := h.run + h.run = nil + h.runMu.Unlock() + if rt == nil { + return + } + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := rt.server.Shutdown(shutdownCtx); err != nil { + h.l.Warn("Failed to shut down host query listener", "error", err) + } +} + +// listen binds the configured address. For unix sockets it also clears a +// stale socket file left by an unclean exit and applies the configured file +// mode. +func (h *hostQueryServer) listen(cfg hostQueryConfig) (net.Listener, error) { + if cfg.network == "unix" { + return h.listenUnix(cfg) + } + + if host, _, err := net.SplitHostPort(cfg.addr); err == nil { + ip, ipErr := netip.ParseAddr(host) + if host == "" || (ipErr == nil && !ip.IsLoopback()) { + h.l.Warn("host_query is listening on a non-loopback tcp address; anything that can reach it can query host identities", "addr", cfg.addr) + } + } + return net.Listen("tcp", cfg.addr) +} + +func (h *hostQueryServer) listenUnix(cfg hostQueryConfig) (net.Listener, error) { + if fi, err := os.Stat(cfg.addr); err == nil { + if fi.Mode()&os.ModeSocket == 0 { + return nil, fmt.Errorf("host_query.listen path %s exists and is not a socket, refusing to replace it", cfg.addr) + } + // A normal shutdown unlinks the socket (unlink-on-close), so a file + // here means a previous process exited uncleanly. Remove it so the + // bind below can succeed. + if err = os.Remove(cfg.addr); err != nil { + return nil, fmt.Errorf("failed to remove stale socket %s: %w", cfg.addr, err) + } + } + + ln, err := net.Listen("unix", cfg.addr) + if err != nil { + return nil, err + } + // The socket is briefly live with umask-derived permissions before this + // chmod lands; tolerated because connections accepted in that window + // still only reach this read-only API. + if err = os.Chmod(cfg.addr, cfg.socketMode); err != nil { + _ = ln.Close() + return nil, fmt.Errorf("failed to set mode on socket %s: %w", cfg.addr, err) + } + return ln, nil +} + +func (h *hostQueryServer) certState() *CertState { + if h.pki == nil { + return nil + } + return h.pki.getCertState() +} + +// handleHost serves GET /v1/host?addr=, answering with the identity +// of the host that owns the address: a peer with an active tunnel, or this +// node itself. addr may include a port, which is ignored, so clients can pass +// a connection's remote address through without parsing it. +func (h *hostQueryServer) handleHost(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query().Get("addr") + if q == "" { + writeJSONError(w, http.StatusBadRequest, "missing addr parameter") + return + } + ip, err := parseQueryAddrParam(q) + if err != nil { + writeJSONError(w, http.StatusBadRequest, "invalid address") + return + } + + crt := findCertificateForVpnAddr(h.certState(), h.hostMap, ip) + if crt == nil { + writeJSONError(w, http.StatusNotFound, "no active tunnel for address") + return + } + h.writeHostIdentity(w, crt) +} + +// handleSelf serves GET /v1/self, answering with this node's own identity. +func (h *hostQueryServer) handleSelf(w http.ResponseWriter, r *http.Request) { + var crt cert.Certificate + if cs := h.certState(); cs != nil { + crt = cs.getCertificate(cs.initiatingVersion) + } + if crt == nil { + writeJSONError(w, http.StatusInternalServerError, "no certificate available") + return + } + h.writeHostIdentity(w, crt) +} + +func (h *hostQueryServer) writeHostIdentity(w http.ResponseWriter, crt cert.Certificate) { + id, err := newHostIdentity(crt) + if err != nil { + writeJSONError(w, http.StatusInternalServerError, "failed to fingerprint certificate") + return + } + w.Header().Set("Content-Type", "application/json") + if err = json.NewEncoder(w).Encode(id); err != nil { + h.l.Debug("Failed to write host query response", "error", err) + } +} + +// findCertificateForVpnAddr answers "who owns this vpn address": ourselves +// (from local cert state, since the hostmap never carries an entry for this +// node) or a peer with an active tunnel. Returns nil when the address is +// unknown or the tunnel is mid-teardown. +func findCertificateForVpnAddr(cs *CertState, hostMap *HostMap, ip netip.Addr) cert.Certificate { + if cs != nil && cs.myVpnAddrsTable != nil && cs.myVpnAddrsTable.Contains(ip) { + return cs.getCertificate(cs.initiatingVersion) + } + + hostinfo := hostMap.QueryVpnAddr(ip) + if hostinfo == nil { + return nil + } + cc := hostinfo.GetCert() + if cc == nil { + return nil + } + return cc.Certificate +} + +// hostIdentity is the JSON document served for both /v1/host and /v1/self. +// Every field is derived from the authenticated certificate alone. +type hostIdentity struct { + Name string `json:"name"` + VpnAddrs []netip.Addr `json:"vpnAddrs"` + Networks []netip.Prefix `json:"networks"` + UnsafeNetworks []netip.Prefix `json:"unsafeNetworks"` + Groups []string `json:"groups"` + Fingerprint string `json:"fingerprint"` + Issuer string `json:"issuer"` + NotBefore time.Time `json:"notBefore"` + NotAfter time.Time `json:"notAfter"` + CertVersion int `json:"certVersion"` +} + +func newHostIdentity(crt cert.Certificate) (hostIdentity, error) { + fp, err := crt.Fingerprint() + if err != nil { + return hostIdentity{}, err + } + + // Slices are always allocated so they marshal as [] rather than null; + // consumers iterate groups without a presence check. + networks := crt.Networks() + id := hostIdentity{ + Name: crt.Name(), + VpnAddrs: make([]netip.Addr, 0, len(networks)), + Networks: append(make([]netip.Prefix, 0, len(networks)), networks...), + UnsafeNetworks: append(make([]netip.Prefix, 0, len(crt.UnsafeNetworks())), crt.UnsafeNetworks()...), + Groups: append(make([]string, 0, len(crt.Groups())), crt.Groups()...), + Fingerprint: fp, + Issuer: crt.Issuer(), + NotBefore: crt.NotBefore(), + NotAfter: crt.NotAfter(), + CertVersion: int(crt.Version()), + } + for _, n := range networks { + id.VpnAddrs = append(id.VpnAddrs, n.Addr()) + } + return id, nil +} + +func writeJSONError(w http.ResponseWriter, status int, msg string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(map[string]string{"error": msg}) +} + +// parseQueryAddrParam parses the addr query parameter, accepting a bare +// address or an address with a port (`192.168.100.7:54321`, `[fd00::1]:443`) +// so callers can pass a connection's RemoteAddr straight through. The result +// is unmapped: 4in6 addresses (::ffff:a.b.c.d) are normalized to ipv4. +func parseQueryAddrParam(s string) (netip.Addr, error) { + if ip, err := netip.ParseAddr(s); err == nil { + return ip.Unmap(), nil + } + ap, err := netip.ParseAddrPort(s) + if err != nil { + return netip.Addr{}, err + } + return ap.Addr().Unmap(), nil +} + +func loadHostQueryConfig(c *config.C) (hostQueryConfig, error) { + cfg := hostQueryConfig{ + enabled: c.GetBool("host_query.enabled", false), + listen: c.GetString("host_query.listen", ""), + } + if !cfg.enabled { + return cfg, nil + } + + if cfg.listen == "" { + return cfg, errors.New("host_query.listen can not be empty when host_query is enabled") + } + network, addr, err := parseHostQueryListen(cfg.listen) + if err != nil { + return cfg, err + } + cfg.network = network + cfg.addr = addr + + if network == "unix" { + // Read as a string so YAML can't reinterpret the octal literal. + modeStr := c.GetString("host_query.socket_mode", "0600") + mode, err := strconv.ParseUint(modeStr, 8, 32) + if err != nil || fs.FileMode(mode)&^fs.ModePerm != 0 { + return cfg, fmt.Errorf("host_query.socket_mode was not a valid octal file mode: %s", modeStr) + } + cfg.socketMode = fs.FileMode(mode) + } + return cfg, nil +} + +// parseHostQueryListen splits the host_query.listen config value into a +// network and address for net.Listen: `unix:///abs/path.sock` selects a unix +// socket, anything else must be a tcp host:port. +func parseHostQueryListen(listen string) (network string, addr string, err error) { + if path, ok := strings.CutPrefix(listen, "unix://"); ok { + if !filepath.IsAbs(path) { + return "", "", fmt.Errorf("host_query.listen unix socket path must be absolute: %s", listen) + } + return "unix", path, nil + } + + if _, _, err = net.SplitHostPort(listen); err != nil { + return "", "", fmt.Errorf("host_query.listen must be a unix:// socket path or a host:port address: %s", listen) + } + return "tcp", listen, nil +} diff --git a/host_query_server_test.go b/host_query_server_test.go new file mode 100644 index 00000000..e4323950 --- /dev/null +++ b/host_query_server_test.go @@ -0,0 +1,450 @@ +package nebula + +import ( + "context" + "encoding/json" + "io/fs" + "log/slog" + "net" + "net/http" + "net/http/httptest" + "net/netip" + "net/url" + "os" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" + "github.com/slackhq/nebula/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_parseHostQueryListen(t *testing.T) { + tests := []struct { + listen string + network string + addr string + wantErr bool + }{ + {listen: "unix:///var/run/nebula.sock", network: "unix", addr: "/var/run/nebula.sock"}, + {listen: "127.0.0.1:8085", network: "tcp", addr: "127.0.0.1:8085"}, + {listen: "[::1]:8085", network: "tcp", addr: "[::1]:8085"}, + {listen: "localhost:8085", network: "tcp", addr: "localhost:8085"}, + {listen: "", wantErr: true}, + {listen: "unix://", wantErr: true}, + {listen: "unix://relative/path.sock", wantErr: true}, + {listen: "not an address", wantErr: true}, + {listen: "127.0.0.1", wantErr: true}, + } + + for _, tt := range tests { + network, addr, err := parseHostQueryListen(tt.listen) + if tt.wantErr { + require.Error(t, err, "listen=%q", tt.listen) + continue + } + require.NoError(t, err, "listen=%q", tt.listen) + assert.Equal(t, tt.network, network, "listen=%q", tt.listen) + assert.Equal(t, tt.addr, addr, "listen=%q", tt.listen) + } +} + +func Test_loadHostQueryConfig(t *testing.T) { + c := config.NewC(nil) + + // Absent section: disabled, no error. + cfg, err := loadHostQueryConfig(c) + require.NoError(t, err) + assert.False(t, cfg.enabled) + + // Enabled without a listen address is an error. + setHostQueryConfig(c, true, "", "") + _, err = loadHostQueryConfig(c) + require.Error(t, err) + + // Unix socket gets the default mode. + setHostQueryConfig(c, true, "unix:///tmp/hq.sock", "") + cfg, err = loadHostQueryConfig(c) + require.NoError(t, err) + assert.Equal(t, "unix", cfg.network) + assert.Equal(t, "/tmp/hq.sock", cfg.addr) + assert.Equal(t, fs.FileMode(0o600), cfg.socketMode) + + setHostQueryConfig(c, true, "unix:///tmp/hq.sock", "0660") + cfg, err = loadHostQueryConfig(c) + require.NoError(t, err) + assert.Equal(t, fs.FileMode(0o660), cfg.socketMode) + + setHostQueryConfig(c, true, "unix:///tmp/hq.sock", "withers") + _, err = loadHostQueryConfig(c) + require.Error(t, err) + + // Mode bits beyond the permission bits are rejected. + setHostQueryConfig(c, true, "unix:///tmp/hq.sock", "10600") + _, err = loadHostQueryConfig(c) + require.Error(t, err) + + setHostQueryConfig(c, true, "127.0.0.1:8085", "") + cfg, err = loadHostQueryConfig(c) + require.NoError(t, err) + assert.Equal(t, "tcp", cfg.network) + assert.Equal(t, "127.0.0.1:8085", cfg.addr) +} + +func setHostQueryConfig(c *config.C, enabled bool, listen, socketMode string) { + settings := map[string]any{ + "enabled": enabled, + "listen": listen, + } + if socketMode != "" { + settings["socket_mode"] = socketMode + } + c.Settings["host_query"] = settings +} + +func newTestHostQueryServer(t *testing.T) (*hostQueryServer, *config.C) { + t.Helper() + h := &hostQueryServer{ + l: slog.New(slog.DiscardHandler), + ctx: context.Background(), + hostMap: newHostMap(slog.New(slog.DiscardHandler)), + } + h.hostMap.preferredRanges.Store(&[]netip.Prefix{}) + return h, config.NewC(nil) +} + +// addTestPeer creates a certificate for a peer owning each addr (as a /24 or +// /64) and inserts it into the hostmap as an established tunnel. +func addTestPeer(t *testing.T, hm *HostMap, name string, addrs []netip.Addr, unsafeNetworks []netip.Prefix, groups []string) cert.Certificate { + t.Helper() + networks := make([]netip.Prefix, 0, len(addrs)) + for _, a := range addrs { + bits := 24 + if a.Is6() { + bits = 64 + } + networks = append(networks, netip.PrefixFrom(a, bits)) + } + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) + crt, _, _, _ := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, ca, caKey, name, time.Time{}, time.Time{}, networks, unsafeNetworks, groups) + fp, err := crt.Fingerprint() + require.NoError(t, err) + + hm.unlockedAddHostInfo(&HostInfo{ + ConnectionState: &ConnectionState{ + peerCert: &cert.CachedCertificate{Certificate: crt, Fingerprint: fp}, + }, + vpnAddrs: addrs, + relayState: RelayState{ + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, + }, + }, &Interface{}) + return crt +} + +func getHost(t *testing.T, h *hostQueryServer, addrParam string) (int, map[string]any) { + t.Helper() + r := httptest.NewRequest(http.MethodGet, "/v1/host?addr="+url.QueryEscape(addrParam), nil) + w := httptest.NewRecorder() + h.handleHost(w, r) + return decodeResponse(t, w) +} + +func decodeResponse(t *testing.T, w *httptest.ResponseRecorder) (int, map[string]any) { + t.Helper() + assert.Equal(t, "application/json", w.Header().Get("Content-Type")) + var body map[string]any + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + return w.Code, body +} + +func TestHostQueryServer_handleHost(t *testing.T) { + h, _ := newTestHostQueryServer(t) + h.pki = newTestPKI(t, "self", []netip.Addr{netip.MustParseAddr("10.0.0.1")}) + + peerV4 := netip.MustParseAddr("10.0.0.99") + peerV6 := netip.MustParseAddr("fd00::99") + addTestPeer(t, h.hostMap, "laptop-alice", []netip.Addr{peerV4, peerV6}, + []netip.Prefix{netip.MustParsePrefix("192.168.50.0/24")}, []string{"eng", "ssh"}) + addTestPeer(t, h.hostMap, "groupless", []netip.Addr{netip.MustParseAddr("10.0.0.77")}, nil, nil) + + // An established peer comes back with its full identity. + code, body := getHost(t, h, "10.0.0.99") + require.Equal(t, http.StatusOK, code) + assert.Equal(t, "laptop-alice", body["name"]) + assert.Equal(t, []any{"10.0.0.99", "fd00::99"}, body["vpnAddrs"]) + assert.Equal(t, []any{"10.0.0.99/24", "fd00::99/64"}, body["networks"]) + assert.Equal(t, []any{"192.168.50.0/24"}, body["unsafeNetworks"]) + assert.Equal(t, []any{"eng", "ssh"}, body["groups"]) + assert.NotEmpty(t, body["fingerprint"]) + assert.Equal(t, float64(2), body["certVersion"]) + assert.NotEmpty(t, body["notBefore"]) + assert.NotEmpty(t, body["notAfter"]) + + // Empty cert slices marshal as [] rather than null. + code, body = getHost(t, h, "10.0.0.77") + require.Equal(t, http.StatusOK, code) + require.NotNil(t, body["groups"]) + assert.Empty(t, body["groups"]) + require.NotNil(t, body["unsafeNetworks"]) + assert.Empty(t, body["unsafeNetworks"]) + + // A port in addr is ignored so RemoteAddr can be passed through directly, + // including the bracketed v6 and 4in6 forms. + for _, q := range []string{"10.0.0.99:54321", "[fd00::99]:443", "::ffff:10.0.0.99"} { + code, body = getHost(t, h, q) + require.Equal(t, http.StatusOK, code, "addr=%q", q) + assert.Equal(t, "laptop-alice", body["name"], "addr=%q", q) + } + + // Our own address answers from the local cert state. + code, body = getHost(t, h, "10.0.0.1") + require.Equal(t, http.StatusOK, code) + assert.Equal(t, "self", body["name"]) + + code, body = getHost(t, h, "10.0.0.42") + assert.Equal(t, http.StatusNotFound, code) + assert.NotEmpty(t, body["error"]) + + // A tunnel mid-teardown (no peer cert) is treated as unknown. + h.hostMap.unlockedAddHostInfo(&HostInfo{ + ConnectionState: &ConnectionState{}, + vpnAddrs: []netip.Addr{netip.MustParseAddr("10.0.0.66")}, + relayState: RelayState{ + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, + }, + }, &Interface{}) + code, _ = getHost(t, h, "10.0.0.66") + assert.Equal(t, http.StatusNotFound, code) + + code, body = getHost(t, h, "not-an-address") + assert.Equal(t, http.StatusBadRequest, code) + assert.NotEmpty(t, body["error"]) + + r := httptest.NewRequest(http.MethodGet, "/v1/host", nil) + w := httptest.NewRecorder() + h.handleHost(w, r) + code, body = decodeResponse(t, w) + assert.Equal(t, http.StatusBadRequest, code) + assert.NotEmpty(t, body["error"]) +} + +func TestHostQueryServer_handleSelf(t *testing.T) { + h, _ := newTestHostQueryServer(t) + h.pki = newTestPKI(t, "lighthouse", []netip.Addr{netip.MustParseAddr("10.0.0.1")}) + + r := httptest.NewRequest(http.MethodGet, "/v1/self", nil) + w := httptest.NewRecorder() + h.handleSelf(w, r) + code, body := decodeResponse(t, w) + require.Equal(t, http.StatusOK, code) + assert.Equal(t, "lighthouse", body["name"]) + assert.Equal(t, []any{"10.0.0.1"}, body["vpnAddrs"]) + + // No cert state available should be an error, not a panic. + h.pki = nil + w = httptest.NewRecorder() + h.handleSelf(w, r) + code, body = decodeResponse(t, w) + assert.Equal(t, http.StatusInternalServerError, code) + assert.NotEmpty(t, body["error"]) +} + +func unixHTTPClient(path string) *http.Client { + return &http.Client{ + Timeout: time.Second, + Transport: &http.Transport{ + DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + return (&net.Dialer{}).DialContext(ctx, "unix", path) + }, + }, + } +} + +// waitForServe polls until a GET /v1/self through client succeeds. +func waitForServe(t *testing.T, client *http.Client) { + t.Helper() + waitFor(t, func() bool { + resp, err := client.Get("http://hostquery/v1/self") + if err != nil { + return false + } + resp.Body.Close() + return resp.StatusCode == http.StatusOK + }) +} + +func skipIfNoUnixSockets(t *testing.T) { + t.Helper() + if runtime.GOOS == "windows" { + t.Skip("unix socket tests are not supported on windows CI") + } +} + +func TestHostQueryServer_unixLifecycle(t *testing.T) { + skipIfNoUnixSockets(t) + h, c := newTestHostQueryServer(t) + h.pki = newTestPKI(t, "self", []netip.Addr{netip.MustParseAddr("10.0.0.1")}) + + sock := filepath.Join(t.TempDir(), "hq.sock") + setHostQueryConfig(c, true, "unix://"+sock, "") + require.NoError(t, h.reload(c, true)) + + done := make(chan struct{}) + go func() { + h.Start() + close(done) + }() + + client := unixHTTPClient(sock) + waitForServe(t, client) + + fi, err := os.Stat(sock) + require.NoError(t, err) + assert.Equal(t, fs.FileMode(0o600), fi.Mode().Perm()) + + resp, err := client.Get("http://hostquery/v1/host?addr=10.0.0.1") + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + + h.Stop() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Start did not return after Stop") + } + _, err = os.Stat(sock) + assert.True(t, os.IsNotExist(err), "socket file should be unlinked on shutdown") +} + +func TestHostQueryServer_tcpLifecycle(t *testing.T) { + h, c := newTestHostQueryServer(t) + h.pki = newTestPKI(t, "self", []netip.Addr{netip.MustParseAddr("10.0.0.1")}) + + setHostQueryConfig(c, true, "127.0.0.1:0", "") + require.NoError(t, h.reload(c, true)) + + done := make(chan struct{}) + go func() { + h.Start() + close(done) + }() + + var addr string + waitFor(t, func() bool { + h.runMu.Lock() + defer h.runMu.Unlock() + if h.run == nil { + return false + } + addr = h.run.listener.Addr().String() + return true + }) + + client := &http.Client{Timeout: time.Second} + resp, err := client.Get("http://" + addr + "/v1/self") + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + var body map[string]any + require.NoError(t, json.NewDecoder(resp.Body).Decode(&body)) + assert.Equal(t, "self", body["name"]) + + h.Stop() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Start did not return after Stop") + } +} + +func TestHostQueryServer_staleSocket(t *testing.T) { + skipIfNoUnixSockets(t) + h, _ := newTestHostQueryServer(t) + sock := filepath.Join(t.TempDir(), "hq.sock") + + // Simulate an unclean exit: a leftover socket file with no listener. + stale, err := net.ListenUnix("unix", &net.UnixAddr{Name: sock, Net: "unix"}) + require.NoError(t, err) + stale.SetUnlinkOnClose(false) + require.NoError(t, stale.Close()) + _, err = os.Stat(sock) + require.NoError(t, err, "stale socket file should exist") + + cfg := hostQueryConfig{network: "unix", addr: sock, socketMode: 0o600} + ln, err := h.listen(cfg) + require.NoError(t, err, "a stale socket should be removed and rebound") + require.NoError(t, ln.Close()) +} + +func TestHostQueryServer_existingFileNotReplaced(t *testing.T) { + skipIfNoUnixSockets(t) + h, _ := newTestHostQueryServer(t) + path := filepath.Join(t.TempDir(), "hq.sock") + require.NoError(t, os.WriteFile(path, []byte("precious"), 0o600)) + + cfg := hostQueryConfig{network: "unix", addr: path, socketMode: 0o600} + _, err := h.listen(cfg) + require.Error(t, err, "a non-socket file at the listen path must not be replaced") + + content, err := os.ReadFile(path) + require.NoError(t, err) + assert.Equal(t, "precious", string(content)) +} + +func TestHostQueryServer_reload(t *testing.T) { + skipIfNoUnixSockets(t) + h, c := newTestHostQueryServer(t) + h.pki = newTestPKI(t, "self", []netip.Addr{netip.MustParseAddr("10.0.0.1")}) + dir := t.TempDir() + sock1 := filepath.Join(dir, "hq1.sock") + sock2 := filepath.Join(dir, "hq2.sock") + + // Initial reload only records config; Control.Start launches the runtime. + setHostQueryConfig(c, false, "unix://"+sock1, "") + require.NoError(t, h.reload(c, true)) + assert.False(t, h.enabled.Load()) + h.runMu.Lock() + assert.Nil(t, h.run) + h.runMu.Unlock() + + // Enabling via reload spawns the listener. + setHostQueryConfig(c, true, "unix://"+sock1, "") + require.NoError(t, h.reload(c, false)) + waitForServe(t, unixHTTPClient(sock1)) + + // Changing the listen path restarts on the new address. + setHostQueryConfig(c, true, "unix://"+sock2, "") + require.NoError(t, h.reload(c, false)) + waitForServe(t, unixHTTPClient(sock2)) + waitFor(t, func() bool { + _, err := os.Stat(sock1) + return os.IsNotExist(err) + }) + + // Reloading an unchanged config does not restart the runtime. + h.runMu.Lock() + rt := h.run + h.runMu.Unlock() + require.NoError(t, h.reload(c, false)) + h.runMu.Lock() + assert.Same(t, rt, h.run) + h.runMu.Unlock() + + // Disabling stops the listener. + setHostQueryConfig(c, false, "unix://"+sock2, "") + require.NoError(t, h.reload(c, false)) + assert.False(t, h.enabled.Load()) + waitFor(t, func() bool { + h.runMu.Lock() + defer h.runMu.Unlock() + return h.run == nil + }) +} diff --git a/main.go b/main.go index 7d7a0f72..ae6288b4 100644 --- a/main.go +++ b/main.go @@ -249,6 +249,11 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err) } + hostQuery, err := newHostQueryServerFromConfig(ctx, l, pki, hostMap, c) + if err != nil { + return nil, util.ContextualizeIfNeeded("Failed to configure the host query API", err) + } + if configTest { return nil, nil } @@ -266,6 +271,7 @@ func Main(c *config.C, configTest bool, buildVersion string, l *slog.Logger, dev sshStart: sshStart, statsStart: stats.Start, dnsStart: ds.Start, + hostQueryStart: hostQuery.Start, lighthouseStart: lightHouse.StartUpdateWorker, connectionManagerStart: connManager.Start, }, nil