Skip to content

Commit

Permalink
fix(pkg/process): gracefully handle read operations on aborted proces…
Browse files Browse the repository at this point in the history
…s, Read to return error if not started (#276)

Signed-off-by: Gyuho Lee <[email protected]>
  • Loading branch information
gyuho authored Dec 30, 2024
1 parent fa9c0e5 commit 687ff0a
Show file tree
Hide file tree
Showing 4 changed files with 298 additions and 21 deletions.
117 changes: 98 additions & 19 deletions pkg/process/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@ type Process interface {

// Starts the process but does not wait for it to exit.
Start(ctx context.Context) error
// Returns true if the process is started.
Started() bool

// Aborts the process and waits for it to exit.
Abort(ctx context.Context) error
// Returns true if the process is aborted.
Aborted() bool

// Waits for the process to exit and returns the error, if any.
// If the command completes successfully, the error will be nil.
Expand All @@ -40,6 +44,9 @@ type Process interface {
//
// If the process exits with a non-zero exit code, stdout/stderr pipes may not work.
// If retry configuration is specified, specify the output file to read all the output.
//
// The returned reader is set to nil upon the abort call on the process,
// to prevent redundant closing of the reader.
StdoutReader() io.Reader

// Returns the stderr reader.
Expand All @@ -48,6 +55,9 @@ type Process interface {
//
// If the process exits with a non-zero exit code, stdout/stderr pipes may not work.
// If retry configuration is specified, specify the output file to read all the output.
//
// The returned reader is set to nil upon the abort call on the process,
// to prevent redundant closing of the reader.
StderrReader() io.Reader
}

Expand All @@ -72,6 +82,12 @@ type process struct {
cmdMu sync.RWMutex
cmd *exec.Cmd

startedMu sync.RWMutex
started bool

abortedMu sync.RWMutex
aborted bool

// error streaming channel, closed on command exit
errc chan error

Expand All @@ -80,9 +96,9 @@ type process struct {
envs []string
runBashFile *os.File

outputFile *os.File
stdoutReader io.ReadCloser
stderrReader io.ReadCloser
outputFile *os.File
stdoutReadCloser io.ReadCloser
stderrReadCloser io.ReadCloser

restartConfig *RestartConfig
}
Expand Down Expand Up @@ -136,9 +152,14 @@ func New(opts ...OpOption) (Process, error) {
errcBuffer = op.restartConfig.Limit
}
return &process{
labels: op.labels,
cmd: nil,
errc: make(chan error, errcBuffer),
labels: op.labels,
cmd: nil,

started: false,
aborted: false,

errc: make(chan error, errcBuffer),

commandArgs: cmdArgs,
envs: op.envs,
runBashFile: bashFile,
Expand All @@ -157,6 +178,20 @@ func (p *process) Labels() map[string]string {
}

func (p *process) Start(ctx context.Context) error {
p.startedMu.RLock()
started := p.started
p.startedMu.RUnlock()
if started { // already started
return nil
}

p.abortedMu.RLock()
aborted := p.aborted
p.abortedMu.RUnlock()
if aborted { // already aborted
return nil
}

p.cmdMu.Lock()
defer p.cmdMu.Unlock()

Expand All @@ -179,6 +214,13 @@ func (p *process) Start(ctx context.Context) error {
return nil
}

func (p *process) Started() bool {
p.startedMu.RLock()
defer p.startedMu.RUnlock()

return p.started
}

func (p *process) startCommand() error {
log.Logger.Debugw("starting command", "command", p.commandArgs)
p.cmd = exec.CommandContext(p.ctx, p.commandArgs[0], p.commandArgs[1:]...)
Expand All @@ -189,16 +231,16 @@ func (p *process) startCommand() error {
p.cmd.Stdout = p.outputFile
p.cmd.Stderr = p.outputFile

p.stdoutReader = p.outputFile
p.stderrReader = p.outputFile
p.stdoutReadCloser = p.outputFile
p.stderrReadCloser = p.outputFile

default:
var err error
p.stdoutReader, err = p.cmd.StdoutPipe()
p.stdoutReadCloser, err = p.cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("failed to get stdout pipe: %w", err)
}
p.stderrReader, err = p.cmd.StderrPipe()
p.stderrReadCloser, err = p.cmd.StderrPipe()
if err != nil {
return fmt.Errorf("failed to get stderr pipe: %w", err)
}
Expand All @@ -209,6 +251,10 @@ func (p *process) startCommand() error {
}
atomic.StoreInt32(&p.pid, int32(p.cmd.Process.Pid))

p.startedMu.Lock()
p.started = true
p.startedMu.Unlock()

return nil
}

Expand Down Expand Up @@ -299,6 +345,20 @@ func (p *process) watchCmd() {
}

func (p *process) Abort(ctx context.Context) error {
p.startedMu.RLock()
started := p.started
p.startedMu.RUnlock()
if !started { // has not started yet
return nil
}

p.abortedMu.RLock()
aborted := p.aborted
p.abortedMu.RUnlock()
if aborted { // already aborted
return nil
}

p.cmdMu.Lock()
defer p.cmdMu.Unlock()

Expand Down Expand Up @@ -331,16 +391,24 @@ func (p *process) Abort(ctx context.Context) error {
if p.runBashFile != nil {
_ = p.runBashFile.Sync()
_ = p.runBashFile.Close()
return os.RemoveAll(p.runBashFile.Name())
if err := os.RemoveAll(p.runBashFile.Name()); err != nil {
log.Logger.Warnw("failed to remove bash file", "error", err)
// Don't return here, continue with cleanup
}
}

if p.stdoutReader != nil {
_ = p.stdoutReader.Close()
p.stdoutReader = nil
if p.stdoutReadCloser != nil {
_ = p.stdoutReadCloser.Close()

// set to nil to prevent redundant closing of the reader
p.stdoutReadCloser = nil
}
if p.stderrReader != nil {
_ = p.stderrReader.Close()
p.stderrReader = nil

if p.stderrReadCloser != nil {
_ = p.stderrReadCloser.Close()

// set to nil to prevent redundant closing of the reader
p.stderrReadCloser = nil
}

if p.cmd.Cancel != nil { // if created with CommandContext
Expand All @@ -351,9 +419,20 @@ func (p *process) Abort(ctx context.Context) error {
// as Wait is still waiting for the process to exit
// p.cmd = nil

p.abortedMu.Lock()
p.aborted = true
p.abortedMu.Unlock()

return nil
}

func (p *process) Aborted() bool {
p.abortedMu.RLock()
defer p.abortedMu.RUnlock()

return p.aborted
}

func (p *process) PID() int32 {
return atomic.LoadInt32(&p.pid)
}
Expand All @@ -365,7 +444,7 @@ func (p *process) StdoutReader() io.Reader {
if p.outputFile != nil {
return p.outputFile
}
return p.stdoutReader
return p.stdoutReadCloser
}

func (p *process) StderrReader() io.Reader {
Expand All @@ -375,7 +454,7 @@ func (p *process) StderrReader() io.Reader {
if p.outputFile != nil {
return p.outputFile
}
return p.stderrReader
return p.stderrReadCloser
}

const bashScriptHeader = `#!/bin/bash
Expand Down
9 changes: 9 additions & 0 deletions pkg/process/process_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ func TestProcess(t *testing.T) {
}
t.Logf("pid: %d", p.PID())

// redunant start is ok
if err := p.Start(ctx); err != nil {
t.Fatal(err)
}

if err := Read(
ctx,
p,
Expand All @@ -44,6 +49,9 @@ func TestProcess(t *testing.T) {
if err := p.Abort(ctx); err != nil {
t.Fatal(err)
}
if !p.Aborted() {
t.Fatal("process is not aborted")
}
}

func TestProcessRunBashScriptContents(t *testing.T) {
Expand Down Expand Up @@ -96,6 +104,7 @@ echo "hello"
if err := p.Abort(ctx); err != nil {
t.Fatal(err)
}
// redunant abort is ok
if err := p.Abort(ctx); err != nil {
t.Fatal(err)
}
Expand Down
33 changes: 31 additions & 2 deletions pkg/process/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,19 @@ func WithWaitForCmd() ReadOpOption {
}
}

var (
ErrProcessNotStarted = errors.New("process not started")
ErrProcessAborted = errors.New("process aborted")
)

func Read(ctx context.Context, p Process, opts ...ReadOpOption) error {
if !p.Started() {
return ErrProcessNotStarted
}
if p.Aborted() {
return ErrProcessAborted
}

op := &ReadOp{}
if err := op.applyOpts(opts); err != nil {
return err
Expand All @@ -71,14 +83,27 @@ func Read(ctx context.Context, p Process, opts ...ReadOpOption) error {
// combine stdout and stderr into a single reader
readers := []io.Reader{}
if op.readStdout {
readers = append(readers, p.StdoutReader())
// may happen if the process is alread aborted
stdoutReader := p.StdoutReader()
if stdoutReader == nil {
return errors.New("stdout reader is nil")
}
readers = append(readers, stdoutReader)
}
if op.readStderr {
readers = append(readers, p.StderrReader())
// may happen if the process is alread aborted
stderrReader := p.StderrReader()
if stderrReader == nil {
return errors.New("stderr reader is nil")
}
readers = append(readers, stderrReader)
}

combinedReader := io.MultiReader(readers...)
scanner := bufio.NewScanner(combinedReader)
if scanner == nil {
return errors.New("scanner is nil")
}

for scanner.Scan() {
// helps with debugging if command times out in the middle of reading
Expand All @@ -93,6 +118,10 @@ func Read(ctx context.Context, p Process, opts ...ReadOpOption) error {
return ctx.Err()
default:
}

if p.Aborted() {
return errors.New("process aborted")
}
}

if serr := scanner.Err(); serr != nil {
Expand Down
Loading

0 comments on commit 687ff0a

Please sign in to comment.