Skip to content

Commit

Permalink
Simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Jun 18, 2024
1 parent 7f863e8 commit 4d69531
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 87 deletions.
144 changes: 62 additions & 82 deletions coroutine.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,79 +24,72 @@ func NewCoroutine[I, O any](name string, fn func(context.Context, I) (O, error))
type GenericCoroutine[I, O any] struct {
GenericFunction[I, O]

instances map[instanceID]coroutine.Coroutine[CoroR[O], CoroS]
nextID instanceID
instances map[coroutineID]dispatchCoroutine
nextID coroutineID
mu sync.Mutex
}

type instanceID = int
type coroutineID = int
type dispatchCoroutine = coroutine.Coroutine[Response, Request]

// Run runs the coroutine function.
func (c *GenericCoroutine[I, O]) Run(ctx context.Context, req Request) Response {
if name := req.Function(); name != c.name {
return NewResponseErrorf("%w: function %q received call for function %q", ErrInvalidArgument, c.name, name)
}

var id instanceID
var coro coroutine.Coroutine[CoroR[O], CoroS]
// Create or deserialize the coroutine (depending on the type of request).
id, coro, err := c.setUp(req)
if err != nil {
return NewResponseError(err)
}
defer c.tearDown(id, coro)

if _, ok := req.Input(); ok {
// Start a new coroutine if the request carries function input.
input, err := c.unpackInput(req)
if err != nil {
return NewResponseError(err)
}
id, coro = c.setup(input)
// Send results from Dispatch to the coroutine (if applicable).
coro.Send(req)

} else if pollResult, ok := req.PollResult(); ok {
// Otherwise, resume a coroutine that is suspended.
var err error
id, coro, err = c.deserialize(pollResult.CoroutineState())
if err != nil {
return NewResponseError(err)
}

// Send poll results to the coroutine.
coro.Send(CoroS{directive: pollResult})
// Run the coroutine until it yields or returns.
if !coro.Next() {
// The coroutine returned; return the result to Dispatch.
return coro.Result()
}

// Tidy up the coroutine when returning.
defer c.tearDown(id, coro)
// The coroutine yielded and is now paused.
yield := coro.Recv()

// Run the coroutine until it yields or returns.
if coro.Next() {
// The coroutine yielded and is now paused.
yield := coro.Recv()

// Serialize the coroutine, unless it's a terminal Exit directive.
var state Any
if _, terminal := yield.directive.(Exit); !terminal {
var err error
state, err = c.serialize(id, coro)
if err != nil {
return NewResponseError(err)
}
} else {
coro.Stop()
coro.Next()
}
// If the coroutine explicitly exited, stop it before yielding to Dispatch.
if _, exit := yield.Exit(); exit {
coro.Stop()
coro.Next()
return yield
}

// Yield to Dispatch with the directive.
return NewResponse(yield.status, yield.directive, CoroutineState(state))
// Serialize the coroutine state and yield to Dispatch with the directive.
state, err := c.serialize(id, coro)
if err != nil {
return NewResponseError(err)
}
return NewResponse(yield.Status(), yield, CoroutineState(state))
}

// The coroutine returned. Serialize the output / error.
result := coro.Result()
if result.err != nil {
// TODO: serialize the output too if present
return NewResponseError(result.err)
func (c *GenericCoroutine[I, O]) setUp(req Request) (id coroutineID, coro dispatchCoroutine, err error) {
// Start a new coroutine if the request carries function input.
// Otherwise, resume a coroutine that is suspended.
if _, ok := req.Input(); ok {
var input I
input, err = c.unpackInput(req)
if err == nil {
id, coro = c.create(input)
}
} else if pollResult, ok := req.PollResult(); ok {
id, coro, err = c.deserialize(pollResult.CoroutineState())
}
return c.packOutput(result.output)
return
}

func (c *GenericCoroutine[I, O]) setup(input I) (instanceID, coroutine.Coroutine[CoroR[O], CoroS]) {
var id instanceID
coro := coroutine.NewWithReturn[CoroR[O], CoroS](c.entrypoint(input))
func (c *GenericCoroutine[I, O]) create(input I) (coroutineID, dispatchCoroutine) {
var id coroutineID
coro := coroutine.NewWithReturn[Response, Request](c.entrypoint(input))

// In volatile mode, we need to create an "instance" of the coroutine that
// resides in memory.
Expand All @@ -109,15 +102,15 @@ func (c *GenericCoroutine[I, O]) setup(input I) (instanceID, coroutine.Coroutine
c.nextID++
id = c.nextID
if c.instances == nil {
c.instances = map[instanceID]coroutine.Coroutine[CoroR[O], CoroS]{}
c.instances = map[coroutineID]dispatchCoroutine{}
}
c.instances[id] = coro
}

return id, coro
}

func (c *GenericCoroutine[I, O]) tearDown(id instanceID, coro coroutine.Coroutine[CoroR[O], CoroS]) {
func (c *GenericCoroutine[I, O]) tearDown(id coroutineID, coro dispatchCoroutine) {
// Always tear down durable coroutines. They'll be rebuilt
// on the next call (if applicable) from their serialized state,
// possibly in a new location.
Expand All @@ -135,7 +128,7 @@ func (c *GenericCoroutine[I, O]) tearDown(id instanceID, coro coroutine.Coroutin
}
}

func (c *GenericCoroutine[I, O]) serialize(id instanceID, coro coroutine.Coroutine[CoroR[O], CoroS]) (Any, error) {
func (c *GenericCoroutine[I, O]) serialize(id coroutineID, coro dispatchCoroutine) (Any, error) {
// In volatile mode, serialize a reference to the coroutine instance.
if !coroutine.Durable {
return Int(id), nil
Expand All @@ -150,14 +143,14 @@ func (c *GenericCoroutine[I, O]) serialize(id instanceID, coro coroutine.Corouti
return state, nil
}

func (c *GenericCoroutine[I, O]) deserialize(state Any) (instanceID, coroutine.Coroutine[CoroR[O], CoroS], error) {
var id instanceID
var coro coroutine.Coroutine[CoroR[O], CoroS]
func (c *GenericCoroutine[I, O]) deserialize(state Any) (coroutineID, dispatchCoroutine, error) {
var id coroutineID
var coro dispatchCoroutine

// Deserialize durable coroutine state.
if coroutine.Durable {
var zero I
coro = coroutine.NewWithReturn[CoroR[O], CoroS](c.entrypoint(zero))
coro = coroutine.NewWithReturn[Response, Request](c.entrypoint(zero))
if state.TypeURL() != durableCoroutineStateTypeUrl {
return 0, coro, fmt.Errorf("%w: unexpected type URL: %q", ErrIncompatibleState, state.TypeURL())
} else if err := coro.Context().Unmarshal(state.Value()); err != nil {
Expand Down Expand Up @@ -201,42 +194,29 @@ func (c *GenericCoroutine[I, O]) Close() error {
return nil
}

func (c *GenericCoroutine[I, O]) entrypoint(input I) func() CoroR[O] {
return func() CoroR[O] {
func (c *GenericCoroutine[I, O]) entrypoint(input I) func() Response {
return func() Response {
// The context that gets passed as argument here should be recreated
// each time the coroutine is resumed, ideally inheriting from the
// parent context passed to the Run method. This is difficult to
// do right in durable mode because we shouldn't capture the parent
// context in the coroutine state.
var r CoroR[O]
r.output, r.err = c.fn(context.TODO(), input)
return r
output, err := c.fn(context.TODO(), input)
if err != nil {
// TODO: include output if not nil
return NewResponseError(err)
}
return c.packOutput(output)
}
}

type CoroS struct {
directive RequestDirective
}

type CoroR[O any] struct {
status Status
directive ResponseDirective

output O
err error
}

// Yield yields control to Dispatch.
//
// The coroutine is paused, serialized and sent to Dispatch. The
// directive instructs Dispatch to perform an operation while
// the coroutine is suspended. Once the operation is complete,
// Dispatch yields control back to the coroutine, which is resumed
// from the point execution was suspended.
func Yield[O any](status Status, directive ResponseDirective) RequestDirective {
result := coroutine.Yield[CoroR[O], CoroS](CoroR[O]{
status: status,
directive: directive,
})
return result.directive
func Yield(res Response) Request {
return coroutine.Yield[Response, Request](res)
}
12 changes: 7 additions & 5 deletions coroutine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func logMode(t *testing.T) {
}
}

func TestCoroutineReturnOnly(t *testing.T) {
func TestCoroutineReturn(t *testing.T) {
logMode(t)

coro := dispatch.NewCoroutine("stringify", func(ctx context.Context, in int) (string, error) {
Expand Down Expand Up @@ -57,17 +57,18 @@ func TestCoroutineReturnOnly(t *testing.T) {
}
}

func TestCoroutineExit(t *testing.T) {
func TestCoroutineYieldExitResponse(t *testing.T) {
logMode(t)

coro := dispatch.NewCoroutine("stringify", func(ctx context.Context, in int) (string, error) {
var res dispatch.Response
if in < 0 {
err := fmt.Errorf("%w: %d", dispatch.ErrInvalidArgument, in)
dispatch.Yield[string](dispatch.InvalidArgumentStatus, dispatch.NewExit(dispatch.NewError(err)))
res = dispatch.NewResponseErrorf("%w: %d", dispatch.ErrInvalidArgument, in)
} else {
output := dispatch.String(strconv.Itoa(in))
dispatch.Yield[string](dispatch.OKStatus, dispatch.NewExit(dispatch.Output(output)))
res = dispatch.NewResponse(dispatch.OKStatus, dispatch.Output(output))
}
dispatch.Yield(res)
panic("unreachable")
})
defer coro.Close()
Expand All @@ -76,6 +77,7 @@ func TestCoroutineExit(t *testing.T) {
if res.Status() != dispatch.OKStatus {
t.Errorf("unexpected status: %s", res.Status())
}
fmt.Println(res)
output, ok := res.Output()
if !ok {
t.Errorf("expected output, got: %s", res)
Expand Down
7 changes: 7 additions & 0 deletions proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,13 @@ func (r Response) Marshal() ([]byte, error) {
return proto.Marshal(r.proto)
}

func (r Response) configureResponse(other *Response) {
if other.proto != nil {
r.proto = proto.Clone(other.proto).(*sdkv1.RunResponse)
fmt.Println("CLONING", other, r)
}
}

func ensureResponseExitResult(r *Response) *sdkv1.CallResult {
var d *sdkv1.RunResponse_Exit
d, ok := r.proto.Directive.(*sdkv1.RunResponse_Exit)
Expand Down

0 comments on commit 4d69531

Please sign in to comment.