Skip to content

Commit

Permalink
Avoid wrapperspb in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Jun 7, 2024
1 parent b3916fc commit 1c234a4
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 34 deletions.
9 changes: 4 additions & 5 deletions dispatch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"connectrpc.com/connect"
"github.com/dispatchrun/dispatch-go"
"github.com/dispatchrun/dispatch-go/dispatchtest"
"google.golang.org/protobuf/types/known/wrapperspb"
)

func TestDispatchEndpoint(t *testing.T) {
Expand Down Expand Up @@ -118,12 +117,12 @@ func TestDispatchCallEnvConfig(t *testing.T) {
t.Fatal(err)
}

fn := dispatch.NewFunction("function2", func(ctx context.Context, req *wrapperspb.StringValue) (*wrapperspb.StringValue, error) {
fn := dispatch.NewFunction("function2", func(ctx context.Context, input string) (string, error) {
panic("not implemented")
})
endpoint.Register(fn)

_, err = fn.Dispatch(context.Background(), wrapperspb.String("foo"), dispatch.Version("xyzzy"))
_, err = fn.Dispatch(context.Background(), "foo", dispatch.Version("xyzzy"))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -157,7 +156,7 @@ func TestDispatchCallsBatch(t *testing.T) {
fn1 := dispatch.NewPrimitiveFunction("function1", func(ctx context.Context, req dispatch.Request) dispatch.Response {
panic("not implemented")
})
fn2 := dispatch.NewFunction("function2", func(ctx context.Context, req *wrapperspb.StringValue) (*wrapperspb.StringValue, error) {
fn2 := dispatch.NewFunction("function2", func(ctx context.Context, input string) (string, error) {
panic("not implemented")
})

Expand All @@ -168,7 +167,7 @@ func TestDispatchCallsBatch(t *testing.T) {
if err != nil {
t.Fatal(err)
}
call2, err := fn2.NewCall(wrapperspb.String("foo"), dispatch.Version("xyzzy"))
call2, err := fn2.NewCall("foo", dispatch.Version("xyzzy"))
if err != nil {
t.Fatal(err)
}
Expand Down
28 changes: 14 additions & 14 deletions dispatchlambda/lambda_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,35 @@ import (
)

func TestHandlerEmptyPayload(t *testing.T) {
fn := dispatch.NewFunction("handler", func(ctx context.Context, input *wrapperspb.StringValue) (*wrapperspb.StringValue, error) {
return nil, nil
fn := dispatch.NewFunction("handler", func(ctx context.Context, input string) (string, error) {
return "", nil
})
h := dispatchlambda.Handler(fn)
_, err := h.Invoke(context.Background(), nil)
assertInvokeError(t, err, "Bad Request", "empty payload")
}

func TestHandlerShortPayload(t *testing.T) {
fn := dispatch.NewFunction("handler", func(ctx context.Context, input *wrapperspb.StringValue) (*wrapperspb.StringValue, error) {
return nil, nil
fn := dispatch.NewFunction("handler", func(ctx context.Context, input string) (string, error) {
return "", nil
})
h := dispatchlambda.Handler(fn)
_, err := h.Invoke(context.Background(), []byte(`@`))
assertInvokeError(t, err, "Bad Request", "payload is too short")
}

func TestHandlerNonBase64Payload(t *testing.T) {
fn := dispatch.NewFunction("handler", func(ctx context.Context, input *wrapperspb.StringValue) (*wrapperspb.StringValue, error) {
return nil, nil
fn := dispatch.NewFunction("handler", func(ctx context.Context, input string) (string, error) {
return "", nil
})
h := dispatchlambda.Handler(fn)
_, err := h.Invoke(context.Background(), []byte(`"not base64"`))
assertInvokeError(t, err, "Bad Request", "payload is not base64 encoded")
}

func TestHandlerInvokePayloadNotProtobufMessage(t *testing.T) {
fn := dispatch.NewFunction("handler", func(ctx context.Context, input *wrapperspb.StringValue) (*wrapperspb.StringValue, error) {
return nil, nil
fn := dispatch.NewFunction("handler", func(ctx context.Context, input string) (string, error) {
return "", nil
})
h := dispatchlambda.Handler(fn)
ctx := lambdacontext.NewContext(context.Background(), &lambdacontext.LambdaContext{
Expand All @@ -56,8 +56,8 @@ func TestHandlerInvokePayloadNotProtobufMessage(t *testing.T) {
}

func TestHandlerInvokeError(t *testing.T) {
fn := dispatch.NewFunction("handler", func(ctx context.Context, input *wrapperspb.StringValue) (*wrapperspb.StringValue, error) {
return nil, errors.New("invoke error")
fn := dispatch.NewFunction("handler", func(ctx context.Context, input string) (string, error) {
return "", errors.New("invoke error")
})
h := dispatchlambda.Handler(fn)
ctx := lambdacontext.NewContext(context.Background(), &lambdacontext.LambdaContext{
Expand Down Expand Up @@ -114,8 +114,8 @@ func TestHandlerInvokeError(t *testing.T) {
}

func TestHandlerInvokeFunction(t *testing.T) {
fn := dispatch.NewFunction("handler", func(ctx context.Context, input *wrapperspb.StringValue) (*wrapperspb.StringValue, error) {
return wrapperspb.String("output"), nil
fn := dispatch.NewFunction("handler", func(ctx context.Context, input string) (string, error) {
return input + "output", nil
})
h := dispatchlambda.Handler(fn)

Expand Down Expand Up @@ -179,8 +179,8 @@ func TestHandlerInvokeFunction(t *testing.T) {
if err := out.UnmarshalTo(&output); err != nil {
t.Fatalf("unexpected error unmarshaling output: %v", err)
}
if output.Value != "output" {
t.Errorf("expected coroutine to return an output with value %q, got %q", "output", output.Value)
if output.Value != "inputoutput" {
t.Errorf("expected coroutine to return an output with value %q, got %q", "inputoutput", output.Value)
}
default:
t.Errorf("expected coroutine to return an error, got %T", coro)
Expand Down
8 changes: 4 additions & 4 deletions function.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func (f *GenericFunction[I, O]) Name() string {
func (f *GenericFunction[I, O]) Run(ctx context.Context, req Request) Response {
boxedInput, ok := req.Input()
if !ok {
return NewResponseErrorf("%w: unsupported request directive: %v", ErrInvalidArgument, req)
return NewResponseErrorf("%w: unsupported request: %v", ErrInvalidArgument, req)
}
var input I
if err := boxedInput.Unmarshal(&input); err != nil {
Expand All @@ -55,7 +55,7 @@ func (f *GenericFunction[I, O]) Run(ctx context.Context, req Request) Response {
}
boxedOutput, err := NewAny(output)
if err != nil {
return NewResponseErrorf("%w: cannot serialize return value %v: %v", ErrInvalidResponse, output, err)
return NewResponseErrorf("%w: invalid output %v: %v", ErrInvalidResponse, output, err)
}
return NewResponse(StatusOf(output), Output(boxedOutput))
}
Expand All @@ -69,11 +69,11 @@ func (f *GenericFunction[I, O]) NewCall(input I, opts ...CallOption) (Call, erro
if f.endpoint == nil {
return Call{}, fmt.Errorf("cannot build function call: function has not been registered with a Dispatch endpoint")
}
anyInput, err := NewAny(input)
boxedInput, err := NewAny(input)
if err != nil {
return Call{}, fmt.Errorf("cannot serialize input: %v", err)
}
opts = append(slices.Clip(opts), Input(anyInput))
opts = append(slices.Clip(opts), Input(boxedInput))
return NewCall(f.endpoint.URL(), f.name, opts...), nil
}

Expand Down
21 changes: 10 additions & 11 deletions function_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@ import (
"testing"

"github.com/dispatchrun/dispatch-go"
"google.golang.org/protobuf/types/known/wrapperspb"
)

func TestFunctionRunError(t *testing.T) {
fn := dispatch.NewFunction("foo", func(ctx context.Context, req *wrapperspb.StringValue) (*wrapperspb.StringValue, error) {
return nil, errors.New("oops")
fn := dispatch.NewFunction("foo", func(ctx context.Context, input string) (string, error) {
return "", errors.New("oops")
})

req := dispatch.NewRequest("foo", dispatch.Input(dispatch.String("hello")))
Expand All @@ -29,8 +28,8 @@ func TestFunctionRunError(t *testing.T) {
}

func TestFunctionRunResult(t *testing.T) {
fn := dispatch.NewFunction("foo", func(ctx context.Context, req *wrapperspb.StringValue) (*wrapperspb.StringValue, error) {
return wrapperspb.String("world"), nil
fn := dispatch.NewFunction("foo", func(ctx context.Context, input string) (string, error) {
return "world", nil
})

req := dispatch.NewRequest("foo", dispatch.Input(dispatch.String("hello")))
Expand Down Expand Up @@ -66,17 +65,17 @@ func TestPrimitiveFunctionNewCallAndDispatchWithoutEndpoint(t *testing.T) {
}

func TestFunctionNewCallAndDispatchWithoutEndpoint(t *testing.T) {
fn := dispatch.NewFunction("foo", func(ctx context.Context, req *wrapperspb.StringValue) (*wrapperspb.StringValue, error) {
fn := dispatch.NewFunction("foo", func(ctx context.Context, input string) (string, error) {
panic("not implemented")
})

wantErr := "cannot build function call: function has not been registered with a Dispatch endpoint"

_, err := fn.NewCall(wrapperspb.String("bar"))
_, err := fn.NewCall("bar")
if err == nil || err.Error() != wantErr {
t.Fatalf("unexpected error: %v", err)
}
_, err = fn.Dispatch(context.Background(), wrapperspb.String("bar"))
_, err = fn.Dispatch(context.Background(), "bar")
if err == nil || err.Error() != wantErr {
t.Fatalf("unexpected error: %v", err)
}
Expand Down Expand Up @@ -119,18 +118,18 @@ func TestFunctionDispatchWithoutClient(t *testing.T) {
t.Fatal(err)
}

fn := dispatch.NewFunction("foo", func(ctx context.Context, req *wrapperspb.StringValue) (*wrapperspb.StringValue, error) {
fn := dispatch.NewFunction("foo", func(ctx context.Context, input string) (string, error) {
panic("not implemented")
})
endpoint.Register(fn)

// It's possible to create a call since an endpoint URL is available.
if _, err := fn.NewCall(wrapperspb.String("bar")); err != nil {
if _, err := fn.NewCall("bar"); err != nil {
t.Fatal(err)
}

// However, a client is not available.
_, err = fn.Dispatch(context.Background(), wrapperspb.String("bar"))
_, err = fn.Dispatch(context.Background(), "bar")
if err == nil {
t.Fatal("expected an error")
} else if err.Error() != "cannot dispatch function call: Dispatch API key has not been set. Use APIKey(..), or set the DISPATCH_API_KEY environment variable" {
Expand Down

0 comments on commit 1c234a4

Please sign in to comment.