From 288c7e134d9fe2cf783eef71dd95929154f80f18 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Wed, 5 Jun 2024 13:35:15 +1000 Subject: [PATCH] Wrap RunResponse to provide a higher level interface --- dispatch.go | 6 +-- dispatch_test.go | 46 +++++++---------- dispatchlambda/lambda.go | 2 +- dispatchtest/endpoint.go | 19 +++++-- error.go | 19 ------- function.go | 45 ++++++----------- function_test.go | 55 ++++++++++----------- proto.go | 104 ++++++++++++++++++++++++++++++++++++--- 8 files changed, 177 insertions(+), 119 deletions(-) diff --git a/dispatch.go b/dispatch.go index c1c2bb6..fa70098 100644 --- a/dispatch.go +++ b/dispatch.go @@ -214,14 +214,14 @@ type dispatchFunctionServiceHandler struct { } func (d *dispatchFunctionServiceHandler) Run(ctx context.Context, req *connect.Request[sdkv1.RunRequest]) (*connect.Response[sdkv1.RunResponse], error) { - var res *sdkv1.RunResponse + var res Response fn := d.dispatch.lookupFunction(req.Msg.Function) if fn == nil { - res = ErrorResponse(fmt.Errorf("%w: function %q not found", ErrNotFound, req.Msg.Function)) + res = NewErrorfResponse("%w: function %q not found", ErrNotFound, req.Msg.Function) } else { res = fn.Run(ctx, req.Msg) } - return connect.NewResponse(res), nil + return connect.NewResponse(res.proto), nil } // ListenAndServe serves the Dispatch endpoint on the specified address. diff --git a/dispatch_test.go b/dispatch_test.go index 50500f5..db35dd7 100644 --- a/dispatch_test.go +++ b/dispatch_test.go @@ -2,7 +2,6 @@ package dispatch_test import ( "context" - "fmt" "testing" "time" @@ -28,23 +27,16 @@ func TestDispatchEndpoint(t *testing.T) { t.Fatal(err) } - endpoint.Register(dispatch.NewPrimitiveFunction("identity", func(ctx context.Context, req *sdkv1.RunRequest) *sdkv1.RunResponse { - var input *anypb.Any + endpoint.Register(dispatch.NewPrimitiveFunction("identity", func(ctx context.Context, req *sdkv1.RunRequest) dispatch.Response { switch d := req.Directive.(type) { case *sdkv1.RunRequest_Input: - input = d.Input + input, err := d.Input.UnmarshalNew() + if err != nil { + return dispatch.NewErrorResponse(err) + } + return dispatch.NewOutputResponse(input) default: - return dispatch.ErrorResponse(fmt.Errorf("%w: unexpected run directive: %T", dispatch.ErrInvalidArgument, d)) - } - return &sdkv1.RunResponse{ - Status: sdkv1.Status_STATUS_OK, - Directive: &sdkv1.RunResponse_Exit{ - Exit: &sdkv1.Exit{ - Result: &sdkv1.CallResult{ - Output: input, - }, - }, - }, + return dispatch.NewErrorfResponse("%w: unexpected run directive: %T", dispatch.ErrInvalidArgument, d) } })) @@ -62,17 +54,15 @@ func TestDispatchEndpoint(t *testing.T) { if err != nil { t.Fatal(err) } - if res.Status != sdkv1.Status_STATUS_OK { - t.Fatalf("unexpected response status: %v", res.Status) + if res.Status() != dispatch.OKStatus { + t.Fatalf("unexpected response status: %v", res.Status()) + } + output, err := res.Output() + if err != nil { + t.Fatalf("invalid response: %v (%v)", res, err) } - if d, ok := res.Directive.(*sdkv1.RunResponse_Exit); !ok { - t.Errorf("unexpected response directive: %T", res.Directive) - } else if output := d.Exit.GetResult().GetOutput(); output == nil { - t.Error("exit directive result or output was nil") - } else if message, err := output.UnmarshalNew(); err != nil { + if v, ok := output.(*wrapperspb.Int32Value); !ok || v.Value != inputValue { t.Errorf("exit directive result or output was invalid: %v", output) - } else if v, ok := message.(*wrapperspb.Int32Value); !ok || v.Value != inputValue { - t.Errorf("exit directive result or output was invalid: %v", v) } // Try running a function that has not been registered. @@ -80,8 +70,8 @@ func TestDispatchEndpoint(t *testing.T) { if err != nil { t.Fatal(err) } - if res.Status != sdkv1.Status_STATUS_NOT_FOUND { - t.Fatalf("unexpected response status: %v", res.Status) + if res.Status() != dispatch.NotFoundStatus { + t.Fatalf("unexpected response status: %v", res.Status()) } // Try with a client that does not sign requests. The Dispatch @@ -110,7 +100,7 @@ func TestDispatchCall(t *testing.T) { t.Fatal(err) } - fn := dispatch.NewPrimitiveFunction("function1", func(ctx context.Context, req *sdkv1.RunRequest) *sdkv1.RunResponse { + fn := dispatch.NewPrimitiveFunction("function1", func(ctx context.Context, req *sdkv1.RunRequest) dispatch.Response { panic("not implemented") }) endpoint.Register(fn) @@ -181,7 +171,7 @@ func TestDispatchCallsBatch(t *testing.T) { t.Fatal(err) } - fn1 := dispatch.NewPrimitiveFunction("function1", func(ctx context.Context, req *sdkv1.RunRequest) *sdkv1.RunResponse { + fn1 := dispatch.NewPrimitiveFunction("function1", func(ctx context.Context, req *sdkv1.RunRequest) dispatch.Response { panic("not implemented") }) fn2 := dispatch.NewFunction("function2", func(ctx context.Context, req *wrapperspb.StringValue) (*wrapperspb.StringValue, error) { diff --git a/dispatchlambda/lambda.go b/dispatchlambda/lambda.go index 2d30df0..0441a35 100644 --- a/dispatchlambda/lambda.go +++ b/dispatchlambda/lambda.go @@ -52,7 +52,7 @@ func (h *handler) Invoke(ctx context.Context, payload []byte) ([]byte, error) { res := h.function.Run(ctx, req) - rawResponse, err := proto.Marshal(res) + rawResponse, err := res.Marshal() if err != nil { return nil, err } diff --git a/dispatchtest/endpoint.go b/dispatchtest/endpoint.go index 9ebd84c..9a16988 100644 --- a/dispatchtest/endpoint.go +++ b/dispatchtest/endpoint.go @@ -14,6 +14,7 @@ import ( "connectrpc.com/validate" "github.com/dispatchrun/dispatch-go" "github.com/dispatchrun/dispatch-go/internal/auth" + "google.golang.org/protobuf/proto" ) // NewEndpoint creates a Dispatch endpoint, like dispatch.New. @@ -111,10 +112,22 @@ func WithSigningKey(signingKey string) EndpointClientOption { } // Run sends a RunRequest and returns a RunResponse. -func (c *EndpointClient) Run(ctx context.Context, req *sdkv1.RunRequest) (*sdkv1.RunResponse, error) { +func (c *EndpointClient) Run(ctx context.Context, req *sdkv1.RunRequest) (dispatch.Response, error) { res, err := c.client.Run(ctx, connect.NewRequest(req)) if err != nil { - return nil, err + return dispatch.Response{}, err + } + return wrapResponse(res.Msg), nil +} + +func wrapResponse(r *sdkv1.RunResponse) dispatch.Response { + b, err := proto.Marshal(r) + if err != nil { + panic(err) + } + response, err := dispatch.UnmarshalResponse(b) + if err != nil { + panic(err) } - return res.Msg, nil + return response } diff --git a/error.go b/error.go index 5d0e95c..357a8e0 100644 --- a/error.go +++ b/error.go @@ -4,8 +4,6 @@ import ( "errors" "reflect" "strings" - - sdkv1 "buf.build/gen/go/stealthrocket/dispatch-proto/protocolbuffers/go/dispatch/sdk/v1" ) var ( @@ -131,20 +129,3 @@ type temporary interface { type timeout interface { Timeout() bool } - -// ErrorResponse creates a RunResponse for the specified error. -func ErrorResponse(err error) *sdkv1.RunResponse { - return &sdkv1.RunResponse{ - Status: errorStatusOf(err).proto(), - Directive: &sdkv1.RunResponse_Exit{ - Exit: &sdkv1.Exit{ - Result: &sdkv1.CallResult{ - Error: &sdkv1.Error{ - Type: errorTypeOf(err), - Message: err.Error(), - }, - }, - }, - }, - } -} diff --git a/function.go b/function.go index 7b72311..0d69ff2 100644 --- a/function.go +++ b/function.go @@ -10,7 +10,6 @@ import ( "github.com/stealthrocket/coroutine" "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/anypb" ) // Function is a Dispatch function. @@ -19,7 +18,7 @@ type Function interface { Name() string // Run runs the function. - Run(context.Context, *sdkv1.RunRequest) *sdkv1.RunResponse + Run(context.Context, *sdkv1.RunRequest) Response // bind is an internal hook for binding a function to // a Dispatch endpoint, allowing the NewCall and Dispatch @@ -47,7 +46,7 @@ func (f *GenericFunction[Input, Output]) Name() string { } // Run runs the function. -func (f *GenericFunction[Input, Output]) Run(ctx context.Context, req *sdkv1.RunRequest) *sdkv1.RunResponse { +func (f *GenericFunction[Input, Output]) Run(ctx context.Context, req *sdkv1.RunRequest) Response { var coro coroutine.Coroutine[any, any] var zero Input @@ -55,7 +54,7 @@ func (f *GenericFunction[Input, Output]) Run(ctx context.Context, req *sdkv1.Run case *sdkv1.RunRequest_PollResult: coro = coroutine.NewWithReturn[any, any](f.entrypoint(zero)) if err := coro.Context().Unmarshal(c.PollResult.GetCoroutineState()); err != nil { - return ErrorResponse(fmt.Errorf("%w: invalid coroutine state: %v", ErrIncompatibleState, err)) + return NewErrorfResponse("%w: invalid coroutine state: %v", ErrIncompatibleState, err) } case *sdkv1.RunRequest_Input: var input Input @@ -66,19 +65,17 @@ func (f *GenericFunction[Input, Output]) Run(ctx context.Context, req *sdkv1.Run RecursionLimit: protowire.DefaultRecursionLimit, } if err := options.Unmarshal(c.Input.Value, message.Interface()); err != nil { - return ErrorResponse(fmt.Errorf("%w: invalid function input: %v", ErrInvalidArgument, err)) + return NewErrorfResponse("%w: invalid function input: %v", ErrInvalidArgument, err) } input = message.Interface().(Input) } coro = coroutine.NewWithReturn[any, any](f.entrypoint(input)) default: - return ErrorResponse(fmt.Errorf("%w: unsupported coroutine directive: %T", ErrInvalidArgument, c)) + return NewErrorfResponse("%w: unsupported coroutine directive: %T", ErrInvalidArgument, c) } - res := &sdkv1.RunResponse{ - Status: sdkv1.Status_STATUS_OK, - } + var res Response // When running in volatile mode, we cannot snapshot the coroutine state // and return it to the caller. Instead, we run the coroutine to completion @@ -90,40 +87,30 @@ func (f *GenericFunction[Input, Output]) Run(ctx context.Context, req *sdkv1.Run return nil }) if canceled { - return ErrorResponse(context.Cause(ctx)) + return NewErrorResponse(context.Cause(ctx)) } } if coro.Next() { coroutineState, err := coro.Context().Marshal() if err != nil { - return ErrorResponse(fmt.Errorf("%w: cannot serialize coroutine: %v", ErrPermanent, err)) + return NewErrorfResponse("%w: cannot serialize coroutine: %v", ErrPermanent, err) } switch yield := coro.Recv().(type) { // TODO default: - res = ErrorResponse(fmt.Errorf("%w: unsupported coroutine yield: %T", ErrInvalidResponse, yield)) + res = NewErrorfResponse("%w: unsupported coroutine yield: %T", ErrInvalidResponse, yield) } // TODO _ = coroutineState } else { switch ret := coro.Result().(type) { case proto.Message: - output, _ := anypb.New(ret) - if status := statusOf(ret); status != UnspecifiedStatus { - res.Status = status.proto() - } - res.Directive = &sdkv1.RunResponse_Exit{ - Exit: &sdkv1.Exit{ - Result: &sdkv1.CallResult{ - Output: output, - }, - }, - } + res = NewOutputResponse(ret) case error: - res = ErrorResponse(ret) + res = NewErrorResponse(ret) default: - res = ErrorResponse(fmt.Errorf("%w: unsupported coroutine return: %T", ErrInvalidResponse, ret)) + res = NewErrorfResponse("%w: unsupported coroutine return: %T", ErrInvalidResponse, ret) } } @@ -172,15 +159,15 @@ func (f *GenericFunction[Input, Output]) entrypoint(input Input) func() any { } // NewPrimitiveFunction creates a PrimitiveFunction. -func NewPrimitiveFunction(name string, fn func(context.Context, *sdkv1.RunRequest) *sdkv1.RunResponse) *PrimitiveFunction { +func NewPrimitiveFunction(name string, fn func(context.Context, *sdkv1.RunRequest) Response) *PrimitiveFunction { return &PrimitiveFunction{name: name, fn: fn} } // PrimitiveFunction is a function that's close to the underlying -// Dispatch protocol, accepting a RunRequest and returning a RunResponse. +// Dispatch protocol, accepting a Request and returning a Response. type PrimitiveFunction struct { name string - fn func(context.Context, *sdkv1.RunRequest) *sdkv1.RunResponse + fn func(context.Context, *sdkv1.RunRequest) Response endpoint *Dispatch } @@ -191,7 +178,7 @@ func (f *PrimitiveFunction) Name() string { } // Run runs the function. -func (f *PrimitiveFunction) Run(ctx context.Context, req *sdkv1.RunRequest) *sdkv1.RunResponse { +func (f *PrimitiveFunction) Run(ctx context.Context, req *sdkv1.RunRequest) Response { return f.fn(ctx, req) } diff --git a/function_test.go b/function_test.go index db289b6..314c420 100644 --- a/function_test.go +++ b/function_test.go @@ -17,8 +17,12 @@ func TestFunctionRunInvalidCoroutineType(t *testing.T) { }) res := fn.Run(context.Background(), &sdkv1.RunRequest{}) - if err := res.GetExit().GetResult().GetError(); err == nil || err.Message != "InvalidArgument: unsupported coroutine directive: " { - t.Fatalf("unexpected error: %#v", err) + error, ok := res.Error() + if !ok { + t.Fatalf("invalid response: %v", res) + } + if error.Message() != "InvalidArgument: unsupported coroutine directive: " { + t.Errorf("unexpected error: %v", error) } } @@ -40,17 +44,15 @@ func TestFunctionRunError(t *testing.T) { }, }) - switch coro := res.Directive.(type) { - case *sdkv1.RunResponse_Exit: - err := coro.Exit.GetResult().GetError() - if err.Type != "errorString" { - t.Fatalf("unexpected coroutine error type: %s", err.Type) - } - if err.Message != "oops" { - t.Fatalf("unexpected coroutine error message: %s", err.Message) - } - default: - t.Fatalf("unexpected coroutine response type: %T", coro) + error, ok := res.Error() + if !ok { + t.Fatalf("invalid response: %v", res) + } + if error.Type() != "errorString" { + t.Errorf("unexpected coroutine error type: %s", error.Type()) + } + if error.Message() != "oops" { + t.Errorf("unexpected coroutine error message: %s", error.Message()) } } @@ -70,26 +72,19 @@ func TestFunctionRunResult(t *testing.T) { }, }) - switch coro := res.Directive.(type) { - case *sdkv1.RunResponse_Exit: - out := coro.Exit.GetResult().GetOutput() - if out.TypeUrl != "type.googleapis.com/google.protobuf.StringValue" { - t.Fatalf("unexpected coroutine output type: %s", out.TypeUrl) - } - var output wrapperspb.StringValue - if err := out.UnmarshalTo(&output); err != nil { - t.Fatal(err) - } - if output.Value != "world" { - t.Fatalf("unexpected coroutine output value: %s", output.Value) - } - default: - t.Fatalf("unexpected coroutine response type: %T", coro) + output, err := res.Output() + if err != nil { + t.Fatalf("invalid response: %v (%v)", res, err) + } + if str, ok := output.(*wrapperspb.StringValue); !ok { + t.Fatalf("unexpected output: %T (%v)", output, output) + } else if str.Value != "world" { + t.Errorf("unexpected output: %s", str.Value) } } func TestPrimitiveFunctionNewCallAndDispatchWithoutEndpoint(t *testing.T) { - fn := dispatch.NewPrimitiveFunction("foo", func(ctx context.Context, req *sdkv1.RunRequest) *sdkv1.RunResponse { + fn := dispatch.NewPrimitiveFunction("foo", func(ctx context.Context, req *sdkv1.RunRequest) dispatch.Response { panic("not implemented") }) @@ -131,7 +126,7 @@ func TestPrimitiveFunctionDispatchWithoutClient(t *testing.T) { t.Fatal(err) } - fn := dispatch.NewPrimitiveFunction("foo", func(ctx context.Context, req *sdkv1.RunRequest) *sdkv1.RunResponse { + fn := dispatch.NewPrimitiveFunction("foo", func(ctx context.Context, req *sdkv1.RunRequest) dispatch.Response { panic("not implemented") }) endpoint.Register(fn) diff --git a/proto.go b/proto.go index 0263215..0be1a18 100644 --- a/proto.go +++ b/proto.go @@ -148,6 +148,15 @@ func NewCallResult(opts ...CallResultOption) (CallResult, error) { return result, nil } +// NewErrorCallResult constructs a CallResult from a Go error. +func NewErrorCallResult(err error) CallResult { + result, err := NewCallResult(WithCallResultError(NewGoError(err))) + if err != nil { + panic(err) // unreachable; no output was specified + } + return result +} + // CallResultOption configures a CallResult. type CallResultOption func(*CallResult) @@ -248,7 +257,7 @@ type Error struct { proto *sdkv1.Error } -// NewError creates an error. +// NewError creates an Error. func NewError(typ, message string, opts ...ErrorOption) Error { err := Error{&sdkv1.Error{ Type: typ, @@ -260,6 +269,11 @@ func NewError(typ, message string, opts ...ErrorOption) Error { return err } +// NewGoError creates an Error from a Go error. +func NewGoError(err error) Error { + return NewError(errorTypeOf(err), err.Error()) +} + // ErrorOption configures an Error. type ErrorOption func(*Error) @@ -351,14 +365,31 @@ func WithTailCall(tailCall Call) ExitOption { // Result is the function call result the exit directive carries. func (e Exit) Result() (CallResult, bool) { - r := e.proto.GetResult() - return CallResult{proto: r}, r != nil + proto := e.proto.GetResult() + return CallResult{proto: proto}, proto != nil +} + +// Error is the error from the function call result the +// exit directive carries. +func (e Exit) Error() (Error, bool) { + proto := e.proto.GetResult().GetError() + return Error{proto: proto}, proto != nil +} + +// Output is the output from the function call result the +// exit directive carries. +func (e Exit) Output() (proto.Message, error) { + output := e.proto.GetResult().GetOutput() + if output == nil { + return nil, nil + } + return output.UnmarshalNew() } // TailCall is the tail call the exit directive carries. func (e Exit) TailCall() (Call, bool) { - c := e.proto.GetTailCall() - return Call{proto: c}, c != nil + proto := e.proto.GetTailCall() + return Call{proto: proto}, proto != nil } // String is the string representation of the Exit directive. @@ -502,11 +533,40 @@ func NewResponse(status Status, directive ResponseDirective) Response { case Poll: response.proto.Directive = &sdkv1.RunResponse_Poll{Poll: d.proto} default: - panic("nil directive") + response.proto.Directive = &sdkv1.RunResponse_Exit{Exit: &sdkv1.Exit{Result: &sdkv1.CallResult{}}} } return response } +// NewOutputResponse creates a Response from the specified output value. +func NewOutputResponse(output proto.Message) Response { + result, err := NewCallResult(WithCallResultOutput(output)) + if err != nil { + return NewErrorfResponse("cannot serialize output: %w", err) + } + exit := NewExit(WithExitResult(result)) + + status := statusOf(output) + if status == UnspecifiedStatus { + status = OKStatus + } + + return NewResponse(status, exit) +} + +// NewErrorResponse creates a Response for the specified error. +func NewErrorResponse(err error) Response { + result := NewErrorCallResult(err) + exit := NewExit(WithExitResult(result)) + return NewResponse(errorStatusOf(err), exit) +} + +// NewErrorfResponse creates a Response from the specified error message +// and args. +func NewErrorfResponse(msg string, args ...any) Response { + return NewErrorResponse(fmt.Errorf(msg, args...)) +} + // ResponseDirective is either Exit or Poll. type ResponseDirective interface { responseDirective() @@ -538,6 +598,24 @@ func (r Response) Exit() (Exit, bool) { return Exit{proto}, proto != nil } +// Error is the error from an exit directive. +func (r Response) Error() (Error, bool) { + exit, ok := r.Exit() + if !ok { + return Error{}, false + } + return exit.Error() +} + +// Output is the output from an exit directive. +func (r Response) Output() (proto.Message, error) { + exit, ok := r.Exit() + if !ok { + return nil, fmt.Errorf("not an exit directive") + } + return exit.Output() +} + // Poll is the poll directive on the response. func (r Response) Poll() (Poll, bool) { proto := r.proto.GetPoll() @@ -570,3 +648,17 @@ func (r Response) Equal(other Response) bool { } return true } + +// Marshal marshals the response. +func (r Response) Marshal() ([]byte, error) { + return proto.Marshal(r.proto) +} + +// UnmarshalResponse unmarshals a response. +func UnmarshalResponse(b []byte) (Response, error) { + var r sdkv1.RunResponse + if err := proto.Unmarshal(b, &r); err != nil { + return Response{}, err + } + return Response{&r}, nil +}