Skip to content

Commit

Permalink
Test the various ways of dispatching calls
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Jun 4, 2024
1 parent 36b62a3 commit ad8ddb4
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 54 deletions.
25 changes: 25 additions & 0 deletions call.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,28 @@ func (c Call) CorrelationID() uint64 {
func (c Call) proto() *sdkv1.Call {
return c.message
}

// Equal is true if the call is equal to another.
func (c Call) Equal(other Call) bool {
if c.message == nil || other.message == nil {
return false
}
if c.Endpoint() != other.Endpoint() {
return false
}
if c.Function() != other.Function() {
return false
}
if c.CorrelationID() != other.CorrelationID() {
return false
}
if c.Expiration() != other.Expiration() {
return false
}
if c.Version() != other.Version() {
return false
}
input, _ := c.Input()
otherInput, _ := other.Input()
return input != nil && otherInput != nil && proto.Equal(input, otherInput)
}
8 changes: 5 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,11 @@ func (b *Batch) Reset() {
b.calls = b.calls[:0]
}

// Add adds a Call to the batch.
func (b *Batch) Add(call Call) {
b.calls = append(b.calls, call.proto())
// Add adds calls to the batch.
func (b *Batch) Add(calls ...Call) {
for i := range calls {
b.calls = append(b.calls, calls[i].proto())
}
}

// Dispatch dispatches the batch of function calls.
Expand Down
105 changes: 67 additions & 38 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,17 @@ import (

"github.com/dispatchrun/dispatch-go"
"github.com/dispatchrun/dispatch-go/dispatchtest"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/wrapperspb"
)

func TestClient(t *testing.T) {
var recorder dispatchtest.Recorder
var recorder dispatchtest.CallRecorder

server := dispatchtest.NewServer(&recorder)
server := dispatchtest.NewDispatchServer(&recorder)

client := &dispatch.Client{ApiKey: "foobar", ApiUrl: server.URL}

input := wrapperspb.Int32(11)
call, err := dispatch.NewCall("http://example.com", "function1", input)
call, err := dispatch.NewCall("http://example.com", "function1", wrapperspb.Int32(11))
if err != nil {
t.Fatal(err)
}
Expand All @@ -28,44 +26,25 @@ func TestClient(t *testing.T) {
t.Fatal(err)
}

if len(recorder.Requests) != 1 {
t.Fatalf("expected one request to Dispatch, got %v", len(recorder.Requests))
}
req := &recorder.Requests[0]
if req.ApiKey != "foobar" {
t.Errorf("unexpected API key: %v", req.ApiKey)
}
if len(req.Calls) != 1 {
t.Fatalf("expected one call to Dispatch, got %v", len(req.Calls))
}
got := req.Calls[0]
if got.Endpoint() != call.Endpoint() {
t.Errorf("unexpected call endpoint: %v", got.Endpoint())
}
if got.Function() != call.Function() {
t.Errorf("unexpected call function: %v", got.Function())
}
gotInput, err := got.Input()
if err != nil {
t.Fatal(err)
}
if !proto.Equal(gotInput, input) {
t.Errorf("unexpected call input: %#v", gotInput)
}
dispatchtest.AssertDispatchRequests(t, recorder.Requests, []dispatchtest.DispatchRequest{
{
ApiKey: "foobar",
Calls: []dispatch.Call{call},
},
})
}

func TestClientEnvConfig(t *testing.T) {
var recorder dispatchtest.Recorder
var recorder dispatchtest.CallRecorder

server := dispatchtest.NewServer(&recorder)
server := dispatchtest.NewDispatchServer(&recorder)

client := &dispatch.Client{Env: []string{
"DISPATCH_API_KEY=foobar",
"DISPATCH_API_URL=" + server.URL,
}}

input := wrapperspb.Int32(11)
call, err := dispatch.NewCall("http://example.com", "function1", input)
call, err := dispatch.NewCall("http://example.com", "function1", wrapperspb.Int32(11))
if err != nil {
t.Fatal(err)
}
Expand All @@ -75,11 +54,61 @@ func TestClientEnvConfig(t *testing.T) {
t.Fatal(err)
}

if len(recorder.Requests) != 1 {
t.Fatalf("expected one request to Dispatch, got %v", len(recorder.Requests))
dispatchtest.AssertDispatchRequests(t, recorder.Requests, []dispatchtest.DispatchRequest{
{
ApiKey: "foobar",
Calls: []dispatch.Call{call},
},
})
}

func TestClientBatch(t *testing.T) {
var recorder dispatchtest.CallRecorder

server := dispatchtest.NewDispatchServer(&recorder)

client := &dispatch.Client{ApiKey: "foobar", ApiUrl: server.URL}

call1, err := dispatch.NewCall("http://example.com", "function1", wrapperspb.Int32(11))
if err != nil {
t.Fatal(err)
}
call2, err := dispatch.NewCall("http://example.com", "function2", wrapperspb.Int32(22))
if err != nil {
t.Fatal(err)
}
call3, err := dispatch.NewCall("http://example.com", "function3", wrapperspb.Int32(33))
if err != nil {
t.Fatal(err)
}
call4, err := dispatch.NewCall("http://example2.com", "function4", wrapperspb.Int32(44))
if err != nil {
t.Fatal(err)
}

batch := client.Batch()
batch.Add(call1, call2)
_, err = batch.Dispatch(context.Background())
if err != nil {
t.Fatal(err)
}
req := &recorder.Requests[0]
if req.ApiKey != "foobar" {
t.Errorf("unexpected API key: %v", req.ApiKey)

batch.Reset()
batch.Add(call3)
batch.Add(call4)
_, err = batch.Dispatch(context.Background())
if err != nil {
t.Fatal(err)
}

dispatchtest.AssertDispatchRequests(t, recorder.Requests, []dispatchtest.DispatchRequest{
{
ApiKey: "foobar",
Calls: []dispatch.Call{call1, call2},
},
{
ApiKey: "foobar",
Calls: []dispatch.Call{call3, call4},
},
})
}
35 changes: 35 additions & 0 deletions dispatch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"

sdkv1 "buf.build/gen/go/stealthrocket/dispatch-proto/protocolbuffers/go/dispatch/sdk/v1"
"connectrpc.com/connect"
Expand Down Expand Up @@ -112,3 +113,37 @@ func TestDispatchEndpoint(t *testing.T) {
t.Fatalf("expected a permission denied error, got %v", err)
}
}

func TestDispatchCalls(t *testing.T) {
var recorder dispatchtest.CallRecorder

server := dispatchtest.NewDispatchServer(&recorder)

d := &dispatch.Dispatch{
EndpointUrl: "http://example.com",
Client: dispatch.Client{ApiKey: "foobar", ApiUrl: server.URL},
}

fn := dispatch.NewPrimitiveFunction("function1", func(ctx context.Context, req *sdkv1.RunRequest) *sdkv1.RunResponse {
panic("not implemented")
})

d.Register(fn)

_, err := fn.Dispatch(context.Background(), wrapperspb.Int32(11), dispatch.WithExpiration(10*time.Second))
if err != nil {
t.Fatal(err)
}

wantCall, err := dispatch.NewCall("http://example.com", "function1", wrapperspb.Int32(11), dispatch.WithExpiration(10*time.Second))
if err != nil {
t.Fatal(err)
}

dispatchtest.AssertDispatchRequests(t, recorder.Requests, []dispatchtest.DispatchRequest{
{
ApiKey: "foobar",
Calls: []dispatch.Call{wantCall},
},
})
}
42 changes: 42 additions & 0 deletions dispatchtest/assert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package dispatchtest

import (
"testing"

"github.com/dispatchrun/dispatch-go"
)

func AssertCalls(t *testing.T, got, want []dispatch.Call) {
t.Helper()

if len(got) != len(want) {
t.Fatalf("unexpected number of calls: got %v, want %v", len(got), len(want))
}
for i, call := range got {
if !call.Equal(want[i]) {
t.Errorf("unexpected call %d: got %#v, want %#v", i, call, want[i])
}
}
}

func AssertCall(t *testing.T, got, want dispatch.Call) {
t.Helper()

if !got.Equal(want) {
t.Errorf("unexpected call: got %#v, want %#v", got, want)
}
}

func AssertDispatchRequests(t *testing.T, got, want []DispatchRequest) {
t.Helper()

if len(got) != len(want) {
t.Fatalf("unexpected number of requests: got %v, want %v", len(got), len(want))
}
for i, req := range got {
if req.ApiKey != want[i].ApiKey {
t.Errorf("unexpected API key on request %d: got %v, want %v", i, req.ApiKey, want[i].ApiKey)
}
AssertCalls(t, req.Calls, want[i].Calls)
}
}
26 changes: 13 additions & 13 deletions dispatchtest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ import (
"github.com/dispatchrun/dispatch-go"
)

// Handler is a handler for dispatched function calls.
type Handler interface {
// DispatchServerHandler is a handler for dispatched function calls.
type DispatchServerHandler interface {
Handle(ctx context.Context, apiKey string, calls []dispatch.Call) ([]dispatch.ID, error)
}

// HandlerFunc creates a Handler from a function.
func HandlerFunc(fn func(ctx context.Context, apiKey string, calls []dispatch.Call) ([]dispatch.ID, error)) Handler {
// DispatchServerHandlerFunc creates a Handler from a function.
func DispatchServerHandlerFunc(fn func(ctx context.Context, apiKey string, calls []dispatch.Call) ([]dispatch.ID, error)) DispatchServerHandler {
return handlerFunc(fn)
}

Expand All @@ -30,15 +30,15 @@ func (h handlerFunc) Handle(ctx context.Context, apiKey string, calls []dispatch
return h(ctx, apiKey, calls)
}

// NewServer creates a new test Dispatch server.
func NewServer(handler Handler) *httptest.Server {
// NewDispatchServer creates a new test Dispatch server.
func NewDispatchServer(handler DispatchServerHandler) *httptest.Server {
mux := http.NewServeMux()
mux.Handle(sdkv1connect.NewDispatchServiceHandler(&dispatchServiceHandler{handler}))
return httptest.NewServer(mux)
}

type dispatchServiceHandler struct {
Handler
DispatchServerHandler
}

func (d *dispatchServiceHandler) Dispatch(ctx context.Context, req *connect.Request[sdkv1.DispatchRequest]) (*connect.Response[sdkv1.DispatchResponse], error) {
Expand Down Expand Up @@ -80,22 +80,22 @@ func wrapCall(c *sdkv1.Call) (dispatch.Call, error) {
dispatch.WithVersion(c.Version))
}

// Recorder is a Handler that captures requests to Dispatch.
type Recorder struct {
Requests []RecorderRequest
// CallRecorder is a DispatchServerHandler that captures requests to Dispatch.
type CallRecorder struct {
Requests []DispatchRequest
calls int
}

type RecorderRequest struct {
type DispatchRequest struct {
ApiKey string
Calls []dispatch.Call
}

func (r *Recorder) Handle(ctx context.Context, apiKey string, calls []dispatch.Call) ([]dispatch.ID, error) {
func (r *CallRecorder) Handle(ctx context.Context, apiKey string, calls []dispatch.Call) ([]dispatch.ID, error) {
base := r.calls
r.calls += len(calls)

r.Requests = append(r.Requests, RecorderRequest{
r.Requests = append(r.Requests, DispatchRequest{
ApiKey: apiKey,
Calls: calls,
})
Expand Down

0 comments on commit ad8ddb4

Please sign in to comment.