diff --git a/sshd/server.go b/sshd/server.go index a8b60ba7..4b5cc3e0 100644 --- a/sshd/server.go +++ b/sshd/server.go @@ -2,10 +2,10 @@ package sshd import ( "bytes" + "context" "errors" "fmt" "net" - "sync" "github.com/armon/go-radix" "github.com/sirupsen/logrus" @@ -27,20 +27,21 @@ type SSHServer struct { commands *radix.Tree listener net.Listener - // Locks the conns/counter to avoid concurrent map access - connsLock sync.Mutex - conns map[int]*session - counter int + // Call the cancel() function to stop all active sessions + ctx context.Context + cancel func() } // NewSSHServer creates a new ssh server rigged with default commands and prepares to listen func NewSSHServer(l *logrus.Entry) (*SSHServer, error) { + ctx, cancel := context.WithCancel(context.Background()) s := &SSHServer{ trustedKeys: make(map[string]map[string]bool), l: l, commands: radix.New(), - conns: make(map[int]*session), + ctx: ctx, + cancel: cancel, } cc := ssh.CertChecker{ @@ -175,44 +176,44 @@ func (s *SSHServer) run() { } return } - - conn, chans, reqs, err := ssh.NewServerConn(c, s.config) - fp := "" - if conn != nil { - fp = conn.Permissions.Extensions["fp"] - } - - if err != nil { - l := s.l.WithError(err).WithField("remoteAddress", c.RemoteAddr()) + go func(c net.Conn) { + // NewServerConn may block while waiting for the client to complete the handshake. + // Ensure that a bad client doesn't hurt us by checking for the parent context + // cancellation before calling NewServerConn, and forcing the socket to close when + // the context is cancelled. + sessionContext, sessionCancel := context.WithCancel(s.ctx) + go func() { + <-sessionContext.Done() + c.Close() + }() + conn, chans, reqs, err := ssh.NewServerConn(c, s.config) + fp := "" if conn != nil { - l = l.WithField("sshUser", conn.User()) - conn.Close() + fp = conn.Permissions.Extensions["fp"] } - if fp != "" { - l = l.WithField("sshFingerprint", fp) + + if err != nil { + l := s.l.WithError(err).WithField("remoteAddress", c.RemoteAddr()) + if conn != nil { + l = l.WithField("sshUser", conn.User()) + conn.Close() + } + if fp != "" { + l = l.WithField("sshFingerprint", fp) + } + l.Warn("failed to handshake") + sessionCancel() + return } - l.Warn("failed to handshake") - continue - } - l := s.l.WithField("sshUser", conn.User()) - l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in") + l := s.l.WithField("sshUser", conn.User()) + l.WithField("remoteAddress", c.RemoteAddr()).WithField("sshFingerprint", fp).Info("ssh user logged in") - session := NewSession(s.commands, conn, chans, l.WithField("subsystem", "sshd.session")) - s.connsLock.Lock() - s.counter++ - counter := s.counter - s.conns[counter] = session - s.connsLock.Unlock() + NewSession(s.commands, conn, chans, sessionCancel, l.WithField("subsystem", "sshd.session")) - go ssh.DiscardRequests(reqs) - go func() { - <-session.exitChan - s.l.WithField("id", counter).Debug("closing conn") - s.connsLock.Lock() - delete(s.conns, counter) - s.connsLock.Unlock() - }() + go ssh.DiscardRequests(reqs) + + }(c) } } @@ -226,9 +227,5 @@ func (s *SSHServer) Stop() { } func (s *SSHServer) closeSessions() { - s.connsLock.Lock() - for _, c := range s.conns { - c.Close() - } - s.connsLock.Unlock() + s.cancel() } diff --git a/sshd/session.go b/sshd/session.go index 87cc216f..39c81bd0 100644 --- a/sshd/session.go +++ b/sshd/session.go @@ -17,15 +17,15 @@ type session struct { c *ssh.ServerConn term *term.Terminal commands *radix.Tree - exitChan chan bool + cancel func() } -func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, l *logrus.Entry) *session { +func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.NewChannel, cancel func(), l *logrus.Entry) *session { s := &session{ commands: radix.NewFromMap(commands.ToMap()), l: l, c: conn, - exitChan: make(chan bool), + cancel: cancel, } s.commands.Insert("logout", &Command{ @@ -42,6 +42,7 @@ func NewSession(commands *radix.Tree, conn *ssh.ServerConn, chans <-chan ssh.New } func (s *session) handleChannels(chans <-chan ssh.NewChannel) { + defer s.Close() for newChannel := range chans { if newChannel.ChannelType() != "session" { s.l.WithField("sshChannelType", newChannel.ChannelType()).Error("unknown channel type") @@ -100,7 +101,6 @@ func (s *session) handleRequests(in <-chan *ssh.Request, channel ssh.Channel) { if err != nil { s.l.WithError(err).Info("Error handling ssh session requests") - s.Close() return } } @@ -123,12 +123,11 @@ func (s *session) createTerm(channel ssh.Channel) *term.Terminal { return "", 0, false } - go s.handleInput(channel) + go s.handleInput() return term } -func (s *session) handleInput(channel ssh.Channel) { - defer s.Close() +func (s *session) handleInput() { w := &stringWriter{w: s.term} for { line, err := s.term.ReadLine() @@ -170,10 +169,9 @@ func (s *session) dispatchCommand(line string, w StringWriter) { } _ = execCommand(c, args[1:], w) - return } func (s *session) Close() { s.c.Close() - s.exitChan <- true + s.cancel() }