From 3637340ce5e584a1ff164a85c477f31ad58bb33d Mon Sep 17 00:00:00 2001 From: Daniel Lohse Date: Thu, 6 Oct 2022 23:18:58 +0200 Subject: [PATCH] Allow transitions in callbacks (#88) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Allow state transitions in callbacks This adds the possibility of "starting" a state machine and have it execute multiple state transitions in succession, given that no errors occur. The equivalent code without this change: ```go var errTransition error for errTransition == nil { transitions := request.FSM.AvailableTransitions() if len(transitions) == 0 { break } if len(transitions) > 1 { errTransition = errors.New("only 1 transition should be available") } errTransition = request.FSM.Event(transitions[0]) } if errTransition != nil { fmt.Println(errTransition) } ``` Arguably, that’s bad because of several reasons: 1. The state machine is used like a puppet. 2. The state transitions that make up the "happy path" are encoded outside the state machine. 3. The code really isn’t good. 4. There’s no way to intervene or make different decisions on which state to transition to next (reinforces bullet point 2). 5. There’s no way to add proper error handling. It is possible to fix a certain number of those problems but not all of them, especially 2 and 4 but also 1. The added test is green and uses both an enter state and an after event callback. No other test case was touched in any way (besides enhancing the context one that was added in the previous commit). * Allow async state transition to be canceled This adds a context and cancelation facility to the type `AsyncError`. Async state transitions can now be canceled by calling `CancelTransition` on the AsyncError returned by `fsm.Event`. The context on that error can also be handed off as described in https://github.com/looplab/fsm/pull/77#issuecomment-930173282. * Add example for triggering transitions in callbacks * Add example for canceling an async transition --- errors.go | 7 +++ examples/cancel_async_transition.go | 54 +++++++++++++++++ examples/transition_callbacks.go | 54 +++++++++++++++++ fsm.go | 57 +++++++++++++++--- fsm_test.go | 88 +++++++++++++++++++++++++++- uncancel_context.go | 21 +++++++ uncancel_context_test.go | 91 +++++++++++++++++++++++++++++ 7 files changed, 362 insertions(+), 10 deletions(-) create mode 100644 examples/cancel_async_transition.go create mode 100644 examples/transition_callbacks.go create mode 100644 uncancel_context.go create mode 100644 uncancel_context_test.go diff --git a/errors.go b/errors.go index 9c32a49..add5dbe 100644 --- a/errors.go +++ b/errors.go @@ -14,6 +14,10 @@ package fsm +import ( + "context" +) + // InvalidEventError is returned by FSM.Event() when the event cannot be called // in the current state. type InvalidEventError struct { @@ -82,6 +86,9 @@ func (e CanceledError) Error() string { // asynchronous state transition. type AsyncError struct { Err error + + Ctx context.Context + CancelTransition func() } func (e AsyncError) Error() string { diff --git a/examples/cancel_async_transition.go b/examples/cancel_async_transition.go new file mode 100644 index 0000000..d67cb4c --- /dev/null +++ b/examples/cancel_async_transition.go @@ -0,0 +1,54 @@ +//go:build ignore +// +build ignore + +package main + +import ( + "context" + "fmt" + "time" + + "github.com/looplab/fsm" +) + +func main() { + f := fsm.NewFSM( + "start", + fsm.Events{ + {Name: "run", Src: []string{"start"}, Dst: "end"}, + }, + fsm.Callbacks{ + "leave_start": func(_ context.Context, e *fsm.Event) { + e.Async() + }, + }, + ) + + err := f.Event(context.Background(), "run") + asyncError, ok := err.(fsm.AsyncError) + if !ok { + panic(fmt.Sprintf("expected error to be 'AsyncError', got %v", err)) + } + var asyncStateTransitionWasCanceled bool + go func() { + <-asyncError.Ctx.Done() + asyncStateTransitionWasCanceled = true + if asyncError.Ctx.Err() != context.Canceled { + panic(fmt.Sprintf("Expected error to be '%v' but was '%v'", context.Canceled, asyncError.Ctx.Err())) + } + }() + asyncError.CancelTransition() + time.Sleep(20 * time.Millisecond) + + if err = f.Transition(); err != nil { + panic(fmt.Sprintf("Error encountered when transitioning: %v", err)) + } + if !asyncStateTransitionWasCanceled { + panic("expected async state transition cancelation to have propagated") + } + if f.Current() != "start" { + panic("expected state to be 'start'") + } + + fmt.Println("Successfully ran state machine.") +} diff --git a/examples/transition_callbacks.go b/examples/transition_callbacks.go new file mode 100644 index 0000000..e84f795 --- /dev/null +++ b/examples/transition_callbacks.go @@ -0,0 +1,54 @@ +//go:build ignore +// +build ignore + +package main + +import ( + "context" + "fmt" + + "github.com/looplab/fsm" +) + +func main() { + var afterFinishCalled bool + fsm := fsm.NewFSM( + "start", + fsm.Events{ + {Name: "run", Src: []string{"start"}, Dst: "end"}, + {Name: "finish", Src: []string{"end"}, Dst: "finished"}, + {Name: "reset", Src: []string{"end", "finished"}, Dst: "start"}, + }, + fsm.Callbacks{ + "enter_end": func(ctx context.Context, e *fsm.Event) { + if err := e.FSM.Event(ctx, "finish"); err != nil { + fmt.Println(err) + } + }, + "after_finish": func(ctx context.Context, e *fsm.Event) { + afterFinishCalled = true + if e.Src != "end" { + panic(fmt.Sprintf("source should have been 'end' but was '%s'", e.Src)) + } + if err := e.FSM.Event(ctx, "reset"); err != nil { + fmt.Println(err) + } + }, + }, + ) + + if err := fsm.Event(context.Background(), "run"); err != nil { + panic(fmt.Sprintf("Error encountered when triggering the run event: %v", err)) + } + + if !afterFinishCalled { + panic(fmt.Sprintf("After finish callback should have run, current state: '%s'", fsm.Current())) + } + + currentState := fsm.Current() + if currentState != "start" { + panic(fmt.Sprintf("expected state to be 'start', was '%s'", currentState)) + } + + fmt.Println("Successfully ran state machine.") +} diff --git a/fsm.go b/fsm.go index 8ab2726..c5d9efd 100644 --- a/fsm.go +++ b/fsm.go @@ -289,7 +289,16 @@ func (f *FSM) SetMetadata(key string, dataValue interface{}) { // internal bug. func (f *FSM) Event(ctx context.Context, event string, args ...interface{}) error { f.eventMu.Lock() - defer f.eventMu.Unlock() + // in order to always unlock the event mutex, the defer is added + // in case the state transition goes through and enter/after callbacks + // are called; because these must be able to trigger new state + // transitions, it is explicitly unlocked in the code below + var unlocked bool + defer func() { + if !unlocked { + f.eventMu.Unlock() + } + }() f.stateMu.RLock() defer f.stateMu.RUnlock() @@ -323,18 +332,48 @@ func (f *FSM) Event(ctx context.Context, event string, args ...interface{}) erro } // Setup the transition, call it later. - f.transition = func() { - f.stateMu.Lock() - f.current = dst - f.stateMu.Unlock() + transitionFunc := func(ctx context.Context, async bool) func() { + return func() { + if ctx.Err() != nil { + if e.Err == nil { + e.Err = ctx.Err() + } + return + } - f.enterStateCallbacks(ctx, e) - f.afterEventCallbacks(ctx, e) + f.stateMu.Lock() + f.current = dst + f.stateMu.Unlock() + + // at this point, we unlock the event mutex in order to allow + // enter state callbacks to trigger another transition + // for aynchronous state transitions this doesn't happen because + // the event mutex has already been unlocked + if !async { + f.eventMu.Unlock() + unlocked = true + } + f.transition = nil // treat the state transition as done + f.enterStateCallbacks(ctx, e) + f.afterEventCallbacks(ctx, e) + } } + f.transition = transitionFunc(ctx, false) + if err = f.leaveStateCallbacks(ctx, e); err != nil { if _, ok := err.(CanceledError); ok { f.transition = nil + } else if asyncError, ok := err.(AsyncError); ok { + // setup a new context in order for async state transitions to work correctly + // this "uncancels" the original context which ignores its cancelation + // but keeps the values of the original context available to callers + ctx, cancel := uncancelContext(ctx) + e.cancelFunc = cancel + asyncError.Ctx = ctx + asyncError.CancelTransition = cancel + f.transition = transitionFunc(ctx, true) + return asyncError } return err } @@ -405,7 +444,7 @@ func (f *FSM) leaveStateCallbacks(ctx context.Context, e *Event) error { if e.canceled { return CanceledError{e.Err} } else if e.async { - return AsyncError{e.Err} + return AsyncError{Err: e.Err} } } if fn, ok := f.callbacks[cKey{"", callbackLeaveState}]; ok { @@ -413,7 +452,7 @@ func (f *FSM) leaveStateCallbacks(ctx context.Context, e *Event) error { if e.canceled { return CanceledError{e.Err} } else if e.async { - return AsyncError{e.Err} + return AsyncError{Err: e.Err} } } return nil diff --git a/fsm_test.go b/fsm_test.go index c5e90bc..985b0ea 100644 --- a/fsm_test.go +++ b/fsm_test.go @@ -16,6 +16,7 @@ package fsm import ( "context" + "errors" "fmt" "sort" "sync" @@ -526,6 +527,42 @@ func TestAsyncTransitionNotInProgress(t *testing.T) { } } +func TestCancelAsyncTransition(t *testing.T) { + fsm := NewFSM( + "start", + Events{ + {Name: "run", Src: []string{"start"}, Dst: "end"}, + }, + Callbacks{ + "leave_start": func(_ context.Context, e *Event) { + e.Async() + }, + }, + ) + err := fsm.Event(context.Background(), "run") + asyncError, ok := err.(AsyncError) + if !ok { + t.Errorf("expected error to be 'AsyncError', got %v", err) + } + var asyncStateTransitionWasCanceled bool + go func() { + <-asyncError.Ctx.Done() + asyncStateTransitionWasCanceled = true + }() + asyncError.CancelTransition() + time.Sleep(20 * time.Millisecond) + + if err = fsm.Transition(); err != nil { + t.Errorf("expected no error, got %v", err) + } + if !asyncStateTransitionWasCanceled { + t.Error("expected async state transition cancelation to have propagated") + } + if fsm.Current() != "start" { + t.Error("expected state to be 'start'") + } +} + func TestCallbackNoError(t *testing.T) { fsm := NewFSM( "start", @@ -695,6 +732,47 @@ func TestDoubleTransition(t *testing.T) { wg.Wait() } +func TestTransitionInCallbacks(t *testing.T) { + var fsm *FSM + var afterFinishCalled bool + fsm = NewFSM( + "start", + Events{ + {Name: "run", Src: []string{"start"}, Dst: "end"}, + {Name: "finish", Src: []string{"end"}, Dst: "finished"}, + {Name: "reset", Src: []string{"end", "finished"}, Dst: "start"}, + }, + Callbacks{ + "enter_end": func(ctx context.Context, e *Event) { + if err := e.FSM.Event(ctx, "finish"); err != nil { + fmt.Println(err) + } + }, + "after_finish": func(ctx context.Context, e *Event) { + afterFinishCalled = true + if e.Src != "end" { + panic(fmt.Sprintf("source should have been 'end' but was '%s'", e.Src)) + } + if err := e.FSM.Event(ctx, "reset"); err != nil { + fmt.Println(err) + } + }, + }, + ) + + if err := fsm.Event(context.Background(), "run"); err != nil { + t.Errorf("expected no error, got %v", err) + } + if !afterFinishCalled { + t.Error("expected after_finish callback to have been executed but it wasn't") + } + + currentState := fsm.Current() + if currentState != "start" { + t.Errorf("expected state to be 'start', was '%s'", currentState) + } +} + func TestContextInCallbacks(t *testing.T) { var fsm *FSM var enterEndAsyncWorkDone bool @@ -711,6 +789,11 @@ func TestContextInCallbacks(t *testing.T) { <-ctx.Done() enterEndAsyncWorkDone = true }() + + <-ctx.Done() + if err := e.FSM.Event(ctx, "finish"); err != nil { + e.Err = fmt.Errorf("transitioning to the finished state failed: %w", err) + } }, }, ) @@ -719,7 +802,10 @@ func TestContextInCallbacks(t *testing.T) { go func() { cancel() }() - fsm.Event(ctx, "run") + err := fsm.Event(ctx, "run") + if !errors.Is(err, context.Canceled) { + t.Errorf("expected 'context canceled' error, got %v", err) + } time.Sleep(20 * time.Millisecond) if !enterEndAsyncWorkDone { diff --git a/uncancel_context.go b/uncancel_context.go new file mode 100644 index 0000000..eabdeac --- /dev/null +++ b/uncancel_context.go @@ -0,0 +1,21 @@ +package fsm + +import ( + "context" + "time" +) + +type uncancel struct { + context.Context +} + +func (*uncancel) Deadline() (deadline time.Time, ok bool) { return } +func (*uncancel) Done() <-chan struct{} { return nil } +func (*uncancel) Err() error { return nil } + +// uncancelContext returns a context which ignores the cancellation of the parent and only keeps the values. +// Also returns a new cancel function. +// This is useful to keep a background task running while the initial request is finished. +func uncancelContext(ctx context.Context) (context.Context, context.CancelFunc) { + return context.WithCancel(&uncancel{ctx}) +} diff --git a/uncancel_context_test.go b/uncancel_context_test.go new file mode 100644 index 0000000..086da78 --- /dev/null +++ b/uncancel_context_test.go @@ -0,0 +1,91 @@ +package fsm + +import ( + "context" + "testing" +) + +func TestUncancel(t *testing.T) { + t.Run("create a new context", func(t *testing.T) { + t.Run("and cancel it", func(t *testing.T) { + ctx := context.Background() + ctx = context.WithValue(ctx, "key1", "value1") + ctx, cancelFunc := context.WithCancel(ctx) + cancelFunc() + + if ctx.Err() != context.Canceled { + t.Errorf("expected context error 'context canceled', got %v", ctx.Err()) + } + select { + case <-ctx.Done(): + default: + t.Error("expected context to be done but it wasn't") + } + + t.Run("and uncancel it", func(t *testing.T) { + ctx, newCancelFunc := uncancelContext(ctx) + if ctx.Err() != nil { + t.Errorf("expected context error to be nil, got %v", ctx.Err()) + } + select { + case <-ctx.Done(): + t.Fail() + default: + } + + t.Run("now it should still contain the values", func(t *testing.T) { + if ctx.Value("key1") != "value1" { + t.Errorf("expected context value of key 'key1' to be 'value1', got %v", ctx.Value("key1")) + } + }) + t.Run("and cancel the child", func(t *testing.T) { + newCancelFunc() + if ctx.Err() != context.Canceled { + t.Errorf("expected context error 'context canceled', got %v", ctx.Err()) + } + select { + case <-ctx.Done(): + default: + t.Error("expected context to be done but it wasn't") + } + }) + }) + }) + t.Run("and uncancel it", func(t *testing.T) { + ctx := context.Background() + parent := ctx + ctx, newCancelFunc := uncancelContext(ctx) + if ctx.Err() != nil { + t.Errorf("expected context error to be nil, got %v", ctx.Err()) + } + select { + case <-ctx.Done(): + t.Fail() + default: + } + + t.Run("and cancel the child", func(t *testing.T) { + newCancelFunc() + if ctx.Err() != context.Canceled { + t.Errorf("expected context error 'context canceled', got %v", ctx.Err()) + } + select { + case <-ctx.Done(): + default: + t.Error("expected context to be done but it wasn't") + } + + t.Run("and ensure the parent is not affected", func(t *testing.T) { + if parent.Err() != nil { + t.Errorf("expected parent context error to be nil, got %v", ctx.Err()) + } + select { + case <-parent.Done(): + t.Fail() + default: + } + }) + }) + }) + }) +}