Skip to content

Commit

Permalink
Continue to simplify the builder helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Jun 7, 2024
1 parent d577f71 commit ec6e73e
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 139 deletions.
12 changes: 6 additions & 6 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestClient(t *testing.T) {
t.Fatal(err)
}

call := dispatch.NewCall("http://example.com", "function1", dispatch.Int(11))
call := dispatch.NewCall("http://example.com", "function1", dispatch.Input(dispatch.Int(11)))

_, err = client.Dispatch(context.Background(), call)
if err != nil {
Expand All @@ -42,7 +42,7 @@ func TestClientEnvConfig(t *testing.T) {
t.Fatal(err)
}

call := dispatch.NewCall("http://example.com", "function1", dispatch.Int(11))
call := dispatch.NewCall("http://example.com", "function1", dispatch.Input(dispatch.Int(11)))

_, err = client.Dispatch(context.Background(), call)
if err != nil {
Expand All @@ -64,10 +64,10 @@ func TestClientBatch(t *testing.T) {
t.Fatal(err)
}

call1 := dispatch.NewCall("http://example.com", "function1", dispatch.Int(11))
call2 := dispatch.NewCall("http://example.com", "function2", dispatch.Int(22))
call3 := dispatch.NewCall("http://example.com", "function3", dispatch.Int(33))
call4 := dispatch.NewCall("http://example2.com", "function4", dispatch.Int(44))
call1 := dispatch.NewCall("http://example.com", "function1", dispatch.Input(dispatch.Int(11)))
call2 := dispatch.NewCall("http://example.com", "function2", dispatch.Input(dispatch.Int(22)))
call3 := dispatch.NewCall("http://example.com", "function3", dispatch.Input(dispatch.Int(33)))
call4 := dispatch.NewCall("http://example2.com", "function4", dispatch.Input(dispatch.Int(44)))

batch := client.Batch()
batch.Add(call1, call2)
Expand Down
2 changes: 1 addition & 1 deletion dispatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func (d *dispatchFunctionServiceHandler) Run(ctx context.Context, req *connect.R
var res Response
fn := d.dispatch.lookupFunction(req.Msg.Function)
if fn == nil {
res = NewResponseWithErrorf("%w: function %q not found", ErrNotFound, req.Msg.Function)
res = NewResponseErrorf("%w: function %q not found", ErrNotFound, req.Msg.Function)
} else {
res = fn.Run(ctx, Request{req.Msg})
}
Expand Down
11 changes: 7 additions & 4 deletions dispatch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ func TestDispatchEndpoint(t *testing.T) {
endpoint.Register(dispatch.NewPrimitiveFunction("identity", func(ctx context.Context, req dispatch.Request) dispatch.Response {
input, ok := req.Input()
if !ok {
return dispatch.NewResponseWithErrorf("%w: unexpected request: %v", dispatch.ErrInvalidArgument, req)
return dispatch.NewResponseErrorf("%w: unexpected request: %v", dispatch.ErrInvalidArgument, req)
}
return dispatch.NewResponseWithOutput(input)
return dispatch.NewResponse(dispatch.OKStatus, dispatch.Output(input))
}))

// Send a request for the identity function, and check that the
Expand Down Expand Up @@ -101,7 +101,8 @@ func TestDispatchCall(t *testing.T) {
recorder.Assert(t, dispatchtest.DispatchRequest{
ApiKey: "foobar",
Calls: []dispatch.Call{
dispatch.NewCall("http://example.com", "function1", dispatch.Int(11),
dispatch.NewCall("http://example.com", "function1",
dispatch.Input(dispatch.Int(11)),
dispatch.Expiration(10*time.Second)),
},
})
Expand Down Expand Up @@ -133,7 +134,9 @@ func TestDispatchCallEnvConfig(t *testing.T) {
recorder.Assert(t, dispatchtest.DispatchRequest{
ApiKey: "foobar",
Calls: []dispatch.Call{
dispatch.NewCall("http://example.com", "function2", dispatch.String("foo"), dispatch.Version("xyzzy")),
dispatch.NewCall("http://example.com", "function2",
dispatch.Input(dispatch.String("foo")),
dispatch.Version("xyzzy")),
},
})
}
Expand Down
54 changes: 29 additions & 25 deletions function.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package dispatch
import (
"context"
"fmt"
"slices"

"github.com/stealthrocket/coroutine"
"google.golang.org/protobuf/proto"
Expand All @@ -25,48 +26,48 @@ type Function interface {
}

// NewFunction creates a Dispatch function.
func NewFunction[Input, Output proto.Message](name string, fn func(context.Context, Input) (Output, error)) *GenericFunction[Input, Output] {
return &GenericFunction[Input, Output]{name: name, fn: fn}
func NewFunction[I, O proto.Message](name string, fn func(context.Context, I) (O, error)) *GenericFunction[I, O] {
return &GenericFunction[I, O]{name: name, fn: fn}
}

// GenericFunction is a Dispatch function that accepts arbitrary input
// and returns arbitrary output.
type GenericFunction[Input, Output proto.Message] struct {
type GenericFunction[I, O proto.Message] struct {
name string
fn func(ctx context.Context, input Input) (Output, error)
fn func(ctx context.Context, input I) (O, error)

endpoint *Dispatch
}

// Name is the name of the function.
func (f *GenericFunction[Input, Output]) Name() string {
func (f *GenericFunction[I, O]) Name() string {
return f.name
}

// Run runs the function.
func (f *GenericFunction[Input, Output]) Run(ctx context.Context, req Request) Response {
func (f *GenericFunction[I, O]) Run(ctx context.Context, req Request) Response {
var coro coroutine.Coroutine[any, any]
var zero Input
var zero I

if boxedInput, ok := req.Input(); ok {
message, err := boxedInput.Proto()
if err != nil {
return NewResponseWithErrorf("%w: invalid input: %v", ErrInvalidArgument, err)
return NewResponseErrorf("%w: invalid input: %v", ErrInvalidArgument, err)
}
input, ok := message.(Input)
input, ok := message.(I)
if !ok {
return NewResponseWithErrorf("%w: invalid input type: %T", ErrInvalidArgument, message)
return NewResponseErrorf("%w: invalid input type: %T", ErrInvalidArgument, message)
}
coro = coroutine.NewWithReturn[any, any](f.entrypoint(input))

} else if pollResult, ok := req.PollResult(); ok {
coro = coroutine.NewWithReturn[any, any](f.entrypoint(zero))
if err := coro.Context().Unmarshal(pollResult.CoroutineState()); err != nil {
return NewResponseWithErrorf("%w: invalid coroutine state: %v", ErrIncompatibleState, err)
return NewResponseErrorf("%w: invalid coroutine state: %v", ErrIncompatibleState, err)
}

} else {
return NewResponseWithErrorf("%w: unsupported request directive: %v", ErrInvalidArgument, req)
return NewResponseErrorf("%w: unsupported request directive: %v", ErrInvalidArgument, req)
}

// When running in volatile mode, we cannot snapshot the coroutine state
Expand All @@ -79,20 +80,20 @@ func (f *GenericFunction[Input, Output]) Run(ctx context.Context, req Request) R
return nil
})
if canceled {
return NewResponseWithError(context.Cause(ctx))
return NewResponseError(context.Cause(ctx))
}
}

var res Response
if coro.Next() {
coroutineState, err := coro.Context().Marshal()
if err != nil {
return NewResponseWithErrorf("%w: cannot serialize coroutine: %v", ErrPermanent, err)
return NewResponseErrorf("%w: cannot serialize coroutine: %v", ErrPermanent, err)
}
switch yield := coro.Recv().(type) {
// TODO
default:
res = NewResponseWithErrorf("%w: unsupported coroutine yield: %T", ErrInvalidResponse, yield)
res = NewResponseErrorf("%w: unsupported coroutine yield: %T", ErrInvalidResponse, yield)
}
// TODO
_ = coroutineState
Expand All @@ -101,38 +102,40 @@ func (f *GenericFunction[Input, Output]) Run(ctx context.Context, req Request) R
case proto.Message:
output, err := NewAny(ret)
if err != nil {
res = NewResponseWithErrorf("%w: cannot serialize return value: %v", ErrInvalidResponse, err)
res = NewResponseErrorf("%w: cannot serialize return value: %v", ErrInvalidResponse, err)
} else {
res = NewResponseWithOutput(output)
// TODO: automatically derive a status from the ret value
res = NewResponse(StatusOf(ret), Output(output))
}
case error:
res = NewResponseWithError(ret)
res = NewResponseError(ret)
default:
res = NewResponseWithErrorf("%w: unsupported coroutine return: %T", ErrInvalidResponse, ret)
res = NewResponseErrorf("%w: unsupported coroutine return: %T", ErrInvalidResponse, ret)
}
}

return res
}

func (f *GenericFunction[Input, Output]) bind(endpoint *Dispatch) {
func (f *GenericFunction[I, O]) bind(endpoint *Dispatch) {
f.endpoint = endpoint
}

// NewCall creates a Call for the function.
func (f *GenericFunction[Input, Output]) NewCall(input Input, opts ...CallOption) (Call, error) {
func (f *GenericFunction[I, O]) NewCall(input I, opts ...CallOption) (Call, error) {
if f.endpoint == nil {
return Call{}, fmt.Errorf("cannot build function call: function has not been registered with a Dispatch endpoint")
}
anyInput, err := NewAny(input)
if err != nil {
return Call{}, fmt.Errorf("cannot serialize input: %v", err)
}
return NewCall(f.endpoint.URL(), f.name, anyInput, opts...), nil
opts = append(slices.Clip(opts), Input(anyInput))
return NewCall(f.endpoint.URL(), f.name, opts...), nil
}

// Dispatch dispatches a call to the function.
func (f *GenericFunction[Input, Output]) Dispatch(ctx context.Context, input Input, opts ...CallOption) (ID, error) {
func (f *GenericFunction[I, O]) Dispatch(ctx context.Context, input I, opts ...CallOption) (ID, error) {
call, err := f.NewCall(input, opts...)
if err != nil {
return "", err
Expand All @@ -145,7 +148,7 @@ func (f *GenericFunction[Input, Output]) Dispatch(ctx context.Context, input Inp
}

//go:noinline
func (f *GenericFunction[Input, Output]) entrypoint(input Input) func() any {
func (f *GenericFunction[I, O]) entrypoint(input I) func() any {
return func() any {
// The context that gets passed as argument here should be recreated
// each time the coroutine is resumed, ideally inheriting from the
Expand Down Expand Up @@ -193,7 +196,8 @@ func (f *PrimitiveFunction) NewCall(input Any, opts ...CallOption) (Call, error)
if f.endpoint == nil {
return Call{}, fmt.Errorf("cannot build function call: function has not been registered with a Dispatch endpoint")
}
return NewCall(f.endpoint.URL(), f.name, input, opts...), nil
opts = append(slices.Clip(opts), Input(input))
return NewCall(f.endpoint.URL(), f.name, opts...), nil
}

// Dispatch dispatches a call to the function.
Expand Down
Loading

0 comments on commit ec6e73e

Please sign in to comment.