From dd433592c2e4a7e73b539031454fe8773dd22320 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Mon, 10 Jun 2024 08:14:04 +1000 Subject: [PATCH] Allow client options and request headers to be set --- dispatch_test.go | 6 +++--- dispatchserver/endpoint.go | 20 +++++++++++++++++--- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/dispatch_test.go b/dispatch_test.go index add19bb..2868777 100644 --- a/dispatch_test.go +++ b/dispatch_test.go @@ -36,7 +36,7 @@ func TestDispatchEndpoint(t *testing.T) { // Send a request for the identity function, and check that the // input was echoed back. req := dispatch.NewRequest("identity", dispatch.Input(dispatch.Int(11))) - res, err := client.Run(context.Background(), req) + res, err := client.Run(context.Background(), nil, req) if err != nil { t.Fatal(err) } else if res.Status() != dispatch.OKStatus { @@ -52,7 +52,7 @@ func TestDispatchEndpoint(t *testing.T) { } // Try running a function that has not been registered. - res, err = client.Run(context.Background(), dispatch.NewRequest("not_found", dispatch.Input(dispatch.Int(22)))) + res, err = client.Run(context.Background(), nil, dispatch.NewRequest("not_found", dispatch.Input(dispatch.Int(22)))) if err != nil { t.Fatal(err) } else if res.Status() != dispatch.NotFoundStatus { @@ -65,7 +65,7 @@ func TestDispatchEndpoint(t *testing.T) { if err != nil { t.Fatal(err) } - _, err = nonSigningClient.Run(context.Background(), req) + _, err = nonSigningClient.Run(context.Background(), nil, req) if err == nil || connect.CodeOf(err) != connect.CodePermissionDenied { t.Fatalf("expected a permission denied error, got %v", err) } diff --git a/dispatchserver/endpoint.go b/dispatchserver/endpoint.go index 817cd4b..3ca7aeb 100644 --- a/dispatchserver/endpoint.go +++ b/dispatchserver/endpoint.go @@ -22,6 +22,7 @@ import ( type EndpointClient struct { httpClient connect.HTTPClient signingKey ed25519.PrivateKey + opts []connect.ClientOption client sdkv1connect.FunctionServiceClient } @@ -48,7 +49,8 @@ func NewEndpointClient(endpointUrl string, opts ...EndpointClientOption) (*Endpo if err != nil { return nil, err } - c.client = sdkv1connect.NewFunctionServiceClient(c.httpClient, endpointUrl, connect.WithInterceptors(validator)) + c.opts = append(c.opts, connect.WithInterceptors(validator)) + c.client = sdkv1connect.NewFunctionServiceClient(c.httpClient, endpointUrl, c.opts...) return c, nil } @@ -71,9 +73,21 @@ func HTTPClient(client connect.HTTPClient) EndpointClientOption { return func(c *EndpointClient) { c.httpClient = client } } +// ClientOptions sets options on the underlying connect (gRPC) client. +func ClientOptions(opts ...connect.ClientOption) EndpointClientOption { + return func(c *EndpointClient) { c.opts = append(c.opts, opts...) } +} + // Run sends a RunRequest and returns a RunResponse. -func (c *EndpointClient) Run(ctx context.Context, req dispatch.Request) (dispatch.Response, error) { - res, err := c.client.Run(ctx, connect.NewRequest(requestProto(req))) +func (c *EndpointClient) Run(ctx context.Context, header http.Header, req dispatch.Request) (dispatch.Response, error) { + connectReq := connect.NewRequest(requestProto(req)) + + connectReqHeader := connectReq.Header() + for name, values := range header { + connectReqHeader[name] = values + } + + res, err := c.client.Run(ctx, connectReq) if err != nil { return dispatch.Response{}, err }