Skip to content

Commit

Permalink
Wrap RunResponse to provide a higher level interface
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Jun 5, 2024
1 parent e08993c commit 288c7e1
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 119 deletions.
6 changes: 3 additions & 3 deletions dispatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,14 @@ type dispatchFunctionServiceHandler struct {
}

func (d *dispatchFunctionServiceHandler) Run(ctx context.Context, req *connect.Request[sdkv1.RunRequest]) (*connect.Response[sdkv1.RunResponse], error) {
var res *sdkv1.RunResponse
var res Response
fn := d.dispatch.lookupFunction(req.Msg.Function)
if fn == nil {
res = ErrorResponse(fmt.Errorf("%w: function %q not found", ErrNotFound, req.Msg.Function))
res = NewErrorfResponse("%w: function %q not found", ErrNotFound, req.Msg.Function)
} else {
res = fn.Run(ctx, req.Msg)
}
return connect.NewResponse(res), nil
return connect.NewResponse(res.proto), nil
}

// ListenAndServe serves the Dispatch endpoint on the specified address.
Expand Down
46 changes: 18 additions & 28 deletions dispatch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package dispatch_test

import (
"context"
"fmt"
"testing"
"time"

Expand All @@ -28,23 +27,16 @@ func TestDispatchEndpoint(t *testing.T) {
t.Fatal(err)
}

endpoint.Register(dispatch.NewPrimitiveFunction("identity", func(ctx context.Context, req *sdkv1.RunRequest) *sdkv1.RunResponse {
var input *anypb.Any
endpoint.Register(dispatch.NewPrimitiveFunction("identity", func(ctx context.Context, req *sdkv1.RunRequest) dispatch.Response {
switch d := req.Directive.(type) {
case *sdkv1.RunRequest_Input:
input = d.Input
input, err := d.Input.UnmarshalNew()
if err != nil {
return dispatch.NewErrorResponse(err)
}
return dispatch.NewOutputResponse(input)
default:
return dispatch.ErrorResponse(fmt.Errorf("%w: unexpected run directive: %T", dispatch.ErrInvalidArgument, d))
}
return &sdkv1.RunResponse{
Status: sdkv1.Status_STATUS_OK,
Directive: &sdkv1.RunResponse_Exit{
Exit: &sdkv1.Exit{
Result: &sdkv1.CallResult{
Output: input,
},
},
},
return dispatch.NewErrorfResponse("%w: unexpected run directive: %T", dispatch.ErrInvalidArgument, d)
}
}))

Expand All @@ -62,26 +54,24 @@ func TestDispatchEndpoint(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if res.Status != sdkv1.Status_STATUS_OK {
t.Fatalf("unexpected response status: %v", res.Status)
if res.Status() != dispatch.OKStatus {
t.Fatalf("unexpected response status: %v", res.Status())
}
output, err := res.Output()
if err != nil {
t.Fatalf("invalid response: %v (%v)", res, err)
}
if d, ok := res.Directive.(*sdkv1.RunResponse_Exit); !ok {
t.Errorf("unexpected response directive: %T", res.Directive)
} else if output := d.Exit.GetResult().GetOutput(); output == nil {
t.Error("exit directive result or output was nil")
} else if message, err := output.UnmarshalNew(); err != nil {
if v, ok := output.(*wrapperspb.Int32Value); !ok || v.Value != inputValue {
t.Errorf("exit directive result or output was invalid: %v", output)
} else if v, ok := message.(*wrapperspb.Int32Value); !ok || v.Value != inputValue {
t.Errorf("exit directive result or output was invalid: %v", v)
}

// Try running a function that has not been registered.
res, err = client.Run(context.Background(), &sdkv1.RunRequest{Function: "not_found"})
if err != nil {
t.Fatal(err)
}
if res.Status != sdkv1.Status_STATUS_NOT_FOUND {
t.Fatalf("unexpected response status: %v", res.Status)
if res.Status() != dispatch.NotFoundStatus {
t.Fatalf("unexpected response status: %v", res.Status())
}

// Try with a client that does not sign requests. The Dispatch
Expand Down Expand Up @@ -110,7 +100,7 @@ func TestDispatchCall(t *testing.T) {
t.Fatal(err)
}

fn := dispatch.NewPrimitiveFunction("function1", func(ctx context.Context, req *sdkv1.RunRequest) *sdkv1.RunResponse {
fn := dispatch.NewPrimitiveFunction("function1", func(ctx context.Context, req *sdkv1.RunRequest) dispatch.Response {
panic("not implemented")
})
endpoint.Register(fn)
Expand Down Expand Up @@ -181,7 +171,7 @@ func TestDispatchCallsBatch(t *testing.T) {
t.Fatal(err)
}

fn1 := dispatch.NewPrimitiveFunction("function1", func(ctx context.Context, req *sdkv1.RunRequest) *sdkv1.RunResponse {
fn1 := dispatch.NewPrimitiveFunction("function1", func(ctx context.Context, req *sdkv1.RunRequest) dispatch.Response {
panic("not implemented")
})
fn2 := dispatch.NewFunction("function2", func(ctx context.Context, req *wrapperspb.StringValue) (*wrapperspb.StringValue, error) {
Expand Down
2 changes: 1 addition & 1 deletion dispatchlambda/lambda.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (h *handler) Invoke(ctx context.Context, payload []byte) ([]byte, error) {

res := h.function.Run(ctx, req)

rawResponse, err := proto.Marshal(res)
rawResponse, err := res.Marshal()
if err != nil {
return nil, err
}
Expand Down
19 changes: 16 additions & 3 deletions dispatchtest/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"connectrpc.com/validate"
"github.com/dispatchrun/dispatch-go"
"github.com/dispatchrun/dispatch-go/internal/auth"
"google.golang.org/protobuf/proto"
)

// NewEndpoint creates a Dispatch endpoint, like dispatch.New.
Expand Down Expand Up @@ -111,10 +112,22 @@ func WithSigningKey(signingKey string) EndpointClientOption {
}

// Run sends a RunRequest and returns a RunResponse.
func (c *EndpointClient) Run(ctx context.Context, req *sdkv1.RunRequest) (*sdkv1.RunResponse, error) {
func (c *EndpointClient) Run(ctx context.Context, req *sdkv1.RunRequest) (dispatch.Response, error) {
res, err := c.client.Run(ctx, connect.NewRequest(req))
if err != nil {
return nil, err
return dispatch.Response{}, err
}
return wrapResponse(res.Msg), nil
}

func wrapResponse(r *sdkv1.RunResponse) dispatch.Response {
b, err := proto.Marshal(r)
if err != nil {
panic(err)
}
response, err := dispatch.UnmarshalResponse(b)
if err != nil {
panic(err)
}
return res.Msg, nil
return response
}
19 changes: 0 additions & 19 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"errors"
"reflect"
"strings"

sdkv1 "buf.build/gen/go/stealthrocket/dispatch-proto/protocolbuffers/go/dispatch/sdk/v1"
)

var (
Expand Down Expand Up @@ -131,20 +129,3 @@ type temporary interface {
type timeout interface {
Timeout() bool
}

// ErrorResponse creates a RunResponse for the specified error.
func ErrorResponse(err error) *sdkv1.RunResponse {
return &sdkv1.RunResponse{
Status: errorStatusOf(err).proto(),
Directive: &sdkv1.RunResponse_Exit{
Exit: &sdkv1.Exit{
Result: &sdkv1.CallResult{
Error: &sdkv1.Error{
Type: errorTypeOf(err),
Message: err.Error(),
},
},
},
},
}
}
45 changes: 16 additions & 29 deletions function.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"github.com/stealthrocket/coroutine"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
)

// Function is a Dispatch function.
Expand All @@ -19,7 +18,7 @@ type Function interface {
Name() string

// Run runs the function.
Run(context.Context, *sdkv1.RunRequest) *sdkv1.RunResponse
Run(context.Context, *sdkv1.RunRequest) Response

// bind is an internal hook for binding a function to
// a Dispatch endpoint, allowing the NewCall and Dispatch
Expand Down Expand Up @@ -47,15 +46,15 @@ func (f *GenericFunction[Input, Output]) Name() string {
}

// Run runs the function.
func (f *GenericFunction[Input, Output]) Run(ctx context.Context, req *sdkv1.RunRequest) *sdkv1.RunResponse {
func (f *GenericFunction[Input, Output]) Run(ctx context.Context, req *sdkv1.RunRequest) Response {
var coro coroutine.Coroutine[any, any]
var zero Input

switch c := req.Directive.(type) {
case *sdkv1.RunRequest_PollResult:
coro = coroutine.NewWithReturn[any, any](f.entrypoint(zero))
if err := coro.Context().Unmarshal(c.PollResult.GetCoroutineState()); err != nil {
return ErrorResponse(fmt.Errorf("%w: invalid coroutine state: %v", ErrIncompatibleState, err))
return NewErrorfResponse("%w: invalid coroutine state: %v", ErrIncompatibleState, err)
}
case *sdkv1.RunRequest_Input:
var input Input
Expand All @@ -66,19 +65,17 @@ func (f *GenericFunction[Input, Output]) Run(ctx context.Context, req *sdkv1.Run
RecursionLimit: protowire.DefaultRecursionLimit,
}
if err := options.Unmarshal(c.Input.Value, message.Interface()); err != nil {
return ErrorResponse(fmt.Errorf("%w: invalid function input: %v", ErrInvalidArgument, err))
return NewErrorfResponse("%w: invalid function input: %v", ErrInvalidArgument, err)
}
input = message.Interface().(Input)
}
coro = coroutine.NewWithReturn[any, any](f.entrypoint(input))

default:
return ErrorResponse(fmt.Errorf("%w: unsupported coroutine directive: %T", ErrInvalidArgument, c))
return NewErrorfResponse("%w: unsupported coroutine directive: %T", ErrInvalidArgument, c)
}

res := &sdkv1.RunResponse{
Status: sdkv1.Status_STATUS_OK,
}
var res Response

// When running in volatile mode, we cannot snapshot the coroutine state
// and return it to the caller. Instead, we run the coroutine to completion
Expand All @@ -90,40 +87,30 @@ func (f *GenericFunction[Input, Output]) Run(ctx context.Context, req *sdkv1.Run
return nil
})
if canceled {
return ErrorResponse(context.Cause(ctx))
return NewErrorResponse(context.Cause(ctx))
}
}

if coro.Next() {
coroutineState, err := coro.Context().Marshal()
if err != nil {
return ErrorResponse(fmt.Errorf("%w: cannot serialize coroutine: %v", ErrPermanent, err))
return NewErrorfResponse("%w: cannot serialize coroutine: %v", ErrPermanent, err)
}
switch yield := coro.Recv().(type) {
// TODO
default:
res = ErrorResponse(fmt.Errorf("%w: unsupported coroutine yield: %T", ErrInvalidResponse, yield))
res = NewErrorfResponse("%w: unsupported coroutine yield: %T", ErrInvalidResponse, yield)
}
// TODO
_ = coroutineState
} else {
switch ret := coro.Result().(type) {
case proto.Message:
output, _ := anypb.New(ret)
if status := statusOf(ret); status != UnspecifiedStatus {
res.Status = status.proto()
}
res.Directive = &sdkv1.RunResponse_Exit{
Exit: &sdkv1.Exit{
Result: &sdkv1.CallResult{
Output: output,
},
},
}
res = NewOutputResponse(ret)
case error:
res = ErrorResponse(ret)
res = NewErrorResponse(ret)
default:
res = ErrorResponse(fmt.Errorf("%w: unsupported coroutine return: %T", ErrInvalidResponse, ret))
res = NewErrorfResponse("%w: unsupported coroutine return: %T", ErrInvalidResponse, ret)
}
}

Expand Down Expand Up @@ -172,15 +159,15 @@ func (f *GenericFunction[Input, Output]) entrypoint(input Input) func() any {
}

// NewPrimitiveFunction creates a PrimitiveFunction.
func NewPrimitiveFunction(name string, fn func(context.Context, *sdkv1.RunRequest) *sdkv1.RunResponse) *PrimitiveFunction {
func NewPrimitiveFunction(name string, fn func(context.Context, *sdkv1.RunRequest) Response) *PrimitiveFunction {
return &PrimitiveFunction{name: name, fn: fn}
}

// PrimitiveFunction is a function that's close to the underlying
// Dispatch protocol, accepting a RunRequest and returning a RunResponse.
// Dispatch protocol, accepting a Request and returning a Response.
type PrimitiveFunction struct {
name string
fn func(context.Context, *sdkv1.RunRequest) *sdkv1.RunResponse
fn func(context.Context, *sdkv1.RunRequest) Response

endpoint *Dispatch
}
Expand All @@ -191,7 +178,7 @@ func (f *PrimitiveFunction) Name() string {
}

// Run runs the function.
func (f *PrimitiveFunction) Run(ctx context.Context, req *sdkv1.RunRequest) *sdkv1.RunResponse {
func (f *PrimitiveFunction) Run(ctx context.Context, req *sdkv1.RunRequest) Response {
return f.fn(ctx, req)
}

Expand Down
Loading

0 comments on commit 288c7e1

Please sign in to comment.