From c82db210ef7a31940412044b4cad0e372ea23658 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Thu, 7 May 2026 11:30:26 -0500 Subject: [PATCH] Change windows unsafe routes to link routes, fix sshd reload bug (#1709) --- e2e/sshd_test.go | 125 +++++++++++++++++++++++++++++++++++++++++ overlay/tun_windows.go | 16 ++++-- sshd/server.go | 56 +++++++++--------- 3 files changed, 162 insertions(+), 35 deletions(-) create mode 100644 e2e/sshd_test.go diff --git a/e2e/sshd_test.go b/e2e/sshd_test.go new file mode 100644 index 00000000..e91f1bd0 --- /dev/null +++ b/e2e/sshd_test.go @@ -0,0 +1,125 @@ +//go:build e2e_testing +// +build e2e_testing + +package e2e + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/pem" + "net" + "strings" + "testing" + "time" + + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" +) + +func TestSSHDLifecycle(t *testing.T) { + // TestSSHDLifecycle exercises the in-process sshd through several config reloads and a Control.Stop. + ca, _, caKey, _ := cert_test.NewTestCaCert( + cert.Version1, cert.Curve_CURVE25519, + time.Now(), time.Now().Add(10*time.Minute), + nil, nil, []string{}, + ) + + hostKeyPEM := generateSSHHostKey(t) + clientSigner, clientAuthKey := generateSSHClientKey(t) + sshdAddr := allocLoopbackPort(t) + + overrides := m{ + "sshd": m{ + "enabled": true, + "listen": sshdAddr, + "host_key": hostKeyPEM, + "authorized_users": []m{{ + "user": "tester", + "keys": []string{clientAuthKey}, + }}, + }, + } + control, _, _, _ := newSimpleServer(cert.Version1, ca, caKey, "sshd-test", "10.222.0.1/24", overrides) + control.Start() + t.Cleanup(func() { control.Stop() }) + + // sshd binds in a goroutine after Start returns; wait for it. + require.Eventually(t, func() bool { return canDial(sshdAddr) }, 2*time.Second, 25*time.Millisecond, + "sshd never started listening") + + for i := 1; i <= 3; i++ { + out := sshExecReload(t, sshdAddr, clientSigner) + assert.Contains(t, out, "Reloading config", "reload cycle %d", i) + require.Eventually(t, func() bool { return canDial(sshdAddr) }, 2*time.Second, 25*time.Millisecond, + "sshd not listening after reload cycle %d", i) + } + + control.Stop() + require.Eventually(t, func() bool { return !canDial(sshdAddr) }, 2*time.Second, 25*time.Millisecond, + "sshd still listening after Control.Stop") +} + +func canDial(addr string) bool { + c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond) + if err != nil { + return false + } + _ = c.Close() + return true +} + +// allocLoopbackPort grabs an unused TCP port on 127.0.0.1, closes it, and returns the address. There +// is a small race between releasing the port and the sshd reclaiming it; in practice the OS keeps the +// port available long enough for the test to bind it. +func allocLoopbackPort(t *testing.T) string { + t.Helper() + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + addr := l.Addr().String() + require.NoError(t, l.Close()) + return addr +} + +func generateSSHHostKey(t *testing.T) string { + t.Helper() + _, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + block, err := ssh.MarshalPrivateKey(priv, "nebula-e2e-host") + require.NoError(t, err) + return string(pem.EncodeToMemory(block)) +} + +func generateSSHClientKey(t *testing.T) (ssh.Signer, string) { + t.Helper() + _, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + signer, err := ssh.NewSignerFromKey(priv) + require.NoError(t, err) + auth := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(signer.PublicKey()))) + return signer, auth +} + +func sshExecReload(t *testing.T, addr string, signer ssh.Signer) string { + t.Helper() + cfg := &ssh.ClientConfig{ + User: "tester", + Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 2 * time.Second, + } + client, err := ssh.Dial("tcp", addr, cfg) + require.NoError(t, err) + defer client.Close() + + sess, err := client.NewSession() + require.NoError(t, err) + defer sess.Close() + + // reload tears the channel down before sending exit-status, so Output returns an error on the + // channel close. The output buffer still has whatever the reload callback wrote before that. + out, _ := sess.Output("reload") + return string(out) +} diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 680dddb3..14c8d499 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -156,11 +156,8 @@ func (t *winTun) addRoutes(logErrors bool) error { continue } - // Add our unsafe route - // Windows does not support multipath routes natively, so we install only a single route. - // This is not a problem as traffic will always be sent to Nebula which handles the multipath routing internally. - // In effect this provides multipath routing support to windows supporting loadbalancing and redundancy. - err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric)) + // Add our unsafe route as an on-link route to the nebula tun device. + err := luid.AddRoute(r.Cidr, unspecifiedNextHop(r.Cidr), uint32(r.Metric)) if err != nil { retErr := util.NewContextualError("Failed to add route", map[string]any{"route": r}, err) if logErrors { @@ -206,7 +203,7 @@ func (t *winTun) removeRoutes(routes []Route) error { } // See comment on luid.AddRoute - err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr()) + err := luid.DeleteRoute(r.Cidr, unspecifiedNextHop(r.Cidr)) if err != nil { t.l.Error("Failed to remove route", "error", err, "route", r) } else { @@ -261,6 +258,13 @@ func (t *winTun) Close() error { return t.tun.Close() } +func unspecifiedNextHop(p netip.Prefix) netip.Addr { + if p.Addr().Is4() { + return netip.IPv4Unspecified() + } + return netip.IPv6Unspecified() +} + func generateGUIDByDeviceName(name string) (*windows.GUID, error) { // GUID is 128 bit hash := crypto.MD5.New() diff --git a/sshd/server.go b/sshd/server.go index ff954bf5..86c52961 100644 --- a/sshd/server.go +++ b/sshd/server.go @@ -27,23 +27,20 @@ type SSHServer struct { commands *radix.Tree listener net.Listener - // Call the cancel() function to stop all active sessions - ctx context.Context - cancel func() + // ctx parents per-Run contexts. Cancelling it (e.g. via Control.Stop) tears the server down even + // across reloads, since each Run derives a fresh child rather than reusing this one directly. + ctx context.Context } // NewSSHServer creates a new ssh server rigged with default commands and prepares to listen. // The ssh server's context is parented off the supplied ctx so cancelling it // (e.g. on Control.Stop) tears down active sessions and closes the listener. func NewSSHServer(ctx context.Context, l *slog.Logger) (*SSHServer, error) { - - ctx, cancel := context.WithCancel(ctx) s := &SSHServer{ trustedKeys: make(map[string]map[string]bool), l: l, commands: radix.New(), ctx: ctx, - cancel: cancel, } cc := ssh.CertChecker{ @@ -153,45 +150,51 @@ func (s *SSHServer) RegisterCommand(c *Command) { s.commands.Insert(c.Name, c) } -// Run begins listening and accepting connections +// Run begins listening and accepting connections. Each invocation derives a fresh per-Run context +// from the constructor-supplied ctx so a Stop+Run sequence (used by config reload) starts clean +// rather than carrying a permanently-cancelled context across runs. func (s *SSHServer) Run(addr string) error { if s.ctx.Err() != nil { return s.ctx.Err() } - var err error - s.listener, err = net.Listen("tcp", addr) + listener, err := net.Listen("tcp", addr) if err != nil { return err } + // s.listener is the public handle Stop uses to interrupt the active run; listener (the local) is what + // this run owns. They start equal but a fast reload may overwrite s.listener with the next run's + // listener before this run's watcher fires, so each run must close its own listener via the local + // reference. + s.listener = listener - s.l.Info("SSH server is listening", "sshListener", addr) + runCtx, cancel := context.WithCancel(s.ctx) + defer cancel() - // Per-invocation watcher: cancellation of the parent context (e.g. - // Control.Stop) closes the listener so Accept unblocks and run returns. - // Closing `done` on exit keeps the watcher from outliving this Run call. - done := make(chan struct{}) + // Close the listener when this run's context is cancelled. That can come from the parent + // (Control.Stop), from Run returning normally (defer cancel above), or transitively when a sibling + // run cancels through Stop closing the listener. net.Listener.Close is idempotent so a duplicate + // close from Stop is benign. go func() { - select { - case <-s.ctx.Done(): - s.Stop() - case <-done: + <-runCtx.Done() + if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + s.l.Warn("Failed to close the sshd listener", "error", err) } }() + s.l.Info("SSH server is listening", "sshListener", addr) + // Run loops until there is an error - s.run() - close(done) - s.closeSessions() + s.run(runCtx, listener) s.l.Info("SSH server stopped listening") // We don't return an error because run logs for us return nil } -func (s *SSHServer) run() { +func (s *SSHServer) run(ctx context.Context, listener net.Listener) { for { - c, err := s.listener.Accept() + c, err := listener.Accept() if err != nil { if !errors.Is(err, net.ErrClosed) { s.l.Warn("Error in listener, shutting down", "error", err) @@ -203,7 +206,7 @@ func (s *SSHServer) run() { // Ensure that a bad client doesn't hurt us by checking for the parent context // cancellation before calling NewServerConn, and forcing the socket to close when // the context is cancelled. - sessionContext, sessionCancel := context.WithCancel(s.ctx) + sessionContext, sessionCancel := context.WithCancel(ctx) go func() { <-sessionContext.Done() c.Close() @@ -246,14 +249,9 @@ func (s *SSHServer) run() { } func (s *SSHServer) Stop() { - // Close the listener, this will cause all session to terminate as well, see SSHServer.Run if s.listener != nil { if err := s.listener.Close(); err != nil { s.l.Warn("Failed to close the sshd listener", "error", err) } } } - -func (s *SSHServer) closeSessions() { - s.cancel() -}