Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Allow transitions in callbacks #1

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

package fsm

import (
"context"
)

// InvalidEventError is returned by FSM.Event() when the event cannot be called
// in the current state.
type InvalidEventError struct {
Expand Down Expand Up @@ -82,6 +86,9 @@ func (e CanceledError) Error() string {
// asynchronous state transition.
type AsyncError struct {
Err error

Ctx context.Context
CancelTransition func()
}

func (e AsyncError) Error() string {
Expand Down
57 changes: 48 additions & 9 deletions fsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,16 @@ func (f *FSM) SetMetadata(key string, dataValue interface{}) {
// internal bug.
func (f *FSM) Event(ctx context.Context, event string, args ...interface{}) error {
f.eventMu.Lock()
defer f.eventMu.Unlock()
// in order to always unlock the event mutex, the defer is added
// in case the state transition goes through and enter/after callbacks
// are called; because these must be able to trigger new state
// transitions, it is explicitly unlocked in the code below
var unlocked bool
defer func() {
if !unlocked {
f.eventMu.Unlock()
}
}()

f.stateMu.RLock()
defer f.stateMu.RUnlock()
Expand Down Expand Up @@ -323,18 +332,48 @@ func (f *FSM) Event(ctx context.Context, event string, args ...interface{}) erro
}

// Setup the transition, call it later.
f.transition = func() {
f.stateMu.Lock()
f.current = dst
f.stateMu.Unlock()
transitionFunc := func(ctx context.Context, async bool) func() {
return func() {
if ctx.Err() != nil {
if e.Err == nil {
e.Err = ctx.Err()
}
return
}

f.enterStateCallbacks(ctx, e)
f.afterEventCallbacks(ctx, e)
f.stateMu.Lock()
f.current = dst
f.stateMu.Unlock()

// at this point, we unlock the event mutex in order to allow
// enter state callbacks to trigger another transition
// for aynchronous state transitions this doesn't happen because
// the event mutex has already been unlocked
if !async {
f.eventMu.Unlock()
unlocked = true
}
f.transition = nil // treat the state transition as done
f.enterStateCallbacks(ctx, e)
f.afterEventCallbacks(ctx, e)
}
}

f.transition = transitionFunc(ctx, false)

if err = f.leaveStateCallbacks(ctx, e); err != nil {
if _, ok := err.(CanceledError); ok {
f.transition = nil
} else if asyncError, ok := err.(AsyncError); ok {
// setup a new context in order for async state transitions to work correctly
// this "uncancels" the original context which ignores its cancelation
// but keeps the values of the original context available to callers
ctx, cancel := uncancelContext(ctx)
e.cancelFunc = cancel
asyncError.Ctx = ctx
asyncError.CancelTransition = cancel
f.transition = transitionFunc(ctx, true)
return asyncError
}
return err
}
Expand Down Expand Up @@ -405,15 +444,15 @@ func (f *FSM) leaveStateCallbacks(ctx context.Context, e *Event) error {
if e.canceled {
return CanceledError{e.Err}
} else if e.async {
return AsyncError{e.Err}
return AsyncError{Err: e.Err}
}
}
if fn, ok := f.callbacks[cKey{"", callbackLeaveState}]; ok {
fn(ctx, e)
if e.canceled {
return CanceledError{e.Err}
} else if e.async {
return AsyncError{e.Err}
return AsyncError{Err: e.Err}
}
}
return nil
Expand Down
88 changes: 87 additions & 1 deletion fsm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package fsm

import (
"context"
"errors"
"fmt"
"sort"
"sync"
Expand Down Expand Up @@ -526,6 +527,42 @@ func TestAsyncTransitionNotInProgress(t *testing.T) {
}
}

func TestCancelAsyncTransition(t *testing.T) {
fsm := NewFSM(
"start",
Events{
{Name: "run", Src: []string{"start"}, Dst: "end"},
},
Callbacks{
"leave_start": func(_ context.Context, e *Event) {
e.Async()
},
},
)
err := fsm.Event(context.Background(), "run")
asyncError, ok := err.(AsyncError)
if !ok {
t.Errorf("expected error to be 'AsyncError', got %v", err)
}
var asyncStateTransitionWasCanceled bool
go func() {
<-asyncError.Ctx.Done()
asyncStateTransitionWasCanceled = true
}()
asyncError.CancelTransition()
time.Sleep(20 * time.Millisecond)

if err = fsm.Transition(); err != nil {
t.Errorf("expected no error, got %v", err)
}
if !asyncStateTransitionWasCanceled {
t.Error("expected async state transition cancelation to have propagated")
}
if fsm.Current() != "start" {
t.Error("expected state to be 'start'")
}
}

func TestCallbackNoError(t *testing.T) {
fsm := NewFSM(
"start",
Expand Down Expand Up @@ -695,6 +732,47 @@ func TestDoubleTransition(t *testing.T) {
wg.Wait()
}

func TestTransitionInCallbacks(t *testing.T) {
var fsm *FSM
var afterFinishCalled bool
fsm = NewFSM(
"start",
Events{
{Name: "run", Src: []string{"start"}, Dst: "end"},
{Name: "finish", Src: []string{"end"}, Dst: "finished"},
{Name: "reset", Src: []string{"end", "finished"}, Dst: "start"},
},
Callbacks{
"enter_end": func(ctx context.Context, e *Event) {
if err := e.FSM.Event(ctx, "finish"); err != nil {
fmt.Println(err)
}
},
"after_finish": func(ctx context.Context, e *Event) {
afterFinishCalled = true
if e.Src != "end" {
panic(fmt.Sprintf("source should have been 'end' but was '%s'", e.Src))
}
if err := e.FSM.Event(ctx, "reset"); err != nil {
fmt.Println(err)
}
},
},
)

if err := fsm.Event(context.Background(), "run"); err != nil {
t.Errorf("expected no error, got %v", err)
}
if !afterFinishCalled {
t.Error("expected after_finish callback to have been executed but it wasn't")
}

currentState := fsm.Current()
if currentState != "start" {
t.Errorf("expected state to be 'start', was '%s'", currentState)
}
}

func TestContextInCallbacks(t *testing.T) {
var fsm *FSM
var enterEndAsyncWorkDone bool
Expand All @@ -711,6 +789,11 @@ func TestContextInCallbacks(t *testing.T) {
<-ctx.Done()
enterEndAsyncWorkDone = true
}()

<-ctx.Done()
if err := e.FSM.Event(ctx, "finish"); err != nil {
e.Err = fmt.Errorf("transitioning to the finished state failed: %w", err)
}
},
},
)
Expand All @@ -719,7 +802,10 @@ func TestContextInCallbacks(t *testing.T) {
go func() {
cancel()
}()
fsm.Event(ctx, "run")
err := fsm.Event(ctx, "run")
if !errors.Is(err, context.Canceled) {
t.Errorf("expected 'context canceled' error, got %v", err)
}
time.Sleep(20 * time.Millisecond)

if !enterEndAsyncWorkDone {
Expand Down
21 changes: 21 additions & 0 deletions uncancel_context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package fsm

import (
"context"
"time"
)

type uncancel struct {
context.Context
}

func (*uncancel) Deadline() (deadline time.Time, ok bool) { return }
func (*uncancel) Done() <-chan struct{} { return nil }
func (*uncancel) Err() error { return nil }

// uncancelContext returns a context which ignores the cancellation of the parent and only keeps the values.
// Also returns a new cancel function.
// This is useful to keep a background task running while the initial request is finished.
func uncancelContext(ctx context.Context) (context.Context, context.CancelFunc) {
return context.WithCancel(&uncancel{ctx})
}
91 changes: 91 additions & 0 deletions uncancel_context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package fsm

import (
"context"
"testing"
)

func TestUncancel(t *testing.T) {
t.Run("create a new context", func(t *testing.T) {
t.Run("and cancel it", func(t *testing.T) {
ctx := context.Background()
ctx = context.WithValue(ctx, "key1", "value1")
ctx, cancelFunc := context.WithCancel(ctx)
cancelFunc()

if ctx.Err() != context.Canceled {
t.Errorf("expected context error 'context canceled', got %v", ctx.Err())
}
select {
case <-ctx.Done():
default:
t.Error("expected context to be done but it wasn't")
}

t.Run("and uncancel it", func(t *testing.T) {
ctx, newCancelFunc := uncancelContext(ctx)
if ctx.Err() != nil {
t.Errorf("expected context error to be nil, got %v", ctx.Err())
}
select {
case <-ctx.Done():
t.Fail()
default:
}

t.Run("now it should still contain the values", func(t *testing.T) {
if ctx.Value("key1") != "value1" {
t.Errorf("expected context value of key 'key1' to be 'value1', got %v", ctx.Value("key1"))
}
})
t.Run("and cancel the child", func(t *testing.T) {
newCancelFunc()
if ctx.Err() != context.Canceled {
t.Errorf("expected context error 'context canceled', got %v", ctx.Err())
}
select {
case <-ctx.Done():
default:
t.Error("expected context to be done but it wasn't")
}
})
})
})
t.Run("and uncancel it", func(t *testing.T) {
ctx := context.Background()
parent := ctx
ctx, newCancelFunc := uncancelContext(ctx)
if ctx.Err() != nil {
t.Errorf("expected context error to be nil, got %v", ctx.Err())
}
select {
case <-ctx.Done():
t.Fail()
default:
}

t.Run("and cancel the child", func(t *testing.T) {
newCancelFunc()
if ctx.Err() != context.Canceled {
t.Errorf("expected context error 'context canceled', got %v", ctx.Err())
}
select {
case <-ctx.Done():
default:
t.Error("expected context to be done but it wasn't")
}

t.Run("and ensure the parent is not affected", func(t *testing.T) {
if parent.Err() != nil {
t.Errorf("expected parent context error to be nil, got %v", ctx.Err())
}
select {
case <-parent.Done():
t.Fail()
default:
}
})
})
})
})
}