diff --git a/.github/workflows/elixir.yml b/.github/workflows/elixir.yml index 10b509e..3de2800 100644 --- a/.github/workflows/elixir.yml +++ b/.github/workflows/elixir.yml @@ -89,7 +89,7 @@ jobs: - run: mix test --trace macos: - runs-on: macos-11 + runs-on: macos-14 steps: - uses: actions/checkout@v4 - uses: DeterminateSystems/nix-installer-action@main diff --git a/flake.lock b/flake.lock index 0360468..8b6cba7 100644 --- a/flake.lock +++ b/flake.lock @@ -1,5 +1,23 @@ { "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1710146030, + "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, "nixpkgs": { "locked": { "lastModified": 1718835956, @@ -18,8 +36,24 @@ }, "root": { "inputs": { + "flake-utils": "flake-utils", "nixpkgs": "nixpkgs" } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } } }, "root": "root", diff --git a/flake.nix b/flake.nix index fc0f7d1..9b64c19 100644 --- a/flake.nix +++ b/flake.nix @@ -3,22 +3,30 @@ inputs = { nixpkgs.url = "github:NixOS/nixpkgs/nixos-24.05"; + flake-utils.url = "github:numtide/flake-utils"; }; - outputs = { self, nixpkgs, ... }: - let - supportedSystems = [ "x86_64-linux" "aarch64-linux" "x86_64-darwin" "aarch64-darwin" ]; - forAllSystems = f: nixpkgs.lib.genAttrs supportedSystems (system: f { + outputs = { self, nixpkgs, flake-utils }: + flake-utils.lib.eachDefaultSystem (system: + let + inherit (nixpkgs.lib) optional; pkgs = import nixpkgs { inherit system; }; + + sdk = with pkgs; + lib.optionals stdenv.isDarwin + (with darwin.apple_sdk.frameworks; [ + # needed for compilation + pkgs.libiconv + AppKit + Foundation + CoreFoundation + CoreServices + ]); + + in { + devShell = pkgs.mkShell { + buildInputs = + [ pkgs.elixir sdk ]; + }; }); - in - { - devShells = forAllSystems ({ pkgs }: { - default = pkgs.mkShell { - packages = with pkgs; [ - elixir - ]; - }; - }); - }; } diff --git a/go_src/executor.go b/go_src/executor.go index c8f2b88..fe44ea8 100644 --- a/go_src/executor.go +++ b/go_src/executor.go @@ -1,6 +1,9 @@ package main import ( + "encoding/binary" + "errors" + "io" "os" "os/exec" "os/signal" @@ -8,37 +11,53 @@ import ( "time" ) -func execute(workdir string, args []string) error { - done := make(chan struct{}) +const SendInput = 1 +const SendOutput = 2 +const Output = 3 +const Input = 4 +const CloseInput = 5 +const OutputEOF = 6 +const CommandEnv = 7 +const Pid = 8 +const StartError = 9 - sigs := make(chan os.Signal, 1) - input := make(chan Packet, 1) - outputDemand := make(chan Packet) - inputDemand := make(chan Packet) +// This size is *NOT* related to pipe buffer size +// 4 bytes for payload length + 1 byte for tag +const BufferSize = (1 << 16) - 5 - proc := exec.Command(args[0], args[1:]...) - proc.Dir = workdir - proc.Env = append(os.Environ(), readEnvFromStdin()...) +// fixed buffer for IO +var buf = make([]byte, BufferSize+5) - logger.Printf("Command path: %v\n", proc.Path) +type Packet struct { + tag uint8 + data []byte +} + +type InputDispatcher func(Packet) - output := startCommandPipeline(proc, input, inputDemand, outputDemand) +type OutPacket func() (Packet, bool) + +func execute(workdir string, args []string) error { + writerDone := make(chan struct{}) + // must be buffered so that function can close without blocking + stdinClose := make(chan struct{}, 1) + sigs := make(chan os.Signal) // Capture common signals. // Setting notify for SIGPIPE is important to capture and without that // we won't be able to handle abrupt beam vm terminations - // Also, SIGPIPE behaviour in golang is bit complex, + // Also, SIGPIPE behavior in golang is complex, + // // see: https://pkg.go.dev/os/signal@go1.22.4#hdr-SIGPIPE signal.Notify(sigs, os.Interrupt, syscall.SIGINT, syscall.SIGTERM, syscall.SIGPIPE) - // go handleSignals(input, outputDemand, done) - go dispatchStdin(input, outputDemand, done) - go collectStdout(proc.Process.Pid, output, inputDemand, sigs, done) + proc := exec.Command(args[0], args[1:]...) + proc.Dir = workdir + proc.Env = append(os.Environ(), readEnvFromStdin()...) - // wait for pipline to exit - <-done + cmdExit := runPipeline(proc, writerDone, stdinClose) + err := waitPipelineTermination(proc, cmdExit, sigs, stdinClose, writerDone) - err := safeExit(proc) if e, ok := err.(*exec.Error); ok { // This shouldn't really happen in practice because we check for // program existence in Elixir, before launching odu @@ -50,60 +69,246 @@ func execute(workdir string, args []string) error { return err } -func dispatchStdin(input chan<- Packet, outputDemand chan<- Packet, done chan struct{}) { +func runPipeline(proc *exec.Cmd, writerDone chan struct{}, stdinClose chan struct{}) chan error { + cmdInput := make(chan Packet, 1) + cmdOutputDemand := make(chan Packet) + cmdInputDemand := make(chan Packet) + cmdExit := make(chan error) + + cmdOutput := startCommandPipeline(proc, cmdInput, cmdInputDemand, cmdOutputDemand) + + // go handleSignals(input, outputDemand, done) + go stdinReader(cmdInput, cmdOutputDemand, writerDone, stdinClose) + go stdoutWriter(proc.Process.Pid, cmdOutput, cmdInputDemand, writerDone) + + go func() { + cmdExit <- proc.Wait() + }() + + return cmdExit +} + +func stdinReader(cmdInput chan<- Packet, cmdOutputDemand chan<- Packet, writerDone <-chan struct{}, stdinClose chan<- struct{}) { // closeChan := closeInputHandler(input) - var dispatch = func(packet Packet) { + defer func() { + close(cmdInput) + close(cmdOutputDemand) + }() + + for { + select { + case <-writerDone: + return + default: + } + + packet, readErr := readPacketFromStdin() + if readErr == io.EOF { + close(stdinClose) + return + } else if readErr != nil { + fatal(readErr) + } + switch packet.tag { case SendOutput: - outputDemand <- packet + cmdOutputDemand <- packet default: - input <- packet + cmdInput <- packet } } - - defer func() { - close(input) - close(outputDemand) - }() - - stdinReader(dispatch, done) } -func collectStdout(pid int, output <-chan Packet, inputDemand <-chan Packet, sigs <-chan os.Signal, done chan struct{}) { +func stdoutWriter(pid int, cmdStdout <-chan Packet, cmdInputDemand <-chan Packet, writerDone chan<- struct{}) { + var ok bool + var packet Packet + var buf [4]byte + defer func() { - close(done) + close(writerDone) }() - merged := func() (Packet, bool) { + // we first write pid before writing anything + writeUint32Be(buf[:], uint32(pid)) + writePacketToStdout(Pid, buf[:]) + + for { select { - case sig := <-sigs: - logger.Printf("Received OS Signal: ", sig) - return Packet{}, false + case packet, ok = <-cmdInputDemand: + case packet, ok = <-cmdStdout: + } - case v, ok := <-inputDemand: - return v, ok + if !ok { + return + } - case v, ok := <-output: - return v, ok + if len(packet.data) > BufferSize { + fatal("Invalid payloadLen") } + + writePacketToStdout(packet.tag, packet.data) + } +} + +func waitPipelineTermination(proc *exec.Cmd, cmdExit <-chan error, sigs <-chan os.Signal, stdinClose <-chan struct{}, writerDone <-chan struct{}) error { + var err error + + select { + case e := <-cmdExit: + // a program might exit without any IO + err = e + + case sig := <-sigs: + logger.Printf("Received OS Signal: %v\n", sig) + err = safeExit(proc, cmdExit) + + case <-stdinClose: + // When stdin closes it imply that VM is down or process GenServer + // is killed so we must prepare for termination. + // + // Not that stdin close for middleware is different from stdin close + // for the external program (CloseInput) + err = safeExit(proc, cmdExit) + + case <-writerDone: + err = safeExit(proc, cmdExit) } - stdoutWriter(pid, merged, done) + return err } -func safeExit(proc *exec.Cmd) error { - done := make(chan error, 1) - go func() { - done <- proc.Wait() - }() +func safeExit(proc *exec.Cmd, procErr <-chan error) error { + logger.Printf("safe exit\n") + select { + case err := <-procErr: + return err case <-time.After(3 * time.Second): if err := proc.Process.Kill(); err != nil { logger.Fatal("failed to kill process: ", err) } logger.Println("process killed as timeout reached") return nil - case err := <-done: - return err } } + +func writeStartError(reason string) { + writePacketToStdout(StartError, []byte(reason)) +} + +func readEnvFromStdin() []string { + // first packet must be env + packet, err := readPacketFromStdin() + if err != nil { + fatal(err) + } + + if packet.tag != CommandEnv { + fatal("First packet must be command Env") + } + + var env []string + var length int + data := packet.data + + for i := 0; i < len(data); { + length = int(binary.BigEndian.Uint16(data[i : i+2])) + i += 2 + + entry := string(data[i : i+length]) + env = append(env, entry) + + i += length + } + + logger.Printf("Command Env: %v\n", env) + + return env +} + +func readPacketFromStdin() (Packet, error) { + var readErr error + var length uint32 + var tag uint8 + + buf := make([]byte, BufferSize) + + length, readErr = readUint32(os.Stdin) + if readErr == io.EOF { + return Packet{}, io.EOF + } else if readErr != nil { + return Packet{}, readErr + } + + dataLen := length - 1 + if dataLen < 0 || dataLen > BufferSize { // payload must be atleast tag size + return Packet{}, errors.New("input payload size is invalid") + } + + tag, readErr = readUint8(os.Stdin) + if readErr != nil { + return Packet{}, readErr + } + + _, readErr = io.ReadFull(os.Stdin, buf[:dataLen]) + if readErr != nil { + return Packet{}, readErr + } + + return Packet{tag, buf[:dataLen]}, nil +} + +func writePacketToStdout(tag uint8, data []byte) { + payloadLen := len(data) + 1 + + writeUint32Be(buf[:4], uint32(payloadLen)) + writeUint8Be(buf[4:5], tag) + copy(buf[5:], data) + + _, writeErr := os.Stdout.Write(buf[:payloadLen+4]) + if writeErr != nil { + switch writeErr.(type) { + // ignore broken pipe or closed pipe errors here. + // currently readCommandStdout closes output chan, making the + // flow break. + case *os.PathError: + logger.Printf("os.PathError: %v\n", writeErr) + return + default: + fatal(writeErr) + } + } + // logger.Printf("stdout written bytes: %v\n", bytesWritten) +} + +func readUint32(stdin io.Reader) (uint32, error) { + var buf [4]byte + + bytesRead, readErr := io.ReadFull(stdin, buf[:]) + if readErr != nil { + return 0, io.EOF + } else if bytesRead == 0 { + return 0, readErr + } + return binary.BigEndian.Uint32(buf[:]), nil +} + +func readUint8(stdin io.Reader) (uint8, error) { + var buf [1]byte + + bytesRead, readErr := io.ReadFull(stdin, buf[:]) + if readErr != nil { + return 0, io.EOF + } else if bytesRead == 0 { + return 0, readErr + } + return uint8(buf[0]), nil +} + +func writeUint32Be(data []byte, num uint32) { + binary.BigEndian.PutUint32(data, num) +} + +func writeUint8Be(data []byte, num uint8) { + data[0] = byte(num) +} diff --git a/go_src/exit_status.go b/go_src/exit_status.go index a774233..3206fa2 100644 --- a/go_src/exit_status.go +++ b/go_src/exit_status.go @@ -1,3 +1,4 @@ +//go:build !plan9 // +build !plan9 package main diff --git a/go_src/process.go b/go_src/process.go index 5720bc5..0939b88 100644 --- a/go_src/process.go +++ b/go_src/process.go @@ -6,7 +6,9 @@ import ( "os/exec" ) -func startCommandPipeline(proc *exec.Cmd, input <-chan Packet, inputDemand chan<- Packet, outputDemand <-chan Packet) <-chan Packet { +func startCommandPipeline(proc *exec.Cmd, input <-chan Packet, inputDemand chan<- Packet, outputDemand <-chan Packet) chan Packet { + logger.Printf("Command: %v\n", proc.String()) + cmdInput, err := proc.StdinPipe() fatalIf(err) diff --git a/go_src/protocol.go b/go_src/protocol.go deleted file mode 100644 index 7dc60e0..0000000 --- a/go_src/protocol.go +++ /dev/null @@ -1,196 +0,0 @@ -package main - -import ( - "encoding/binary" - "errors" - "io" - "os" -) - -const SendInput = 1 -const SendOutput = 2 -const Output = 3 -const Input = 4 -const CloseInput = 5 -const OutputEOF = 6 -const CommandEnv = 7 -const Pid = 8 -const StartError = 9 - -// This size is *NOT* related to pipe buffer size -// 4 bytes for payload length + 1 byte for tag -const BufferSize = (1 << 16) - 5 - -type Packet struct { - tag uint8 - data []byte -} - -type InputDispatcher func(Packet) - -func stdinReader(dispatch InputDispatcher, done <-chan struct{}) { - for { - select { - case <-done: - return - default: - } - - packet, readErr := readPacket() - if readErr == io.EOF { - return - } - fatalIf(readErr) - - dispatch(packet) - } -} - -type OutPacket func() (Packet, bool) - -func stdoutWriter(pid int, fn OutPacket, done <-chan struct{}) { - var ok bool - var packet Packet - - var buf [4]byte - - // we first write pid before writing anything - writeUint32Be(buf[:], uint32(pid)) - writePacket(Pid, buf[:]) - - for { - packet, ok = fn() - if !ok { - return - } - - if len(packet.data) > BufferSize { - fatal("Invalid payloadLen") - } - - writePacket(packet.tag, packet.data) - } -} - -func writeStartError(reason string) { - writePacket(StartError, []byte(reason)) -} - -func readEnvFromStdin() []string { - // first packet must be env - packet, err := readPacket() - if err != nil { - fatal(err) - } - - if packet.tag != CommandEnv { - fatal("First packet must be command Env") - } - - var env []string - var length int - data := packet.data - - for i := 0; i < len(data); { - length = int(binary.BigEndian.Uint16(data[i : i+2])) - i += 2 - - entry := string(data[i : i+length]) - env = append(env, entry) - - i += length - } - - logger.Printf("Command Env: %v\n", env) - - return env -} - -func readPacket() (Packet, error) { - var readErr error - var length uint32 - var tag uint8 - - buf := make([]byte, BufferSize) - - length, readErr = readUint32(os.Stdin) - if readErr == io.EOF { - return Packet{}, io.EOF - } else if readErr != nil { - return Packet{}, readErr - } - - dataLen := length - 1 - if dataLen < 0 || dataLen > BufferSize { // payload must be atleast tag size - return Packet{}, errors.New("input payload size is invalid") - } - - tag, readErr = readUint8(os.Stdin) - if readErr != nil { - return Packet{}, readErr - } - - _, readErr = io.ReadFull(os.Stdin, buf[:dataLen]) - if readErr != nil { - return Packet{}, readErr - } - - return Packet{tag, buf[:dataLen]}, nil -} - -var buf = make([]byte, BufferSize+5) - -func writePacket(tag uint8, data []byte) { - payloadLen := len(data) + 1 - - writeUint32Be(buf[:4], uint32(payloadLen)) - writeUint8Be(buf[4:5], tag) - copy(buf[5:], data) - - _, writeErr := os.Stdout.Write(buf[:payloadLen+4]) - if writeErr != nil { - switch writeErr.(type) { - // ignore broken pipe or closed pipe errors here. - // currently readCommandStdout closes output chan, making the - // flow break. - case *os.PathError: - logger.Printf("os.PathError: ", writeErr) - return - default: - fatal(writeErr) - } - } - // logger.Printf("stdout written bytes: %v\n", bytesWritten) -} - -func readUint32(stdin io.Reader) (uint32, error) { - var buf [4]byte - - bytesRead, readErr := io.ReadFull(stdin, buf[:]) - if readErr != nil { - return 0, io.EOF - } else if bytesRead == 0 { - return 0, readErr - } - return binary.BigEndian.Uint32(buf[:]), nil -} - -func readUint8(stdin io.Reader) (uint8, error) { - var buf [1]byte - - bytesRead, readErr := io.ReadFull(stdin, buf[:]) - if readErr != nil { - return 0, io.EOF - } else if bytesRead == 0 { - return 0, readErr - } - return uint8(buf[0]), nil -} - -func writeUint32Be(data []byte, num uint32) { - binary.BigEndian.PutUint32(data, num) -} - -func writeUint8Be(data []byte, num uint8) { - data[0] = byte(num) -} diff --git a/mix.exs b/mix.exs index c8f7376..21d367e 100644 --- a/mix.exs +++ b/mix.exs @@ -12,6 +12,7 @@ defmodule ExCmd.MixProject do start_permanent: Mix.env() == :prod, deps: deps(), compilers: Mix.compilers() ++ [:odu], + aliases: aliases(), # Ensure dialyzer sees mix modules dialyzer: [plt_add_apps: [:mix]], @@ -63,4 +64,13 @@ defmodule ExCmd.MixProject do {:ex_doc, ">= 0.0.0", only: :dev} ] end + + defp aliases do + [ + format: [ + "format", + "cmd --cd go_src/ go fmt" + ] + ] + end end diff --git a/test/ex_cmd/process_test.exs b/test/ex_cmd/process_test.exs index cad4764..69c3625 100644 --- a/test/ex_cmd/process_test.exs +++ b/test/ex_cmd/process_test.exs @@ -34,6 +34,14 @@ defmodule ExCmd.ProcessTest do assert {:done, 0} == Process.status(s) end + test "await_exit without read" do + {:ok, s} = Process.start_link(~w(cat)) + assert :ok == Process.write(s, "hello") + assert :ok == Process.close_stdin(s) + :timer.sleep(50) + assert {:ok, 0} = Process.await_exit(s) + end + test "stdin close" do logger = start_events_collector() diff --git a/test/ex_cmd_exit_test.exs b/test/ex_cmd_exit_test.exs index 53fd62e..ed367cc 100644 --- a/test/ex_cmd_exit_test.exs +++ b/test/ex_cmd_exit_test.exs @@ -4,7 +4,7 @@ defmodule ExCmdExitTest do # currently running `elixir` command is not working in Windows @tag os: :unix test "if it kills external command on abnormal vm exit" do - ex_cmd_expr = ~S{ExCmd.stream!(["cat"]) |> Stream.run()} + ex_cmd_expr = ~S{ExCmd.stream!(["sleep", "10"]) |> Stream.run()} port = Port.open( @@ -29,7 +29,7 @@ defmodule ExCmdExitTest do assert {:ok, _msg} = os_process_kill(os_pid) # wait for the cleanup - :timer.sleep(5000) + :timer.sleep(3000) refute os_process_alive?(os_pid) refute os_process_alive?(cmd_pid)