From 01ed97d1355213fd37ca50b843bd9edbac773dd1 Mon Sep 17 00:00:00 2001 From: masahide Date: Sun, 25 Aug 2024 13:17:54 +0000 Subject: [PATCH] add agent-bench --- .github/workflows/build.yml | 15 +- cmd/agent-bench/main.go | 165 ++++++++++++++++++ cmd/agent-bench/unix.go | 42 +++++ cmd/agent-bench/win.go | 44 +++++ cmd/wsl2-ssh-agent-proxy/main.go | 259 +++++++++++++++++++++------- cmd/wsl2-ssh-agent-proxy/pwsh.ps1 | 78 +++++---- hack/ubuntu.setup.sh | 57 +++--- hack/ubuntu.wsl2-ssh-agent-proxy.sh | 40 +++++ 8 files changed, 571 insertions(+), 129 deletions(-) create mode 100644 cmd/agent-bench/main.go create mode 100644 cmd/agent-bench/unix.go create mode 100644 cmd/agent-bench/win.go create mode 100644 hack/ubuntu.wsl2-ssh-agent-proxy.sh diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 680a12c..7d71e61 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -37,8 +37,11 @@ jobs: run: | wails build -nsis go build -o build/bin/omni-socat.exe ./cmd/omni-socat + go build -o build/bin/agent-bench.exe ./cmd/agent-bench powershell Compress-Archive -Path build/bin/omni-socat.exe -DestinationPath build/bin/omni-socat.zip rm build/bin/omni-socat.exe + powershell Compress-Archive -Path build/bin/agent-bench.exe -DestinationPath build/bin/agent-bench.zip + rm build/bin/agent-bench.exe powershell Compress-Archive -Path build/bin/OmniSSHAgent.exe -DestinationPath build/bin/OmniSSHAgent.zip rm build/bin/OmniSSHAgent.exe @@ -52,7 +55,7 @@ jobs: run: | echo "ls build/bin" >> $env:GITHUB_STEP_SUMMARY ls "build/bin">> $env:GITHUB_STEP_SUMMARY - build-wsl2-ssh-agent-proxy: + build-unix: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -63,7 +66,15 @@ jobs: - name: build run: | CGO_ENABLED=0 go build -o build/bin/wsl2-ssh-agent-proxy ./cmd/wsl2-ssh-agent-proxy + CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -o build/bin/agent-bench-linux-arm64 ./cmd/agent-bench + CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o build/bin/agent-bench-linux-amd64 ./cmd/agent-bench + CGO_ENABLED=0 GOOS=darwin GOARCH=arm64 go build -o build/bin/agent-bench-mac-arm64 ./cmd/agent-bench + CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -o build/bin/agent-bench-mac-amd64 ./cmd/agent-bench gzip build/bin/wsl2-ssh-agent-proxy + gzip build/bin/agent-bench-linux-arm64 + gzip build/bin/agent-bench-linux-amd64 + gzip build/bin/agent-bench-mac-arm64 + gzip build/bin/agent-bench-mac-amd64 - uses: actions/upload-artifact@v4 with: name: build-files-linux @@ -76,7 +87,7 @@ jobs: create-release: if: startsWith(github.ref, 'refs/tags/') runs-on: ubuntu-22.04 - needs: [build-exe, build-wsl2-ssh-agent-proxy] + needs: [build-exe, build-unix] steps: - uses: actions/checkout@v4 - name: Download All Artifacts diff --git a/cmd/agent-bench/main.go b/cmd/agent-bench/main.go new file mode 100644 index 0000000..7f9e3ff --- /dev/null +++ b/cmd/agent-bench/main.go @@ -0,0 +1,165 @@ +package main + +import ( + "flag" + "fmt" + "log" + "net" + "sort" + "sync" + "time" + + "github.com/kelseyhightower/envconfig" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" +) + +type Specification struct { + PERSISTENT bool `default:"false"` + CONCURRENCY int `default:"10"` + RUN_COUNT int `default:"100"` +} + +type sshAgent interface { + agent.Agent + Close() error +} +type Agent struct { + agent.ExtendedAgent + net.Conn +} + +type exAgent struct { + agent.Agent +} + +func (e *exAgent) SignWithFlags(key ssh.PublicKey, data []byte, flags agent.SignatureFlags) (*ssh.Signature, error) { + return nil, nil +} + +func (e *exAgent) Extension(string, []byte) ([]byte, error) { + return nil, nil +} + +func getKey() *agent.Key { + var key *agent.Key + + a, err := newAgent() + if err != nil { + log.Fatal(err) + } + keys, err := a.List() + if err != nil { + log.Fatalf("Failed to list keys: %v", err) + } + if len(keys) == 0 { + log.Fatalf("No keys found in SSH agent") + } + key = keys[0] + a.Close() + return key +} + +func main() { + s := Specification{} + err := envconfig.Process("", &s) + if err != nil { + log.Fatal(err) + } + flag.BoolVar(&s.PERSISTENT, "persistent", s.PERSISTENT, "persistent mode") + flag.IntVar(&s.CONCURRENCY, "c", s.CONCURRENCY, "Number of concurrency processing") + flag.IntVar(&s.RUN_COUNT, "n", s.RUN_COUNT, "run count") + flag.Parse() + taskCh := make(chan struct{}) + doneCh := make(chan []time.Duration, s.CONCURRENCY) + + var wg sync.WaitGroup + key := getKey() + fmt.Printf("The key used for measurement:%s\n", key.String()) + fmt.Printf("Start %d worker\n", s.CONCURRENCY) + for i := 0; i < s.CONCURRENCY; i++ { + wg.Add(1) + go func() { + defer wg.Done() + worker(key, taskCh, doneCh, s.PERSISTENT) + }() + } + + start := time.Now() + go func() { + for i := 0; i < s.RUN_COUNT; i++ { + taskCh <- struct{}{} + } + close(taskCh) + }() + + go func() { + wg.Wait() + close(doneCh) + }() + var allExecutionTimes []time.Duration + for times := range doneCh { + allExecutionTimes = append(allExecutionTimes, times...) + } + fmt.Printf("\ndone.\n") + totalTime := time.Duration(0) + var minTime, maxTime time.Duration + minTime = allExecutionTimes[0] + maxTime = allExecutionTimes[0] + + for _, t := range allExecutionTimes { + totalTime += t + if t < minTime { + minTime = t + } + if t > maxTime { + maxTime = t + } + } + + averageTime := totalTime / time.Duration(len(allExecutionTimes)) + + sort.Slice(allExecutionTimes, func(i, j int) bool { + return allExecutionTimes[i] < allExecutionTimes[j] + }) + p99Time := allExecutionTimes[int(float64(len(allExecutionTimes))*0.99)-1] + + fmt.Printf("Real Time: %v\n", time.Since(start)) + fmt.Printf("Total Executions: %d\n", s.RUN_COUNT) + fmt.Printf("Concurrency: %d\n", s.CONCURRENCY) + fmt.Printf("Persistent Mode: %v\n", s.PERSISTENT) + fmt.Printf("Total Time: %v\n", totalTime) + fmt.Printf("Average Execution Time: %v\n", averageTime) + fmt.Printf("Min Execution Time: %v\n", minTime) + fmt.Printf("Max Execution Time: %v\n", maxTime) + fmt.Printf("99th Percentile Execution Time: %v\n", p99Time) +} + +func worker(key *agent.Key, taskCh <-chan struct{}, doneCh chan<- []time.Duration, persistent bool) { + var executionTimes []time.Duration + var err error + var agentClient sshAgent + for range taskCh { + start := time.Now() + if agentClient == nil { + agentClient, err = newAgent() + if err != nil { + log.Fatal(err) + } + } + data := []byte("Benchmark data") + _, err := agentClient.Sign(key, data) + if !persistent { + agentClient.Close() + agentClient = nil + } + if err != nil { + log.Printf("Failed to sign data: %v", err) + continue + } + duration := time.Since(start) + executionTimes = append(executionTimes, duration) + fmt.Print(".") + } + doneCh <- executionTimes +} diff --git a/cmd/agent-bench/unix.go b/cmd/agent-bench/unix.go new file mode 100644 index 0000000..5f48044 --- /dev/null +++ b/cmd/agent-bench/unix.go @@ -0,0 +1,42 @@ +//go:build unix + +package main + +import ( + "log" + "net" + "os" + + "golang.org/x/crypto/ssh/agent" +) + +func listKeys() { + socketPath := os.Getenv("SSH_AUTH_SOCK") + conn, err := net.Dial("unix", socketPath) + if err != nil { + log.Fatal(err) + } + agentClient := agent.NewClient(conn) + list, err := agentClient.List() + if err != nil { + log.Fatal(err) + } + + for _, key := range list { + log.Println(key.String()) + } +} + +func NewUnixDomain() (sshAgent, error) { + socketPath := os.Getenv("SSH_AUTH_SOCK") + conn, err := net.Dial("unix", socketPath) + if err != nil { + log.Fatal(err) + } + a := agent.NewClient(conn) + return &Agent{ExtendedAgent: &exAgent{a}, Conn: conn}, nil +} + +func newAgent() (sshAgent, error) { + return NewUnixDomain() +} diff --git a/cmd/agent-bench/win.go b/cmd/agent-bench/win.go new file mode 100644 index 0000000..60078de --- /dev/null +++ b/cmd/agent-bench/win.go @@ -0,0 +1,44 @@ +//go:build windows + +package main + +import ( + "errors" + + "github.com/Microsoft/go-winio" + "github.com/davidmz/go-pageant" + "golang.org/x/crypto/ssh/agent" +) + +const ( + sshAgentPipe = `\\.\pipe\openssh-ssh-agent` +) + +func NewPageant() (sshAgent, error) { + ok := pageant.Available() + if !ok { + return nil, errors.New("pageant is not available") + } + p := pageant.New() + return &Agent{ExtendedAgent: &exAgent{p}}, nil +} + +func (a *Agent) Close() error { + if a.Conn != nil { + a.Conn.Close() + } + return nil +} + +func NewNamedPipe() (sshAgent, error) { + conn, err := winio.DialPipe(sshAgentPipe, nil) + if err != nil { + return nil, err + } + a := agent.NewClient(conn) + return &Agent{ExtendedAgent: &exAgent{a}, Conn: conn}, nil +} + +func newAgent() (sshAgent, error) { + return NewNamedPipe() +} diff --git a/cmd/wsl2-ssh-agent-proxy/main.go b/cmd/wsl2-ssh-agent-proxy/main.go index 0214811..c0fe5a9 100644 --- a/cmd/wsl2-ssh-agent-proxy/main.go +++ b/cmd/wsl2-ssh-agent-proxy/main.go @@ -3,23 +3,28 @@ package main import ( "bufio" "bytes" + "context" _ "embed" "encoding/binary" + "errors" + "flag" "io" "log" "net" "os" "os/exec" + "os/signal" "path/filepath" "strings" "sync" + "syscall" ) //go:embed pwsh.ps1 var pwshScript string +var debug bool const ( - DEBUG = false HeaderSize = 12 // 4 bytes for channel ID, 4 bytes for message length PacketTypeConnectSend = uint32(0) @@ -27,6 +32,12 @@ const ( PacketTypeClose = uint32(2) ) +type Packet struct { + PacketType uint32 + ChannelID uint32 + Payload []byte +} + type Multiplexer struct { writer io.Writer reader *bufio.Reader @@ -40,30 +51,31 @@ func NewMultiplexer(writer io.Writer, reader io.Reader) *Multiplexer { reader: bufio.NewReader(reader), channels: make(map[uint32]chan []byte), } - go mux.readLoop() + //go mux.readLoop() return mux } -type Packet struct { - PacketType uint32 - ChannelID uint32 - Payload []byte -} - -func (mux *Multiplexer) readLoop() { - if DEBUG { - for { - buf := make([]byte, 4096) - n, err := mux.reader.Read(buf) - log.Printf("debug read buf:[%s] n:%d,err=%v", buf, n, err) - } - } +func (mux *Multiplexer) readLoop(ctx context.Context) error { for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } header := make([]byte, HeaderSize) _, err := io.ReadFull(mux.reader, header) if err != nil { + if errors.Is(err, io.EOF) { + if debug { + log.Println("Connection closed") + } + return err + } log.Println("Error reading header:", err) - return + return err + } + if debug { + log.Printf("mux readFull header:[%v]", header) } packetType := binary.LittleEndian.Uint32(header[:4]) @@ -73,19 +85,27 @@ func (mux *Multiplexer) readLoop() { _, err = io.ReadFull(mux.reader, payload) if err != nil { log.Println("Error reading payload:", err) - return + return err + } + if debug { + log.Printf("mux readFull payload type:%d, ch:%d, len:%d ", packetType, channelID, length) } - //log.Printf("mux readFull payload type:%d, ch:%d, len:%d ", packetType, channelID, length) switch packetType { case PacketTypeSend: mux.channelsMu.Lock() - if ch, ok := mux.channels[channelID]; ok { + ch, ok := mux.channels[channelID] + mux.channelsMu.Unlock() + if ok { ch <- payload + } else { + log.Printf("mux readFull error: channel %d not found", channelID) } - mux.channelsMu.Unlock() case PacketTypeClose: mux.CloseChannel(channelID) + if debug { + log.Printf("mux readFull close channel %d", channelID) + } } } } @@ -121,40 +141,140 @@ func (mux *Multiplexer) CloseChannel(channelID uint32) { } } -func handleConnection(conn net.Conn, mux *Multiplexer, channelID uint32) { +func (ps *pwshIOStream) handleConnection(ctx context.Context, conn net.Conn, channelID uint32) { defer conn.Close() - ch := mux.OpenChannel(channelID) + ch := ps.OpenChannel(channelID) go func() { - //reader := bufio.NewReader(conn) + defer func() { + ps.WriteChannel(Packet{PacketType: PacketTypeClose, ChannelID: channelID, Payload: []byte{}}) + ps.CloseChannel(channelID) + }() packetType := PacketTypeConnectSend for { payload := make([]byte, 4096) n, err := conn.Read(payload) if err != nil { if err == io.EOF { + if debug { + log.Printf("DomainSocket.read ch:%d io.EOF", channelID) + } break } log.Println("Error reading from connection:", err) break } - //log.Printf("handleCoonection read: channelID:%d byte[%v]", channelID, payload[:n]) - mux.WriteChannel(Packet{PacketType: packetType, ChannelID: channelID, Payload: payload[:n]}) + ps.WriteChannel(Packet{PacketType: packetType, ChannelID: channelID, Payload: payload[:n]}) packetType = PacketTypeSend + select { + case <-ctx.Done(): + return + default: + } } - mux.WriteChannel(Packet{PacketType: PacketTypeClose, ChannelID: channelID, Payload: []byte{}}) - mux.CloseChannel(channelID) }() - writer := bufio.NewWriter(conn) + domainSocketWriter := bufio.NewWriter(conn) for msg := range ch { - _, err := writer.Write(msg) + _, err := domainSocketWriter.Write(msg) if err != nil { log.Println("Error writing to connection:", err) break } - writer.Flush() + domainSocketWriter.Flush() + if debug { + log.Printf("DomainSocketWriter.Write ch:%d len:%d", channelID, len(msg)) + } + } + if debug { + log.Printf("Close DomainSocket ch:%d", channelID) + } +} + +type pwshIOStream struct { + *Multiplexer + exePath string + + cmd *exec.Cmd + out io.ReadCloser + in io.WriteCloser + cancel context.CancelFunc +} + +func (ps *pwshIOStream) setCancel(cancel context.CancelFunc) { + ps.cancel = cancel +} + +func NewPwshIOStream(exePath string) *pwshIOStream { + return &pwshIOStream{ + exePath: exePath, + } +} + +/* +func (ps *pwshIOStream) readLoop() error { + return ps.Multiplexer.readLoop() +} +*/ + +func (ps *pwshIOStream) sigStopWorker() { + // Capture kill signals + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + sig := <-sigChan + log.Printf("Received signal: %s. Shutting down PowerShell process...", sig) + ps.killPwsh() + os.Exit(0) +} + +func (ps *pwshIOStream) startPowerShellProces(ctx context.Context) { + log.Println("start PowerShell process...") + ps.cmd = exec.Command(ps.exePath, "-NoProfile", "-Command", "-") + + defer ps.cancel() + var err error + ps.in, err = ps.cmd.StdinPipe() + if err != nil { + log.Printf("cmd.StdinPipe() err:%s", err) + return + } + ps.out, err = ps.cmd.StdoutPipe() + if err != nil { + log.Printf("cmd.StdoutPipe() err:%s", err) + return + } + ps.cmd.Stderr = os.Stderr + + if err := ps.cmd.Start(); err != nil { + log.Printf("cmd.Start() err:%s", err) + return + } + log.Printf("Started PowerShell process with PID: %d", ps.cmd.Process.Pid) + + if debug { + pwshScript = uncommentWriteLines(pwshScript) + } + io.WriteString(ps.in, pwshScript) + ps.Multiplexer = NewMultiplexer(ps.in, ps.out) + checkStartAgent(ps.out) + if err = ps.readLoop(ctx); err != nil { + log.Printf("readLoop() err:%s", err) + return } } + +func (ps *pwshIOStream) killPwsh() { + if err := ps.cmd.Process.Signal(syscall.SIGTERM); err != nil { + log.Printf("send signal err:%s", err) + } + if ps.in != nil { + ps.in.Close() + } + if ps.out != nil { + ps.out.Close() + } + log.Printf("PowerShell process exited") +} + func checkStartAgent(r io.Reader) { buf := make([]byte, 1024) n, err := r.Read(buf) @@ -162,10 +282,6 @@ func checkStartAgent(r io.Reader) { log.Printf("read err: %s", err) return } - if DEBUG { - log.Printf("debug read: %s", string(buf[:n])) - return - } if !bytes.Equal(buf[:n], []byte("startAgent")) { log.Printf(" startAgent err:?:[%s]", string(buf[:n])) os.Exit(1) @@ -185,6 +301,9 @@ func getSystem32Path() string { } func main() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + flag.BoolVar(&debug, "debug", false, "debug mode") + flag.Parse() exePath := getSystem32Path() if len(exePath) > 0 { exePath = filepath.Join(exePath, "WindowsPowerShell/v1.0/powershell.exe") @@ -193,34 +312,6 @@ func main() { if err != nil { exePath = ("powershell.exe") } - cmd := exec.Command(exePath, "-NoProfile", "-Command", "-") - - psIn, err := cmd.StdinPipe() - if err != nil { - log.Println("Error creating stdin pipe:", err) - return - } - psOut, err := cmd.StdoutPipe() - if err != nil { - log.Println("Error creating stdout pipe:", err) - return - } - cmd.Stderr = os.Stderr - - if err := cmd.Start(); err != nil { - log.Println("Error starting command:", err) - return - } - - io.WriteString(psIn, pwshScript) - - log.SetFlags(log.LstdFlags | log.Lshortfile) - log.Printf("writeString powershellScript") - - checkStartAgent(psOut) - - mux := NewMultiplexer(psIn, psOut) - socketPath := os.Getenv("SSH_AUTH_SOCK") if len(socketPath) == 0 { log.Fatal("env SSH_AUTH_SOCK is not set") @@ -232,23 +323,59 @@ func main() { log.Fatal("remove old socket err:", err) } } - log.Printf("listen socket:%s", socketPath) + + ps := NewPwshIOStream(exePath) + defer ps.killPwsh() + go ps.sigStopWorker() listener, err := net.Listen("unix", socketPath) if err != nil { log.Println("Error creating Unix domain socket:", err) return } defer listener.Close() + for { + // Start PowerShell process + ctx, cancel := context.WithCancel(context.Background()) + ps.setCancel(cancel) + go ps.startPowerShellProces(ctx) + log.Printf("listen socket:%s", socketPath) + ps.listenLoop(ctx, listener) + cancel() + } +} +func (ps *pwshIOStream) listenLoop(ctx context.Context, listener net.Listener) error { var channelID uint32 = 1 for { conn, err := listener.Accept() if err != nil { + if errors.Is(err, net.ErrClosed) { + log.Println("Listener has been closed.") + return err + } log.Println("Error accepting connection:", err) continue } - //log.Printf("accept: %v", conn.LocalAddr()) - go handleConnection(conn, mux, channelID) + if debug { + log.Printf("domainSocket:%v accept ch:%d", conn.LocalAddr(), channelID) + } + go ps.handleConnection(ctx, conn, channelID) channelID++ + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + } +} + +func uncommentWriteLines(script string) string { + lines := strings.Split(script, "\n") + for i, line := range lines { + trimmedLine := strings.TrimSpace(line) + if strings.HasPrefix(trimmedLine, "# [Console]::Error.WriteLine(") { + lines[i] = strings.Replace(line, "# ", "", 1) + } } + return strings.Join(lines, "\n") } diff --git a/cmd/wsl2-ssh-agent-proxy/pwsh.ps1 b/cmd/wsl2-ssh-agent-proxy/pwsh.ps1 index 2c2a1aa..9e515ac 100644 --- a/cmd/wsl2-ssh-agent-proxy/pwsh.ps1 +++ b/cmd/wsl2-ssh-agent-proxy/pwsh.ps1 @@ -8,20 +8,26 @@ $WritePacketWorker = { [System.IO.StreamWriter] $OutputStreamWriter ) - #[Console]::Error.WriteLine("WritePacketWorker started.") + [Console]::Error.WriteLine("WritePacketWorker started.") while ($true) { $null = $MainPacketQueueSignal.WaitOne() - #[Console]::Error.WriteLine("WritePacketWorker: Signal received, processing packet queue.") + # [Console]::Error.WriteLine("WritePacketWorker: Signal received, processing packet queue.") $Packet = $null - if ($PacketQueue.TryDequeue([ref]$Packet)) { - #[Console]::Error.WriteLine("WritePacketWorker: Packet dequeued. Length: $($Packet.Length), Channel ID: $($Packet.ChannelID), Type: $($Packet.Type).") + while ($PacketQueue.TryDequeue([ref]$Packet)) { + # [Console]::Error.WriteLine("WritePacketWorker [ch$($Packet.ChannelID),type:$($Packet.Type)]: Packet dequeued. Length: $($Packet.Length)") $Header = [BitConverter]::GetBytes($Packet.Type) + [BitConverter]::GetBytes($Packet.ChannelID) + [BitConverter]::GetBytes($Packet.Payload.Length) - $OutputStreamWriter.BaseStream.Write($Header, 0, $Header.Length) - $OutputStreamWriter.BaseStream.Write($Packet.Payload, 0, $Packet.Payload.Length) - $OutputStreamWriter.Flush() - #[Console]::Error.WriteLine("WritePacketWorker: Packet written to output stream.") + try { + $OutputStreamWriter.BaseStream.Write($Header, 0, $Header.Length) + $OutputStreamWriter.BaseStream.Write($Packet.Payload, 0, $Packet.Payload.Length) + $OutputStreamWriter.Flush() + } + catch { + [Console]::Error.WriteLine("WritePacketWorker [ch$($Packet.ChannelID),type:$($Packet.Type)]: Write error:[$error]") + continue + } + # [Console]::Error.WriteLine("WritePacketWorker [ch$($Packet.ChannelID),type:$($Packet.Type)]: Packet written to output stream.") } } } @@ -43,31 +49,31 @@ $PacketWorkerScript = { [void]SendResponse([hashtable]$Packet) { $null = $this.WorkerInstance.MainPacketQueue.Enqueue($Packet) $null = $this.WorkerInstance.MainPacketQueueSignal.Set() - #[Console]::Error.WriteLine("PacketWorker: Response sent for Channel ID: $($Packet.ChannelID).") + # [Console]::Error.WriteLine("PacketWorker [ch:$($Packet.ChannelID) type:$($Packet.TypeNum)]: Response sent.") } [void]StopWorker([Int32]$ChannelID) { $this.SendResponse(@{ Type = 2; Payload = [byte[]]::new(0); ChannelID = $ChannelID }) $null = $this.WorkerInstance.WorkerQueue.Enqueue($this.WorkerInstance) - #[Console]::Error.WriteLine("PacketWorker: Worker stopped for Channel ID: $ChannelID.") + # [Console]::Error.WriteLine("PacketWorker [ch:$($ChannelID)]: Worker stopped.") } [void]Run() { - #[Console]::Error.WriteLine("PacketWorker started.") + # [Console]::Error.WriteLine("PacketWorker started.") while ($true) { $null = $this.WorkerInstance.PacketQueueSignal.WaitOne() $Packet = $null - if ($this.WorkerInstance.PacketQueue.TryDequeue([ref]$Packet)) { - #[Console]::Error.WriteLine("PacketWorker: Packet received. Channel ID: $($Packet.ChannelID).") + while ($this.WorkerInstance.PacketQueue.TryDequeue([ref]$Packet)) { + # [Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: Packet received.") try { if (!$this.ProcessPacket($Packet)) { - #[Console]::Error.WriteLine("PacketWorker: Processing failed for Channel ID: $($Packet.ChannelID). Worker will stop.") + # [Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: Processing failed. Worker will stop.") $this.StopWorker($Packet.ChannelID) continue } } catch { - [Console]::Error.WriteLine("PacketWorker: Exception occurred while processing Channel ID: $($Packet.ChannelID). Error: $($_.Exception.Message). Worker will stop.") + [Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: Exception occurred while processing. Error: $($_.Exception.Message). Worker will stop.") $this.StopWorker($Packet.ChannelID) continue } @@ -76,10 +82,10 @@ $PacketWorkerScript = { } [bool]ProcessPacket([hashtable]$Packet) { - #[Console]::Error.WriteLine("PacketWorker: Processing packet. Type: $($Packet.TypeNum), Channel ID: $($Packet.ChannelID).") + # [Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: Processing packet.") if (0 -eq $Packet.TypeNum) { if ($null -ne $this.NamedPipeStream) { - [Console]::Error.WriteLine("PacketWorker: Named pipe connection already closed. Channel ID: $($Packet.ChannelID).") + # [Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: Named pipe connection already closed.") $this.NamedPipeStream.Close() $this.NamedPipeStream = $null return $false @@ -87,28 +93,28 @@ $PacketWorkerScript = { $this.NamedPipeStream = [System.IO.Pipes.NamedPipeClientStream]::new(".", "openssh-ssh-agent", [System.IO.Pipes.PipeDirection]::InOut) $this.NamedPipeStream.Connect() $this.WorkerInstance.ChannelID = $Packet.ChannelID - #[Console]::Error.WriteLine("PacketWorker: Named pipe connection established. Channel ID: $($Packet.ChannelID).") + # [Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: Named pipe connection established.") } elseif (2 -eq $Packet.TypeNum) { if ($null -eq $this.NamedPipeStream) { - #[Console]::Error.WriteLine("PacketWorker: No active named pipe connection to close. Channel ID: $($Packet.ChannelID).") + # [Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: No active named pipe connection to close.") return $false } $this.NamedPipeStream.Close() $this.NamedPipeStream = $null - #[Console]::Error.WriteLine("PacketWorker: Named pipe connection closed. Channel ID: $($Packet.ChannelID).") + # [Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: Named pipe connection closed.") return $false } $this.NamedPipeStream.Write($Packet.Payload, 0, $Packet.Payload.Length) $this.NamedPipeStream.Flush() - #[Console]::Error.WriteLine("PacketWorker: Data written to named pipe. Channel ID: $($Packet.ChannelID).") + # [Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: Data written to named pipe.") $Payload = [byte[]]::new(10240) $n = $this.NamedPipeStream.Read($Payload, 0, $Payload.Length) if ($n -gt 0) { $Payload = $Payload[0..($n - 1)] $this.SendResponse(@{ Type = 1; Payload = $Payload; ChannelID = $Packet.ChannelID }) - #[Console]::Error.WriteLine("PacketWorker: Response read from named pipe and sent. Channel ID: $($Packet.ChannelID).") + # [Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: Response read from named pipe and sent.") } return $true } @@ -139,38 +145,46 @@ class PacketReader { } [Hashtable] ReadPacket ([System.IO.Stream] $InputStreamReader) { - #[Console]::Error.WriteLine("PacketReader: Reading packet from input stream.") + # [Console]::Error.WriteLine("PacketReader: Reading packet from input stream.") $Header = [byte[]]::new(12) $n = $InputStreamReader.Read($Header, 0, $Header.Length) if ($n -eq 0) { return @{Error = "PacketReader: Failed to read header (length zero)." } } - #[Console]::Error.WriteLine("PacketReader: Header read successfully. Length: $n.") $Res = @{ TypeNum = [BitConverter]::ToInt32($Header, 0) ChannelID = [BitConverter]::ToInt32($Header, 4) Length = [BitConverter]::ToInt32($Header, 8) Error = $null } + # [Console]::Error.WriteLine("PacketReader [ch:$($Res.ChannelID) type:$($Res.TypeNum)]: Header read successfully. Length: $n.") $Res.Payload = [byte[]]::new($Res.Length) $n = $InputStreamReader.Read($Res.Payload, 0, $Res.Length) if ($n -ne $Res.Length) { - $Res = @{Error = "PacketReader: Incomplete payload read. Expected: $($Res.Length), Actual: $n. Channel ID: $($Res.ChannelID), Type: $($Res.TypeNum)." } + $Res = @{Error = "PacketReader [ch:$($Res.ChannelID) type:$($Res.TypeNum)]: Incomplete payload read. Expected: $($Res.Length), Actual: $n." } } - #[Console]::Error.WriteLine("PacketReader: Packet read completed. Channel ID: $($Res.ChannelID), Type: $($Res.TypeNum), Length: $($Res.Length).") + # [Console]::Error.WriteLine("PacketReader [ch:$($Res.ChannelID) type:$($Res.TypeNum)]: Packet read completed. Length: $($Res.Length).") return $Res } [void] Run() { [Console]::Error.WriteLine("PacketReader started.") while ($true) { - #[Console]::Error.WriteLine("PacketReader: Waiting for packets.") - $Packet = $this.ReadPacket($this.InputStreamReader) + # [Console]::Error.WriteLine("PacketReader: Waiting for packets.") + $Packet = $null + try { + $Packet = $this.ReadPacket($this.InputStreamReader) + } + catch { + [Console]::Error.WriteLine("InputStreamRead error: [$error]") + return + } if ($null -ne $Packet.Error) { [Console]::Error.WriteLine($Packet.Error) + Start-Sleep -Seconds 1.0 continue } - #[Console]::Error.WriteLine("PacketReader: Packet received. Type: $($Packet.TypeNum), Channel ID: $($Packet.ChannelID), Length: $($Packet.Length).") + # [Console]::Error.WriteLine("PacketReader [ch:$($Packet.ChannelID) type:$($Packet.TypeNum)]: Packet received. Length: $($Packet.Length).") $WorkerInstance = $null if ($this.Channels.ContainsKey($Packet.ChannelID)) { $WorkerInstance = $this.Channels[$Packet.ChannelID] @@ -179,7 +193,7 @@ class PacketReader { if ($this.WorkerQueue.TryDequeue([ref]$WorkerInstance)) { $this.Channels.Remove($WorkerInstance.ChannelID) $WorkerInstance.ChannelID = $Packet.ChannelID - #[Console]::Error.WriteLine("PacketReader: Reusing existing worker for Channel ID: $($Packet.ChannelID).") + # [Console]::Error.WriteLine("PacketReader [ch:$($Packet.ChannelID) type:$($Packet.TypeNum)]: Reusing existing worker.") } else { $WorkerInstance = @{ @@ -192,13 +206,13 @@ class PacketReader { } $null = [PowerShell]::Create().AddScript($this.PacketWorkerScript). AddArgument($WorkerInstance).BeginInvoke() - #[Console]::Error.WriteLine("PacketReader: New worker initialized for Channel ID: $($Packet.ChannelID).") + # [Console]::Error.WriteLine("PacketReader [ch:$($Packet.ChannelID) type:$($Packet.TypeNum)]: New worker initialized.") } $this.Channels[$WorkerInstance.ChannelID] = $WorkerInstance } $WorkerInstance.PacketQueue.Enqueue($Packet) $WorkerInstance.PacketQueueSignal.Set() - #[Console]::Error.WriteLine("PacketReader: Packet dispatched to worker. Channel ID: $($Packet.ChannelID).") + # [Console]::Error.WriteLine("PacketReader [ch:$($Packet.ChannelID) type:$($Packet.TypeNum)]: Packet dispatched to worker.") } } } diff --git a/hack/ubuntu.setup.sh b/hack/ubuntu.setup.sh index 5f039a8..3650b7a 100644 --- a/hack/ubuntu.setup.sh +++ b/hack/ubuntu.setup.sh @@ -1,41 +1,40 @@ #!/bin/sh -set -ex -NAME=wsl2-ssh-agent-proxy -SSH_AUTH_SOCK="${HOME}/.ssh/${NAME}/${NAME}.sock" -PROXYCMD_DIR="${HOME}/${NAME}" -CMD="${PROXYCMD_DIR}/${NAME}" - -RELEASE_NAME=$1 -REPO_URL=https://github.com/masahide/OmniSSHAgent -if [ -z "${RELEASE_NAME}" ]; then - VER_PATH="releases/latest" -else - VER_PATH="download/${RELEASE_NAME}" -fi - -__get_proxy() { - echo "Downloading ${NAME}.gz" - mkdir -p "${PROXYCMD_DIR}" - curl "${REPO_URL}/releases/${VER_PATH}/${NAME}.gz" -sL | gunzip >"${CMD}" - chmod +x "${CMD}" +OMNISOCATCMD="$HOME/omni-socat/omni-socat.exe" +export SSH_AUTH_SOCK="$HOME/.ssh/agent.sock" + +__get_omnisocat() { + echo "Downloading omni-socat.exe" + sudo apt -y install unzip + curl https://github.com/masahide/OmniSSHAgent/releases/latest/download/omni-socat.zip \ + -sLo omni-socat.zip + unzip -o omni-socat.zip -d "$(dirname "$OMNISOCATCMD")" + chmod +x "$OMNISOCATCMD" + rm -f omni-socat.zip } -setup_proxy() { - [ -f "${CMD}" ] || __get_proxy +__get_socat() { + echo "Installing socat" + sudo apt -y install socat +} + +setup_omnisocat() { + [ -f "$OMNISOCATCMD" ] || __get_omnisocat + command -v socat > /dev/null 2>&1 || __get_socat # Checks wether $SSH_AUTH_SOCK is a socket or not - (ps aux | grep "${CMD}" | grep -qv "grep") && [ -S "${SSH_AUTH_SOCK}" ] && return + (ss -a | grep -q "$SSH_AUTH_SOCK") && [ -S "$SSH_AUTH_SOCK" ] && return # Create directory for the socket, if it is missing - SSH_AUTH_SOCK_DIR="$(dirname "${SSH_AUTH_SOCK}")" - mkdir -p "${SSH_AUTH_SOCK_DIR}" - # Applying best-practice permissions if we are creating ${HOME}/.ssh - if [ "${SSH_AUTH_SOCK_DIR}" = "${HOME}/.ssh" ]; then - chmod 700 "${SSH_AUTH_SOCK_DIR}" + SSH_AUTH_SOCK_DIR="$(dirname "$SSH_AUTH_SOCK")" + mkdir -p "$SSH_AUTH_SOCK_DIR" + # Applying best-practice permissions if we are creating $HOME/.ssh + if [ "$SSH_AUTH_SOCK_DIR" = "$HOME/.ssh" ]; then + chmod 700 "$SSH_AUTH_SOCK_DIR" fi - ${CMD} >>"${PROXYCMD_DIR}/${NAME}.log" 2>&1 & + rm -f "$SSH_AUTH_SOCK" + (setsid socat UNIX-LISTEN:"$SSH_AUTH_SOCK",fork EXEC:"$OMNISOCATCMD",nofork &) > /dev/null 2>&1 } -setup_proxy +setup_omnisocat diff --git a/hack/ubuntu.wsl2-ssh-agent-proxy.sh b/hack/ubuntu.wsl2-ssh-agent-proxy.sh new file mode 100644 index 0000000..5fb4286 --- /dev/null +++ b/hack/ubuntu.wsl2-ssh-agent-proxy.sh @@ -0,0 +1,40 @@ +#!/bin/sh + +NAME=wsl2-ssh-agent-proxy +export SSH_AUTH_SOCK="${HOME}/.ssh/${NAME}/${NAME}.sock" +PROXYCMD_DIR="${HOME}/${NAME}" +CMD="${PROXYCMD_DIR}/${NAME}" + +RELEASE_NAME=$1 +REPO_URL=https://github.com/masahide/OmniSSHAgent +if [ -z "${RELEASE_NAME}" ]; then + VER_PATH="releases/latest" +else + VER_PATH="download/${RELEASE_NAME}" +fi + +__get_proxy() { + echo "Downloading ${NAME}.gz" + mkdir -p "${PROXYCMD_DIR}" + curl "${REPO_URL}/releases/${VER_PATH}/${NAME}.gz" -sL | gunzip > "${CMD}" + chmod +x "${CMD}" +} + +setup_proxy() { + [ -f "${CMD}" ] || __get_proxy + + # Checks wether $SSH_AUTH_SOCK is a socket or not + (ps aux | grep "${CMD}" | grep -qv "grep") && [ -S "${SSH_AUTH_SOCK}" ] && return + + # Create directory for the socket, if it is missing + SSH_AUTH_SOCK_DIR="$(dirname "${SSH_AUTH_SOCK}")" + mkdir -p "${SSH_AUTH_SOCK_DIR}" + # Applying best-practice permissions if we are creating ${HOME}/.ssh + if [ "${SSH_AUTH_SOCK_DIR}" = "${HOME}/.ssh" ]; then + chmod 700 "${SSH_AUTH_SOCK_DIR}" + fi + + (setsid "${CMD}" >> "${PROXYCMD_DIR}/${NAME}.log" 2>&1 &) +} + +setup_proxy