diff --git a/_examples/ssh-sftpserver/sftp.go b/_examples/ssh-sftpserver/sftp.go index 120df2c..f51322d 100644 --- a/_examples/ssh-sftpserver/sftp.go +++ b/_examples/ssh-sftpserver/sftp.go @@ -3,7 +3,6 @@ package main import ( "fmt" "io" - "io/ioutil" "log" "github.com/gliderlabs/ssh" @@ -12,7 +11,7 @@ import ( // SftpHandler handler for SFTP subsystem func SftpHandler(sess ssh.Session) { - debugStream := ioutil.Discard + debugStream := io.Discard serverOptions := []sftp.ServerOption{ sftp.WithDebug(debugStream), } diff --git a/agent.go b/agent.go index d8dcb9a..99e84c1 100644 --- a/agent.go +++ b/agent.go @@ -2,8 +2,8 @@ package ssh import ( "io" - "io/ioutil" "net" + "os" "path" "sync" @@ -36,7 +36,7 @@ func AgentRequested(sess Session) bool { // NewAgentListener sets up a temporary Unix socket that can be communicated // to the session environment and used for forwarding connections. func NewAgentListener() (net.Listener, error) { - dir, err := ioutil.TempDir("", agentTempDir) + dir, err := os.MkdirTemp("", agentTempDir) if err != nil { return nil, err } diff --git a/example_test.go b/example_test.go index 972d3ef..36745ea 100644 --- a/example_test.go +++ b/example_test.go @@ -2,7 +2,7 @@ package ssh_test import ( "io" - "io/ioutil" + "os" "github.com/gliderlabs/ssh" ) @@ -28,7 +28,7 @@ func ExampleNoPty() { func ExamplePublicKeyAuth() { ssh.ListenAndServe(":2222", nil, ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool { - data, _ := ioutil.ReadFile("/path/to/allowed/key.pub") + data, _ := os.ReadFile("/path/to/allowed/key.pub") allowed, _, _, _, _ := ssh.ParseAuthorizedKey(data) return ssh.KeysEqual(key, allowed) }), diff --git a/options.go b/options.go index 303dcc3..29c8ef1 100644 --- a/options.go +++ b/options.go @@ -1,7 +1,7 @@ package ssh import ( - "io/ioutil" + "os" gossh "golang.org/x/crypto/ssh" ) @@ -26,7 +26,7 @@ func PublicKeyAuth(fn PublicKeyHandler) Option { // from a PEM file at filepath. func HostKeyFile(filepath string) Option { return func(srv *Server) error { - pemBytes, err := ioutil.ReadFile(filepath) + pemBytes, err := os.ReadFile(filepath) if err != nil { return err } diff --git a/server.go b/server.go index d017269..f783ee5 100644 --- a/server.go +++ b/server.go @@ -39,6 +39,7 @@ type Server struct { Version string // server version to be sent before the initial handshake Banner string // server banner + BannerHandler BannerHandler // server banner handler, overrides Banner KeyboardInteractiveHandler KeyboardInteractiveHandler // keyboard-interactive authentication handler PasswordHandler PasswordHandler // password authentication handler PublicKeyHandler PublicKeyHandler // public key authentication handler @@ -134,10 +135,16 @@ func (srv *Server) config(ctx Context) *gossh.ServerConfig { config.ServerVersion = "SSH-2.0-" + srv.Version } if srv.Banner != "" { - config.BannerCallback = func(conn gossh.ConnMetadata) string { + config.BannerCallback = func(_ gossh.ConnMetadata) string { return srv.Banner } } + if srv.BannerHandler != nil { + config.BannerCallback = func(conn gossh.ConnMetadata) string { + applyConnMetadata(ctx, conn) + return srv.BannerHandler(ctx) + } + } if srv.PasswordHandler != nil { config.PasswordCallback = func(conn gossh.ConnMetadata, password []byte) (*gossh.Permissions, error) { applyConnMetadata(ctx, conn) diff --git a/ssh.go b/ssh.go index fbeb150..775b454 100644 --- a/ssh.go +++ b/ssh.go @@ -35,6 +35,9 @@ type Option func(*Server) error // Handler is a callback for handling established SSH sessions. type Handler func(Session) +// BannerHandler is a callback for displaying the server banner. +type BannerHandler func(ctx Context) string + // PublicKeyHandler is a callback for performing public key authentication. type PublicKeyHandler func(ctx Context, key PublicKey) bool @@ -115,8 +118,7 @@ func Handle(handler Handler) { // KeysEqual is constant time compare of the keys to avoid timing attacks. func KeysEqual(ak, bk PublicKey) bool { - - //avoid panic if one of the keys is nil, return false instead + // avoid panic if one of the keys is nil, return false instead if ak == nil || bk == nil { return false } diff --git a/tcpip_test.go b/tcpip_test.go index 3c27eb1..4ddf40e 100644 --- a/tcpip_test.go +++ b/tcpip_test.go @@ -2,7 +2,7 @@ package ssh import ( "bytes" - "io/ioutil" + "io" "net" "strconv" "strings" @@ -58,7 +58,7 @@ func TestLocalPortForwardingWorks(t *testing.T) { if err != nil { t.Fatalf("Error connecting to %v: %v", l.Addr().String(), err) } - result, err := ioutil.ReadAll(conn) + result, err := io.ReadAll(conn) if err != nil { t.Fatal(err) }