Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
Allow async state transition to be canceled
Browse files Browse the repository at this point in the history
This adds a context and cancelation facility to the type `AsyncError`.
Async state transitions can now be canceled by calling `CancelTransition`
on the AsyncError returned by `fsm.Event`. The context on that error can
also be handed off as described in looplab#77 (comment).
  • Loading branch information
annismckenzie committed Dec 12, 2021
1 parent 74cac90 commit 7db91e5
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 26 deletions.
8 changes: 7 additions & 1 deletion errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

package fsm

import "errors"
import (
"context"
"errors"
)

// InvalidEventError is returned by FSM.Event() when the event cannot be called
// in the current state.
Expand Down Expand Up @@ -84,6 +87,9 @@ func (e CanceledError) Error() string {
// asynchronous state transition.
type AsyncError struct {
Err error

Ctx context.Context
CancelTransition func()
}

func (e AsyncError) Error() string {
Expand Down
58 changes: 36 additions & 22 deletions fsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,9 @@ func (f *FSM) Event(ctx context.Context, event string, args ...interface{}) erro
return UnknownEventError{event}
}

e := &Event{f, event, f.current, dst, nil, args, false, false}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
e.cancelFunc = cancel
e := &Event{f, event, f.current, dst, nil, args, false, false, cancel}

if err := f.beforeEventCallbacks(ctx, e); err != nil {
return err
Expand All @@ -332,8 +331,15 @@ func (f *FSM) Event(ctx context.Context, event string, args ...interface{}) erro
}

// Setup the transition, call it later.
transitionFunc := func(async bool) func() {
transitionFunc := func(ctx context.Context, async bool) func() {
return func() {
if ctx.Err() != nil {
if e.Err == nil {
e.Err = ctx.Err()
}
return
}

f.stateMu.Lock()
f.current = dst
f.stateMu.Unlock()
Expand All @@ -352,13 +358,21 @@ func (f *FSM) Event(ctx context.Context, event string, args ...interface{}) erro
}
}

f.transition = transitionFunc(false)
f.transition = transitionFunc(ctx, false)

if err := f.leaveStateCallbacks(ctx, e); err != nil {
if _, ok := err.(CanceledError); ok {
f.transition = nil
} else if _, ok := err.(AsyncError); ok {
f.transition = transitionFunc(true)
} else if asyncError, ok := err.(AsyncError); ok {
// setup a new context in order for async state transitions to work correctly
// this "uncancels" the original context which ignores its cancelation
// but keeps the values of the original context available to callers
ctx, cancel := uncancelContext(ctx)
e.cancelFunc = cancel
asyncError.Ctx = ctx
asyncError.CancelTransition = cancel
f.transition = transitionFunc(ctx, true)
return asyncError
}
return err
}
Expand Down Expand Up @@ -404,63 +418,63 @@ func (t transitionerStruct) transition(f *FSM) error {

// beforeEventCallbacks calls the before_ callbacks, first the named then the
// general version.
fn(e)
if e.canceled {
return CanceledError{e.Err}
func (f *FSM) beforeEventCallbacks(ctx context.Context, e *Event) error {
if fn, ok := f.callbacks[cKey{e.Event, callbackTypeBeforeEvent}]; ok {
fn(ctx, e)
if e.canceled {
return CanceledError{e.Err}
}
}
fn(e)
if fn, ok := f.callbacks[cKey{"", callbackTypeBeforeEvents}]; ok {
fn(ctx, e)
if e.canceled {
return CanceledError{e.Err}
if fn, ok := f.callbacks[cKey{"", callbackTypeBeforeEvents}]; ok {
}
}
return nil
}

// leaveStateCallbacks calls the leave_ callbacks, first the named then the
// general version.
fn(e)
func (f *FSM) leaveStateCallbacks(ctx context.Context, e *Event) error {
if fn, ok := f.callbacks[cKey{f.current, callbackTypeLeaveState}]; ok {
fn(ctx, e)
if e.canceled {
return CanceledError{e.Err}
} else if e.async {
return AsyncError{e.Err}
func (f *FSM) leaveStateCallbacks(ctx context.Context, e *Event) error {
if fn, ok := f.callbacks[cKey{f.current, callbackTypeLeaveState}]; ok {
return AsyncError{Err: e.Err}
}
}
fn(e)
if fn, ok := f.callbacks[cKey{"", callbackTypeLeaveStates}]; ok {
fn(ctx, e)
if e.canceled {
return CanceledError{e.Err}
} else if e.async {
return AsyncError{e.Err}
if fn, ok := f.callbacks[cKey{"", callbackTypeLeaveStates}]; ok {
return AsyncError{Err: e.Err}
}
}
return nil
}

// enterStateCallbacks calls the enter_ callbacks, first the named then the
// general version.
fn(e)
func (f *FSM) enterStateCallbacks(ctx context.Context, e *Event) {
if fn, ok := f.callbacks[cKey{f.current, callbackTypeEnterState}]; ok {
fn(ctx, e)
}
fn(e)
if fn, ok := f.callbacks[cKey{"", callbackTypeEnterStates}]; ok {
fn(ctx, e)
}
}

// afterEventCallbacks calls the after_ callbacks, first the named then the
// general version.
fn(e)
func (f *FSM) afterEventCallbacks(ctx context.Context, e *Event) {
if fn, ok := f.callbacks[cKey{e.Event, callbackTypeAfterEvent}]; ok {
fn(ctx, e)
}
fn(e)
if fn, ok := f.callbacks[cKey{"", callbackTypeAfterEvents}]; ok {
fn(ctx, e)
}
}

Expand Down
49 changes: 46 additions & 3 deletions fsm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,42 @@ func TestAsyncTransitionNotInProgress(t *testing.T) {
}
}

func TestCancelAsyncTransition(t *testing.T) {
fsm := NewFSM(
"start",
Events{
{Name: "run", Src: []string{"start"}, Dst: "end"},
},
Callbacks{
"leave_start": func(_ context.Context, e *Event) {
e.Async()
},
},
)
err := fsm.Event(context.Background(), "run")
asyncError, ok := err.(AsyncError)
if !ok {
t.Errorf("expected error to be 'AsyncError', got %v", err)
}
var asyncStateTransitionWasCanceled bool
go func() {
<-asyncError.Ctx.Done()
asyncStateTransitionWasCanceled = true
}()
asyncError.CancelTransition()
time.Sleep(20 * time.Millisecond)

if err = fsm.Transition(); err != nil {
t.Errorf("expected no error, got %v", err)
}
if !asyncStateTransitionWasCanceled {
t.Error("expected async state transition cancelation to have propagated")
}
if fsm.Current() != "start" {
t.Error("expected state to be 'start'")
}
}

func TestCallbackNoError(t *testing.T) {
fsm := NewFSM(
"start",
Expand Down Expand Up @@ -713,14 +749,21 @@ func TestContextInCallbacks(t *testing.T) {
<-ctx.Done()
enterEndAsyncWorkDone = true
}()
e.Err = e.Transition(ctx, "finish")

<-ctx.Done()
if err := e.Transition(ctx, "finish"); err != nil {
e.Err = fmt.Errorf("transitioning to the finished state failed: %w", err)
}
},
},
)

ctx, cancel := context.WithCancel(context.Background())
cancel()
if err := fsm.Event(ctx, "run"); err != context.Canceled {
go func() {
cancel()
}()
err := fsm.Event(ctx, "run")
if !errors.Is(err, context.Canceled) {
t.Errorf("expected 'context canceled' error, got %v", err)
}
time.Sleep(20 * time.Millisecond)
Expand Down
21 changes: 21 additions & 0 deletions uncancel_context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package fsm

import (
"context"
"time"
)

type uncancel struct {
context.Context
}

func (*uncancel) Deadline() (deadline time.Time, ok bool) { return }
func (*uncancel) Done() <-chan struct{} { return nil }
func (*uncancel) Err() error { return nil }

// uncancelContext returns a context which ignores the cancellation of the parent and only keeps the values.
// Also returns a new cancel function.
// This is useful to keep a background task running while the initial request is finished.
func uncancelContext(ctx context.Context) (context.Context, context.CancelFunc) {
return context.WithCancel(&uncancel{ctx})
}
91 changes: 91 additions & 0 deletions uncancel_context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package fsm

import (
"context"
"testing"
)

func TestUncancel(t *testing.T) {
t.Run("create a new context", func(t *testing.T) {
t.Run("and cancel it", func(t *testing.T) {
ctx := context.Background()
ctx = context.WithValue(ctx, "key1", "value1")
ctx, cancelFunc := context.WithCancel(ctx)
cancelFunc()

if ctx.Err() != context.Canceled {
t.Errorf("expected context error 'context canceled', got %v", ctx.Err())
}
select {
case <-ctx.Done():
default:
t.Error("expected context to be done but it wasn't")
}

t.Run("and uncancel it", func(t *testing.T) {
ctx, newCancelFunc := uncancelContext(ctx)
if ctx.Err() != nil {
t.Errorf("expected context error to be nil, got %v", ctx.Err())
}
select {
case <-ctx.Done():
t.Fail()
default:
}

t.Run("now it should still contain the values", func(t *testing.T) {
if ctx.Value("key1") != "value1" {
t.Errorf("expected context value of key 'key1' to be 'value1', got %v", ctx.Value("key1"))
}
})
t.Run("and cancel the child", func(t *testing.T) {
newCancelFunc()
if ctx.Err() != context.Canceled {
t.Errorf("expected context error 'context canceled', got %v", ctx.Err())
}
select {
case <-ctx.Done():
default:
t.Error("expected context to be done but it wasn't")
}
})
})
})
t.Run("and uncancel it", func(t *testing.T) {
ctx := context.Background()
parent := ctx
ctx, newCancelFunc := uncancelContext(ctx)
if ctx.Err() != nil {
t.Errorf("expected context error to be nil, got %v", ctx.Err())
}
select {
case <-ctx.Done():
t.Fail()
default:
}

t.Run("and cancel the child", func(t *testing.T) {
newCancelFunc()
if ctx.Err() != context.Canceled {
t.Errorf("expected context error 'context canceled', got %v", ctx.Err())
}
select {
case <-ctx.Done():
default:
t.Error("expected context to be done but it wasn't")
}

t.Run("and ensure the parent is not affected", func(t *testing.T) {
if parent.Err() != nil {
t.Errorf("expected parent context error to be nil, got %v", ctx.Err())
}
select {
case <-parent.Done():
t.Fail()
default:
}
})
})
})
})
}

0 comments on commit 7db91e5

Please sign in to comment.