diff --git a/dispatch_test.go b/dispatch_test.go index 44338de..4c21f41 100644 --- a/dispatch_test.go +++ b/dispatch_test.go @@ -8,7 +8,6 @@ import ( "connectrpc.com/connect" "github.com/dispatchrun/dispatch-go" "github.com/dispatchrun/dispatch-go/dispatchtest" - "google.golang.org/protobuf/types/known/wrapperspb" ) func TestDispatchEndpoint(t *testing.T) { @@ -118,12 +117,12 @@ func TestDispatchCallEnvConfig(t *testing.T) { t.Fatal(err) } - fn := dispatch.NewFunction("function2", func(ctx context.Context, req *wrapperspb.StringValue) (*wrapperspb.StringValue, error) { + fn := dispatch.NewFunction("function2", func(ctx context.Context, input string) (string, error) { panic("not implemented") }) endpoint.Register(fn) - _, err = fn.Dispatch(context.Background(), wrapperspb.String("foo"), dispatch.Version("xyzzy")) + _, err = fn.Dispatch(context.Background(), "foo", dispatch.Version("xyzzy")) if err != nil { t.Fatal(err) } @@ -157,7 +156,7 @@ func TestDispatchCallsBatch(t *testing.T) { fn1 := dispatch.NewPrimitiveFunction("function1", func(ctx context.Context, req dispatch.Request) dispatch.Response { panic("not implemented") }) - fn2 := dispatch.NewFunction("function2", func(ctx context.Context, req *wrapperspb.StringValue) (*wrapperspb.StringValue, error) { + fn2 := dispatch.NewFunction("function2", func(ctx context.Context, input string) (string, error) { panic("not implemented") }) @@ -168,7 +167,7 @@ func TestDispatchCallsBatch(t *testing.T) { if err != nil { t.Fatal(err) } - call2, err := fn2.NewCall(wrapperspb.String("foo"), dispatch.Version("xyzzy")) + call2, err := fn2.NewCall("foo", dispatch.Version("xyzzy")) if err != nil { t.Fatal(err) } diff --git a/dispatchlambda/lambda_test.go b/dispatchlambda/lambda_test.go index 9c77425..f6661ae 100644 --- a/dispatchlambda/lambda_test.go +++ b/dispatchlambda/lambda_test.go @@ -17,8 +17,8 @@ import ( ) func TestHandlerEmptyPayload(t *testing.T) { - fn := dispatch.NewFunction("handler", func(ctx context.Context, input *wrapperspb.StringValue) (*wrapperspb.StringValue, error) { - return nil, nil + fn := dispatch.NewFunction("handler", func(ctx context.Context, input string) (string, error) { + return "", nil }) h := dispatchlambda.Handler(fn) _, err := h.Invoke(context.Background(), nil) @@ -26,8 +26,8 @@ func TestHandlerEmptyPayload(t *testing.T) { } func TestHandlerShortPayload(t *testing.T) { - fn := dispatch.NewFunction("handler", func(ctx context.Context, input *wrapperspb.StringValue) (*wrapperspb.StringValue, error) { - return nil, nil + fn := dispatch.NewFunction("handler", func(ctx context.Context, input string) (string, error) { + return "", nil }) h := dispatchlambda.Handler(fn) _, err := h.Invoke(context.Background(), []byte(`@`)) @@ -35,8 +35,8 @@ func TestHandlerShortPayload(t *testing.T) { } func TestHandlerNonBase64Payload(t *testing.T) { - fn := dispatch.NewFunction("handler", func(ctx context.Context, input *wrapperspb.StringValue) (*wrapperspb.StringValue, error) { - return nil, nil + fn := dispatch.NewFunction("handler", func(ctx context.Context, input string) (string, error) { + return "", nil }) h := dispatchlambda.Handler(fn) _, err := h.Invoke(context.Background(), []byte(`"not base64"`)) @@ -44,8 +44,8 @@ func TestHandlerNonBase64Payload(t *testing.T) { } func TestHandlerInvokePayloadNotProtobufMessage(t *testing.T) { - fn := dispatch.NewFunction("handler", func(ctx context.Context, input *wrapperspb.StringValue) (*wrapperspb.StringValue, error) { - return nil, nil + fn := dispatch.NewFunction("handler", func(ctx context.Context, input string) (string, error) { + return "", nil }) h := dispatchlambda.Handler(fn) ctx := lambdacontext.NewContext(context.Background(), &lambdacontext.LambdaContext{ @@ -56,8 +56,8 @@ func TestHandlerInvokePayloadNotProtobufMessage(t *testing.T) { } func TestHandlerInvokeError(t *testing.T) { - fn := dispatch.NewFunction("handler", func(ctx context.Context, input *wrapperspb.StringValue) (*wrapperspb.StringValue, error) { - return nil, errors.New("invoke error") + fn := dispatch.NewFunction("handler", func(ctx context.Context, input string) (string, error) { + return "", errors.New("invoke error") }) h := dispatchlambda.Handler(fn) ctx := lambdacontext.NewContext(context.Background(), &lambdacontext.LambdaContext{ @@ -114,8 +114,8 @@ func TestHandlerInvokeError(t *testing.T) { } func TestHandlerInvokeFunction(t *testing.T) { - fn := dispatch.NewFunction("handler", func(ctx context.Context, input *wrapperspb.StringValue) (*wrapperspb.StringValue, error) { - return wrapperspb.String("output"), nil + fn := dispatch.NewFunction("handler", func(ctx context.Context, input string) (string, error) { + return input + "output", nil }) h := dispatchlambda.Handler(fn) @@ -179,8 +179,8 @@ func TestHandlerInvokeFunction(t *testing.T) { if err := out.UnmarshalTo(&output); err != nil { t.Fatalf("unexpected error unmarshaling output: %v", err) } - if output.Value != "output" { - t.Errorf("expected coroutine to return an output with value %q, got %q", "output", output.Value) + if output.Value != "inputoutput" { + t.Errorf("expected coroutine to return an output with value %q, got %q", "inputoutput", output.Value) } default: t.Errorf("expected coroutine to return an error, got %T", coro) diff --git a/function.go b/function.go index 58e2b97..0155a8d 100644 --- a/function.go +++ b/function.go @@ -43,7 +43,7 @@ func (f *GenericFunction[I, O]) Name() string { func (f *GenericFunction[I, O]) Run(ctx context.Context, req Request) Response { boxedInput, ok := req.Input() if !ok { - return NewResponseErrorf("%w: unsupported request directive: %v", ErrInvalidArgument, req) + return NewResponseErrorf("%w: unsupported request: %v", ErrInvalidArgument, req) } var input I if err := boxedInput.Unmarshal(&input); err != nil { @@ -55,7 +55,7 @@ func (f *GenericFunction[I, O]) Run(ctx context.Context, req Request) Response { } boxedOutput, err := NewAny(output) if err != nil { - return NewResponseErrorf("%w: cannot serialize return value %v: %v", ErrInvalidResponse, output, err) + return NewResponseErrorf("%w: invalid output %v: %v", ErrInvalidResponse, output, err) } return NewResponse(StatusOf(output), Output(boxedOutput)) } @@ -69,11 +69,11 @@ func (f *GenericFunction[I, O]) NewCall(input I, opts ...CallOption) (Call, erro if f.endpoint == nil { return Call{}, fmt.Errorf("cannot build function call: function has not been registered with a Dispatch endpoint") } - anyInput, err := NewAny(input) + boxedInput, err := NewAny(input) if err != nil { return Call{}, fmt.Errorf("cannot serialize input: %v", err) } - opts = append(slices.Clip(opts), Input(anyInput)) + opts = append(slices.Clip(opts), Input(boxedInput)) return NewCall(f.endpoint.URL(), f.name, opts...), nil } diff --git a/function_test.go b/function_test.go index 96ba34e..a09cbb2 100644 --- a/function_test.go +++ b/function_test.go @@ -6,12 +6,11 @@ import ( "testing" "github.com/dispatchrun/dispatch-go" - "google.golang.org/protobuf/types/known/wrapperspb" ) func TestFunctionRunError(t *testing.T) { - fn := dispatch.NewFunction("foo", func(ctx context.Context, req *wrapperspb.StringValue) (*wrapperspb.StringValue, error) { - return nil, errors.New("oops") + fn := dispatch.NewFunction("foo", func(ctx context.Context, input string) (string, error) { + return "", errors.New("oops") }) req := dispatch.NewRequest("foo", dispatch.Input(dispatch.String("hello"))) @@ -29,8 +28,8 @@ func TestFunctionRunError(t *testing.T) { } func TestFunctionRunResult(t *testing.T) { - fn := dispatch.NewFunction("foo", func(ctx context.Context, req *wrapperspb.StringValue) (*wrapperspb.StringValue, error) { - return wrapperspb.String("world"), nil + fn := dispatch.NewFunction("foo", func(ctx context.Context, input string) (string, error) { + return "world", nil }) req := dispatch.NewRequest("foo", dispatch.Input(dispatch.String("hello"))) @@ -66,17 +65,17 @@ func TestPrimitiveFunctionNewCallAndDispatchWithoutEndpoint(t *testing.T) { } func TestFunctionNewCallAndDispatchWithoutEndpoint(t *testing.T) { - fn := dispatch.NewFunction("foo", func(ctx context.Context, req *wrapperspb.StringValue) (*wrapperspb.StringValue, error) { + fn := dispatch.NewFunction("foo", func(ctx context.Context, input string) (string, error) { panic("not implemented") }) wantErr := "cannot build function call: function has not been registered with a Dispatch endpoint" - _, err := fn.NewCall(wrapperspb.String("bar")) + _, err := fn.NewCall("bar") if err == nil || err.Error() != wantErr { t.Fatalf("unexpected error: %v", err) } - _, err = fn.Dispatch(context.Background(), wrapperspb.String("bar")) + _, err = fn.Dispatch(context.Background(), "bar") if err == nil || err.Error() != wantErr { t.Fatalf("unexpected error: %v", err) } @@ -119,18 +118,18 @@ func TestFunctionDispatchWithoutClient(t *testing.T) { t.Fatal(err) } - fn := dispatch.NewFunction("foo", func(ctx context.Context, req *wrapperspb.StringValue) (*wrapperspb.StringValue, error) { + fn := dispatch.NewFunction("foo", func(ctx context.Context, input string) (string, error) { panic("not implemented") }) endpoint.Register(fn) // It's possible to create a call since an endpoint URL is available. - if _, err := fn.NewCall(wrapperspb.String("bar")); err != nil { + if _, err := fn.NewCall("bar"); err != nil { t.Fatal(err) } // However, a client is not available. - _, err = fn.Dispatch(context.Background(), wrapperspb.String("bar")) + _, err = fn.Dispatch(context.Background(), "bar") if err == nil { t.Fatal("expected an error") } else if err.Error() != "cannot dispatch function call: Dispatch API key has not been set. Use APIKey(..), or set the DISPATCH_API_KEY environment variable" {