From 7b7f2de4cdf361a783d52ab33bac3723fcabb934 Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Fri, 20 Oct 2023 18:54:34 -0400 Subject: [PATCH] Add support for experimental secondary dispatching When enabled and configured, (for now, check) dispatches will also be made to the secondary dispatchers returned by the CEL expression. This will enable items such as cache prewarming by dispatching to a backup cluster while the primary cluster is still running --- internal/dispatch/combined/combined.go | 62 +++++++-- internal/dispatch/remote/cluster.go | 142 +++++++++++++++++++- internal/dispatch/remote/cluster_test.go | 159 +++++++++++++++++++++-- internal/dispatch/remote/expr.go | 103 +++++++++++++++ internal/dispatch/remote/expr_test.go | 125 ++++++++++++++++++ pkg/cmd/serve.go | 3 + pkg/cmd/server/server.go | 5 + pkg/cmd/server/zz_generated.options.go | 32 +++++ 8 files changed, 606 insertions(+), 25 deletions(-) create mode 100644 internal/dispatch/remote/expr.go create mode 100644 internal/dispatch/remote/expr_test.go diff --git a/internal/dispatch/combined/combined.go b/internal/dispatch/combined/combined.go index 156ca5a019..c5ed59c574 100644 --- a/internal/dispatch/combined/combined.go +++ b/internal/dispatch/combined/combined.go @@ -3,6 +3,7 @@ package combined import ( + "fmt" "time" "github.com/authzed/grpcutil" @@ -23,15 +24,17 @@ import ( type Option func(*optionState) type optionState struct { - metricsEnabled bool - prometheusSubsystem string - upstreamAddr string - upstreamCAPath string - grpcPresharedKey string - grpcDialOpts []grpc.DialOption - cache cache.Cache - concurrencyLimits graph.ConcurrencyLimits - remoteDispatchTimeout time.Duration + metricsEnabled bool + prometheusSubsystem string + upstreamAddr string + upstreamCAPath string + grpcPresharedKey string + grpcDialOpts []grpc.DialOption + cache cache.Cache + concurrencyLimits graph.ConcurrencyLimits + remoteDispatchTimeout time.Duration + secondaryUpstreamAddrs map[string]string + secondaryUpstreamExprs map[string]string } // MetricsEnabled enables issuing prometheus metrics @@ -63,6 +66,23 @@ func UpstreamCAPath(path string) Option { } } +// SecondaryUpstreamAddrs sets a named map of upstream addresses for secondary +// dispatching. +func SecondaryUpstreamAddrs(addrs map[string]string) Option { + return func(state *optionState) { + state.secondaryUpstreamAddrs = addrs + } +} + +// SecondaryUpstreamExprs sets a named map from dispatch type to the associated +// CEL expression to run to determine which secondary dispatch addresses (if any) +// to use for that incoming request. +func SecondaryUpstreamExprs(addrs map[string]string) Option { + return func(state *optionState) { + state.secondaryUpstreamExprs = addrs + } +} + // GrpcPresharedKey sets the preshared key used to authenticate for optional // cluster dispatching. func GrpcPresharedKey(key string) Option { @@ -141,10 +161,32 @@ func NewDispatcher(options ...Option) (dispatch.Dispatcher, error) { if err != nil { return nil, err } + + secondaryClients := make(map[string]remote.SecondaryDispatch, len(opts.secondaryUpstreamAddrs)) + for name, addr := range opts.secondaryUpstreamAddrs { + secondaryConn, err := grpc.Dial(addr, opts.grpcDialOpts...) + if err != nil { + return nil, err + } + secondaryClients[name] = remote.SecondaryDispatch{ + Name: name, + Client: v1.NewDispatchServiceClient(secondaryConn), + } + } + + secondaryExprs := make(map[string]*remote.DispatchExpr, len(opts.secondaryUpstreamExprs)) + for name, exprString := range opts.secondaryUpstreamExprs { + parsed, err := remote.ParseDispatchExpression(name, exprString) + if err != nil { + return nil, fmt.Errorf("error parsing secondary dispatch expr `%s` for method `%s`: %w", exprString, name, err) + } + secondaryExprs[name] = parsed + } + redispatch = remote.NewClusterDispatcher(v1.NewDispatchServiceClient(conn), conn, remote.ClusterDispatcherConfig{ KeyHandler: &keys.CanonicalKeyHandler{}, DispatchOverallTimeout: opts.remoteDispatchTimeout, - }) + }, secondaryClients, secondaryExprs) } cachingRedispatch.SetDelegate(redispatch) diff --git a/internal/dispatch/remote/cluster.go b/internal/dispatch/remote/cluster.go index 81ff45bd41..5cee5f57c7 100644 --- a/internal/dispatch/remote/cluster.go +++ b/internal/dispatch/remote/cluster.go @@ -5,11 +5,15 @@ import ( "errors" "fmt" "io" + "strings" "time" "github.com/authzed/consistent" + "github.com/prometheus/client_golang/prometheus" + "github.com/rs/zerolog" "google.golang.org/grpc" "google.golang.org/grpc/connectivity" + "google.golang.org/protobuf/proto" "github.com/authzed/spicedb/internal/dispatch" "github.com/authzed/spicedb/internal/dispatch/keys" @@ -18,7 +22,18 @@ import ( "github.com/authzed/spicedb/pkg/spiceerrors" ) -type clusterClient interface { +var dispatchCounter = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: "spicedb", + Subsystem: "dispatch", + Name: "remote_dispatch_handler_total", + Help: "which dispatcher handled a request", +}, []string{"request_kind", "handler_name"}) + +func init() { + prometheus.MustRegister(dispatchCounter) +} + +type ClusterClient interface { DispatchCheck(ctx context.Context, req *v1.DispatchCheckRequest, opts ...grpc.CallOption) (*v1.DispatchCheckResponse, error) DispatchExpand(ctx context.Context, req *v1.DispatchExpandRequest, opts ...grpc.CallOption) (*v1.DispatchExpandResponse, error) DispatchReachableResources(ctx context.Context, in *v1.DispatchReachableResourcesRequest, opts ...grpc.CallOption) (v1.DispatchService_DispatchReachableResourcesClient, error) @@ -35,9 +50,16 @@ type ClusterDispatcherConfig struct { DispatchOverallTimeout time.Duration } +// SecondaryDispatch defines a struct holding a client and its name for secondary +// dispatching. +type SecondaryDispatch struct { + Name string + Client ClusterClient +} + // NewClusterDispatcher creates a dispatcher implementation that uses the provided client // to dispatch requests to peer nodes in the cluster. -func NewClusterDispatcher(client clusterClient, conn *grpc.ClientConn, config ClusterDispatcherConfig) dispatch.Dispatcher { +func NewClusterDispatcher(client ClusterClient, conn *grpc.ClientConn, config ClusterDispatcherConfig, secondaryDispatch map[string]SecondaryDispatch, secondaryDispatchExprs map[string]*DispatchExpr) dispatch.Dispatcher { keyHandler := config.KeyHandler if keyHandler == nil { keyHandler = &keys.DirectKeyHandler{} @@ -53,14 +75,18 @@ func NewClusterDispatcher(client clusterClient, conn *grpc.ClientConn, config Cl conn: conn, keyHandler: keyHandler, dispatchOverallTimeout: dispatchOverallTimeout, + secondaryDispatch: secondaryDispatch, + secondaryDispatchExprs: secondaryDispatchExprs, } } type clusterDispatcher struct { - clusterClient clusterClient + clusterClient ClusterClient conn *grpc.ClientConn keyHandler keys.Handler dispatchOverallTimeout time.Duration + secondaryDispatch map[string]SecondaryDispatch + secondaryDispatchExprs map[string]*DispatchExpr } func (cr *clusterDispatcher) DispatchCheck(ctx context.Context, req *v1.DispatchCheckRequest) (*v1.DispatchCheckResponse, error) { @@ -75,18 +101,120 @@ func (cr *clusterDispatcher) DispatchCheck(ctx context.Context, req *v1.Dispatch ctx = context.WithValue(ctx, consistent.CtxKey, requestKey) - withTimeout, cancelFn := context.WithTimeout(ctx, cr.dispatchOverallTimeout) - defer cancelFn() + resp, err := dispatchRequest(ctx, cr, "check", req, func(ctx context.Context, client ClusterClient) (*v1.DispatchCheckResponse, error) { + resp, err := client.DispatchCheck(ctx, req) + if err != nil { + return resp, err + } - resp, err := cr.clusterClient.DispatchCheck(withTimeout, req) + err = adjustMetadataForDispatch(resp.Metadata) + return resp, err + }) if err != nil { return &v1.DispatchCheckResponse{Metadata: requestFailureMetadata}, err } - err = adjustMetadataForDispatch(resp.Metadata) return resp, err } +type requestMessage interface { + zerolog.LogObjectMarshaler + + GetMetadata() *v1.ResolverMeta +} + +type responseMessage interface { + proto.Message + + GetMetadata() *v1.ResponseMeta +} + +type respTuple[S responseMessage] struct { + resp S + err error +} + +type secondaryRespTuple[S responseMessage] struct { + handlerName string + resp S +} + +func dispatchRequest[Q requestMessage, S responseMessage](ctx context.Context, cr *clusterDispatcher, reqKey string, req Q, handler func(context.Context, ClusterClient) (S, error)) (S, error) { + withTimeout, cancelFn := context.WithTimeout(ctx, cr.dispatchOverallTimeout) + defer cancelFn() + + if len(cr.secondaryDispatchExprs) == 0 || len(cr.secondaryDispatch) == 0 { + return handler(withTimeout, cr.clusterClient) + } + + // If no secondary dispatches are defined, just invoke directly. + expr, ok := cr.secondaryDispatchExprs[reqKey] + if !ok { + return handler(withTimeout, cr.clusterClient) + } + + // Otherwise invoke in parallel with any secondary matches. + primaryResultChan := make(chan respTuple[S], 1) + secondaryResultChan := make(chan secondaryRespTuple[S], len(cr.secondaryDispatch)) + + // Run the main dispatch. + go func() { + resp, err := handler(withTimeout, cr.clusterClient) + primaryResultChan <- respTuple[S]{resp, err} + }() + + result, err := RunDispatchExpr(expr, req) + if err != nil { + log.Warn().Err(err).Msg("error when trying to evaluate the dispatch expression") + } + + log.Trace().Str("secondary-dispatchers", strings.Join(result, ",")).Object("request", req).Msg("running secondary dispatchers") + + for _, secondaryDispatchName := range result { + secondary, ok := cr.secondaryDispatch[secondaryDispatchName] + if !ok { + log.Warn().Str("secondary-dispatcher-name", secondaryDispatchName).Msg("received unknown secondary dispatcher") + continue + } + + log.Trace().Str("secondary-dispatcher", secondary.Name).Object("request", req).Msg("running secondary dispatcher") + go func() { + resp, err := handler(withTimeout, secondary.Client) + if err != nil { + // For secondary dispatches, ignore any errors, as only the primary will be handled in + // that scenario. + log.Trace().Str("secondary", secondary.Name).Err(err).Msg("got ignored secondary dispatch error") + return + } + + secondaryResultChan <- secondaryRespTuple[S]{resp: resp, handlerName: secondary.Name} + }() + } + + var foundError error + select { + case <-withTimeout.Done(): + return *new(S), fmt.Errorf("check dispatch has timed out") + + case r := <-primaryResultChan: + if r.err == nil { + dispatchCounter.WithLabelValues(reqKey, "(primary)").Add(1) + return r.resp, nil + } + + // Otherwise, if an error was found, log it and we'll return after *all* the secondaries have run. + // This allows an otherwise error-state to be handled by one of the secondaries. + foundError = r.err + + case r := <-secondaryResultChan: + dispatchCounter.WithLabelValues(reqKey, r.handlerName).Add(1) + return r.resp, nil + } + + dispatchCounter.WithLabelValues(reqKey, "(primary)").Add(1) + return *new(S), foundError +} + func adjustMetadataForDispatch(metadata *v1.ResponseMeta) error { if metadata == nil { return spiceerrors.MustBugf("received a nil metadata") diff --git a/internal/dispatch/remote/cluster_test.go b/internal/dispatch/remote/cluster_test.go index b7b25011ff..6e23e434f4 100644 --- a/internal/dispatch/remote/cluster_test.go +++ b/internal/dispatch/remote/cluster_test.go @@ -16,20 +16,23 @@ import ( "google.golang.org/grpc/test/bufconn" "github.com/authzed/spicedb/internal/dispatch/keys" - core "github.com/authzed/spicedb/pkg/proto/core/v1" + corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" ) type fakeDispatchSvc struct { v1.UnimplementedDispatchServiceServer - sleepTime time.Duration + sleepTime time.Duration + dispatchCount uint32 } func (fds *fakeDispatchSvc) DispatchCheck(context.Context, *v1.DispatchCheckRequest) (*v1.DispatchCheckResponse, error) { time.Sleep(fds.sleepTime) return &v1.DispatchCheckResponse{ - Metadata: emptyMetadata, + Metadata: &v1.ResponseMeta{ + DispatchCount: fds.dispatchCount, + }, }, nil } @@ -90,16 +93,16 @@ func TestDispatchTimeout(t *testing.T) { dispatcher := NewClusterDispatcher(v1.NewDispatchServiceClient(conn), conn, ClusterDispatcherConfig{ KeyHandler: &keys.DirectKeyHandler{}, DispatchOverallTimeout: tc.timeout, - }) + }, nil, nil) require.True(t, dispatcher.ReadyState().IsReady) // Invoke a dispatched "check" and ensure it times out, as the fake dispatch will wait // longer than the configured timeout. resp, err := dispatcher.DispatchCheck(context.Background(), &v1.DispatchCheckRequest{ - ResourceRelation: &core.RelationReference{Namespace: "sometype", Relation: "somerel"}, + ResourceRelation: &corev1.RelationReference{Namespace: "sometype", Relation: "somerel"}, ResourceIds: []string{"foo"}, Metadata: &v1.ResolverMeta{DepthRemaining: 50}, - Subject: &core.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, + Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, }) if tc.sleepTime > tc.timeout { require.Error(t, err) @@ -113,10 +116,10 @@ func TestDispatchTimeout(t *testing.T) { // Invoke a dispatched "LookupSubjects" and test as well. stream := dispatch.NewCollectingDispatchStream[*v1.DispatchLookupSubjectsResponse](context.Background()) err = dispatcher.DispatchLookupSubjects(&v1.DispatchLookupSubjectsRequest{ - ResourceRelation: &core.RelationReference{Namespace: "sometype", Relation: "somerel"}, + ResourceRelation: &corev1.RelationReference{Namespace: "sometype", Relation: "somerel"}, ResourceIds: []string{"foo"}, Metadata: &v1.ResolverMeta{DepthRemaining: 50}, - SubjectRelation: &core.RelationReference{Namespace: "sometype", Relation: "somerel"}, + SubjectRelation: &corev1.RelationReference{Namespace: "sometype", Relation: "somerel"}, }, stream) if tc.sleepTime > tc.timeout { require.Error(t, err) @@ -129,3 +132,143 @@ func TestDispatchTimeout(t *testing.T) { }) } } + +func TestSecondaryDispatch(t *testing.T) { + for _, tc := range []struct { + name string + expr string + request *v1.DispatchCheckRequest + primarySleepTime time.Duration + expectedResult uint32 + }{ + { + "no multidispatch", + "['invalid']", + &v1.DispatchCheckRequest{ + ResourceRelation: &corev1.RelationReference{ + Namespace: "somenamespace", + Relation: "somerelation", + }, + ResourceIds: []string{"foo"}, + Metadata: &v1.ResolverMeta{DepthRemaining: 50}, + Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, + }, + 0 * time.Millisecond, + 1, + }, + { + "basic multidispatch", + "['secondary']", + &v1.DispatchCheckRequest{ + ResourceRelation: &corev1.RelationReference{ + Namespace: "somenamespace", + Relation: "somerelation", + }, + ResourceIds: []string{"foo"}, + Metadata: &v1.ResolverMeta{DepthRemaining: 50}, + Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, + }, + 1 * time.Second, + 2, + }, + { + "basic multidispatch, expr doesn't call secondary", + "['notconfigured']", + &v1.DispatchCheckRequest{ + ResourceRelation: &corev1.RelationReference{ + Namespace: "somenamespace", + Relation: "somerelation", + }, + ResourceIds: []string{"foo"}, + Metadata: &v1.ResolverMeta{DepthRemaining: 50}, + Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, + }, + 1 * time.Second, + 1, + }, + { + "expr matches request", + "request.resource_relation.namespace == 'somenamespace' ? ['secondary'] : []", + &v1.DispatchCheckRequest{ + ResourceRelation: &corev1.RelationReference{ + Namespace: "somenamespace", + Relation: "somerelation", + }, + ResourceIds: []string{"foo"}, + Metadata: &v1.ResolverMeta{DepthRemaining: 50}, + Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, + }, + 1 * time.Second, + 2, + }, + { + "expr does not match request", + "request.resource_relation.namespace == 'somenamespace' ? ['secondary'] : []", + &v1.DispatchCheckRequest{ + ResourceRelation: &corev1.RelationReference{ + Namespace: "someothernamespace", + Relation: "somerelation", + }, + ResourceIds: []string{"foo"}, + Metadata: &v1.ResolverMeta{DepthRemaining: 50}, + Subject: &corev1.ObjectAndRelation{Namespace: "foo", ObjectId: "bar", Relation: "..."}, + }, + 1 * time.Second, + 1, + }, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + conn := connectionForDispatching(t, &fakeDispatchSvc{dispatchCount: 1, sleepTime: tc.primarySleepTime}) + secondaryConn := connectionForDispatching(t, &fakeDispatchSvc{dispatchCount: 2, sleepTime: 0 * time.Millisecond}) + + parsed, err := ParseDispatchExpression("check", tc.expr) + require.NoError(t, err) + + dispatcher := NewClusterDispatcher(v1.NewDispatchServiceClient(conn), conn, ClusterDispatcherConfig{ + KeyHandler: &keys.DirectKeyHandler{}, + DispatchOverallTimeout: 30 * time.Second, + }, map[string]SecondaryDispatch{ + "secondary": {Name: "secondary", Client: v1.NewDispatchServiceClient(secondaryConn)}, + }, map[string]*DispatchExpr{ + "check": parsed, + }) + require.True(t, dispatcher.ReadyState().IsReady) + + resp, err := dispatcher.DispatchCheck(context.Background(), tc.request) + require.NoError(t, err) + require.Equal(t, tc.expectedResult, resp.Metadata.DispatchCount) + }) + } +} + +func connectionForDispatching(t *testing.T, svc v1.DispatchServiceServer) *grpc.ClientConn { + listener := bufconn.Listen(humanize.MiByte) + s := grpc.NewServer() + + v1.RegisterDispatchServiceServer(s, svc) + + go func() { + // Ignore any errors + _ = s.Serve(listener) + }() + + conn, err := grpc.DialContext( + context.Background(), + "", + grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { + return listener.Dial() + }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithBlock(), + ) + require.NoError(t, err) + + t.Cleanup(func() { + conn.Close() + listener.Close() + s.Stop() + }) + + return conn +} diff --git a/internal/dispatch/remote/expr.go b/internal/dispatch/remote/expr.go new file mode 100644 index 0000000000..61c31007d4 --- /dev/null +++ b/internal/dispatch/remote/expr.go @@ -0,0 +1,103 @@ +package remote + +import ( + "fmt" + + "github.com/authzed/cel-go/cel" + "github.com/authzed/cel-go/common" + "github.com/authzed/cel-go/common/types" + "github.com/authzed/cel-go/common/types/ref" + "google.golang.org/protobuf/proto" + + corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" + dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" +) + +// DispatchExpr is a CEL expression that can be run to determine the secondary dispatchers, if any, +// to invoke for the incoming request. +type DispatchExpr struct { + env *cel.Env + registry *types.Registry + methodName string + exprAst *cel.Ast +} + +var dispatchRequestTypes = []proto.Message{ + &dispatchv1.DispatchCheckRequest{}, + &corev1.RelationReference{}, + &corev1.ObjectAndRelation{}, +} + +// ParseDispatchExpression parses a dispatch expression via CEL. +func ParseDispatchExpression(methodName string, exprString string) (*DispatchExpr, error) { + registry, err := types.NewRegistry(dispatchRequestTypes...) + if err != nil { + return nil, fmt.Errorf("unable to initialize dispatch expression type registry") + } + + opts := make([]cel.EnvOption, 0) + opts = append(opts, cel.OptionalTypes(cel.OptionalTypesVersion(0))) + opts = append(opts, cel.Variable("request", cel.DynType)) + + celEnv, err := cel.NewEnv(opts...) + if err != nil { + return nil, err + } + + ast, issues := celEnv.CompileSource(common.NewStringSource(exprString, methodName)) + if issues != nil && issues.Err() != nil { + return nil, issues.Err() + } + + if !ast.OutputType().IsEquivalentType(cel.ListType(cel.StringType)) { + return nil, fmt.Errorf("dispatch expression must result in a list[string] value: found `%s`", ast.OutputType().String()) + } + + return &DispatchExpr{ + env: celEnv, + registry: registry, + methodName: methodName, + exprAst: ast, + }, nil +} + +// RunDispatchExpr runs a dispatch CEL expression over the given request and returns the secondary dispatchers +// to invoke, if any. +func RunDispatchExpr[R any](de *DispatchExpr, request R) ([]string, error) { + celopts := make([]cel.ProgramOption, 0, 3) + + celopts = append(celopts, cel.EvalOptions(cel.OptTrackState)) + celopts = append(celopts, cel.EvalOptions(cel.OptPartialEval)) + celopts = append(celopts, cel.CostLimit(50)) + + prg, err := de.env.Program(de.exprAst, celopts...) + if err != nil { + return nil, err + } + + // Mark any unspecified variables as unknown, to ensure that partial application + // will result in producing a type of Unknown. + activation, err := de.env.PartialVars(map[string]any{ + "request": de.registry.NativeToValue(request), + }) + if err != nil { + return nil, err + } + + val, _, err := prg.Eval(activation) + if err != nil { + return nil, fmt.Errorf("unable to evaluate dispatch expression: %w", err) + } + + // If the value produced has Unknown type, then it means required context was missing. + if types.IsUnknown(val) { + return nil, fmt.Errorf("unable to eval dispatch expression; did you make sure you use `request.`?") + } + + values := val.Value().([]ref.Val) + convertedValues := make([]string, 0, len(values)) + for _, value := range values { + convertedValues = append(convertedValues, value.Value().(string)) + } + return convertedValues, nil +} diff --git a/internal/dispatch/remote/expr_test.go b/internal/dispatch/remote/expr_test.go new file mode 100644 index 0000000000..00f270f13c --- /dev/null +++ b/internal/dispatch/remote/expr_test.go @@ -0,0 +1,125 @@ +package remote + +import ( + "testing" + + "github.com/stretchr/testify/require" + + corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" + dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" +) + +func TestParseDispatchExpression(t *testing.T) { + tcs := []struct { + name string + expr string + expectedError string + }{ + { + "empty", + "", + "mismatched input ''", + }, + { + "returns string", + "'somestring'", + "a list[string] value", + }, + { + "invalid expression", + "a.b.c!d", + "mismatched input '!'", + }, + { + "valid expression", + "['prewarm']", + "", + }, + { + "valid big expression", + "request.resource_relation.namespace == 'foo' ? ['prewarm'] : []", + "", + }, + } + + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + _, err := ParseDispatchExpression("somemethod", tc.expr) + if tc.expectedError != "" { + require.ErrorContains(t, err, tc.expectedError) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestRunCheckDispatchExpr(t *testing.T) { + tcs := []struct { + name string + expr string + request *dispatchv1.DispatchCheckRequest + expectedResult []string + expectedError string + }{ + { + "static", + "['prewarm']", + nil, + []string{"prewarm"}, + "", + }, + { + "basic", + "request.resource_relation.namespace == 'somenamespace' ? ['prewarm'] : ['other']", + &dispatchv1.DispatchCheckRequest{ + ResourceRelation: &corev1.RelationReference{ + Namespace: "somenamespace", + Relation: "somerelation", + }, + }, + []string{"prewarm"}, + "", + }, + { + "basic other branch", + "request.resource_relation.namespace == 'somethingelse' ? ['prewarm'] : ['other']", + &dispatchv1.DispatchCheckRequest{ + ResourceRelation: &corev1.RelationReference{ + Namespace: "somenamespace", + Relation: "somerelation", + }, + }, + []string{"other"}, + "", + }, + { + "invalid field", + "request.resource_relation.invalidfield == 'somethingelse' ? ['prewarm'] : ['other']", + &dispatchv1.DispatchCheckRequest{ + ResourceRelation: &corev1.RelationReference{ + Namespace: "somenamespace", + Relation: "somerelation", + }, + }, + nil, + "no such field 'invalidfield'", + }, + } + + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + parsed, err := ParseDispatchExpression("check", tc.expr) + require.NoError(t, err) + + resp, err := RunDispatchExpr(parsed, tc.request) + if tc.expectedError != "" { + require.ErrorContains(t, err, tc.expectedError) + } else { + require.Equal(t, tc.expectedResult, resp) + } + }) + } +} diff --git a/pkg/cmd/serve.go b/pkg/cmd/serve.go index 958a3c250f..66dffd96ed 100644 --- a/pkg/cmd/serve.go +++ b/pkg/cmd/serve.go @@ -115,6 +115,9 @@ func RegisterServeFlags(cmd *cobra.Command, config *server.Config) error { cmd.Flags().Uint16Var(&config.DispatchHashringReplicationFactor, "dispatch-hashring-replication-factor", 100, "set the replication factor of the consistent hasher used for the dispatcher") cmd.Flags().Uint8Var(&config.DispatchHashringSpread, "dispatch-hashring-spread", 1, "set the spread of the consistent hasher used for the dispatcher") + cmd.Flags().StringToStringVar(&config.DispatchSecondaryUpstreamAddrs, "experimental-dispatch-secondary-upstream-addrs", nil, "secondary upstream addresses for dispatches, each with a name") + cmd.Flags().StringToStringVar(&config.DispatchSecondaryUpstreamExprs, "experimental-dispatch-secondary-upstream-exprs", nil, "map from request type (currently supported: `check`) to its associated CEL expression, which returns the secondary upstream(s) to be used for the request") + // Flags for configuring API behavior cmd.Flags().BoolVar(&config.DisableV1SchemaAPI, "disable-v1-schema-api", false, "disables the V1 schema API") cmd.Flags().BoolVar(&config.DisableVersionResponse, "disable-version-response", false, "disables version response support in the API") diff --git a/pkg/cmd/server/server.go b/pkg/cmd/server/server.go index 5796a1ff46..54f1093723 100644 --- a/pkg/cmd/server/server.go +++ b/pkg/cmd/server/server.go @@ -97,6 +97,9 @@ type Config struct { DispatchHashringReplicationFactor uint16 `debugmap:"visible"` DispatchHashringSpread uint8 `debugmap:"visible"` + DispatchSecondaryUpstreamAddrs map[string]string `debugmap:"visible"` + DispatchSecondaryUpstreamExprs map[string]string `debugmap:"visible"` + DispatchCacheConfig CacheConfig `debugmap:"visible"` ClusterDispatchCacheConfig CacheConfig `debugmap:"visible"` @@ -261,6 +264,8 @@ func (c *Config) Complete(ctx context.Context) (RunnableServer, error) { dispatcher, err = combineddispatch.NewDispatcher( combineddispatch.UpstreamAddr(c.DispatchUpstreamAddr), combineddispatch.UpstreamCAPath(c.DispatchUpstreamCAPath), + combineddispatch.SecondaryUpstreamAddrs(c.DispatchSecondaryUpstreamAddrs), + combineddispatch.SecondaryUpstreamExprs(c.DispatchSecondaryUpstreamExprs), combineddispatch.GrpcPresharedKey(dispatchPresharedKey), combineddispatch.GrpcDialOpts( grpc.WithUnaryInterceptor(otelgrpc.UnaryClientInterceptor()), diff --git a/pkg/cmd/server/zz_generated.options.go b/pkg/cmd/server/zz_generated.options.go index a5153c09dd..2cf7cd1a93 100644 --- a/pkg/cmd/server/zz_generated.options.go +++ b/pkg/cmd/server/zz_generated.options.go @@ -69,6 +69,8 @@ func (c *Config) ToOption() ConfigOption { to.Dispatcher = c.Dispatcher to.DispatchHashringReplicationFactor = c.DispatchHashringReplicationFactor to.DispatchHashringSpread = c.DispatchHashringSpread + to.DispatchSecondaryUpstreamAddrs = c.DispatchSecondaryUpstreamAddrs + to.DispatchSecondaryUpstreamExprs = c.DispatchSecondaryUpstreamExprs to.DispatchCacheConfig = c.DispatchCacheConfig to.ClusterDispatchCacheConfig = c.ClusterDispatchCacheConfig to.DisableV1SchemaAPI = c.DisableV1SchemaAPI @@ -123,6 +125,8 @@ func (c Config) DebugMap() map[string]any { debugMap["Dispatcher"] = helpers.DebugValue(c.Dispatcher, false) debugMap["DispatchHashringReplicationFactor"] = helpers.DebugValue(c.DispatchHashringReplicationFactor, false) debugMap["DispatchHashringSpread"] = helpers.DebugValue(c.DispatchHashringSpread, false) + debugMap["DispatchSecondaryUpstreamAddrs"] = helpers.DebugValue(c.DispatchSecondaryUpstreamAddrs, false) + debugMap["DispatchSecondaryUpstreamExprs"] = helpers.DebugValue(c.DispatchSecondaryUpstreamExprs, false) debugMap["DispatchCacheConfig"] = helpers.DebugValue(c.DispatchCacheConfig, false) debugMap["ClusterDispatchCacheConfig"] = helpers.DebugValue(c.ClusterDispatchCacheConfig, false) debugMap["DisableV1SchemaAPI"] = helpers.DebugValue(c.DisableV1SchemaAPI, false) @@ -386,6 +390,34 @@ func WithDispatchHashringSpread(dispatchHashringSpread uint8) ConfigOption { } } +// WithDispatchSecondaryUpstreamAddrs returns an option that can append DispatchSecondaryUpstreamAddrss to Config.DispatchSecondaryUpstreamAddrs +func WithDispatchSecondaryUpstreamAddrs(key string, value string) ConfigOption { + return func(c *Config) { + c.DispatchSecondaryUpstreamAddrs[key] = value + } +} + +// SetDispatchSecondaryUpstreamAddrs returns an option that can set DispatchSecondaryUpstreamAddrs on a Config +func SetDispatchSecondaryUpstreamAddrs(dispatchSecondaryUpstreamAddrs map[string]string) ConfigOption { + return func(c *Config) { + c.DispatchSecondaryUpstreamAddrs = dispatchSecondaryUpstreamAddrs + } +} + +// WithDispatchSecondaryUpstreamExprs returns an option that can append DispatchSecondaryUpstreamExprss to Config.DispatchSecondaryUpstreamExprs +func WithDispatchSecondaryUpstreamExprs(key string, value string) ConfigOption { + return func(c *Config) { + c.DispatchSecondaryUpstreamExprs[key] = value + } +} + +// SetDispatchSecondaryUpstreamExprs returns an option that can set DispatchSecondaryUpstreamExprs on a Config +func SetDispatchSecondaryUpstreamExprs(dispatchSecondaryUpstreamExprs map[string]string) ConfigOption { + return func(c *Config) { + c.DispatchSecondaryUpstreamExprs = dispatchSecondaryUpstreamExprs + } +} + // WithDispatchCacheConfig returns an option that can set DispatchCacheConfig on a Config func WithDispatchCacheConfig(dispatchCacheConfig CacheConfig) ConfigOption { return func(c *Config) {