Skip to content

Commit

Permalink
fix possible race condition in (*os.Cmd).Wait()
Browse files Browse the repository at this point in the history
Connects-to: #33
Change-type: patch
Signed-off-by: Will Boyce <[email protected]>
  • Loading branch information
wrboyce committed May 6, 2019
1 parent 61db426 commit 11da2b2
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 15 deletions.
1 change: 1 addition & 0 deletions .errcheck.exclude
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
io.Copy
(*os.Process).Kill
(net.Conn).Close
fmt.Fprintf
fmt.Fprintln
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ test-dep: dep
go test -i -v ./...

test: test-dep
go test -v ./...
go test -race -v ./...

release: $(addsuffix .tar.gz,$(addprefix build/$(EXECUTABLE)-$(VERSION)_,$(subst /,_,$(BUILD_PLATFORMS))))
release: $(addsuffix .tar.gz.sha256,$(addprefix build/$(EXECUTABLE)-$(VERSION)_,$(subst /,_,$(BUILD_PLATFORMS))))
Expand Down
25 changes: 11 additions & 14 deletions sshproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,24 +272,21 @@ func (s *Server) handleRequests(reqs <-chan *ssh.Request, channel ssh.Channel, c
go func() {
done := make(chan error, 1)
go func() {
done <- cmd.Wait()
}()
Loop:
for {
select {
case <-time.After(10 * time.Second):
if _, err := channel.SendRequest("ping", false, []byte{}); err != nil {
// Channel is dead, kill process
if err := cmd.Process.Kill(); err != nil {
s.handleError(err, nil)
for {
select {
case <-time.After(10 * time.Second):
if _, err := channel.SendRequest("ping", false, []byte{}); err != nil {
// Channel is dead, attempt to kill process
cmd.Process.Kill()
break
}
break Loop
case <-done:
break
}
case <-done:
break Loop
}
}
}()

done <- cmd.Wait()
exitStatusPayload := make([]byte, 4)
exitStatus := uint32(1)
if cmd.ProcessState != nil {
Expand Down
62 changes: 62 additions & 0 deletions sshproxy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package sshproxy_test

import (
"net"
"testing"
"time"

"github.com/balena-io/sshproxy"
"golang.org/x/crypto/ssh"
)

func TestRace(t *testing.T) {
server, err := sshproxy.New(
"/tmp",
"/bin/bash",
false,
nil,
3,
nil,
func(err error, tags map[string]string) {
t.Logf("uncaught error: %s", err)
})

if err != nil {
t.Fatalf("error calling sshproxy.New :( %s", err)
}

go func() {
if err := server.Listen("12345"); err != nil {
t.Fatalf("Cannot start server! %s", err)
}
}()

config := &ssh.ClientConfig{
User: "user",
Auth: []ssh.AuthMethod{
ssh.Password("password"),
},
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
}
for i := 0; i < 10; i++ {
client, err := ssh.Dial("tcp", "localhost:12345", config)
if err != nil {
t.Errorf("Cannot connect to server :( %s", err)
}
session, err := client.NewSession()
if err != nil {
t.Errorf("Cannot create session :( %s", err)
}
time.Sleep(time.Second)
_, err = session.SendRequest("exec", false, []byte{0, 0, 0, 4, 't', 'e', 's', 't'})
if err != nil {
t.Errorf("Cannot send exec request :( %q", err)
}
time.Sleep(time.Duration(i*100) * time.Millisecond)
if err := client.Close(); err != nil {
t.Errorf("Error closing client - %s", err)
}
}
}

0 comments on commit 11da2b2

Please sign in to comment.