diff --git a/go.mod b/go.mod index 3cde2e1c40..5e2d77c58f 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/google/go-containerregistry v0.15.2 github.com/google/go-github/v49 v49.1.0 github.com/google/uuid v1.3.1 + github.com/hectane/go-acl v0.0.0-20190604041725-da78bae5fc95 github.com/heroku/color v0.0.6 github.com/hinshun/vt10x v0.0.0-20220228203356-1ab2cad5fd82 github.com/manifestival/client-go-client v0.5.0 diff --git a/go.sum b/go.sum index 62ff7c9752..e22ea37c93 100644 --- a/go.sum +++ b/go.sum @@ -569,6 +569,8 @@ github.com/hashicorp/golang-lru v1.0.2 h1:dV3g9Z/unq5DpblPpw+Oqcv4dU/1omnb4Ok8iP github.com/hashicorp/golang-lru v1.0.2/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/hectane/go-acl v0.0.0-20190604041725-da78bae5fc95 h1:S4qyfL2sEm5Budr4KVMyEniCy+PbS55651I/a+Kn/NQ= +github.com/hectane/go-acl v0.0.0-20190604041725-da78bae5fc95/go.mod h1:QiyDdbZLaJ/mZP4Zwc9g2QsfaEA4o7XvvgZegSci5/E= github.com/heroku/color v0.0.6 h1:UTFFMrmMLFcL3OweqP1lAdp8i1y/9oHqkeHjQ/b/Ny0= github.com/heroku/color v0.0.6/go.mod h1:ZBvOcx7cTF2QKOv4LbmoBtNl5uB17qWxGuzZrsi1wLU= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= @@ -1155,6 +1157,7 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190529164535-6a60838ec259/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190616124812-15dcb6c0061f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/pkg/ssh/server_test.go b/pkg/ssh/server_test.go new file mode 100644 index 0000000000..d8202a6570 --- /dev/null +++ b/pkg/ssh/server_test.go @@ -0,0 +1,597 @@ +package ssh_test + +import ( + "bytes" + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/md5" + "crypto/rsa" + "encoding/binary" + "errors" + "fmt" + "io" + "math/rand" + "net" + "net/http" + "os" + "strconv" + "strings" + "sync" + "testing" + "time" + + "golang.org/x/crypto/ssh" +) + +type SSHServer struct { + lock sync.Locker + dockerServer http.Server + dockerListener listener + dockerHost string + hostIPv4 string + hostIPv6 string + portIPv4 int + portIPv6 int + hasDialStdio bool + isWin bool + serverKeys []any + authorizedKeys []any +} + +func (s *SSHServer) SetIsWindows(v bool) { + s.lock.Lock() + defer s.lock.Unlock() + s.isWin = v +} + +func (s *SSHServer) IsWindows() bool { + s.lock.Lock() + defer s.lock.Unlock() + return s.isWin +} + +func (s *SSHServer) SetDockerHostEnvVar(host string) { + s.lock.Lock() + defer s.lock.Unlock() + s.dockerHost = host +} + +func (s *SSHServer) GetDockerHostEnvVar() string { + s.lock.Lock() + defer s.lock.Unlock() + return s.dockerHost +} + +func (s *SSHServer) HasDialStdio() bool { + s.lock.Lock() + defer s.lock.Unlock() + return s.hasDialStdio +} + +func (s *SSHServer) SetHasDialStdio(v bool) { + s.lock.Lock() + defer s.lock.Unlock() + s.hasDialStdio = v +} + +const dockerUnixSocket = "/home/testuser/test.sock" +const dockerTCPSocket = "localhost:1234" + +// We need to set up SSH server against which we will run the tests. +// This will return SSHServer structure representing the state of the testing server. +func prepareSSHServer(t *testing.T, authorizedKeys ...any) (sshServer *SSHServer, err error) { + ctx, cancel := context.WithCancel(context.Background()) + defer func() { + if err != nil { + cancel() + } + }() + + httpServerErrChan := make(chan error) + pollingLoopErr := make(chan error) + pollingLoopIPv6Err := make(chan error) + + handlePing := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + writer.Header().Add("Content-Type", "text/plain") + writer.WriteHeader(200) + _, _ = writer.Write([]byte("OK")) + }) + + sshServer = &SSHServer{ + dockerServer: http.Server{ + Handler: handlePing, + }, + dockerListener: listener{conns: make(chan net.Conn), closed: make(chan struct{})}, + lock: &sync.Mutex{}, + authorizedKeys: authorizedKeys, + } + + rsaKey, err := rsa.GenerateKey(rand.New(rand.NewSource(time.Now().UnixNano())), 2048) + if err != nil { + t.Fatal(err) + } + ecdsaKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.New(rand.NewSource(time.Now().UnixNano()))) + if err != nil { + t.Fatal(err) + } + sshServer.serverKeys = []any{rsaKey, ecdsaKey} + + sshTCPListener, err := net.Listen("tcp4", "localhost:0") + if err != nil { + return + } + + hasIPv6 := true + sshTCP6Listener, err := net.Listen("tcp6", "localhost:0") + if err != nil { + hasIPv6 = false + t.Log(err) + } + + host, p, err := net.SplitHostPort(sshTCPListener.Addr().String()) + if err != nil { + return + } + port, err := strconv.ParseInt(p, 10, 32) + if err != nil { + return + } + sshServer.hostIPv4 = host + sshServer.portIPv4 = int(port) + + if hasIPv6 { + host, p, err = net.SplitHostPort(sshTCP6Listener.Addr().String()) + if err != nil { + return + } + port, err = strconv.ParseInt(p, 10, 32) + if err != nil { + return + } + sshServer.hostIPv6 = host + sshServer.portIPv6 = int(port) + } + + t.Logf("Listening on %s", sshTCPListener.Addr()) + if hasIPv6 { + t.Logf("Listening on %s", sshTCP6Listener.Addr()) + } + + go func() { + httpServerErrChan <- sshServer.dockerServer.Serve(&sshServer.dockerListener) + }() + + stopSSH := func() { + var err error + cancel() + + stopCtx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + err = sshServer.dockerServer.Shutdown(stopCtx) + if err != nil { + t.Error(err) + } + + err = <-httpServerErrChan + if err != nil && !strings.Contains(err.Error(), "Server closed") { + t.Error(err) + } + + sshTCPListener.Close() + err = <-pollingLoopErr + if err != nil && !errors.Is(err, net.ErrClosed) { + t.Error(err) + } + + if hasIPv6 { + sshTCP6Listener.Close() + err = <-pollingLoopIPv6Err + if err != nil && !errors.Is(err, net.ErrClosed) { + t.Error(err) + } + } + } + t.Cleanup(stopSSH) + + connChan := make(chan net.Conn) + + go func() { + for { + tcpConn, err := sshTCPListener.Accept() + if err != nil { + pollingLoopErr <- err + return + } + connChan <- tcpConn + } + }() + + if hasIPv6 { + go func() { + for { + tcpConn, err := sshTCP6Listener.Accept() + if err != nil { + pollingLoopIPv6Err <- err + return + } + connChan <- tcpConn + } + }() + } + + go func() { + for { + conn := <-connChan + go func(conn net.Conn) { + err := sshServer.handleConnection(ctx, conn) + if err != nil { + fmt.Fprintf(os.Stderr, "err: %v\n", err) + } + }(conn) + } + }() + + return sshServer, err +} + +func (s *SSHServer) setupServerAuth() (conf *ssh.ServerConfig, err error) { + passwd := map[string]string{ + "testuser": "idkfa", + "root": "iddqd", + } + + authorizedKeys := make(map[[16]byte][]byte, len(s.authorizedKeys)) + for _, key := range s.authorizedKeys { + var pk ssh.PublicKey + pk, err = ssh.NewPublicKey(key) + if err != nil { + return + } + bs := pk.Marshal() + authorizedKeys[md5.Sum(bs)] = bs + } + + conf = &ssh.ServerConfig{ + PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { + if p, ok := passwd[conn.User()]; ok && p == string(password) { + return nil, nil + } + return nil, fmt.Errorf("incorrect password %q for user %q", string(password), conn.User()) + }, + PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + keyBytes := key.Marshal() + if b, ok := authorizedKeys[md5.Sum(keyBytes)]; ok && bytes.Equal(b, keyBytes) { + return &ssh.Permissions{}, nil + } + return nil, fmt.Errorf("untrusted public key: %q", string(keyBytes)) + }, + } + + for _, k := range s.serverKeys { + signer, e := ssh.NewSignerFromKey(k) + if e != nil { + return nil, e + } + conf.AddHostKey(signer) + } + + return conf, nil +} + +func (s *SSHServer) handleConnection(ctx context.Context, conn net.Conn) error { + config, err := s.setupServerAuth() + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "cannot load auth: %v\n", err) + } + sshConn, newChannels, reqs, err := ssh.NewServerConn(conn, config) + if err != nil { + return err + } + + go func() { + <-ctx.Done() + err = sshConn.Close() + if err != nil && !errors.Is(err, net.ErrClosed) { + fmt.Fprintf(os.Stderr, "err: %v\n", err) + } + }() + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + ssh.DiscardRequests(reqs) + }() + + for newChannel := range newChannels { + wg.Add(1) + go func(newChannel ssh.NewChannel) { + defer wg.Done() + s.handleChannel(newChannel) + }(newChannel) + } + + wg.Wait() + + return nil +} + +func (s *SSHServer) handleChannel(newChannel ssh.NewChannel) { + var err error + switch newChannel.ChannelType() { + case "session": + s.handleSession(newChannel) + case "direct-streamlocal@openssh.com", "direct-tcpip": + s.handleTunnel(newChannel) + default: + err = newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("type of channel %q is not supported", newChannel.ChannelType())) + if err != nil { + fmt.Fprintf(os.Stderr, "err: %v\n", err) + } + } +} + +func (s *SSHServer) handleSession(newChannel ssh.NewChannel) { + ch, reqs, err := newChannel.Accept() + if err != nil { + fmt.Fprintf(os.Stderr, "err: %v\n", err) + return + } + + defer ch.Close() + for req := range reqs { + if req.Type == "exec" { + s.handleExec(ch, req) + break + } + } +} + +func (s *SSHServer) handleExec(ch ssh.Channel, req *ssh.Request) { + var err error + err = req.Reply(true, nil) + if err != nil { + fmt.Fprintf(os.Stderr, "err: %v\n", err) + return + } + execData := struct { + Command string + }{} + err = ssh.Unmarshal(req.Payload, &execData) + if err != nil { + fmt.Fprintf(os.Stderr, "err: %v\n", err) + return + } + + sendExitCode := func(ret uint32) { + msg := []byte{0, 0, 0, 0} + binary.BigEndian.PutUint32(msg, ret) + _, err = ch.SendRequest("exit-status", false, msg) + if err != nil && !errors.Is(err, io.EOF) { + fmt.Fprintf(os.Stderr, "err: %v\n", err) + } + } + + var ret uint32 + switch { + case execData.Command == "set": + ret = 0 + dh := s.GetDockerHostEnvVar() + if dh != "" { + _, _ = fmt.Fprintf(ch, "DOCKER_HOST=%s\n", dh) + } + case execData.Command == "systeminfo" && s.IsWindows(): + _, _ = fmt.Fprintln(ch, "something Windows something") + ret = 0 + case execData.Command == "docker system dial-stdio --help" && s.HasDialStdio(): + _, _ = fmt.Fprintln(ch, "\nUsage: docker system dial-stdio\n\nProxy the stdio stream to the daemon connection. Should not be invoked manually.") + ret = 0 + case execData.Command == "docker system dial-stdio" && s.HasDialStdio(): + pr, pw, conn := newPipeConn() + + select { + case s.dockerListener.conns <- conn: + case <-s.dockerListener.closed: + err = ch.Close() + if err != nil { + fmt.Fprintf(os.Stderr, "err: %v\n", err) + } + } + + cpDone := make(chan struct{}) + go func() { + var err error + _, err = io.Copy(pw, ch) + if err != nil { + fmt.Fprintf(os.Stderr, "err: %v\n", err) + } + err = pw.Close() + if err != nil { + fmt.Fprintf(os.Stderr, "err: %v\n", err) + } + cpDone <- struct{}{} + }() + + _, err = io.Copy(ch, pr) + if err != nil { + fmt.Fprintf(os.Stderr, "err: %v\n", err) + } + err = pr.Close() + if err != nil { + fmt.Fprintf(os.Stderr, "err: %v\n", err) + } + + <-cpDone + + <-conn.closed + if err != nil { + fmt.Fprintf(os.Stderr, "err: %v\n", err) + } + + ret = 0 + default: + _, _ = fmt.Fprintf(ch.Stderr(), "unknown command: %q\n", execData.Command) + ret = 127 + } + sendExitCode(ret) +} + +func newPipeConn() (*io.PipeReader, *io.PipeWriter, *rwcConn) { + pr0, pw0 := io.Pipe() + pr1, pw1 := io.Pipe() + rwc := pipeReaderWriterCloser{r: pr0, w: pw1} + return pr1, pw0, newRWCConn(rwc) +} + +type pipeReaderWriterCloser struct { + r *io.PipeReader + w *io.PipeWriter +} + +func (d pipeReaderWriterCloser) Read(p []byte) (n int, err error) { + return d.r.Read(p) +} + +func (d pipeReaderWriterCloser) Write(p []byte) (n int, err error) { + return d.w.Write(p) +} + +func (d pipeReaderWriterCloser) Close() error { + err := d.r.Close() + if err != nil { + return err + } + return d.w.Close() +} + +func (s *SSHServer) handleTunnel(newChannel ssh.NewChannel) { + var err error + + switch newChannel.ChannelType() { + case "direct-streamlocal@openssh.com": + bs := newChannel.ExtraData() + unixExtraData := struct { + SocketPath string + Reserved0 string + Reserved1 uint32 + }{} + err = ssh.Unmarshal(bs, &unixExtraData) + if err != nil { + fmt.Fprintf(os.Stderr, "err: %v\n", err) + return + } + if unixExtraData.SocketPath != dockerUnixSocket { + err = newChannel.Reject(ssh.ConnectionFailed, fmt.Sprintf("bad socket: %q", unixExtraData.SocketPath)) + if err != nil { + fmt.Fprintf(os.Stderr, "err: %v\n", err) + } + return + } + case "direct-tcpip": + bs := newChannel.ExtraData() + tcpExtraData := struct { //nolint:maligned + HostLocal string + PortLocal uint32 + HostRemote string + PortRemote uint32 + }{} + err = ssh.Unmarshal(bs, &tcpExtraData) + if err != nil { + fmt.Fprintf(os.Stderr, "err: %v\n", err) + return + } + + hostPort := fmt.Sprintf("%s:%d", tcpExtraData.HostLocal, tcpExtraData.PortLocal) + if hostPort != dockerTCPSocket { + err = newChannel.Reject(ssh.ConnectionFailed, fmt.Sprintf("bad socket: '%s:%d'", tcpExtraData.HostLocal, tcpExtraData.PortLocal)) + if err != nil { + fmt.Fprintf(os.Stderr, "err: %v\n", err) + } + return + } + } + + ch, _, err := newChannel.Accept() + if err != nil { + fmt.Fprintf(os.Stderr, "err: %v\n", err) + return + } + conn := newRWCConn(ch) + select { + case s.dockerListener.conns <- conn: + case <-s.dockerListener.closed: + err = ch.Close() + if err != nil { + fmt.Fprintf(os.Stderr, "err: %v\n", err) + } + return + } + <-conn.closed +} + +type listener struct { + conns chan net.Conn + closed chan struct{} + o sync.Once +} + +func (l *listener) Accept() (net.Conn, error) { + select { + case <-l.closed: + return nil, net.ErrClosed + case conn := <-l.conns: + return conn, nil + } +} + +func (l *listener) Close() error { + l.o.Do(func() { + close(l.closed) + }) + return nil +} + +func (l *listener) Addr() net.Addr { + return &net.UnixAddr{Name: dockerUnixSocket, Net: "unix"} +} + +func newRWCConn(rwc io.ReadWriteCloser) *rwcConn { + return &rwcConn{rwc: rwc, closed: make(chan struct{})} +} + +type rwcConn struct { + rwc io.ReadWriteCloser + closed chan struct{} + o sync.Once +} + +func (c *rwcConn) Read(b []byte) (n int, err error) { + return c.rwc.Read(b) +} + +func (c *rwcConn) Write(b []byte) (n int, err error) { + return c.rwc.Write(b) +} + +func (c *rwcConn) Close() error { + c.o.Do(func() { + close(c.closed) + }) + return c.rwc.Close() +} + +func (c *rwcConn) LocalAddr() net.Addr { + return &net.UnixAddr{Name: dockerUnixSocket, Net: "unix"} +} + +func (c *rwcConn) RemoteAddr() net.Addr { + return &net.UnixAddr{Name: "@", Net: "unix"} +} + +func (c *rwcConn) SetDeadline(t time.Time) error { return nil } + +func (c *rwcConn) SetReadDeadline(t time.Time) error { return nil } + +func (c *rwcConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/pkg/ssh/ssh_dialer_test.go b/pkg/ssh/ssh_dialer_test.go new file mode 100644 index 0000000000..1369e89ce3 --- /dev/null +++ b/pkg/ssh/ssh_dialer_test.go @@ -0,0 +1,1026 @@ +package ssh_test + +import ( + "bytes" + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "io" + "math/rand" + "net" + "net/http" + "net/url" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "sync" + "testing" + "text/template" + "time" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + + th "github.com/buildpacks/pack/testhelpers" + "github.com/docker/docker/pkg/homedir" + "github.com/pkg/errors" + + funcssh "knative.dev/func/pkg/ssh" +) + +type args struct { + connStr string + credentialConfig funcssh.Config +} +type testParams struct { + name string + args args + setUpEnv setUpEnvFn + skipOnWin bool + skipOnRoot bool + CreateError string + DialError string +} + +func TestCreateDialer(t *testing.T) { + + clientPrivKeyRSA, clientPrivKeyECDSA := generateClientKeys(t) + + withoutSSHAgent(t) + withCleanHome(t) + + connConfig, err := prepareSSHServer(t, &clientPrivKeyRSA.PublicKey, &clientPrivKeyECDSA.PublicKey) + th.AssertNil(t, err) + + time.Sleep(time.Second * 1) + + tests := []testParams{ + { + name: "read password from input", + args: args{ + connStr: fmt.Sprintf("ssh://testuser@%s:%d/home/testuser/test.sock", + connConfig.hostIPv4, + connConfig.portIPv4, + ), + credentialConfig: funcssh.Config{PasswordCallback: func() (string, error) { + return "idkfa", nil + }}, + }, + setUpEnv: all(withoutSSHAgent, withCleanHome, withKnowHosts(connConfig)), + }, + { + name: "password in url", + args: args{connStr: fmt.Sprintf("ssh://testuser:idkfa@%s:%d/home/testuser/test.sock", + connConfig.hostIPv4, + connConfig.portIPv4, + )}, + setUpEnv: all(withoutSSHAgent, withCleanHome, withKnowHosts(connConfig)), + }, + { + name: "server key is not in known_hosts (the file doesn't exists)", + args: args{connStr: fmt.Sprintf("ssh://testuser:idkfa@%s:%d/home/testuser/test.sock", + connConfig.hostIPv4, + connConfig.portIPv4, + )}, + setUpEnv: all(withoutSSHAgent, withCleanHome), + CreateError: funcssh.ErrUnknownServerKeyMsg, + }, + { + name: "server key is not in known_hosts (the file exists)", + args: args{connStr: fmt.Sprintf("ssh://testuser:idkfa@%s:%d/home/testuser/test.sock", + connConfig.hostIPv4, + connConfig.portIPv4, + )}, + setUpEnv: all(withoutSSHAgent, withCleanHome, withEmptyKnownHosts), + CreateError: funcssh.ErrUnknownServerKeyMsg, + }, + { + name: "server key is not in known_hosts (the filed doesn't exists) - user force trust", + args: args{ + connStr: fmt.Sprintf("ssh://testuser:idkfa@%s:%d/home/testuser/test.sock", + connConfig.hostIPv4, + connConfig.portIPv4, + ), + credentialConfig: funcssh.Config{HostKeyCallback: func(hostPort string, pubKey ssh.PublicKey) error { + return nil + }}, + }, + setUpEnv: all(withoutSSHAgent, withCleanHome), + }, + { + name: "server key is not in known_hosts (the file exists) - user force trust", + args: args{ + connStr: fmt.Sprintf("ssh://testuser:idkfa@%s:%d/home/testuser/test.sock", + connConfig.hostIPv4, + connConfig.portIPv4, + ), + credentialConfig: funcssh.Config{HostKeyCallback: func(hostPort string, pubKey ssh.PublicKey) error { + return nil + }}, + }, + setUpEnv: all(withoutSSHAgent, withCleanHome, withEmptyKnownHosts), + }, + { + name: "server key does not match the respective key in known_host", + args: args{connStr: fmt.Sprintf("ssh://testuser:idkfa@%s:%d/home/testuser/test.sock", + connConfig.hostIPv4, + connConfig.portIPv4, + )}, + setUpEnv: all(withoutSSHAgent, withCleanHome, withBadKnownHosts(connConfig)), + CreateError: funcssh.ErrBadServerKeyMsg, + }, + { + name: "key from identity parameter", + args: args{ + connStr: fmt.Sprintf("ssh://testuser@%s:%d/home/testuser/test.sock", + connConfig.hostIPv4, + connConfig.portIPv4, + ), + credentialConfig: funcssh.Config{Identity: tempKey(t, clientPrivKeyECDSA, "")}, + }, + setUpEnv: all(withoutSSHAgent, withCleanHome, withKnowHosts(connConfig)), + }, + { + name: "key at standard location with need to read passphrase", + args: args{ + connStr: fmt.Sprintf("ssh://testuser@%s:%d/home/testuser/test.sock", + connConfig.hostIPv4, + connConfig.portIPv4, + ), + credentialConfig: funcssh.Config{PassPhraseCallback: func() (string, error) { + return "nbusr123", nil + }}, + }, + setUpEnv: all(withoutSSHAgent, withCleanHome, withKey(clientPrivKeyRSA, "id_rsa", "nbusr123"), withKnowHosts(connConfig)), + }, + { + name: "key at standard location with explicitly set passphrase", + args: args{ + connStr: fmt.Sprintf("ssh://testuser@%s:%d/home/testuser/test.sock", + connConfig.hostIPv4, + connConfig.portIPv4, + ), + credentialConfig: funcssh.Config{PassPhrase: "nbusr123"}, + }, + setUpEnv: all(withoutSSHAgent, withCleanHome, withKey(clientPrivKeyECDSA, "id_ecdsa", "nbusr123"), withKnowHosts(connConfig)), + }, + { + name: "key at standard location with no passphrase", + args: args{connStr: fmt.Sprintf("ssh://testuser@%s:%d/home/testuser/test.sock", + connConfig.hostIPv4, + connConfig.portIPv4, + )}, + setUpEnv: all(withoutSSHAgent, withCleanHome, withKey(clientPrivKeyECDSA, "id_ecdsa", ""), withKnowHosts(connConfig)), + }, + { + name: "key from ssh-agent", + args: args{connStr: fmt.Sprintf("ssh://testuser@%s:%d/home/testuser/test.sock", + connConfig.hostIPv4, + connConfig.portIPv4, + )}, + setUpEnv: all(withGoodSSHAgent(clientPrivKeyRSA, clientPrivKeyECDSA), withCleanHome, withKnowHosts(connConfig)), + }, + { + name: "password in url with IPv6", + args: args{connStr: fmt.Sprintf("ssh://testuser:idkfa@[%s]:%d/home/testuser/test.sock", + connConfig.hostIPv6, + connConfig.portIPv6, + )}, + setUpEnv: all(withoutSSHAgent, withCleanHome, withKnowHosts(connConfig)), + }, + { + name: "broken known host", + args: args{connStr: fmt.Sprintf("ssh://testuser:idkfa@%s:%d/home/testuser/test.sock", + connConfig.hostIPv4, + connConfig.portIPv4, + )}, + setUpEnv: all(withoutSSHAgent, withCleanHome, withBrokenKnownHosts), + CreateError: "invalid entry in known_hosts", + }, + { + name: "inaccessible known host", + args: args{connStr: fmt.Sprintf("ssh://testuser:idkfa@%s:%d/home/testuser/test.sock", + connConfig.hostIPv4, + connConfig.portIPv4, + )}, + setUpEnv: all(withoutSSHAgent, withCleanHome, withInaccessibleKnownHosts), + skipOnWin: true, + skipOnRoot: true, + CreateError: "permission denied", + }, + { + name: "failing pass phrase cbk", + args: args{ + connStr: fmt.Sprintf("ssh://testuser:idkfa@%s:%d/home/testuser/test.sock", + connConfig.hostIPv4, + connConfig.portIPv4, + ), + credentialConfig: funcssh.Config{PassPhraseCallback: func() (string, error) { + return "", errors.New("test_error_msg") + }}, + }, + setUpEnv: all(withoutSSHAgent, withCleanHome, withKey(clientPrivKeyRSA, "id_rsa", "nbusr123"), withKnowHosts(connConfig)), + CreateError: "test_error_msg", + }, + { + name: "with broken key at default location", + args: args{connStr: fmt.Sprintf("ssh://testuser:idkfa@%s:%d/home/testuser/test.sock", + connConfig.hostIPv4, + connConfig.portIPv4, + )}, + setUpEnv: all(withoutSSHAgent, withCleanHome, withGibberishKey("id_dsa"), withKnowHosts(connConfig)), + CreateError: "failed to parse private key", + }, + { + name: "with broken key explicit", + args: args{ + connStr: fmt.Sprintf("ssh://testuser:idkfa@%s:%d/home/testuser/test.sock", + connConfig.hostIPv4, + connConfig.portIPv4, + ), + credentialConfig: funcssh.Config{Identity: gibberishKey(t)}, + }, + setUpEnv: all(withoutSSHAgent, withCleanHome, withKnowHosts(connConfig)), + CreateError: "failed to parse private key", + }, + { + name: "with inaccessible key", + args: args{connStr: fmt.Sprintf("ssh://testuser:idkfa@%s:%d/home/testuser/test.sock", + connConfig.hostIPv4, + connConfig.portIPv4, + )}, + setUpEnv: all(withoutSSHAgent, withCleanHome, withInaccessibleKey("id_rsa"), withKnowHosts(connConfig)), + skipOnWin: true, + skipOnRoot: true, + CreateError: "failed to read key file", + }, + { + name: "socket doesn't exist in remote", + args: args{ + connStr: fmt.Sprintf("ssh://testuser@%s:%d/does/not/exist/test.sock", + connConfig.hostIPv4, + connConfig.portIPv4, + ), + credentialConfig: funcssh.Config{PasswordCallback: func() (string, error) { + return "idkfa", nil + }}, + }, + setUpEnv: all(withoutSSHAgent, withCleanHome, withKnowHosts(connConfig)), + DialError: "failed to dial unix socket in the remote", + }, + { + name: "ssh agent non-existent socket", + args: args{ + connStr: fmt.Sprintf("ssh://testuser@%s:%d/does/not/exist/test.sock", + connConfig.hostIPv4, + connConfig.portIPv4, + ), + }, + setUpEnv: all(withBadSSHAgentSocket, withCleanHome, withKnowHosts(connConfig)), + CreateError: "failed to connect to ssh-agent's socket", + }, + { + name: "bad ssh agent", + args: args{ + connStr: fmt.Sprintf("ssh://testuser@%s:%d/does/not/exist/test.sock", + connConfig.hostIPv4, + connConfig.portIPv4, + ), + }, + setUpEnv: all(withBadSSHAgent, withCleanHome, withKnowHosts(connConfig)), + CreateError: "failed to get signers from ssh-agent", + }, + { + name: "use docker host from remote unix", + args: args{ + connStr: fmt.Sprintf("ssh://testuser@%s:%d", + connConfig.hostIPv4, + connConfig.portIPv4, + ), + credentialConfig: funcssh.Config{Identity: tempKey(t, clientPrivKeyECDSA, "")}, + }, + setUpEnv: all(withoutSSHAgent, withCleanHome, withKnowHosts(connConfig), + withRemoteDockerHost("unix:///home/testuser/test.sock", connConfig)), + }, + { + name: "use docker host from remote tcp", + args: args{ + connStr: fmt.Sprintf("ssh://testuser@%s:%d", + connConfig.hostIPv4, + connConfig.portIPv4, + ), + credentialConfig: funcssh.Config{Identity: tempKey(t, clientPrivKeyECDSA, "")}, + }, + setUpEnv: all(withoutSSHAgent, withCleanHome, withKnowHosts(connConfig), + withRemoteDockerHost("tcp://localhost:1234", connConfig)), + }, + { + name: "use docker host from remote fd", + args: args{ + connStr: fmt.Sprintf("ssh://testuser@%s:%d", + connConfig.hostIPv4, + connConfig.portIPv4, + ), + credentialConfig: funcssh.Config{Identity: tempKey(t, clientPrivKeyECDSA, "")}, + }, + setUpEnv: all(withoutSSHAgent, withCleanHome, withKnowHosts(connConfig), + withRemoteDockerHost("fd://localhost:1234", connConfig)), + }, + { + name: "windows without docker system dial-stdio", + args: args{ + connStr: fmt.Sprintf("ssh://testuser@%s:%d", + connConfig.hostIPv4, + connConfig.portIPv4, + ), + credentialConfig: funcssh.Config{Identity: tempKey(t, clientPrivKeyECDSA, "")}, + }, + setUpEnv: all(withoutSSHAgent, withCleanHome, withKnowHosts(connConfig), + withEmulatingWindows(connConfig)), + CreateError: "cannot use dial-stdio", + }, + { + name: "windows with system dial-stdio", + args: args{ + connStr: fmt.Sprintf("ssh://testuser@%s:%d", + connConfig.hostIPv4, + connConfig.portIPv4, + ), + credentialConfig: funcssh.Config{Identity: tempKey(t, clientPrivKeyECDSA, "")}, + }, + setUpEnv: all(withoutSSHAgent, withCleanHome, withEmulatingWindows(connConfig), withKnowHosts(connConfig), + withEmulatedDockerSystemDialStdio(connConfig), withFixedUpSSHCLI), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u, err := url.Parse(tt.args.connStr) + th.AssertNil(t, err) + + if net.ParseIP(u.Hostname()).To4() == nil && connConfig.hostIPv6 == "" { + t.Skip("skipping ipv6 test since test environment doesn't support ipv6 connection") + } + + if tt.skipOnWin && runtime.GOOS == "windows" { + t.Skip("skipping this test on windows") + } + + if tt.skipOnRoot && os.Geteuid() == 0 { + t.Skip("skipping this test when running as a root") + } + + tt.setUpEnv(t) + + dialContext, _, err := funcssh.NewDialContext(u, tt.args.credentialConfig) + + if tt.CreateError == "" { + th.AssertEq(t, err, nil) + } else { + // I wish I could use errors.Is(), + // however foreign code is not wrapping errors thoroughly + if err != nil { + th.AssertContains(t, err.Error(), tt.CreateError) + } else { + t.Error("expected error but got nil") + } + } + if err != nil { + return + } + + transport := http.Transport{DialContext: dialContext.DialContext} + httpClient := http.Client{Transport: &transport} + defer httpClient.CloseIdleConnections() + resp, err := httpClient.Get("http://docker/") + if tt.DialError == "" { + th.AssertNil(t, err) + } else { + // I wish I could use errors.Is(), + // however foreign code is not wrapping errors thoroughly + if err != nil { + th.AssertContains(t, err.Error(), tt.CreateError) + } else { + t.Error("expected error but got nil") + } + } + if err != nil { + return + } + defer resp.Body.Close() + + b, err := io.ReadAll(resp.Body) + th.AssertTrue(t, err == nil) + if err != nil { + return + } + th.AssertEq(t, string(b), "OK") + }) + } +} + +// function that prepares testing environment and returns clean up function +// this should be used in conjunction with defer: `defer fn()()` +// e.g. sets environment variables or starts mock up services +// it returns clean up procedure that restores old values of environment variables +// or shuts down mock up services +type setUpEnvFn func(t *testing.T) + +// combines multiple setUp routines into one setUp routine +func all(fns ...setUpEnvFn) setUpEnvFn { + return func(t *testing.T) { + //t.Helper() + + for _, fn := range fns { + fn(t) + } + } +} + +// puts private key to $HOME/.ssh/{keyName} +func withKey(key any, keyName, passphrase string) setUpEnvFn { + return func(t *testing.T) { + t.Helper() + + home, err := os.UserHomeDir() + th.AssertNil(t, err) + + err = os.MkdirAll(filepath.Join(home, ".ssh"), 0700) + th.AssertNil(t, err) + + keyDest := filepath.Join(home, ".ssh", keyName) + + marshallKey(t, key, keyDest, passphrase) + + t.Cleanup(func() { + _ = os.Remove(keyDest) + }) + } +} + +func gibberishKey(t *testing.T) string { + t.Helper() + p := filepath.Join(t.TempDir(), "id") + err := os.WriteFile(p, []byte("definetelynotakey"), 0600) + th.AssertNil(t, err) + return p +} + +func withGibberishKey(keyName string) setUpEnvFn { + return func(t *testing.T) { + t.Helper() + + home, err := os.UserHomeDir() + th.AssertNil(t, err) + + err = os.MkdirAll(filepath.Join(home, ".ssh"), 0700) + th.AssertNil(t, err) + + keyDest := filepath.Join(home, ".ssh", keyName) + err = os.WriteFile(keyDest, []byte("definetelynotakey"), 0600) + th.AssertNil(t, err) + } +} + +// this function marshals key to temporary file and returns its path +func tempKey(t *testing.T, key any, passphrase string) string { + p := filepath.Join(t.TempDir(), "id") + marshallKey(t, key, p, passphrase) + return p +} + +func marshallKey(t *testing.T, key any, destPath, passphrase string) { + var ( + err error + raw []byte + pemType string + ) + + if k, ok := key.(*rsa.PrivateKey); ok { + pemType = "RSA PRIVATE KEY" + raw = x509.MarshalPKCS1PrivateKey(k) + } else if k, ok := key.(*ecdsa.PrivateKey); ok { + pemType = "EC PRIVATE KEY" + raw, err = x509.MarshalECPrivateKey(k) + th.AssertNil(t, err) + } else { + panic("unsupported key type") + } + + blk := &pem.Block{ + Type: pemType, + Bytes: raw, + } + + if passphrase != "" { + //nolint:staticcheck + blk, err = x509.EncryptPEMBlock(rand.New(rand.NewSource(time.Now().UnixNano())), blk.Type, blk.Bytes, []byte(passphrase), x509.PEMCipherAES256) + th.AssertNil(t, err) + } + + f, err := os.OpenFile(destPath, os.O_CREATE|os.O_WRONLY, 0600) + th.AssertNil(t, err) + defer f.Close() + + err = pem.Encode(f, blk) + th.AssertNil(t, err) + _ = f.Close() + + fixupPrivateKeyMod(destPath) +} + +// withInaccessibleKey creates inaccessible key of give type (specified by keyName) +func withInaccessibleKey(keyName string) setUpEnvFn { + return func(t *testing.T) { + t.Helper() + var err error + + home, err := os.UserHomeDir() + th.AssertNil(t, err) + + err = os.MkdirAll(filepath.Join(home, ".ssh"), 0700) + th.AssertNil(t, err) + + keyDest := filepath.Join(home, ".ssh", keyName) + f, err := os.OpenFile(keyDest, os.O_CREATE|os.O_WRONLY, 0000) + th.AssertNil(t, err) + f.Close() + + t.Cleanup(func() { + _ = os.Remove(keyDest) + }) + } +} + +// sets clean temporary $HOME for test +// this prevents interaction with actual user home which may contain .ssh/ +func withCleanHome(t *testing.T) { + t.Helper() + homeName := "HOME" + if runtime.GOOS == "windows" { + homeName = "USERPROFILE" + } + tempHome := t.TempDir() + t.Setenv(homeName, tempHome) +} + +// withKnowHosts creates $HOME/.ssh/known_hosts with correct entries +func withKnowHosts(connConfig *SSHServer) setUpEnvFn { + return func(t *testing.T) { + t.Helper() + + knownHosts := filepath.Join(homedir.Get(), ".ssh", "known_hosts") + + err := os.MkdirAll(filepath.Join(homedir.Get(), ".ssh"), 0700) + th.AssertNil(t, err) + + _, err = os.Stat(knownHosts) + if err == nil || !errors.Is(err, os.ErrNotExist) { + t.Fatal("known_hosts already exists") + } + + f, err := os.OpenFile(knownHosts, os.O_CREATE|os.O_WRONLY, 0600) + th.AssertNil(t, err) + defer f.Close() + + // generate known_hosts + for _, privKey := range connConfig.serverKeys { + pubKey := publicKey(privKey) + k, err := ssh.NewPublicKey(pubKey) + if err != nil { + t.Fatal(err) + } + bs := ssh.MarshalAuthorizedKey(k) + + fmt.Fprintf(f, "%s %s", connConfig.hostIPv4, string(bs)) + fmt.Fprintf(f, "[%s]:%d %s", connConfig.hostIPv4, connConfig.portIPv4, string(bs)) + + if connConfig.hostIPv6 != "" { + fmt.Fprintf(f, "%s %s", connConfig.hostIPv6, string(bs)) + fmt.Fprintf(f, "[%s]:%d %s", connConfig.hostIPv6, connConfig.portIPv6, string(bs)) + } + } + t.Cleanup(func() { + _ = os.Remove(knownHosts) + }) + } +} + +func publicKey(privKey any) any { + switch privKey := privKey.(type) { + case *rsa.PrivateKey: + return &privKey.PublicKey + case *ecdsa.PrivateKey: + return &privKey.PublicKey + default: + panic("unsupported key type") + } +} + +// withBadKnownHosts creates $HOME/.ssh/known_hosts with incorrect entries +func withBadKnownHosts(connConfig *SSHServer) setUpEnvFn { + return func(t *testing.T) { + t.Helper() + + knownHosts := filepath.Join(homedir.Get(), ".ssh", "known_hosts") + + err := os.MkdirAll(filepath.Join(homedir.Get(), ".ssh"), 0700) + th.AssertNil(t, err) + + _, err = os.Stat(knownHosts) + if err == nil || !errors.Is(err, os.ErrNotExist) { + t.Fatal("known_hosts already exists") + } + + f, err := os.OpenFile(knownHosts, os.O_CREATE|os.O_WRONLY, 0600) + th.AssertNil(t, err) + defer f.Close() + + knownHostTemplate := `{{range $host := .}}{{$host}} ssh-dss AAAAB3NzaC1kc3MAAACBAKH4ufS3ABVb780oTgEL1eu+pI1p6YOq/1KJn5s3zm+L3cXXq76r5OM/roGEYrXWUDGRtfVpzYTAKoMWuqcVc0AZ2zOdYkoy1fSjJ3MqDGF53QEO3TXIUt3gUzmLOewwmZWle0RgMa9GHccv7XVVIZB36RR68ZEUswLaTnlVhXQ1AAAAFQCl4t/LnY7kuUI+tL2qT2XmxmiyqwAAAIB72XaO+LfyIiqBOaTkQf+5rvH1i6y6LDO1QD9pzGWUYw3y03AEveHJMjW0EjnYBKJjK39wcZNTieRyU54lhH/HWeWABn9NcQ3duEf1WSO/s7SPsFO2R6quqVSsStkqf2Yfdy4fl24mH41olwtNA6ft5nkVfkqrIa51si4jU8fBVAAAAIB8SSvyYBcyMGLUlQjzQqhhhAHer9x/1YbknVz+y5PHJLLjHjMC4ZRfLgNEojvMKQW46Te9Pwnudcwv19ho4F+kkCOfss7xjyH70gQm6Sj76DxClmnnPoSRq3qEAOMy5Oh+7vyzxm68KHqd/aOmUaiT1LgqgViS9+kNdCoVMGAMOg== mvasek@bellatrix +{{$host}} ecdsa-sha2-nistp384 AAAAE2VjZHNhLXNoYTItbmlzdHAzODQAAAAIbmlzdHAzODQAAABhBKPrqGp4c5ZstymDqXOxPsIEH6e6a4Pi8qcTRUkbyQllWjyQVx0A/o4yA8cd222x3t9gsiGa+mNgCYkyFehH0nKO7gk057jNmALc9xhbj25EdmREjdex+yUrmxdxcG9mtQ== +{{$host}} ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIOKymJNQszrxetVffPZRfZGKWK786r0mNcg/Wah4+2wn mvasek@bellatrix +{{$host}} ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQC/1/OCwec2Gyv5goNYYvos4iOA+a0NolOGsZA/93jmSArPY1zZS1UWeJ6dDTmxGoL/e7jm9lM6NJY7a/zM0C/GqCNRGR/aCUHBJTIgGtH+79FDKO/LWY6ClGY7Lw8qNgZpugbBw3N3HqTtyb2lELhFLT0FEb+le4WUbryooLK2zsz6DnqV4JvTYyyHcanS0h68iSXC7XbkZchvL99l5LT0gD1oDteBPKKFdNOwIjpMkk/IrbFM24xoNkaTDXN87EpQPQzYDfsoGymprc5OZZ8kzrtErQR+yfuunHfzzqDHWi7ga5pbgkuxNt10djWgCfBRsy07FTEgV0JirS0TCfwTBbqRzdjf3dgi8AP+WtkW3mcv4a1XYeqoBo2o9TbfyiA9kERs79UBN0mCe3KNX3Ns0PvutsRLaHmdJ49eaKWkJ6GgL37aqSlIwTixz2xY3eoDSkqHoZpx6Q1MdpSIl5gGVzlaobM/PNM1jqVdyUj+xpjHyiXwHQMKc3eJna7s8Jc= mvasek@bellatrix +{{end}}` + + tmpl := template.New(knownHostTemplate) + tmpl, err = tmpl.Parse(knownHostTemplate) + th.AssertNil(t, err) + + hosts := make([]string, 0, 4) + hosts = append(hosts, connConfig.hostIPv4, fmt.Sprintf("[%s]:%d", connConfig.hostIPv4, connConfig.portIPv4)) + if connConfig.hostIPv6 != "" { + hosts = append(hosts, connConfig.hostIPv6, fmt.Sprintf("[%s]:%d", connConfig.hostIPv6, connConfig.portIPv4)) + } + + err = tmpl.Execute(f, hosts) + th.AssertNil(t, err) + + t.Cleanup(func() { + _ = os.Remove(knownHosts) + }) + } +} + +// withBrokenKnownHosts creates broken $HOME/.ssh/known_hosts +func withBrokenKnownHosts(t *testing.T) { + t.Helper() + + knownHosts := filepath.Join(homedir.Get(), ".ssh", "known_hosts") + + err := os.MkdirAll(filepath.Join(homedir.Get(), ".ssh"), 0700) + th.AssertNil(t, err) + + _, err = os.Stat(knownHosts) + if err == nil || !errors.Is(err, os.ErrNotExist) { + t.Fatal("known_hosts already exists") + } + + f, err := os.OpenFile(knownHosts, os.O_CREATE|os.O_WRONLY, 0600) + th.AssertNil(t, err) + defer f.Close() + + _, err = f.WriteString("somegarbage\nsome rubish\n stuff\tqwerty") + th.AssertNil(t, err) + + t.Cleanup(func() { + os.Remove(knownHosts) + }) +} + +// withInaccessibleKnownHosts creates inaccessible $HOME/.ssh/known_hosts +func withInaccessibleKnownHosts(t *testing.T) { + t.Helper() + + knownHosts := filepath.Join(homedir.Get(), ".ssh", "known_hosts") + + err := os.MkdirAll(filepath.Join(homedir.Get(), ".ssh"), 0700) + th.AssertNil(t, err) + + _, err = os.Stat(knownHosts) + if err == nil || !errors.Is(err, os.ErrNotExist) { + t.Fatal("known_hosts already exists") + } + + f, err := os.OpenFile(knownHosts, os.O_CREATE|os.O_WRONLY, 0000) + th.AssertNil(t, err) + defer f.Close() + + t.Cleanup(func() { + _ = os.Remove(knownHosts) + }) +} + +// withEmptyKnownHosts creates empty $HOME/.ssh/known_hosts +func withEmptyKnownHosts(t *testing.T) { + t.Helper() + + knownHosts := filepath.Join(homedir.Get(), ".ssh", "known_hosts") + + err := os.MkdirAll(filepath.Join(homedir.Get(), ".ssh"), 0700) + th.AssertNil(t, err) + + _, err = os.Stat(knownHosts) + if err == nil || !errors.Is(err, os.ErrNotExist) { + t.Fatal("known_hosts already exists") + } + + err = os.WriteFile(knownHosts, []byte{}, 0644) + th.AssertNil(t, err) + + t.Cleanup(func() { + _ = os.Remove(knownHosts) + }) +} + +// withoutSSHAgent unsets the SSH_AUTH_SOCK environment variable so ssh-agent is not used by test +func withoutSSHAgent(t *testing.T) { + t.Helper() + t.Setenv("SSH_AUTH_SOCK", "") +} + +// withBadSSHAgentSocket sets the SSH_AUTH_SOCK environment variable to non-existing file +func withBadSSHAgentSocket(t *testing.T) { + t.Helper() + t.Setenv("SSH_AUTH_SOCK", "/does/not/exists.sock") +} + +// withGoodSSHAgent starts serving ssh-agent on temporary unix socket. +// It sets the SSH_AUTH_SOCK environment variable to the temporary socket. +// The agent will return correct keys for the testing ssh server. +func withGoodSSHAgent(keys ...any) setUpEnvFn { + return func(t *testing.T) { + t.Helper() + withSSHAgent(t, signerAgent{keys}) + } +} + +// withBadSSHAgent starts serving ssh-agent on temporary unix socket. +// It sets the SSH_AUTH_SOCK environment variable to the temporary socket. +// The agent will return incorrect keys for the testing ssh server. +func withBadSSHAgent(t *testing.T) { + withSSHAgent(t, badAgent{}) +} + +func withSSHAgent(t *testing.T, ag agent.Agent) { + var err error + t.Helper() + + var tmpDirForSocket string + var agentSocketPath string + if runtime.GOOS == "windows" { + agentSocketPath = `\\.\pipe\openssh-ssh-agent-test` + } else { + tmpDirForSocket, err = os.MkdirTemp("", "forAuthSock") + th.AssertNil(t, err) + + agentSocketPath = filepath.Join(tmpDirForSocket, "agent.sock") + } + + unixListener, err := listen(agentSocketPath) + th.AssertNil(t, err) + + os.Setenv("SSH_AUTH_SOCK", agentSocketPath) + + ctx, cancel := context.WithCancel(context.Background()) + errChan := make(chan error, 1) + var wg sync.WaitGroup + + go func() { + for { + conn, err := unixListener.Accept() + if err != nil { + errChan <- err + + return + } + + wg.Add(1) + go func(conn net.Conn) { + defer wg.Done() + go func() { + <-ctx.Done() + conn.Close() + }() + err := agent.ServeAgent(ag, conn) + if err != nil { + if !isErrClosed(err) { + fmt.Fprintf(os.Stderr, "agent.ServeAgent() failed: %v\n", err) + } + } + }(conn) + } + }() + + t.Cleanup(func() { + os.Unsetenv("SSH_AUTH_SOCK") + + err := unixListener.Close() + th.AssertNil(t, err) + + err = <-errChan + + if !isErrClosed(err) { + t.Fatal(err) + } + cancel() + wg.Wait() + if tmpDirForSocket != "" { + os.RemoveAll(tmpDirForSocket) + } + }) +} + +type signerAgent struct { + keys []any +} + +func (a signerAgent) List() ([]*agent.Key, error) { + result := make([]*agent.Key, 0, len(a.keys)) + for _, key := range a.keys { + signer, err := ssh.NewSignerFromKey(key) + if err != nil { + return nil, err + } + result = append(result, &agent.Key{ + Format: signer.PublicKey().Type(), + Blob: signer.PublicKey().Marshal(), + }) + } + return result, nil +} + +func (a signerAgent) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) { + for _, k := range a.keys { + signer, err := ssh.NewSignerFromKey(k) + if err != nil { + return nil, err + } + if signer.PublicKey().Type() == key.Type() && + bytes.Equal(signer.PublicKey().Marshal(), key.Marshal()) { + return signer.Sign(rand.New(rand.NewSource(time.Now().UnixNano())), data) + } + } + return nil, errors.New("key not found") +} + +func (a signerAgent) Add(key agent.AddedKey) error { + panic("implement me") +} + +func (a signerAgent) Remove(key ssh.PublicKey) error { + panic("implement me") +} + +func (a signerAgent) RemoveAll() error { + panic("implement me") +} + +func (a signerAgent) Lock(passphrase []byte) error { + panic("implement me") +} + +func (a signerAgent) Unlock(passphrase []byte) error { + panic("implement me") +} + +func (a signerAgent) Signers() ([]ssh.Signer, error) { + panic("implement me") +} + +var errBadAgent = errors.New("bad agent error") + +type badAgent struct{} + +func (b badAgent) List() ([]*agent.Key, error) { + return nil, errBadAgent +} + +func (b badAgent) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) { + return nil, errBadAgent +} + +func (b badAgent) Add(key agent.AddedKey) error { + return errBadAgent +} + +func (b badAgent) Remove(key ssh.PublicKey) error { + return errBadAgent +} + +func (b badAgent) RemoveAll() error { + return errBadAgent +} + +func (b badAgent) Lock(passphrase []byte) error { + return errBadAgent +} + +func (b badAgent) Unlock(passphrase []byte) error { + return errBadAgent +} + +func (b badAgent) Signers() ([]ssh.Signer, error) { + return nil, errBadAgent +} + +// openSSH CLI doesn't take the HOME/USERPROFILE environment variable into account. +// It gets user home in different way (e.g. reading /etc/passwd). +// This means tests cannot mock home dir just by setting environment variable. +// withFixedUpSSHCLI works around the problem, it forces usage of known_hosts from HOME/USERPROFILE. +func withFixedUpSSHCLI(t *testing.T) { + t.Helper() + + sshAbsPath, err := exec.LookPath("ssh") + th.AssertNil(t, err) + + sshScript := `#!/bin/sh +SSH_BIN -o PasswordAuthentication=no -o ConnectTimeout=3 -o UserKnownHostsFile="$HOME/.ssh/known_hosts" $@ +` + if runtime.GOOS == "windows" { + sshScript = `@echo off +"SSH_BIN" -o PasswordAuthentication=no -o ConnectTimeout=3 -o UserKnownHostsFile=%USERPROFILE%\.ssh\known_hosts %* +` + } + sshScript = strings.ReplaceAll(sshScript, "SSH_BIN", sshAbsPath) + + home, err := os.UserHomeDir() + th.AssertNil(t, err) + + homeBin := filepath.Join(home, "bin") + err = os.MkdirAll(homeBin, 0700) + th.AssertNil(t, err) + + sshScriptName := "ssh" + if runtime.GOOS == "windows" { + sshScriptName = "ssh.bat" + } + + sshScriptFullPath := filepath.Join(homeBin, sshScriptName) + err = os.WriteFile(sshScriptFullPath, []byte(sshScript), 0700) + th.AssertNil(t, err) + + t.Setenv("PATH", homeBin+string(os.PathListSeparator)+os.Getenv("PATH")) + t.Cleanup(func() { + os.RemoveAll(homeBin) + }) +} + +// withEmulatedDockerSystemDialStdio makes `docker system dial-stdio` viable in the testing ssh server. +// It does so by appending definition of shell function named `docker` into .bashrc . +func withEmulatedDockerSystemDialStdio(sshServer *SSHServer) setUpEnvFn { + return func(t *testing.T) { + t.Helper() + + oldHasDialStdio := sshServer.HasDialStdio() + sshServer.SetHasDialStdio(true) + t.Cleanup(func() { + sshServer.SetHasDialStdio(oldHasDialStdio) + }) + } +} + +// withEmulatingWindows makes changes to the testing ssh server such that +// the server appears to be Windows server for simple check done calling the `systeminfo` command +func withEmulatingWindows(sshServer *SSHServer) setUpEnvFn { + return func(t *testing.T) { + oldIsWindows := sshServer.IsWindows() + sshServer.SetIsWindows(true) + t.Cleanup(func() { + sshServer.SetIsWindows(oldIsWindows) + }) + } +} + +// withRemoteDockerHost makes changes to the testing ssh server such that +// the DOCKER_HOST environment is set to host parameter +func withRemoteDockerHost(host string, sshServer *SSHServer) setUpEnvFn { + return func(t *testing.T) { + oldHost := sshServer.GetDockerHostEnvVar() + sshServer.SetDockerHostEnvVar(host) + t.Cleanup(func() { + sshServer.SetDockerHostEnvVar(oldHost) + }) + } +} + +func generateClientKeys(t *testing.T) (privKeyRSA *rsa.PrivateKey, privKeyECDSA *ecdsa.PrivateKey) { + var err error + + privKeyRSA, err = rsa.GenerateKey(rand.New(rand.NewSource(time.Now().UnixNano())), 2048) + if err != nil { + t.Fatal(err) + } + + privKeyECDSA, err = ecdsa.GenerateKey(elliptic.P384(), rand.New(rand.NewSource(time.Now().UnixNano()))) + if err != nil { + t.Fatal(err) + } + + return privKeyRSA, privKeyECDSA +} diff --git a/pkg/ssh/ssh_posix_test.go b/pkg/ssh/ssh_posix_test.go new file mode 100644 index 0000000000..f3ebe64064 --- /dev/null +++ b/pkg/ssh/ssh_posix_test.go @@ -0,0 +1,25 @@ +//go:build !windows +// +build !windows + +package ssh_test + +import ( + "errors" + "net" + "os" +) + +func fixupPrivateKeyMod(path string) { + err := os.Chmod(path, 0600) + if err != nil { + panic(err) + } +} + +func listen(addr string) (net.Listener, error) { + return net.Listen("unix", addr) +} + +func isErrClosed(err error) bool { + return errors.Is(err, net.ErrClosed) +} diff --git a/pkg/ssh/ssh_windows_test.go b/pkg/ssh/ssh_windows_test.go new file mode 100644 index 0000000000..94172c0d03 --- /dev/null +++ b/pkg/ssh/ssh_windows_test.go @@ -0,0 +1,39 @@ +package ssh_test + +import ( + "errors" + "net" + "os/user" + "strings" + + "github.com/Microsoft/go-winio" + "github.com/hectane/go-acl" +) + +func fixupPrivateKeyMod(path string) { + usr, err := user.Current() + if err != nil { + panic(err) + } + mode := uint32(0600) + err = acl.Apply(path, + true, + false, + acl.GrantName(((mode&0700)<<23)|((mode&0200)<<9), usr.Username)) + + // See https://github.com/hectane/go-acl/issues/1 + if err != nil && err.Error() != "The operation completed successfully." { + panic(err) + } +} + +func listen(addr string) (net.Listener, error) { + if strings.Contains(addr, "\\pipe\\") { + return winio.ListenPipe(addr, nil) + } + return net.Listen("unix", addr) +} + +func isErrClosed(err error) bool { + return errors.Is(err, net.ErrClosed) || errors.Is(err, winio.ErrPipeListenerClosed) || errors.Is(err, winio.ErrFileClosed) +}