SSH handshake in goroutine and defer close (#1640)
Some checks failed
gofmt / Run gofmt (push) Failing after 2s
smoke-extra / Run extra smoke tests (push) Failing after 3s
smoke / Run multi node smoke test (push) Failing after 2s
Build and test / Build all and test on ubuntu-linux (push) Failing after 3s
Build and test / Build and test on linux with boringcrypto (push) Failing after 2s
Build and test / Build and test on linux with pkcs11 (push) Failing after 2s
Build and test / Build and test on macos-latest (push) Has been cancelled
Build and test / Build and test on windows-latest (push) Has been cancelled

* SSH handshake in goroutine and defer close
This commit is contained in:
brad-defined
2026-04-23 14:53:52 -04:00
committed by GitHub
parent db9218b0be
commit db85d61c23
2 changed files with 47 additions and 52 deletions

View File

@@ -2,10 +2,10 @@ package sshd
import ( import (
"bytes" "bytes"
"context"
"errors" "errors"
"fmt" "fmt"
"net" "net"
"sync"
"github.com/armon/go-radix" "github.com/armon/go-radix"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -27,20 +27,21 @@ type SSHServer struct {
commands *radix.Tree commands *radix.Tree
listener net.Listener listener net.Listener
// Locks the conns/counter to avoid concurrent map access // Call the cancel() function to stop all active sessions
connsLock sync.Mutex ctx context.Context
conns map[int]*session cancel func()
counter int
} }
// NewSSHServer creates a new ssh server rigged with default commands and prepares to listen // NewSSHServer creates a new ssh server rigged with default commands and prepares to listen
func NewSSHServer(l *logrus.Entry) (*SSHServer, error) { func NewSSHServer(l *logrus.Entry) (*SSHServer, error) {
ctx, cancel := context.WithCancel(context.Background())
s := &SSHServer{ s := &SSHServer{
trustedKeys: make(map[string]map[string]bool), trustedKeys: make(map[string]map[string]bool),
l: l, l: l,
commands: radix.New(), commands: radix.New(),
conns: make(map[int]*session), ctx: ctx,
cancel: cancel,
} }
cc := ssh.CertChecker{ cc := ssh.CertChecker{
@@ -175,7 +176,16 @@ func (s *SSHServer) run() {
} }
return return
} }
go func(c net.Conn) {
// NewServerConn may block while waiting for the client to complete the handshake.
// Ensure that a bad client doesn't hurt us by checking for the parent context
// cancellation before calling NewServerConn, and forcing the socket to close when
// the context is cancelled.
sessionContext, sessionCancel := context.WithCancel(s.ctx)
go func() {
<-sessionContext.Done()
c.Close()
}()
conn, chans, reqs, err := ssh.NewServerConn(c, s.config) conn, chans, reqs, err := ssh.NewServerConn(c, s.config)
fp := "" fp := ""
if conn != nil { if conn != nil {
@@ -192,27 +202,18 @@ func (s *SSHServer) run() {
l = l.WithField("sshFingerprint", fp) l = l.WithField("sshFingerprint", fp)
} }
l.Warn("failed to handshake") l.Warn("failed to handshake")
continue sessionCancel()
return
} }
l := s.l.WithField("sshUser", conn.User()) l := s.l.WithField("sshUser", conn.User())
l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in") l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in")
session := NewSession(s.commands, conn, chans, l.WithField("subsystem", "sshd.session")) NewSession(s.commands, conn, chans, sessionCancel, l.WithField("subsystem", "sshd.session"))
s.connsLock.Lock()
s.counter++
counter := s.counter
s.conns[counter] = session
s.connsLock.Unlock()
go ssh.DiscardRequests(reqs) go ssh.DiscardRequests(reqs)
go func() {
<-session.exitChan }(c)
s.l.WithField("id", counter).Debug("closing conn")
s.connsLock.Lock()
delete(s.conns, counter)
s.connsLock.Unlock()
}()
} }
} }
@@ -226,9 +227,5 @@ func (s *SSHServer) Stop() {
} }
func (s *SSHServer) closeSessions() { func (s *SSHServer) closeSessions() {
s.connsLock.Lock() s.cancel()
for _, c := range s.conns {
c.Close()
}
s.connsLock.Unlock()
} }

View File

@@ -17,15 +17,15 @@ type session struct {
c *ssh.ServerConn c *ssh.ServerConn
term *term.Terminal term *term.Terminal
commands *radix.Tree commands *radix.Tree
exitChan chan bool cancel func()
} }
func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, l *logrus.Entry) *session { func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, cancel func(), l *logrus.Entry) *session {
s := &session{ s := &session{
commands: radix.NewFromMap(commands.ToMap()), commands: radix.NewFromMap(commands.ToMap()),
l: l, l: l,
c: conn, c: conn,
exitChan: make(chan bool), cancel: cancel,
} }
s.commands.Insert("logout", &Command{ s.commands.Insert("logout", &Command{
@@ -42,6 +42,7 @@ func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.New
} }
func (s *session) handleChannels(chans <-chan ssh.NewChannel) { func (s *session) handleChannels(chans <-chan ssh.NewChannel) {
defer s.Close()
for newChannel := range chans { for newChannel := range chans {
if newChannel.ChannelType() != "session" { if newChannel.ChannelType() != "session" {
s.l.WithField("sshChannelType", newChannel.ChannelType()).Error("unknown channel type") s.l.WithField("sshChannelType", newChannel.ChannelType()).Error("unknown channel type")
@@ -100,7 +101,6 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) {
if err != nil { if err != nil {
s.l.WithError(err).Info("Error handling ssh session requests") s.l.WithError(err).Info("Error handling ssh session requests")
s.Close()
return return
} }
} }
@@ -123,12 +123,11 @@ func (s *session) createTerm(channel ssh.Channel) *term.Terminal {
return "", 0, false return "", 0, false
} }
go s.handleInput(channel) go s.handleInput()
return term return term
} }
func (s *session) handleInput(channel ssh.Channel) { func (s *session) handleInput() {
defer s.Close()
w := &stringWriter{w: s.term} w := &stringWriter{w: s.term}
for { for {
line, err := s.term.ReadLine() line, err := s.term.ReadLine()
@@ -170,10 +169,9 @@ func (s *session) dispatchCommand(line string, w StringWriter) {
} }
_ = execCommand(c, args[1:], w) _ = execCommand(c, args[1:], w)
return
} }
func (s *session) Close() { func (s *session) Close() {
s.c.Close() s.c.Close()
s.exitChan <- true s.cancel()
} }