diff --git a/conn.go b/conn.go index ebef884..c3fa824 100644 --- a/conn.go +++ b/conn.go @@ -9,13 +9,16 @@ import ( type serverConn struct { net.Conn - idleTimeout time.Duration - maxDeadline time.Time - closeCanceler context.CancelFunc + idleTimeout time.Duration + handshakeDeadline time.Time + maxDeadline time.Time + closeCanceler context.CancelFunc } func (c *serverConn) Write(p []byte) (n int, err error) { - c.updateDeadline() + if c.idleTimeout > 0 { + c.updateDeadline() + } n, err = c.Conn.Write(p) if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { c.closeCanceler() @@ -24,7 +27,9 @@ func (c *serverConn) Write(p []byte) (n int, err error) { } func (c *serverConn) Read(b []byte) (n int, err error) { - c.updateDeadline() + if c.idleTimeout > 0 { + c.updateDeadline() + } n, err = c.Conn.Read(b) if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { c.closeCanceler() @@ -41,15 +46,18 @@ func (c *serverConn) Close() (err error) { } func (c *serverConn) updateDeadline() { - switch { - case c.idleTimeout > 0: + deadline := c.maxDeadline + + if !c.handshakeDeadline.IsZero() && (deadline.IsZero() || c.handshakeDeadline.Before(deadline)) { + deadline = c.handshakeDeadline + } + + if c.idleTimeout > 0 { idleDeadline := time.Now().Add(c.idleTimeout) - if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() { - c.Conn.SetDeadline(idleDeadline) - return + if deadline.IsZero() || idleDeadline.Before(deadline) { + deadline = idleDeadline } - fallthrough - default: - c.Conn.SetDeadline(c.maxDeadline) } + + c.Conn.SetDeadline(deadline) } diff --git a/server.go b/server.go index f783ee5..7dbaa0f 100644 --- a/server.go +++ b/server.go @@ -52,8 +52,9 @@ type Server struct { ConnectionFailedCallback ConnectionFailedCallback // callback to report connection failures - IdleTimeout time.Duration // connection timeout when no activity, none if empty - MaxTimeout time.Duration // absolute connection timeout, none if empty + HandshakeTimeout time.Duration // connection timeout until successful handshake, none if empty + IdleTimeout time.Duration // connection timeout when no activity, none if empty + MaxTimeout time.Duration // absolute connection timeout, none if empty // ChannelHandlers allow overriding the built-in session handlers or provide // extensions to the protocol, such as tcpip forwarding. By default only the @@ -290,6 +291,10 @@ func (srv *Server) HandleConn(newConn net.Conn) { if srv.MaxTimeout > 0 { conn.maxDeadline = time.Now().Add(srv.MaxTimeout) } + if srv.HandshakeTimeout > 0 { + conn.handshakeDeadline = time.Now().Add(srv.HandshakeTimeout) + } + conn.updateDeadline() defer conn.Close() sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx)) if err != nil { @@ -298,7 +303,8 @@ func (srv *Server) HandleConn(newConn net.Conn) { } return } - + conn.handshakeDeadline = time.Time{} + conn.updateDeadline() srv.trackConn(sshConn, true) defer srv.trackConn(sshConn, false) diff --git a/server_test.go b/server_test.go index 8028a3a..63fe694 100644 --- a/server_test.go +++ b/server_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "io" + "net" "testing" "time" ) @@ -124,3 +125,36 @@ func TestServerClose(t *testing.T) { return } } + +func TestServerHandshakeTimeout(t *testing.T) { + l := newLocalListener() + + s := &Server{ + HandshakeTimeout: time.Millisecond, + } + go func() { + if err := s.Serve(l); err != nil { + t.Error(err) + } + }() + + conn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + ch := make(chan struct{}) + go func() { + defer close(ch) + io.Copy(io.Discard, conn) + }() + + select { + case <-ch: + return + case <-time.After(time.Second): + t.Fatal("client connection was not force-closed") + return + } +}