diff --git a/client_test.go b/client_test.go index 79829c5..b3897aa 100644 --- a/client_test.go +++ b/client_test.go @@ -17,7 +17,7 @@ func TestClient(t *testing.T) { t.Fatal(err) } - call := dispatch.NewCall("http://example.com", "function1", dispatch.Int(11)) + call := dispatch.NewCall("http://example.com", "function1", dispatch.Input(dispatch.Int(11))) _, err = client.Dispatch(context.Background(), call) if err != nil { @@ -42,7 +42,7 @@ func TestClientEnvConfig(t *testing.T) { t.Fatal(err) } - call := dispatch.NewCall("http://example.com", "function1", dispatch.Int(11)) + call := dispatch.NewCall("http://example.com", "function1", dispatch.Input(dispatch.Int(11))) _, err = client.Dispatch(context.Background(), call) if err != nil { @@ -64,10 +64,10 @@ func TestClientBatch(t *testing.T) { t.Fatal(err) } - call1 := dispatch.NewCall("http://example.com", "function1", dispatch.Int(11)) - call2 := dispatch.NewCall("http://example.com", "function2", dispatch.Int(22)) - call3 := dispatch.NewCall("http://example.com", "function3", dispatch.Int(33)) - call4 := dispatch.NewCall("http://example2.com", "function4", dispatch.Int(44)) + call1 := dispatch.NewCall("http://example.com", "function1", dispatch.Input(dispatch.Int(11))) + call2 := dispatch.NewCall("http://example.com", "function2", dispatch.Input(dispatch.Int(22))) + call3 := dispatch.NewCall("http://example.com", "function3", dispatch.Input(dispatch.Int(33))) + call4 := dispatch.NewCall("http://example2.com", "function4", dispatch.Input(dispatch.Int(44))) batch := client.Batch() batch.Add(call1, call2) diff --git a/dispatch.go b/dispatch.go index 720f8f3..f5c9745 100644 --- a/dispatch.go +++ b/dispatch.go @@ -207,7 +207,7 @@ func (d *dispatchFunctionServiceHandler) Run(ctx context.Context, req *connect.R var res Response fn := d.dispatch.lookupFunction(req.Msg.Function) if fn == nil { - res = NewResponseWithErrorf("%w: function %q not found", ErrNotFound, req.Msg.Function) + res = NewResponseErrorf("%w: function %q not found", ErrNotFound, req.Msg.Function) } else { res = fn.Run(ctx, Request{req.Msg}) } diff --git a/dispatch_test.go b/dispatch_test.go index c6c02dc..4e28e36 100644 --- a/dispatch_test.go +++ b/dispatch_test.go @@ -28,9 +28,9 @@ func TestDispatchEndpoint(t *testing.T) { endpoint.Register(dispatch.NewPrimitiveFunction("identity", func(ctx context.Context, req dispatch.Request) dispatch.Response { input, ok := req.Input() if !ok { - return dispatch.NewResponseWithErrorf("%w: unexpected request: %v", dispatch.ErrInvalidArgument, req) + return dispatch.NewResponseErrorf("%w: unexpected request: %v", dispatch.ErrInvalidArgument, req) } - return dispatch.NewResponseWithOutput(input) + return dispatch.NewResponse(dispatch.OKStatus, dispatch.Output(input)) })) // Send a request for the identity function, and check that the @@ -101,7 +101,8 @@ func TestDispatchCall(t *testing.T) { recorder.Assert(t, dispatchtest.DispatchRequest{ ApiKey: "foobar", Calls: []dispatch.Call{ - dispatch.NewCall("http://example.com", "function1", dispatch.Int(11), + dispatch.NewCall("http://example.com", "function1", + dispatch.Input(dispatch.Int(11)), dispatch.Expiration(10*time.Second)), }, }) @@ -133,7 +134,9 @@ func TestDispatchCallEnvConfig(t *testing.T) { recorder.Assert(t, dispatchtest.DispatchRequest{ ApiKey: "foobar", Calls: []dispatch.Call{ - dispatch.NewCall("http://example.com", "function2", dispatch.String("foo"), dispatch.Version("xyzzy")), + dispatch.NewCall("http://example.com", "function2", + dispatch.Input(dispatch.String("foo")), + dispatch.Version("xyzzy")), }, }) } diff --git a/function.go b/function.go index 60f6d2f..8920eb7 100644 --- a/function.go +++ b/function.go @@ -5,6 +5,7 @@ package dispatch import ( "context" "fmt" + "slices" "github.com/stealthrocket/coroutine" "google.golang.org/protobuf/proto" @@ -25,48 +26,48 @@ type Function interface { } // NewFunction creates a Dispatch function. -func NewFunction[Input, Output proto.Message](name string, fn func(context.Context, Input) (Output, error)) *GenericFunction[Input, Output] { - return &GenericFunction[Input, Output]{name: name, fn: fn} +func NewFunction[I, O proto.Message](name string, fn func(context.Context, I) (O, error)) *GenericFunction[I, O] { + return &GenericFunction[I, O]{name: name, fn: fn} } // GenericFunction is a Dispatch function that accepts arbitrary input // and returns arbitrary output. -type GenericFunction[Input, Output proto.Message] struct { +type GenericFunction[I, O proto.Message] struct { name string - fn func(ctx context.Context, input Input) (Output, error) + fn func(ctx context.Context, input I) (O, error) endpoint *Dispatch } // Name is the name of the function. -func (f *GenericFunction[Input, Output]) Name() string { +func (f *GenericFunction[I, O]) Name() string { return f.name } // Run runs the function. -func (f *GenericFunction[Input, Output]) Run(ctx context.Context, req Request) Response { +func (f *GenericFunction[I, O]) Run(ctx context.Context, req Request) Response { var coro coroutine.Coroutine[any, any] - var zero Input + var zero I if boxedInput, ok := req.Input(); ok { message, err := boxedInput.Proto() if err != nil { - return NewResponseWithErrorf("%w: invalid input: %v", ErrInvalidArgument, err) + return NewResponseErrorf("%w: invalid input: %v", ErrInvalidArgument, err) } - input, ok := message.(Input) + input, ok := message.(I) if !ok { - return NewResponseWithErrorf("%w: invalid input type: %T", ErrInvalidArgument, message) + return NewResponseErrorf("%w: invalid input type: %T", ErrInvalidArgument, message) } coro = coroutine.NewWithReturn[any, any](f.entrypoint(input)) } else if pollResult, ok := req.PollResult(); ok { coro = coroutine.NewWithReturn[any, any](f.entrypoint(zero)) if err := coro.Context().Unmarshal(pollResult.CoroutineState()); err != nil { - return NewResponseWithErrorf("%w: invalid coroutine state: %v", ErrIncompatibleState, err) + return NewResponseErrorf("%w: invalid coroutine state: %v", ErrIncompatibleState, err) } } else { - return NewResponseWithErrorf("%w: unsupported request directive: %v", ErrInvalidArgument, req) + return NewResponseErrorf("%w: unsupported request directive: %v", ErrInvalidArgument, req) } // When running in volatile mode, we cannot snapshot the coroutine state @@ -79,7 +80,7 @@ func (f *GenericFunction[Input, Output]) Run(ctx context.Context, req Request) R return nil }) if canceled { - return NewResponseWithError(context.Cause(ctx)) + return NewResponseError(context.Cause(ctx)) } } @@ -87,12 +88,12 @@ func (f *GenericFunction[Input, Output]) Run(ctx context.Context, req Request) R if coro.Next() { coroutineState, err := coro.Context().Marshal() if err != nil { - return NewResponseWithErrorf("%w: cannot serialize coroutine: %v", ErrPermanent, err) + return NewResponseErrorf("%w: cannot serialize coroutine: %v", ErrPermanent, err) } switch yield := coro.Recv().(type) { // TODO default: - res = NewResponseWithErrorf("%w: unsupported coroutine yield: %T", ErrInvalidResponse, yield) + res = NewResponseErrorf("%w: unsupported coroutine yield: %T", ErrInvalidResponse, yield) } // TODO _ = coroutineState @@ -101,26 +102,27 @@ func (f *GenericFunction[Input, Output]) Run(ctx context.Context, req Request) R case proto.Message: output, err := NewAny(ret) if err != nil { - res = NewResponseWithErrorf("%w: cannot serialize return value: %v", ErrInvalidResponse, err) + res = NewResponseErrorf("%w: cannot serialize return value: %v", ErrInvalidResponse, err) } else { - res = NewResponseWithOutput(output) + // TODO: automatically derive a status from the ret value + res = NewResponse(StatusOf(ret), Output(output)) } case error: - res = NewResponseWithError(ret) + res = NewResponseError(ret) default: - res = NewResponseWithErrorf("%w: unsupported coroutine return: %T", ErrInvalidResponse, ret) + res = NewResponseErrorf("%w: unsupported coroutine return: %T", ErrInvalidResponse, ret) } } return res } -func (f *GenericFunction[Input, Output]) bind(endpoint *Dispatch) { +func (f *GenericFunction[I, O]) bind(endpoint *Dispatch) { f.endpoint = endpoint } // NewCall creates a Call for the function. -func (f *GenericFunction[Input, Output]) NewCall(input Input, opts ...CallOption) (Call, error) { +func (f *GenericFunction[I, O]) NewCall(input I, opts ...CallOption) (Call, error) { if f.endpoint == nil { return Call{}, fmt.Errorf("cannot build function call: function has not been registered with a Dispatch endpoint") } @@ -128,11 +130,12 @@ func (f *GenericFunction[Input, Output]) NewCall(input Input, opts ...CallOption if err != nil { return Call{}, fmt.Errorf("cannot serialize input: %v", err) } - return NewCall(f.endpoint.URL(), f.name, anyInput, opts...), nil + opts = append(slices.Clip(opts), Input(anyInput)) + return NewCall(f.endpoint.URL(), f.name, opts...), nil } // Dispatch dispatches a call to the function. -func (f *GenericFunction[Input, Output]) Dispatch(ctx context.Context, input Input, opts ...CallOption) (ID, error) { +func (f *GenericFunction[I, O]) Dispatch(ctx context.Context, input I, opts ...CallOption) (ID, error) { call, err := f.NewCall(input, opts...) if err != nil { return "", err @@ -145,7 +148,7 @@ func (f *GenericFunction[Input, Output]) Dispatch(ctx context.Context, input Inp } //go:noinline -func (f *GenericFunction[Input, Output]) entrypoint(input Input) func() any { +func (f *GenericFunction[I, O]) entrypoint(input I) func() any { return func() any { // The context that gets passed as argument here should be recreated // each time the coroutine is resumed, ideally inheriting from the @@ -193,7 +196,8 @@ func (f *PrimitiveFunction) NewCall(input Any, opts ...CallOption) (Call, error) if f.endpoint == nil { return Call{}, fmt.Errorf("cannot build function call: function has not been registered with a Dispatch endpoint") } - return NewCall(f.endpoint.URL(), f.name, input, opts...), nil + opts = append(slices.Clip(opts), Input(input)) + return NewCall(f.endpoint.URL(), f.name, opts...), nil } // Dispatch dispatches a call to the function. diff --git a/proto.go b/proto.go index bd5b9c2..c40f596 100644 --- a/proto.go +++ b/proto.go @@ -17,11 +17,10 @@ type Call struct { } // NewCall creates a Call. -func NewCall(endpoint, function string, input Any, opts ...CallOption) Call { +func NewCall(endpoint, function string, opts ...CallOption) Call { call := Call{&sdkv1.Call{ Endpoint: endpoint, Function: function, - Input: input.proto, }} for _, opt := range opts { opt.configureCall(&call) @@ -36,6 +35,22 @@ type callOptionFunc func(*Call) func (fn callOptionFunc) configureCall(c *Call) { fn(c) } +// Input sets the output from a function call or Response. +func Input(input Any) interface { + CallOption + RequestOption +} { + return inputOption(input) +} + +type inputOption Any + +func (i inputOption) configureCall(r *Call) { r.proto.Input = i.proto } + +func (i inputOption) configureRequest(r *Request) { + r.proto.Directive = &sdkv1.RunRequest_Input{Input: i.proto} +} + // Expiration sets a function call expiration. func Expiration(expiration time.Duration) CallOption { return callOptionFunc(func(c *Call) { c.proto.Expiration = durationpb.New(expiration) }) @@ -119,14 +134,18 @@ func NewCallResult(opts ...CallResultOption) CallResult { // CallResultOption configures a CallResult. type CallResultOption interface{ configureCallResult(*CallResult) } -type callResultOptionFunc func(*CallResult) +// Output sets the output from a function call or Response. +func Output(output Any) interface { + CallResultOption + ResponseOption +} { + return outputOption(output) +} -func (fn callResultOptionFunc) configureCallResult(r *CallResult) { fn(r) } +type outputOption Any -// Output sets the output from the function call. -func Output(output Any) CallResultOption { - return callResultOptionFunc(func(result *CallResult) { result.proto.Output = output.proto }) -} +func (o outputOption) configureCallResult(r *CallResult) { r.proto.Output = o.proto } +func (o outputOption) configureResponse(r *Response) { ensureResponseExitResult(r).Output = o.proto } // DispatchID sets the opaque identifier for the function call. func DispatchID(id ID) interface { @@ -188,8 +207,14 @@ type Error struct { proto *sdkv1.Error } -// NewError creates an Error. -func NewError(typ, message string, opts ...ErrorOption) Error { +// NewError creates an Error from a Go error. +func NewError(err error) Error { + // TODO: use ErrorValue / Traceback + return NewErrorMessage(errorTypeOf(err), err.Error()) +} + +// NewErrorMessage creates an Error. +func NewErrorMessage(typ, message string, opts ...ErrorOption) Error { err := Error{&sdkv1.Error{ Type: typ, Message: message, @@ -200,12 +225,6 @@ func NewError(typ, message string, opts ...ErrorOption) Error { return err } -// FromError creates an Error from a Go error. -func FromError(err error) Error { - // TODO: use ErrorValue / Traceback - return NewError(errorTypeOf(err), err.Error()) -} - // ErrorOption configures an Error. type ErrorOption func(*Error) @@ -268,6 +287,14 @@ func (e Error) configurePollResult(p *PollResult) { p.proto.Error = e.proto } +func (e Error) configureExit(x *Exit) { + x.proto.Result = &sdkv1.CallResult{Error: e.proto} +} + +func (e Error) configureResponse(r *Response) { + ensureResponseExitResult(r).Error = e.proto +} + // Exit is a directive that terminates a function call. type Exit struct { proto *sdkv1.Exit @@ -338,6 +365,10 @@ func (e Exit) Equal(other Exit) bool { return proto.Equal(e.proto, other.proto) } +func (e Exit) configureResponse(r *Response) { + r.proto.Directive = &sdkv1.RunResponse_Exit{Exit: e.proto} +} + // Poll is a general purpose directive used to spawn // function calls and wait for their results, and/or // to implement sleep/timer functionality. @@ -438,6 +469,10 @@ func (p Poll) Equal(other Poll) bool { return proto.Equal(p.proto, other.proto) } +func (p Poll) configureResponse(r *Response) { + r.proto.Directive = &sdkv1.RunResponse_Poll{Poll: p.proto} +} + // PollResult is the result of a poll operation. type PollResult struct { proto *sdkv1.PollResult @@ -516,34 +551,16 @@ type Request struct { } // NewRequest creates a Request. -func NewRequest(function string, directive RequestDirective, opts ...RequestOption) Request { +func NewRequest(function string, opts ...RequestOption) Request { request := Request{&sdkv1.RunRequest{ Function: function, }} for _, opt := range opts { opt.configureRequest(&request) } - switch d := directive.(type) { - case Input: - request.proto.Directive = &sdkv1.RunRequest_Input{Input: Any(d).proto} - case PollResult: - request.proto.Directive = &sdkv1.RunRequest_PollResult{PollResult: d.proto} - default: - panic("invalid request directive") - } return request } -// RequestDirective is a request directive, either Input or PollResult. -type RequestDirective interface{ requestDirective() } - -func (Input) requestDirective() {} -func (PollResult) requestDirective() {} - -// Input is a directive to start execution of a function -// with an input value. -type Input Any - // RequestOption configures a Request. type RequestOption interface{ configureRequest(*Request) } @@ -576,18 +593,6 @@ func (r Request) Function() string { return r.proto.GetFunction() } -// RequestDirective is the RequestDirective, either Input or PollResult. -func (r Request) Directive() RequestDirective { - switch d := r.proto.GetDirective().(type) { - case *sdkv1.RunRequest_Input: - return Input(Any{d.Input}) - case *sdkv1.RunRequest_PollResult: - return PollResult{d.PollResult} - default: - return nil - } -} - // Input is input to the function, along with a boolean // flag that indicates whether the request carries a directive // to start the function with the input. @@ -665,73 +670,39 @@ type Response struct { proto *sdkv1.RunResponse } +// ResponseOption configures a Response. +type ResponseOption interface{ configureResponse(*Response) } + // NewResponse creates a Response. -func NewResponse(status Status, directive ResponseDirective) Response { +func NewResponse(status Status, opts ...ResponseOption) Response { response := Response{&sdkv1.RunResponse{ Status: sdkv1.Status(status), }} - switch d := directive.(type) { - case Exit: - response.proto.Directive = &sdkv1.RunResponse_Exit{Exit: d.proto} - case Poll: - response.proto.Directive = &sdkv1.RunResponse_Poll{Poll: d.proto} - default: - response.proto.Directive = &sdkv1.RunResponse_Exit{Exit: &sdkv1.Exit{Result: &sdkv1.CallResult{}}} + for _, opt := range opts { + opt.configureResponse(&response) } - return response -} - -// NewResponseWithOutput creates a Response from the specified output value. -func NewResponseWithOutput(output Any) Response { - result := NewCallResult(Output(output)) - - // FIXME: the interface{ Status() Status } implementation - // is lost earlier when an any is converted to Any. Do - // the conversion here, so that the original object (and status) - // is available. - status := StatusOf(output) - if status == UnspecifiedStatus { - status = OKStatus + if response.proto.Directive == nil { + ensureResponseExitResult(&response) } - - return NewResponse(status, NewExit(result)) + return response } -// NewResponseWithError creates a Response from the specified error. -func NewResponseWithError(err error) Response { - result := NewCallResult(FromError(err)) - return NewResponse(ErrorStatus(err), NewExit(result)) +// NewResponseError creates a Response from the specified error. +func NewResponseError(err error) Response { + return NewResponse(ErrorStatus(err), NewError(err)) } -// NewResponseWithErrorf creates a Response from the specified error message +// NewResponseErrorf creates a Response from the specified error message // and args. -func NewResponseWithErrorf(msg string, args ...any) Response { - return NewResponseWithError(fmt.Errorf(msg, args...)) +func NewResponseErrorf(msg string, args ...any) Response { + return NewResponseError(fmt.Errorf(msg, args...)) } -// ResponseDirective is either Exit or Poll. -type ResponseDirective interface{ responseDirective() } - -func (Poll) responseDirective() {} -func (Exit) responseDirective() {} - // Status is the response status. func (r Response) Status() Status { return Status(r.proto.GetStatus()) } -// Directive is the response directive, either Exit or Poll. -func (r Response) Directive() ResponseDirective { - switch d := r.proto.GetDirective().(type) { - case *sdkv1.RunResponse_Exit: - return Exit{d.Exit} - case *sdkv1.RunResponse_Poll: - return Poll{d.Poll} - default: - return nil - } -} - // Exit is the exit directive on the response. func (r Response) Exit() (Exit, bool) { proto := r.proto.GetExit() @@ -777,6 +748,22 @@ func (r Response) Marshal() ([]byte, error) { return proto.Marshal(r.proto) } +func ensureResponseExitResult(r *Response) *sdkv1.CallResult { + var d *sdkv1.RunResponse_Exit + d, ok := r.proto.Directive.(*sdkv1.RunResponse_Exit) + if !ok { + d = &sdkv1.RunResponse_Exit{} + r.proto.Directive = d + } + if d.Exit == nil { + d.Exit = &sdkv1.Exit{} + } + if d.Exit.Result == nil { + d.Exit.Result = &sdkv1.CallResult{} + } + return d.Exit.Result +} + // These are hooks used by the dispatchlambda and dispatchtest // package that let us avoid exposing proto messages. Exposing // the underlying proto messages complicates the API and opens diff --git a/proto_test.go b/proto_test.go index f910ea7..8bab51c 100644 --- a/proto_test.go +++ b/proto_test.go @@ -13,7 +13,7 @@ import ( func TestCall(t *testing.T) { t.Run("with no options", func(t *testing.T) { - call := NewCall("endpoint1", "function2", Int(11)) + call := NewCall("endpoint1", "function2", Input(Int(11))) if got := call.Endpoint(); got != "endpoint1" { t.Errorf("unexpected call endpoint: %v", got) @@ -46,8 +46,11 @@ func TestCall(t *testing.T) { }) t.Run("with options", func(t *testing.T) { - call := NewCall("endpoint1", "function2", Int(11), - CorrelationID(1234), Expiration(10*time.Second), Version("xyzzy")) + call := NewCall("endpoint1", "function2", + Input(Int(11)), + CorrelationID(1234), + Expiration(10*time.Second), + Version("xyzzy")) if got := call.Endpoint(); got != "endpoint1" { t.Errorf("unexpected call endpoint: %v", got) diff --git a/status.go b/status.go index a0a084a..0c1a3ac 100644 --- a/status.go +++ b/status.go @@ -68,14 +68,15 @@ func (s Status) GoString() string { // The object can provide a status by implementing // interface{ Status() Status }. func StatusOf(v any) Status { - if s, ok := v.(status); ok { - return s.Status() - } if e, ok := v.(error); ok { var s status if errors.As(e, &s) { return s.Status() } + return ErrorStatus(e) + } + if s, ok := v.(status); ok { + return s.Status() } - return UnspecifiedStatus + return OKStatus }