diff --git a/errors.go b/errors.go index 9c32a49..add5dbe 100644 --- a/errors.go +++ b/errors.go @@ -14,6 +14,10 @@ package fsm +import ( + "context" +) + // InvalidEventError is returned by FSM.Event() when the event cannot be called // in the current state. type InvalidEventError struct { @@ -82,6 +86,9 @@ func (e CanceledError) Error() string { // asynchronous state transition. type AsyncError struct { Err error + + Ctx context.Context + CancelTransition func() } func (e AsyncError) Error() string { diff --git a/examples/cancel_async_transition.go b/examples/cancel_async_transition.go new file mode 100644 index 0000000..d67cb4c --- /dev/null +++ b/examples/cancel_async_transition.go @@ -0,0 +1,54 @@ +//go:build ignore +// +build ignore + +package main + +import ( + "context" + "fmt" + "time" + + "github.com/looplab/fsm" +) + +func main() { + f := fsm.NewFSM( + "start", + fsm.Events{ + {Name: "run", Src: []string{"start"}, Dst: "end"}, + }, + fsm.Callbacks{ + "leave_start": func(_ context.Context, e *fsm.Event) { + e.Async() + }, + }, + ) + + err := f.Event(context.Background(), "run") + asyncError, ok := err.(fsm.AsyncError) + if !ok { + panic(fmt.Sprintf("expected error to be 'AsyncError', got %v", err)) + } + var asyncStateTransitionWasCanceled bool + go func() { + <-asyncError.Ctx.Done() + asyncStateTransitionWasCanceled = true + if asyncError.Ctx.Err() != context.Canceled { + panic(fmt.Sprintf("Expected error to be '%v' but was '%v'", context.Canceled, asyncError.Ctx.Err())) + } + }() + asyncError.CancelTransition() + time.Sleep(20 * time.Millisecond) + + if err = f.Transition(); err != nil { + panic(fmt.Sprintf("Error encountered when transitioning: %v", err)) + } + if !asyncStateTransitionWasCanceled { + panic("expected async state transition cancelation to have propagated") + } + if f.Current() != "start" { + panic("expected state to be 'start'") + } + + fmt.Println("Successfully ran state machine.") +} diff --git a/examples/transition_callbacks.go b/examples/transition_callbacks.go new file mode 100644 index 0000000..e84f795 --- /dev/null +++ b/examples/transition_callbacks.go @@ -0,0 +1,54 @@ +//go:build ignore +// +build ignore + +package main + +import ( + "context" + "fmt" + + "github.com/looplab/fsm" +) + +func main() { + var afterFinishCalled bool + fsm := fsm.NewFSM( + "start", + fsm.Events{ + {Name: "run", Src: []string{"start"}, Dst: "end"}, + {Name: "finish", Src: []string{"end"}, Dst: "finished"}, + {Name: "reset", Src: []string{"end", "finished"}, Dst: "start"}, + }, + fsm.Callbacks{ + "enter_end": func(ctx context.Context, e *fsm.Event) { + if err := e.FSM.Event(ctx, "finish"); err != nil { + fmt.Println(err) + } + }, + "after_finish": func(ctx context.Context, e *fsm.Event) { + afterFinishCalled = true + if e.Src != "end" { + panic(fmt.Sprintf("source should have been 'end' but was '%s'", e.Src)) + } + if err := e.FSM.Event(ctx, "reset"); err != nil { + fmt.Println(err) + } + }, + }, + ) + + if err := fsm.Event(context.Background(), "run"); err != nil { + panic(fmt.Sprintf("Error encountered when triggering the run event: %v", err)) + } + + if !afterFinishCalled { + panic(fmt.Sprintf("After finish callback should have run, current state: '%s'", fsm.Current())) + } + + currentState := fsm.Current() + if currentState != "start" { + panic(fmt.Sprintf("expected state to be 'start', was '%s'", currentState)) + } + + fmt.Println("Successfully ran state machine.") +} diff --git a/fsm.go b/fsm.go index 8ab2726..c5d9efd 100644 --- a/fsm.go +++ b/fsm.go @@ -289,7 +289,16 @@ func (f *FSM) SetMetadata(key string, dataValue interface{}) { // internal bug. func (f *FSM) Event(ctx context.Context, event string, args ...interface{}) error { f.eventMu.Lock() - defer f.eventMu.Unlock() + // in order to always unlock the event mutex, the defer is added + // in case the state transition goes through and enter/after callbacks + // are called; because these must be able to trigger new state + // transitions, it is explicitly unlocked in the code below + var unlocked bool + defer func() { + if !unlocked { + f.eventMu.Unlock() + } + }() f.stateMu.RLock() defer f.stateMu.RUnlock() @@ -323,18 +332,48 @@ func (f *FSM) Event(ctx context.Context, event string, args ...interface{}) erro } // Setup the transition, call it later. - f.transition = func() { - f.stateMu.Lock() - f.current = dst - f.stateMu.Unlock() + transitionFunc := func(ctx context.Context, async bool) func() { + return func() { + if ctx.Err() != nil { + if e.Err == nil { + e.Err = ctx.Err() + } + return + } - f.enterStateCallbacks(ctx, e) - f.afterEventCallbacks(ctx, e) + f.stateMu.Lock() + f.current = dst + f.stateMu.Unlock() + + // at this point, we unlock the event mutex in order to allow + // enter state callbacks to trigger another transition + // for aynchronous state transitions this doesn't happen because + // the event mutex has already been unlocked + if !async { + f.eventMu.Unlock() + unlocked = true + } + f.transition = nil // treat the state transition as done + f.enterStateCallbacks(ctx, e) + f.afterEventCallbacks(ctx, e) + } } + f.transition = transitionFunc(ctx, false) + if err = f.leaveStateCallbacks(ctx, e); err != nil { if _, ok := err.(CanceledError); ok { f.transition = nil + } else if asyncError, ok := err.(AsyncError); ok { + // setup a new context in order for async state transitions to work correctly + // this "uncancels" the original context which ignores its cancelation + // but keeps the values of the original context available to callers + ctx, cancel := uncancelContext(ctx) + e.cancelFunc = cancel + asyncError.Ctx = ctx + asyncError.CancelTransition = cancel + f.transition = transitionFunc(ctx, true) + return asyncError } return err } @@ -405,7 +444,7 @@ func (f *FSM) leaveStateCallbacks(ctx context.Context, e *Event) error { if e.canceled { return CanceledError{e.Err} } else if e.async { - return AsyncError{e.Err} + return AsyncError{Err: e.Err} } } if fn, ok := f.callbacks[cKey{"", callbackLeaveState}]; ok { @@ -413,7 +452,7 @@ func (f *FSM) leaveStateCallbacks(ctx context.Context, e *Event) error { if e.canceled { return CanceledError{e.Err} } else if e.async { - return AsyncError{e.Err} + return AsyncError{Err: e.Err} } } return nil diff --git a/fsm_test.go b/fsm_test.go index c5e90bc..985b0ea 100644 --- a/fsm_test.go +++ b/fsm_test.go @@ -16,6 +16,7 @@ package fsm import ( "context" + "errors" "fmt" "sort" "sync" @@ -526,6 +527,42 @@ func TestAsyncTransitionNotInProgress(t *testing.T) { } } +func TestCancelAsyncTransition(t *testing.T) { + fsm := NewFSM( + "start", + Events{ + {Name: "run", Src: []string{"start"}, Dst: "end"}, + }, + Callbacks{ + "leave_start": func(_ context.Context, e *Event) { + e.Async() + }, + }, + ) + err := fsm.Event(context.Background(), "run") + asyncError, ok := err.(AsyncError) + if !ok { + t.Errorf("expected error to be 'AsyncError', got %v", err) + } + var asyncStateTransitionWasCanceled bool + go func() { + <-asyncError.Ctx.Done() + asyncStateTransitionWasCanceled = true + }() + asyncError.CancelTransition() + time.Sleep(20 * time.Millisecond) + + if err = fsm.Transition(); err != nil { + t.Errorf("expected no error, got %v", err) + } + if !asyncStateTransitionWasCanceled { + t.Error("expected async state transition cancelation to have propagated") + } + if fsm.Current() != "start" { + t.Error("expected state to be 'start'") + } +} + func TestCallbackNoError(t *testing.T) { fsm := NewFSM( "start", @@ -695,6 +732,47 @@ func TestDoubleTransition(t *testing.T) { wg.Wait() } +func TestTransitionInCallbacks(t *testing.T) { + var fsm *FSM + var afterFinishCalled bool + fsm = NewFSM( + "start", + Events{ + {Name: "run", Src: []string{"start"}, Dst: "end"}, + {Name: "finish", Src: []string{"end"}, Dst: "finished"}, + {Name: "reset", Src: []string{"end", "finished"}, Dst: "start"}, + }, + Callbacks{ + "enter_end": func(ctx context.Context, e *Event) { + if err := e.FSM.Event(ctx, "finish"); err != nil { + fmt.Println(err) + } + }, + "after_finish": func(ctx context.Context, e *Event) { + afterFinishCalled = true + if e.Src != "end" { + panic(fmt.Sprintf("source should have been 'end' but was '%s'", e.Src)) + } + if err := e.FSM.Event(ctx, "reset"); err != nil { + fmt.Println(err) + } + }, + }, + ) + + if err := fsm.Event(context.Background(), "run"); err != nil { + t.Errorf("expected no error, got %v", err) + } + if !afterFinishCalled { + t.Error("expected after_finish callback to have been executed but it wasn't") + } + + currentState := fsm.Current() + if currentState != "start" { + t.Errorf("expected state to be 'start', was '%s'", currentState) + } +} + func TestContextInCallbacks(t *testing.T) { var fsm *FSM var enterEndAsyncWorkDone bool @@ -711,6 +789,11 @@ func TestContextInCallbacks(t *testing.T) { <-ctx.Done() enterEndAsyncWorkDone = true }() + + <-ctx.Done() + if err := e.FSM.Event(ctx, "finish"); err != nil { + e.Err = fmt.Errorf("transitioning to the finished state failed: %w", err) + } }, }, ) @@ -719,7 +802,10 @@ func TestContextInCallbacks(t *testing.T) { go func() { cancel() }() - fsm.Event(ctx, "run") + err := fsm.Event(ctx, "run") + if !errors.Is(err, context.Canceled) { + t.Errorf("expected 'context canceled' error, got %v", err) + } time.Sleep(20 * time.Millisecond) if !enterEndAsyncWorkDone { diff --git a/uncancel_context.go b/uncancel_context.go new file mode 100644 index 0000000..eabdeac --- /dev/null +++ b/uncancel_context.go @@ -0,0 +1,21 @@ +package fsm + +import ( + "context" + "time" +) + +type uncancel struct { + context.Context +} + +func (*uncancel) Deadline() (deadline time.Time, ok bool) { return } +func (*uncancel) Done() <-chan struct{} { return nil } +func (*uncancel) Err() error { return nil } + +// uncancelContext returns a context which ignores the cancellation of the parent and only keeps the values. +// Also returns a new cancel function. +// This is useful to keep a background task running while the initial request is finished. +func uncancelContext(ctx context.Context) (context.Context, context.CancelFunc) { + return context.WithCancel(&uncancel{ctx}) +} diff --git a/uncancel_context_test.go b/uncancel_context_test.go new file mode 100644 index 0000000..086da78 --- /dev/null +++ b/uncancel_context_test.go @@ -0,0 +1,91 @@ +package fsm + +import ( + "context" + "testing" +) + +func TestUncancel(t *testing.T) { + t.Run("create a new context", func(t *testing.T) { + t.Run("and cancel it", func(t *testing.T) { + ctx := context.Background() + ctx = context.WithValue(ctx, "key1", "value1") + ctx, cancelFunc := context.WithCancel(ctx) + cancelFunc() + + if ctx.Err() != context.Canceled { + t.Errorf("expected context error 'context canceled', got %v", ctx.Err()) + } + select { + case <-ctx.Done(): + default: + t.Error("expected context to be done but it wasn't") + } + + t.Run("and uncancel it", func(t *testing.T) { + ctx, newCancelFunc := uncancelContext(ctx) + if ctx.Err() != nil { + t.Errorf("expected context error to be nil, got %v", ctx.Err()) + } + select { + case <-ctx.Done(): + t.Fail() + default: + } + + t.Run("now it should still contain the values", func(t *testing.T) { + if ctx.Value("key1") != "value1" { + t.Errorf("expected context value of key 'key1' to be 'value1', got %v", ctx.Value("key1")) + } + }) + t.Run("and cancel the child", func(t *testing.T) { + newCancelFunc() + if ctx.Err() != context.Canceled { + t.Errorf("expected context error 'context canceled', got %v", ctx.Err()) + } + select { + case <-ctx.Done(): + default: + t.Error("expected context to be done but it wasn't") + } + }) + }) + }) + t.Run("and uncancel it", func(t *testing.T) { + ctx := context.Background() + parent := ctx + ctx, newCancelFunc := uncancelContext(ctx) + if ctx.Err() != nil { + t.Errorf("expected context error to be nil, got %v", ctx.Err()) + } + select { + case <-ctx.Done(): + t.Fail() + default: + } + + t.Run("and cancel the child", func(t *testing.T) { + newCancelFunc() + if ctx.Err() != context.Canceled { + t.Errorf("expected context error 'context canceled', got %v", ctx.Err()) + } + select { + case <-ctx.Done(): + default: + t.Error("expected context to be done but it wasn't") + } + + t.Run("and ensure the parent is not affected", func(t *testing.T) { + if parent.Err() != nil { + t.Errorf("expected parent context error to be nil, got %v", ctx.Err()) + } + select { + case <-parent.Done(): + t.Fail() + default: + } + }) + }) + }) + }) +}