diff --git a/pkg/process/process.go b/pkg/process/process.go index c4498563..547943c2 100644 --- a/pkg/process/process.go +++ b/pkg/process/process.go @@ -23,9 +23,13 @@ type Process interface { // Starts the process but does not wait for it to exit. Start(ctx context.Context) error + // Returns true if the process is started. + Started() bool // Aborts the process and waits for it to exit. Abort(ctx context.Context) error + // Returns true if the process is aborted. + Aborted() bool // Waits for the process to exit and returns the error, if any. // If the command completes successfully, the error will be nil. @@ -40,6 +44,9 @@ type Process interface { // // If the process exits with a non-zero exit code, stdout/stderr pipes may not work. // If retry configuration is specified, specify the output file to read all the output. + // + // The returned reader is set to nil upon the abort call on the process, + // to prevent redundant closing of the reader. StdoutReader() io.Reader // Returns the stderr reader. @@ -48,6 +55,9 @@ type Process interface { // // If the process exits with a non-zero exit code, stdout/stderr pipes may not work. // If retry configuration is specified, specify the output file to read all the output. + // + // The returned reader is set to nil upon the abort call on the process, + // to prevent redundant closing of the reader. StderrReader() io.Reader } @@ -72,6 +82,12 @@ type process struct { cmdMu sync.RWMutex cmd *exec.Cmd + startedMu sync.RWMutex + started bool + + abortedMu sync.RWMutex + aborted bool + // error streaming channel, closed on command exit errc chan error @@ -80,9 +96,9 @@ type process struct { envs []string runBashFile *os.File - outputFile *os.File - stdoutReader io.ReadCloser - stderrReader io.ReadCloser + outputFile *os.File + stdoutReadCloser io.ReadCloser + stderrReadCloser io.ReadCloser restartConfig *RestartConfig } @@ -136,9 +152,14 @@ func New(opts ...OpOption) (Process, error) { errcBuffer = op.restartConfig.Limit } return &process{ - labels: op.labels, - cmd: nil, - errc: make(chan error, errcBuffer), + labels: op.labels, + cmd: nil, + + started: false, + aborted: false, + + errc: make(chan error, errcBuffer), + commandArgs: cmdArgs, envs: op.envs, runBashFile: bashFile, @@ -157,6 +178,20 @@ func (p *process) Labels() map[string]string { } func (p *process) Start(ctx context.Context) error { + p.startedMu.RLock() + started := p.started + p.startedMu.RUnlock() + if started { // already started + return nil + } + + p.abortedMu.RLock() + aborted := p.aborted + p.abortedMu.RUnlock() + if aborted { // already aborted + return nil + } + p.cmdMu.Lock() defer p.cmdMu.Unlock() @@ -179,6 +214,13 @@ func (p *process) Start(ctx context.Context) error { return nil } +func (p *process) Started() bool { + p.startedMu.RLock() + defer p.startedMu.RUnlock() + + return p.started +} + func (p *process) startCommand() error { log.Logger.Debugw("starting command", "command", p.commandArgs) p.cmd = exec.CommandContext(p.ctx, p.commandArgs[0], p.commandArgs[1:]...) @@ -189,16 +231,16 @@ func (p *process) startCommand() error { p.cmd.Stdout = p.outputFile p.cmd.Stderr = p.outputFile - p.stdoutReader = p.outputFile - p.stderrReader = p.outputFile + p.stdoutReadCloser = p.outputFile + p.stderrReadCloser = p.outputFile default: var err error - p.stdoutReader, err = p.cmd.StdoutPipe() + p.stdoutReadCloser, err = p.cmd.StdoutPipe() if err != nil { return fmt.Errorf("failed to get stdout pipe: %w", err) } - p.stderrReader, err = p.cmd.StderrPipe() + p.stderrReadCloser, err = p.cmd.StderrPipe() if err != nil { return fmt.Errorf("failed to get stderr pipe: %w", err) } @@ -209,6 +251,10 @@ func (p *process) startCommand() error { } atomic.StoreInt32(&p.pid, int32(p.cmd.Process.Pid)) + p.startedMu.Lock() + p.started = true + p.startedMu.Unlock() + return nil } @@ -299,6 +345,20 @@ func (p *process) watchCmd() { } func (p *process) Abort(ctx context.Context) error { + p.startedMu.RLock() + started := p.started + p.startedMu.RUnlock() + if !started { // has not started yet + return nil + } + + p.abortedMu.RLock() + aborted := p.aborted + p.abortedMu.RUnlock() + if aborted { // already aborted + return nil + } + p.cmdMu.Lock() defer p.cmdMu.Unlock() @@ -331,16 +391,24 @@ func (p *process) Abort(ctx context.Context) error { if p.runBashFile != nil { _ = p.runBashFile.Sync() _ = p.runBashFile.Close() - return os.RemoveAll(p.runBashFile.Name()) + if err := os.RemoveAll(p.runBashFile.Name()); err != nil { + log.Logger.Warnw("failed to remove bash file", "error", err) + // Don't return here, continue with cleanup + } } - if p.stdoutReader != nil { - _ = p.stdoutReader.Close() - p.stdoutReader = nil + if p.stdoutReadCloser != nil { + _ = p.stdoutReadCloser.Close() + + // set to nil to prevent redundant closing of the reader + p.stdoutReadCloser = nil } - if p.stderrReader != nil { - _ = p.stderrReader.Close() - p.stderrReader = nil + + if p.stderrReadCloser != nil { + _ = p.stderrReadCloser.Close() + + // set to nil to prevent redundant closing of the reader + p.stderrReadCloser = nil } if p.cmd.Cancel != nil { // if created with CommandContext @@ -351,9 +419,20 @@ func (p *process) Abort(ctx context.Context) error { // as Wait is still waiting for the process to exit // p.cmd = nil + p.abortedMu.Lock() + p.aborted = true + p.abortedMu.Unlock() + return nil } +func (p *process) Aborted() bool { + p.abortedMu.RLock() + defer p.abortedMu.RUnlock() + + return p.aborted +} + func (p *process) PID() int32 { return atomic.LoadInt32(&p.pid) } @@ -365,7 +444,7 @@ func (p *process) StdoutReader() io.Reader { if p.outputFile != nil { return p.outputFile } - return p.stdoutReader + return p.stdoutReadCloser } func (p *process) StderrReader() io.Reader { @@ -375,7 +454,7 @@ func (p *process) StderrReader() io.Reader { if p.outputFile != nil { return p.outputFile } - return p.stderrReader + return p.stderrReadCloser } const bashScriptHeader = `#!/bin/bash diff --git a/pkg/process/process_test.go b/pkg/process/process_test.go index 3cfe1c75..bfe10155 100644 --- a/pkg/process/process_test.go +++ b/pkg/process/process_test.go @@ -27,6 +27,11 @@ func TestProcess(t *testing.T) { } t.Logf("pid: %d", p.PID()) + // redunant start is ok + if err := p.Start(ctx); err != nil { + t.Fatal(err) + } + if err := Read( ctx, p, @@ -44,6 +49,9 @@ func TestProcess(t *testing.T) { if err := p.Abort(ctx); err != nil { t.Fatal(err) } + if !p.Aborted() { + t.Fatal("process is not aborted") + } } func TestProcessRunBashScriptContents(t *testing.T) { @@ -96,6 +104,7 @@ echo "hello" if err := p.Abort(ctx); err != nil { t.Fatal(err) } + // redunant abort is ok if err := p.Abort(ctx); err != nil { t.Fatal(err) } diff --git a/pkg/process/utils.go b/pkg/process/utils.go index 7953fbdb..c146ac88 100644 --- a/pkg/process/utils.go +++ b/pkg/process/utils.go @@ -62,7 +62,19 @@ func WithWaitForCmd() ReadOpOption { } } +var ( + ErrProcessNotStarted = errors.New("process not started") + ErrProcessAborted = errors.New("process aborted") +) + func Read(ctx context.Context, p Process, opts ...ReadOpOption) error { + if !p.Started() { + return ErrProcessNotStarted + } + if p.Aborted() { + return ErrProcessAborted + } + op := &ReadOp{} if err := op.applyOpts(opts); err != nil { return err @@ -71,14 +83,27 @@ func Read(ctx context.Context, p Process, opts ...ReadOpOption) error { // combine stdout and stderr into a single reader readers := []io.Reader{} if op.readStdout { - readers = append(readers, p.StdoutReader()) + // may happen if the process is alread aborted + stdoutReader := p.StdoutReader() + if stdoutReader == nil { + return errors.New("stdout reader is nil") + } + readers = append(readers, stdoutReader) } if op.readStderr { - readers = append(readers, p.StderrReader()) + // may happen if the process is alread aborted + stderrReader := p.StderrReader() + if stderrReader == nil { + return errors.New("stderr reader is nil") + } + readers = append(readers, stderrReader) } combinedReader := io.MultiReader(readers...) scanner := bufio.NewScanner(combinedReader) + if scanner == nil { + return errors.New("scanner is nil") + } for scanner.Scan() { // helps with debugging if command times out in the middle of reading @@ -93,6 +118,10 @@ func Read(ctx context.Context, p Process, opts ...ReadOpOption) error { return ctx.Err() default: } + + if p.Aborted() { + return errors.New("process aborted") + } } if serr := scanner.Err(); serr != nil { diff --git a/pkg/process/utils_test.go b/pkg/process/utils_test.go index 2a113263..ee73f07c 100644 --- a/pkg/process/utils_test.go +++ b/pkg/process/utils_test.go @@ -33,6 +33,10 @@ func (p *testProcess) Start(context.Context) error { return nil } +func (p *testProcess) Started() bool { + return true +} + func (p *testProcess) StdoutReader() io.Reader { p.mu.Lock() defer p.mu.Unlock() @@ -59,6 +63,10 @@ func (p *testProcess) Abort(ctx context.Context) error { return nil } +func (p *testProcess) Aborted() bool { + return false +} + func newTestProcess(command string, args ...string) *testProcess { cmd := exec.Command(command, args...) stdout, _ := cmd.StdoutPipe() @@ -193,3 +201,155 @@ func TestReadAll(t *testing.T) { } }) } + +func TestNilReaders(t *testing.T) { + // Test nil stdout reader + t.Run("nil stdout reader", func(t *testing.T) { + p := &nilReaderProcess{returnNilStdout: true} + err := Read(context.Background(), p, WithReadStdout()) + if err == nil || err.Error() != "stdout reader is nil" { + t.Errorf("expected 'stdout reader is nil' error, got %v", err) + } + }) + + // Test nil stderr reader + t.Run("nil stderr reader", func(t *testing.T) { + p := &nilReaderProcess{returnNilStderr: true} + err := Read(context.Background(), p, WithReadStderr()) + if err == nil || err.Error() != "stderr reader is nil" { + t.Errorf("expected 'stderr reader is nil' error, got %v", err) + } + }) + + // Test both nil readers + t.Run("both nil readers", func(t *testing.T) { + p := &nilReaderProcess{returnNilStdout: true, returnNilStderr: true} + err := Read(context.Background(), p, WithReadStdout(), WithReadStderr()) + if err == nil || err.Error() != "stdout reader is nil" { + t.Errorf("expected 'stdout reader is nil' error, got %v", err) + } + }) +} + +// nilReaderProcess implements Process interface for testing nil reader cases +type nilReaderProcess struct { + returnNilStdout bool + returnNilStderr bool +} + +func (p *nilReaderProcess) Labels() map[string]string { + return nil +} + +func (p *nilReaderProcess) PID() int32 { + return 0 +} + +func (p *nilReaderProcess) Start(context.Context) error { + return nil +} + +func (p *nilReaderProcess) Started() bool { + return true +} + +func (p *nilReaderProcess) StdoutReader() io.Reader { + if p.returnNilStdout { + return nil + } + return strings.NewReader("") +} + +func (p *nilReaderProcess) StderrReader() io.Reader { + if p.returnNilStderr { + return nil + } + return strings.NewReader("") +} + +func (p *nilReaderProcess) Wait() <-chan error { + ch := make(chan error, 1) + close(ch) + return ch +} + +func (p *nilReaderProcess) Abort(context.Context) error { + return nil +} + +func (p *nilReaderProcess) Aborted() bool { + return false +} + +// stateProcess implements Process interface for testing process states +type stateProcess struct { + isStarted bool + isAborted bool +} + +func (p *stateProcess) Labels() map[string]string { + return nil +} + +func (p *stateProcess) PID() int32 { + return 0 +} + +func (p *stateProcess) Start(context.Context) error { + return nil +} + +func (p *stateProcess) Started() bool { + return p.isStarted +} + +func (p *stateProcess) StdoutReader() io.Reader { + return strings.NewReader("") +} + +func (p *stateProcess) StderrReader() io.Reader { + return strings.NewReader("") +} + +func (p *stateProcess) Wait() <-chan error { + ch := make(chan error, 1) + close(ch) + return ch +} + +func (p *stateProcess) Abort(context.Context) error { + return nil +} + +func (p *stateProcess) Aborted() bool { + return p.isAborted +} + +func TestProcessStates(t *testing.T) { + // Test not started process + t.Run("not started process", func(t *testing.T) { + p := &stateProcess{isStarted: false} + err := Read(context.Background(), p, WithReadStdout()) + if err != ErrProcessNotStarted { + t.Errorf("expected ErrProcessNotStarted, got %v", err) + } + }) + + // Test started process + t.Run("started process", func(t *testing.T) { + p := &stateProcess{isStarted: true} + err := Read(context.Background(), p, WithReadStdout()) + if err != nil { + t.Errorf("expected no error for started process, got %v", err) + } + }) + + // Test aborted process + t.Run("aborted process", func(t *testing.T) { + p := &stateProcess{isStarted: true, isAborted: true} + err := Read(context.Background(), p, WithReadStdout()) + if err != ErrProcessAborted { + t.Errorf("expected ErrProcessAborted, got %v", err) + } + }) +}