Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handshake timeout #204

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ import (
type serverConn struct {
net.Conn

idleTimeout time.Duration
maxDeadline time.Time
closeCanceler context.CancelFunc
idleTimeout time.Duration
handshakeDeadline time.Time
maxDeadline time.Time
closeCanceler context.CancelFunc
}

func (c *serverConn) Write(p []byte) (n int, err error) {
c.updateDeadline()
if c.idleTimeout > 0 {
c.updateDeadline()
}
n, err = c.Conn.Write(p)
if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
c.closeCanceler()
Expand All @@ -24,7 +27,9 @@ func (c *serverConn) Write(p []byte) (n int, err error) {
}

func (c *serverConn) Read(b []byte) (n int, err error) {
c.updateDeadline()
if c.idleTimeout > 0 {
c.updateDeadline()
}
n, err = c.Conn.Read(b)
if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
c.closeCanceler()
Expand All @@ -41,15 +46,18 @@ func (c *serverConn) Close() (err error) {
}

func (c *serverConn) updateDeadline() {
switch {
case c.idleTimeout > 0:
deadline := c.maxDeadline

if !c.handshakeDeadline.IsZero() && (deadline.IsZero() || c.handshakeDeadline.Before(deadline)) {
deadline = c.handshakeDeadline
}

if c.idleTimeout > 0 {
idleDeadline := time.Now().Add(c.idleTimeout)
if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() {
c.Conn.SetDeadline(idleDeadline)
return
if deadline.IsZero() || idleDeadline.Before(deadline) {
deadline = idleDeadline
}
fallthrough
default:
c.Conn.SetDeadline(c.maxDeadline)
}

c.Conn.SetDeadline(deadline)
}
12 changes: 9 additions & 3 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ type Server struct {

ConnectionFailedCallback ConnectionFailedCallback // callback to report connection failures

IdleTimeout time.Duration // connection timeout when no activity, none if empty
MaxTimeout time.Duration // absolute connection timeout, none if empty
HandshakeTimeout time.Duration // connection timeout until successful handshake, none if empty
IdleTimeout time.Duration // connection timeout when no activity, none if empty
MaxTimeout time.Duration // absolute connection timeout, none if empty

// ChannelHandlers allow overriding the built-in session handlers or provide
// extensions to the protocol, such as tcpip forwarding. By default only the
Expand Down Expand Up @@ -290,6 +291,10 @@ func (srv *Server) HandleConn(newConn net.Conn) {
if srv.MaxTimeout > 0 {
conn.maxDeadline = time.Now().Add(srv.MaxTimeout)
}
if srv.HandshakeTimeout > 0 {
conn.handshakeDeadline = time.Now().Add(srv.HandshakeTimeout)
}
conn.updateDeadline()
defer conn.Close()
sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx))
if err != nil {
Expand All @@ -298,7 +303,8 @@ func (srv *Server) HandleConn(newConn net.Conn) {
}
return
}

conn.handshakeDeadline = time.Time{}
conn.updateDeadline()
srv.trackConn(sshConn, true)
defer srv.trackConn(sshConn, false)

Expand Down
34 changes: 34 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"io"
"net"
"testing"
"time"
)
Expand Down Expand Up @@ -124,3 +125,36 @@ func TestServerClose(t *testing.T) {
return
}
}

func TestServerHandshakeTimeout(t *testing.T) {
l := newLocalListener()

s := &Server{
HandshakeTimeout: time.Millisecond,
}
go func() {
if err := s.Serve(l); err != nil {
t.Error(err)
}
}()

conn, err := net.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatal(err)
}
defer conn.Close()

ch := make(chan struct{})
go func() {
defer close(ch)
io.Copy(io.Discard, conn)
}()

select {
case <-ch:
return
case <-time.After(time.Second):
t.Fatal("client connection was not force-closed")
return
}
}