Fix sshd goroutine leak and other cleanup

This commit is contained in:
Nate Brown
2026-03-17 21:12:38 -05:00
parent 1aa1a0476f
commit 0c5f48d695
2 changed files with 26 additions and 17 deletions

View File

@@ -16,16 +16,13 @@ type SSHServer struct {
config *ssh.ServerConfig config *ssh.ServerConfig
l *logrus.Entry l *logrus.Entry
certChecker *ssh.CertChecker
// Map of user -> authorized keys // Map of user -> authorized keys
trustedKeys map[string]map[string]bool trustedKeys map[string]map[string]bool
trustedCAs []ssh.PublicKey trustedCAs []ssh.PublicKey
// List of available commands // List of available commands
helpCommand *Command commands *radix.Tree
commands *radix.Tree listener net.Listener
listener net.Listener
// Locks the conns/counter to avoid concurrent map access // Locks the conns/counter to avoid concurrent map access
connsLock sync.Mutex connsLock sync.Mutex
@@ -184,7 +181,11 @@ func (s *SSHServer) run() {
if err != nil { if err != nil {
l := s.l.WithError(err).WithField("remoteAddress", c.RemoteAddr()) l := s.l.WithError(err).WithField("remoteAddress", c.RemoteAddr())
if conn != nil { if conn == nil {
// conn is nil when the handshake failed before authentication
// close the raw TCP connection to avoid leaking the file descriptor.
c.Close()
} else {
l = l.WithField("sshUser", conn.User()) l = l.WithField("sshUser", conn.User())
conn.Close() conn.Close()
} }

View File

@@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"sort" "sort"
"strings" "strings"
"sync"
"github.com/anmitsu/go-shlex" "github.com/anmitsu/go-shlex"
"github.com/armon/go-radix" "github.com/armon/go-radix"
@@ -13,11 +14,12 @@ import (
) )
type session struct { type session struct {
l *logrus.Entry l *logrus.Entry
c *ssh.ServerConn c *ssh.ServerConn
term *term.Terminal term *term.Terminal
commands *radix.Tree commands *radix.Tree
exitChan chan bool exitChan chan struct{}
closeOnce sync.Once
} }
func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, l *logrus.Entry) *session { func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, l *logrus.Entry) *session {
@@ -25,7 +27,7 @@ func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.New
commands: radix.NewFromMap(commands.ToMap()), commands: radix.NewFromMap(commands.ToMap()),
l: l, l: l,
c: conn, c: conn,
exitChan: make(chan bool), exitChan: make(chan struct{}),
} }
s.commands.Insert("logout", &Command{ s.commands.Insert("logout", &Command{
@@ -37,7 +39,10 @@ func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.New
}, },
}) })
go s.handleChannels(chans) go func() {
s.handleChannels(chans)
s.Close()
}()
return s return s
} }
@@ -82,6 +87,7 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) {
cErr := ssh.Unmarshal(req.Payload, &payload) cErr := ssh.Unmarshal(req.Payload, &payload)
if cErr != nil { if cErr != nil {
req.Reply(false, nil) req.Reply(false, nil)
channel.Close()
return return
} }
@@ -123,11 +129,11 @@ func (s *session) createTerm(channel ssh.Channel) *term.Terminal {
return "", 0, false return "", 0, false
} }
go s.handleInput(channel) go s.handleInput()
return term return term
} }
func (s *session) handleInput(channel ssh.Channel) { func (s *session) handleInput() {
defer s.Close() defer s.Close()
w := &stringWriter{w: s.term} w := &stringWriter{w: s.term}
for { for {
@@ -174,6 +180,8 @@ func (s *session) dispatchCommand(line string, w StringWriter) {
} }
func (s *session) Close() { func (s *session) Close() {
s.c.Close() s.closeOnce.Do(func() {
s.exitChan <- true s.c.Close()
close(s.exitChan)
})
} }