From 1b1ca2b71b09d4c14e323304e10583bf01b04dce Mon Sep 17 00:00:00 2001 From: Daniel Lohse Date: Sun, 12 Dec 2021 16:28:10 +0100 Subject: [PATCH 1/2] Allow state transitions in callbacks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This adds the possibility of "starting" a state machine and have it execute multiple state transitions in succession, given that no errors occur. The equivalent code without this change: ```go var errTransition error for errTransition == nil { transitions := request.FSM.AvailableTransitions() if len(transitions) == 0 { break } if len(transitions) > 1 { errTransition = errors.New("only 1 transition should be available") } errTransition = request.FSM.Event(transitions[0]) } if errTransition != nil { fmt.Println(errTransition) } ``` Arguably, that’s bad because of several reasons: 1. The state machine is used like a puppet. 2. The state transitions that make up the "happy path" are encoded outside the state machine. 3. The code really isn’t good. 4. There’s no way to intervene or make different decisions on which state to transition to next (reinforces bullet point 2). 5. There’s no way to add proper error handling. It is possible to fix a certain number of those problems but not all of them, especially 2 and 4 but also 1. The added test is green and uses both an enter state and an after event callback. No other test case was touched in any way (besides enhancing the context one that was added in the previous commit). --- fsm.go | 49 ++++++++++++++++++++++++++++++++++++++++++------- fsm_test.go | 52 +++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 93 insertions(+), 8 deletions(-) diff --git a/fsm.go b/fsm.go index 8ab2726..8379e95 100644 --- a/fsm.go +++ b/fsm.go @@ -289,7 +289,16 @@ func (f *FSM) SetMetadata(key string, dataValue interface{}) { // internal bug. func (f *FSM) Event(ctx context.Context, event string, args ...interface{}) error { f.eventMu.Lock() - defer f.eventMu.Unlock() + // in order to always unlock the event mutex, the defer is added + // in case the state transition goes through and enter/after callbacks + // are called; because these must be able to trigger new state + // transitions, it is explicitly unlocked in the code below + var unlocked bool + defer func() { + if !unlocked { + f.eventMu.Unlock() + } + }() f.stateMu.RLock() defer f.stateMu.RUnlock() @@ -323,18 +332,44 @@ func (f *FSM) Event(ctx context.Context, event string, args ...interface{}) erro } // Setup the transition, call it later. - f.transition = func() { - f.stateMu.Lock() - f.current = dst - f.stateMu.Unlock() + transitionFunc := func(ctx context.Context, async bool) func() { + return func() { + if ctx.Err() != nil { + if e.Err == nil { + e.Err = ctx.Err() + } + return + } - f.enterStateCallbacks(ctx, e) - f.afterEventCallbacks(ctx, e) + f.stateMu.Lock() + f.current = dst + f.stateMu.Unlock() + + // at this point, we unlock the event mutex in order to allow + // enter state callbacks to trigger another transition + // for aynchronous state transitions this doesn't happen because + // the event mutex has already been unlocked + if !async { + f.eventMu.Unlock() + unlocked = true + } + f.transition = nil // treat the state transition as done + f.enterStateCallbacks(ctx, e) + f.afterEventCallbacks(ctx, e) + } } + f.transition = transitionFunc(ctx, false) + if err = f.leaveStateCallbacks(ctx, e); err != nil { if _, ok := err.(CanceledError); ok { f.transition = nil + } else if asyncError, ok := err.(AsyncError); ok { + // setup a new context in order for async state transitions to work correctly + ctx, cancel := context.WithCancel(context.Background()) + e.cancelFunc = cancel + f.transition = transitionFunc(ctx, true) + return asyncError } return err } diff --git a/fsm_test.go b/fsm_test.go index c5e90bc..a3d1a44 100644 --- a/fsm_test.go +++ b/fsm_test.go @@ -16,6 +16,7 @@ package fsm import ( "context" + "errors" "fmt" "sort" "sync" @@ -695,6 +696,47 @@ func TestDoubleTransition(t *testing.T) { wg.Wait() } +func TestTransitionInCallbacks(t *testing.T) { + var fsm *FSM + var afterFinishCalled bool + fsm = NewFSM( + "start", + Events{ + {Name: "run", Src: []string{"start"}, Dst: "end"}, + {Name: "finish", Src: []string{"end"}, Dst: "finished"}, + {Name: "reset", Src: []string{"end", "finished"}, Dst: "start"}, + }, + Callbacks{ + "enter_end": func(ctx context.Context, e *Event) { + if err := e.FSM.Event(ctx, "finish"); err != nil { + fmt.Println(err) + } + }, + "after_finish": func(ctx context.Context, e *Event) { + afterFinishCalled = true + if e.Src != "end" { + panic(fmt.Sprintf("source should have been 'end' but was '%s'", e.Src)) + } + if err := e.FSM.Event(ctx, "reset"); err != nil { + fmt.Println(err) + } + }, + }, + ) + + if err := fsm.Event(context.Background(), "run"); err != nil { + t.Errorf("expected no error, got %v", err) + } + if !afterFinishCalled { + t.Error("expected after_finish callback to have been executed but it wasn't") + } + + currentState := fsm.Current() + if currentState != "start" { + t.Errorf("expected state to be 'start', was '%s'", currentState) + } +} + func TestContextInCallbacks(t *testing.T) { var fsm *FSM var enterEndAsyncWorkDone bool @@ -711,6 +753,11 @@ func TestContextInCallbacks(t *testing.T) { <-ctx.Done() enterEndAsyncWorkDone = true }() + + <-ctx.Done() + if err := e.FSM.Event(ctx, "finish"); err != nil { + e.Err = fmt.Errorf("transitioning to the finished state failed: %w", err) + } }, }, ) @@ -719,7 +766,10 @@ func TestContextInCallbacks(t *testing.T) { go func() { cancel() }() - fsm.Event(ctx, "run") + err := fsm.Event(ctx, "run") + if !errors.Is(err, context.Canceled) { + t.Errorf("expected 'context canceled' error, got %v", err) + } time.Sleep(20 * time.Millisecond) if !enterEndAsyncWorkDone { From 94fb353ef273319b14928aa3a04ab4e01a58d6da Mon Sep 17 00:00:00 2001 From: Daniel Lohse Date: Sun, 12 Dec 2021 16:30:28 +0100 Subject: [PATCH 2/2] Allow async state transition to be canceled This adds a context and cancelation facility to the type `AsyncError`. Async state transitions can now be canceled by calling `CancelTransition` on the AsyncError returned by `fsm.Event`. The context on that error can also be handed off as described in https://github.com/looplab/fsm/pull/77#issuecomment-930173282. --- errors.go | 7 ++++ fsm.go | 10 +++-- fsm_test.go | 36 ++++++++++++++++ uncancel_context.go | 21 ++++++++++ uncancel_context_test.go | 91 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 162 insertions(+), 3 deletions(-) create mode 100644 uncancel_context.go create mode 100644 uncancel_context_test.go diff --git a/errors.go b/errors.go index 9c32a49..add5dbe 100644 --- a/errors.go +++ b/errors.go @@ -14,6 +14,10 @@ package fsm +import ( + "context" +) + // InvalidEventError is returned by FSM.Event() when the event cannot be called // in the current state. type InvalidEventError struct { @@ -82,6 +86,9 @@ func (e CanceledError) Error() string { // asynchronous state transition. type AsyncError struct { Err error + + Ctx context.Context + CancelTransition func() } func (e AsyncError) Error() string { diff --git a/fsm.go b/fsm.go index 8379e95..c5d9efd 100644 --- a/fsm.go +++ b/fsm.go @@ -366,8 +366,12 @@ func (f *FSM) Event(ctx context.Context, event string, args ...interface{}) erro f.transition = nil } else if asyncError, ok := err.(AsyncError); ok { // setup a new context in order for async state transitions to work correctly - ctx, cancel := context.WithCancel(context.Background()) + // this "uncancels" the original context which ignores its cancelation + // but keeps the values of the original context available to callers + ctx, cancel := uncancelContext(ctx) e.cancelFunc = cancel + asyncError.Ctx = ctx + asyncError.CancelTransition = cancel f.transition = transitionFunc(ctx, true) return asyncError } @@ -440,7 +444,7 @@ func (f *FSM) leaveStateCallbacks(ctx context.Context, e *Event) error { if e.canceled { return CanceledError{e.Err} } else if e.async { - return AsyncError{e.Err} + return AsyncError{Err: e.Err} } } if fn, ok := f.callbacks[cKey{"", callbackLeaveState}]; ok { @@ -448,7 +452,7 @@ func (f *FSM) leaveStateCallbacks(ctx context.Context, e *Event) error { if e.canceled { return CanceledError{e.Err} } else if e.async { - return AsyncError{e.Err} + return AsyncError{Err: e.Err} } } return nil diff --git a/fsm_test.go b/fsm_test.go index a3d1a44..985b0ea 100644 --- a/fsm_test.go +++ b/fsm_test.go @@ -527,6 +527,42 @@ func TestAsyncTransitionNotInProgress(t *testing.T) { } } +func TestCancelAsyncTransition(t *testing.T) { + fsm := NewFSM( + "start", + Events{ + {Name: "run", Src: []string{"start"}, Dst: "end"}, + }, + Callbacks{ + "leave_start": func(_ context.Context, e *Event) { + e.Async() + }, + }, + ) + err := fsm.Event(context.Background(), "run") + asyncError, ok := err.(AsyncError) + if !ok { + t.Errorf("expected error to be 'AsyncError', got %v", err) + } + var asyncStateTransitionWasCanceled bool + go func() { + <-asyncError.Ctx.Done() + asyncStateTransitionWasCanceled = true + }() + asyncError.CancelTransition() + time.Sleep(20 * time.Millisecond) + + if err = fsm.Transition(); err != nil { + t.Errorf("expected no error, got %v", err) + } + if !asyncStateTransitionWasCanceled { + t.Error("expected async state transition cancelation to have propagated") + } + if fsm.Current() != "start" { + t.Error("expected state to be 'start'") + } +} + func TestCallbackNoError(t *testing.T) { fsm := NewFSM( "start", diff --git a/uncancel_context.go b/uncancel_context.go new file mode 100644 index 0000000..eabdeac --- /dev/null +++ b/uncancel_context.go @@ -0,0 +1,21 @@ +package fsm + +import ( + "context" + "time" +) + +type uncancel struct { + context.Context +} + +func (*uncancel) Deadline() (deadline time.Time, ok bool) { return } +func (*uncancel) Done() <-chan struct{} { return nil } +func (*uncancel) Err() error { return nil } + +// uncancelContext returns a context which ignores the cancellation of the parent and only keeps the values. +// Also returns a new cancel function. +// This is useful to keep a background task running while the initial request is finished. +func uncancelContext(ctx context.Context) (context.Context, context.CancelFunc) { + return context.WithCancel(&uncancel{ctx}) +} diff --git a/uncancel_context_test.go b/uncancel_context_test.go new file mode 100644 index 0000000..086da78 --- /dev/null +++ b/uncancel_context_test.go @@ -0,0 +1,91 @@ +package fsm + +import ( + "context" + "testing" +) + +func TestUncancel(t *testing.T) { + t.Run("create a new context", func(t *testing.T) { + t.Run("and cancel it", func(t *testing.T) { + ctx := context.Background() + ctx = context.WithValue(ctx, "key1", "value1") + ctx, cancelFunc := context.WithCancel(ctx) + cancelFunc() + + if ctx.Err() != context.Canceled { + t.Errorf("expected context error 'context canceled', got %v", ctx.Err()) + } + select { + case <-ctx.Done(): + default: + t.Error("expected context to be done but it wasn't") + } + + t.Run("and uncancel it", func(t *testing.T) { + ctx, newCancelFunc := uncancelContext(ctx) + if ctx.Err() != nil { + t.Errorf("expected context error to be nil, got %v", ctx.Err()) + } + select { + case <-ctx.Done(): + t.Fail() + default: + } + + t.Run("now it should still contain the values", func(t *testing.T) { + if ctx.Value("key1") != "value1" { + t.Errorf("expected context value of key 'key1' to be 'value1', got %v", ctx.Value("key1")) + } + }) + t.Run("and cancel the child", func(t *testing.T) { + newCancelFunc() + if ctx.Err() != context.Canceled { + t.Errorf("expected context error 'context canceled', got %v", ctx.Err()) + } + select { + case <-ctx.Done(): + default: + t.Error("expected context to be done but it wasn't") + } + }) + }) + }) + t.Run("and uncancel it", func(t *testing.T) { + ctx := context.Background() + parent := ctx + ctx, newCancelFunc := uncancelContext(ctx) + if ctx.Err() != nil { + t.Errorf("expected context error to be nil, got %v", ctx.Err()) + } + select { + case <-ctx.Done(): + t.Fail() + default: + } + + t.Run("and cancel the child", func(t *testing.T) { + newCancelFunc() + if ctx.Err() != context.Canceled { + t.Errorf("expected context error 'context canceled', got %v", ctx.Err()) + } + select { + case <-ctx.Done(): + default: + t.Error("expected context to be done but it wasn't") + } + + t.Run("and ensure the parent is not affected", func(t *testing.T) { + if parent.Err() != nil { + t.Errorf("expected parent context error to be nil, got %v", ctx.Err()) + } + select { + case <-parent.Done(): + t.Fail() + default: + } + }) + }) + }) + }) +}