From 36d5e415cea033ee370bb214f6431fa03c87714e Mon Sep 17 00:00:00 2001 From: Evan Cordell Date: Thu, 4 May 2023 13:27:01 -0400 Subject: [PATCH] balancer: rewrite the consistent hashring balancer to avoid recomputates The previous implementation was recomputing the hashring any time a subconnection moved from ready->idle or back, which happened frequently. The new implementation includes idle and connecting subconns in the ring, and triggers a connection if one is selected. It also adds/removes from a long-lived ring instead of recomputing a ring from scratch each time. ReplicationFactor and Spread can now be passed in as part of the service config instead of registered globally with the balancer --- cmd/spicedb/main.go | 3 +- internal/testserver/cluster.go | 9 +- pkg/balancer/hashring.go | 323 +++++++++++++++++++++------ pkg/balancer/hashring_test.go | 392 +++++++++++++++++++++++++++++++++ pkg/cmd/server/server.go | 27 +-- 5 files changed, 657 insertions(+), 97 deletions(-) create mode 100644 pkg/balancer/hashring_test.go diff --git a/cmd/spicedb/main.go b/cmd/spicedb/main.go index aa7e275339..c8517262d9 100644 --- a/cmd/spicedb/main.go +++ b/cmd/spicedb/main.go @@ -11,7 +11,6 @@ import ( _ "google.golang.org/grpc/xds" log "github.com/authzed/spicedb/internal/logging" - consistentbalancer "github.com/authzed/spicedb/pkg/balancer" "github.com/authzed/spicedb/pkg/cmd" cmdutil "github.com/authzed/spicedb/pkg/cmd/server" "github.com/authzed/spicedb/pkg/cmd/testserver" @@ -24,7 +23,7 @@ func main() { kuberesolver.RegisterInCluster() // Enable consistent hashring gRPC load balancer - balancer.Register(consistentbalancer.NewConsistentHashringBuilder(cmdutil.ConsistentHashringPicker)) + balancer.Register(cmdutil.ConsistentHashringBuilder) log.SetGlobalLogger(zerolog.New(os.Stdout)) diff --git a/internal/testserver/cluster.go b/internal/testserver/cluster.go index 3e993a76f1..724e367b90 100644 --- a/internal/testserver/cluster.go +++ b/internal/testserver/cluster.go @@ -59,9 +59,7 @@ var testResolverBuilder = &SafeManualResolverBuilder{} func init() { // register hashring balancer - balancer.Register(hashbalancer.NewConsistentHashringBuilder( - hashbalancer.NewConsistentHashringPickerBuilder(xxhash.Sum64, 1500, 1)), - ) + balancer.Register(hashbalancer.NewConsistentHashringBuilder(xxhash.Sum64)) // Register a manual resolver.Builder that we can feed addresses for tests // Registration is not thread safe, so we register a single resolver.Builder @@ -168,7 +166,10 @@ func TestClusterWithDispatchAndCacheConfig(t testing.TB, size uint, ds datastore combineddispatch.UpstreamAddr("test://" + prefix), combineddispatch.PrometheusSubsystem(fmt.Sprintf("%s_%d_client_dispatch", prefix, i)), combineddispatch.GrpcDialOpts( - grpc.WithDefaultServiceConfig(hashbalancer.BalancerServiceConfig), + grpc.WithDefaultServiceConfig((&hashbalancer.ConsistentHashringBalancerConfig{ + ReplicationFactor: 1500, + Spread: 1, + }).MustToServiceConfigJSON()), grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) { // it's possible grpc tries to dial before we have set the // buffconn dialers, we have to return a "TempError" so that diff --git a/pkg/balancer/hashring.go b/pkg/balancer/hashring.go index 0107a962ba..6b080c83fe 100644 --- a/pkg/balancer/hashring.go +++ b/pkg/balancer/hashring.go @@ -1,61 +1,81 @@ package balancer import ( + "encoding/json" + "errors" "fmt" "math/rand" "sync" "time" - "github.com/rs/zerolog" "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/serviceconfig" "github.com/authzed/spicedb/pkg/consistent" ) +// This is based off of the example implementation in grpc-go: +// https://github.com/grpc/grpc-go/blob/afcbdc9ace7b4af94d014620727ea331cc3047fe/balancer/base/balancer.go +// The original work is copyright gRPC authors and licensed under the Apache License, Version 2.0. + type ctxKey string const ( // BalancerName is the name of consistent-hashring balancer. BalancerName = "consistent-hashring" - // BalancerServiceConfig is a service config that sets the default balancer - // to the consistent-hashring balancer - BalancerServiceConfig = `{"loadBalancingPolicy":"consistent-hashring"}` - // CtxKey is the key for the grpc request's context.Context which points to // the key to hash for the request. The value it points to must be []byte CtxKey ctxKey = "requestKey" ) +var defaultBalancerServiceConfig = &ConsistentHashringBalancerConfig{ + ReplicationFactor: 100, + Spread: 1, +} +var DefaultBalancerServiceConfigJSON = defaultBalancerServiceConfig.MustToServiceConfigJSON() + +type ConsistentHashringBalancerConfig struct { + serviceconfig.LoadBalancingConfig `json:"-"` + ReplicationFactor uint16 `json:"replicationFactor,omitempty"` + Spread uint8 `json:"spread,omitempty"` +} + +func (c *ConsistentHashringBalancerConfig) ToServiceConfigJSON() (string, error) { + type wrapper struct { + Config []map[string]*ConsistentHashringBalancerConfig `json:"loadBalancingConfig"` + } + out := wrapper{Config: []map[string]*ConsistentHashringBalancerConfig{{ + BalancerName: c, + }}} + j, err := json.Marshal(out) + if err != nil { + return "", err + } + return string(j), nil +} + +func (c *ConsistentHashringBalancerConfig) MustToServiceConfigJSON() string { + o, err := c.ToServiceConfigJSON() + if err != nil { + panic(err) + } + return o +} + var logger = grpclog.Component("consistenthashring") // NewConsistentHashringBuilder creates a new balancer.Builder that -// will create a consistent hashring balancer with the picker builder. +// will create a consistent hashring balancer. // Before making a connection, register it with grpc with: -// `balancer.Register(consistent.NewConsistentHashringBuilder(hasher, factor, spread))` -func NewConsistentHashringBuilder(pickerBuilder base.PickerBuilder) balancer.Builder { - return base.NewBalancerBuilder( - BalancerName, - pickerBuilder, - base.Config{HealthCheck: true}, - ) -} - -// NewConsistentHashringPickerBuilder creates a new picker builder -// that will create consistent hashrings according to the supplied -// config. If the ReplicationFactor is changed, that new parameter -// will be used when the next picker is created. -func NewConsistentHashringPickerBuilder( - hasher consistent.HasherFunc, - initialReplicationFactor uint16, - spread uint8, -) *ConsistentHashringPickerBuilder { - return &ConsistentHashringPickerBuilder{ - hasher: hasher, - replicationFactor: initialReplicationFactor, - spread: spread, +// `balancer.Register(consistent.NewConsistentHashringBuilder(hasher))` +func NewConsistentHashringBuilder(hasher consistent.HasherFunc) *ConsistentHashringBuilder { + return &ConsistentHashringBuilder{ + hasher: hasher, } } @@ -72,69 +92,223 @@ func (s subConnMember) Key() string { var _ consistent.Member = &subConnMember{} -// ConsistentHashringPickerBuilder is an implementation of base.PickerBuilder and -// is used to build pickers based on updates to the node architecture. -type ConsistentHashringPickerBuilder struct { +type ConsistentHashringBuilder struct { sync.Mutex + hasher consistent.HasherFunc + config ConsistentHashringBalancerConfig +} - hasher consistent.HasherFunc - replicationFactor uint16 - spread uint8 +func (b *ConsistentHashringBuilder) Build(cc balancer.ClientConn, _ balancer.BuildOptions) balancer.Balancer { + bal := &ConsistentHashringBalancer{ + cc: cc, + subConns: resolver.NewAddressMap(), + scStates: make(map[balancer.SubConn]connectivity.State), + csEvltr: &balancer.ConnectivityStateEvaluator{}, + state: connectivity.Connecting, + hasher: b.hasher, + } + // Initialize picker to a picker that always returns + // ErrNoSubConnAvailable, because when state of a SubConn changes, we + // may call UpdateState with this picker. + bal.picker = base.NewErrPicker(balancer.ErrNoSubConnAvailable) + return bal } -func (b *ConsistentHashringPickerBuilder) MarshalZerologObject(e *zerolog.Event) { - e.Uint16("consistent-hashring-replication-factor", b.replicationFactor) - e.Uint8("consistent-hashring-spread", b.spread) +func (b *ConsistentHashringBuilder) Name() string { + return BalancerName } -func (b *ConsistentHashringPickerBuilder) MustReplicationFactor(rf uint16) { - if rf == 0 { - panic("invalid ReplicationFactor") +func (b *ConsistentHashringBuilder) ParseConfig(js json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { + var lbCfg ConsistentHashringBalancerConfig + if err := json.Unmarshal(js, &lbCfg); err != nil { + return nil, fmt.Errorf("wrr: unable to unmarshal LB policy config: %s, error: %w", string(js), err) + } + + logger.Infof("parsed balancer config %s", js) + + if lbCfg.ReplicationFactor == 0 { + lbCfg.ReplicationFactor = 100 + } + + if lbCfg.Spread == 0 { + lbCfg.Spread = 1 } b.Lock() - defer b.Unlock() - b.replicationFactor = rf + b.config = lbCfg + b.Unlock() + + return &lbCfg, nil +} + +type ConsistentHashringBalancer struct { + state connectivity.State + cc balancer.ClientConn + picker balancer.Picker + csEvltr *balancer.ConnectivityStateEvaluator + subConns *resolver.AddressMap + scStates map[balancer.SubConn]connectivity.State + + config *ConsistentHashringBalancerConfig + hashring *consistent.Hashring + hasher consistent.HasherFunc + + resolverErr error // the last error reported by the resolver; cleared on successful resolution + connErr error // the last connection error; cleared upon leaving TransientFailure } -func (b *ConsistentHashringPickerBuilder) MustSpread(spread uint8) { - if spread == 0 { - panic("invalid Spread") +func (b *ConsistentHashringBalancer) ResolverError(err error) { + b.resolverErr = err + if b.subConns.Len() == 0 { + b.state = connectivity.TransientFailure + b.picker = base.NewErrPicker(errors.Join(b.connErr, b.resolverErr)) } - b.Lock() - defer b.Unlock() - b.spread = spread + if b.state != connectivity.TransientFailure { + // The picker will not change since the balancer does not currently + // report an error. + return + } + b.cc.UpdateState(balancer.State{ + ConnectivityState: b.state, + Picker: b.picker, + }) } -func (b *ConsistentHashringPickerBuilder) Build(info base.PickerBuildInfo) balancer.Picker { - logger.Infof("consistentHashringPicker: Build called with info: %v", info) - if len(info.ReadySCs) == 0 { - return base.NewErrPicker(balancer.ErrNoSubConnAvailable) +func (b *ConsistentHashringBalancer) UpdateClientConnState(s balancer.ClientConnState) error { + if logger.V(2) { + logger.Info("got new ClientConn state: ", s) } + // Successful resolution; clear resolver error and ensure we return nil. + b.resolverErr = nil - b.Lock() - hashring := consistent.MustNewHashring(b.hasher, b.replicationFactor) - b.Unlock() + if s.BalancerConfig != nil { + svcConfig := s.BalancerConfig.(*ConsistentHashringBalancerConfig) + if b.config == nil || svcConfig.ReplicationFactor != b.config.ReplicationFactor { + b.hashring = consistent.MustNewHashring(b.hasher, svcConfig.ReplicationFactor) + b.config = svcConfig + } + } + if b.hashring == nil { + b.picker = base.NewErrPicker(errors.Join(b.connErr, b.resolverErr)) + b.cc.UpdateState(balancer.State{ConnectivityState: b.state, Picker: b.picker}) + return fmt.Errorf("no hashring configured") + } - for sc, scInfo := range info.ReadySCs { - if err := hashring.Add(subConnMember{ - SubConn: sc, - key: scInfo.Address.Addr + scInfo.Address.ServerName, - }); err != nil { - return base.NewErrPicker(err) + // addrsSet is the set converted from addrs, it's used for quick lookup of an address. + addrsSet := resolver.NewAddressMap() + for _, a := range s.ResolverState.Addresses { + addrsSet.Set(a, nil) + if _, ok := b.subConns.Get(a); !ok { + // a is a new address (not existing in b.subConns). + sc, err := b.cc.NewSubConn([]resolver.Address{a}, balancer.NewSubConnOptions{HealthCheckEnabled: false}) + if err != nil { + logger.Warningf("base.baseBalancer: failed to create new SubConn: %v", err) + continue + } + b.subConns.Set(a, sc) + b.scStates[sc] = connectivity.Idle + b.csEvltr.RecordTransition(connectivity.Shutdown, connectivity.Idle) + sc.Connect() + if err := b.hashring.Add(subConnMember{ + SubConn: sc, + key: a.ServerName + a.Addr, + }); err != nil { + return fmt.Errorf("couldn't add to hashring") + } } } + for _, a := range b.subConns.Keys() { + sci, _ := b.subConns.Get(a) + sc := sci.(balancer.SubConn) + // a was removed by resolver. + if _, ok := addrsSet.Get(a); !ok { + b.cc.RemoveSubConn(sc) + b.subConns.Delete(a) + // Keep the state of this sc in b.scStates until sc's state becomes Shutdown. + // The entry will be deleted in UpdateSubConnState. - if b.spread == 0 { - return base.NewErrPicker(fmt.Errorf("received invalid spread for consistent hash ring picker builder: %d", b.spread)) + if err := b.hashring.Remove(subConnMember{ + SubConn: sc, + key: a.ServerName + a.Addr, + }); err != nil { + return fmt.Errorf("couldn't add to hashring") + } + } + } + if logger.V(2) { + logger.Infof("%d hashring members found", len(b.hashring.Members())) + for _, m := range b.hashring.Members() { + logger.Infof("hashring member %s", m.Key()) + } } - return &consistentHashringPicker{ - hashring: hashring, - spread: b.spread, - rand: rand.New(rand.NewSource(time.Now().UnixNano())), + // If resolver state contains no addresses, return an error so ClientConn + // will trigger re-resolve. Also records this as a resolver error, so when + // the overall state turns transient failure, the error message will have + // the zero address information. + if len(s.ResolverState.Addresses) == 0 { + b.ResolverError(errors.New("produced zero addresses")) + return balancer.ErrBadResolverState } + + if b.state == connectivity.TransientFailure { + b.picker = base.NewErrPicker(errors.Join(b.connErr, b.resolverErr)) + } else { + b.picker = &consistentHashringPicker{ + hashring: b.hashring, + spread: b.config.Spread, + rand: rand.New(rand.NewSource(time.Now().UnixNano())), + } + } + + b.cc.UpdateState(balancer.State{ConnectivityState: b.state, Picker: b.picker}) + return nil +} + +func (b *ConsistentHashringBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { + s := state.ConnectivityState + if logger.V(2) { + logger.Infof("base.baseBalancer: handle SubConn state change: %p, %v", sc, s) + } + oldS, ok := b.scStates[sc] + if !ok { + if logger.V(2) { + logger.Infof("base.baseBalancer: got state changes for an unknown SubConn: %p, %v", sc, s) + } + return + } + if oldS == connectivity.TransientFailure && + (s == connectivity.Connecting || s == connectivity.Idle) { + // Once a subconn enters TRANSIENT_FAILURE, ignore subsequent IDLE or + // CONNECTING transitions to prevent the aggregated state from being + // always CONNECTING when many backends exist but are all down. + if s == connectivity.Idle { + sc.Connect() + } + return + } + b.scStates[sc] = s + switch s { + case connectivity.Idle: + sc.Connect() + case connectivity.Shutdown: + // When an address was removed by resolver, b called RemoveSubConn but + // kept the sc's state in scStates. Remove state for this sc here. + delete(b.scStates, sc) + case connectivity.TransientFailure: + // Save error to be reported via picker. + b.connErr = state.ConnectionError + } + + b.state = b.csEvltr.RecordTransition(oldS, s) + + b.cc.UpdateState(balancer.State{ConnectivityState: b.state, Picker: b.picker}) +} + +// Close is a nop because base balancer doesn't have internal state to clean up, +// and it doesn't need to call RemoveSubConn for the SubConns. +func (b *ConsistentHashringBalancer) Close() { } type consistentHashringPicker struct { @@ -151,15 +325,18 @@ func (p *consistentHashringPicker) Pick(info balancer.PickInfo) (balancer.PickRe return balancer.PickResult{}, err } - // rand is not safe for concurrent use - p.Lock() - index := p.rand.Intn(int(p.spread)) - p.Unlock() + index := 0 + + if p.spread > 1 { + // TODO: should look into other options for this to avoid locking; we mostly use spread 1 so it's not urgent + // rand is not safe for concurrent use + p.Lock() + index = p.rand.Intn(int(p.spread)) + p.Unlock() + } chosen := members[index].(subConnMember) return balancer.PickResult{ SubConn: chosen.SubConn, }, nil } - -var _ base.PickerBuilder = &ConsistentHashringPickerBuilder{} diff --git a/pkg/balancer/hashring_test.go b/pkg/balancer/hashring_test.go new file mode 100644 index 0000000000..e7e3c0ceca --- /dev/null +++ b/pkg/balancer/hashring_test.go @@ -0,0 +1,392 @@ +package balancer + +import ( + "context" + "errors" + "fmt" + "math/rand" + "reflect" + "sync" + "testing" + + "github.com/cespare/xxhash/v2" + "github.com/samber/lo" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/resolver" + + "github.com/authzed/spicedb/pkg/consistent" +) + +type fakeSubConn struct { + balancer.SubConn + id string +} + +func (fakeSubConn) Connect() {} + +// Note: this is testing picker behavior and not the hashring +// behavior itself, see `pkg/consistent` for tests of the hashring. +func TestConsistentHashringPickerPick(t *testing.T) { + rnd := rand.New(rand.NewSource(1)) + tests := []struct { + name string + spread uint8 + rf uint16 + info balancer.PickInfo + want balancer.PickResult + wantErr bool + }{ + { + name: "pick one", + spread: 1, + rf: 100, + info: balancer.PickInfo{ + Ctx: context.WithValue(context.Background(), CtxKey, []byte("test")), + }, + want: balancer.PickResult{ + SubConn: &fakeSubConn{id: "1"}, + }, + }, + { + name: "pick another", + spread: 1, + rf: 100, + info: balancer.PickInfo{ + Ctx: context.WithValue(context.Background(), CtxKey, []byte("test2")), + }, + want: balancer.PickResult{ + SubConn: &fakeSubConn{id: "3"}, + }, + }, + { + name: "pick with spread", + spread: 2, + rf: 100, + info: balancer.PickInfo{ + Ctx: context.WithValue(context.Background(), CtxKey, []byte("test")), + }, + want: balancer.PickResult{ + // without spread, this would always be 1. + // it can be 1 or 3 with spread 2, but pinning the seed makes it always 3 in the test + SubConn: &fakeSubConn{id: "3"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &consistentHashringPicker{ + hashring: consistent.MustNewHashring(xxhash.Sum64, tt.rf), + spread: tt.spread, + rand: rnd, + } + require.NoError(t, p.hashring.Add(subConnMember{key: "1", SubConn: &fakeSubConn{id: "1"}})) + require.NoError(t, p.hashring.Add(subConnMember{key: "2", SubConn: &fakeSubConn{id: "2"}})) + require.NoError(t, p.hashring.Add(subConnMember{key: "3", SubConn: &fakeSubConn{id: "3"}})) + got, err := p.Pick(tt.info) + if (err != nil) != tt.wantErr { + t.Errorf("Pick() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Pick() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestConsistentHashringBalancerConfigToServiceConfigJSON(t *testing.T) { + tests := []struct { + name string + replicationFactor uint16 + spread uint8 + want string + wantErr bool + }{ + { + name: "sets rf and spread", + replicationFactor: 300, + spread: 2, + want: `{"loadBalancingConfig":[{"consistent-hashring":{"replicationFactor":300,"spread":2}}]}`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &ConsistentHashringBalancerConfig{ + ReplicationFactor: tt.replicationFactor, + Spread: tt.spread, + } + got, err := c.ToServiceConfigJSON() + if (err != nil) != tt.wantErr { + t.Errorf("ToServiceConfigJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("ToServiceConfigJSON() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestConsistentHashringBalancer_UpdateClientConnState(t *testing.T) { + type balancerState struct { + ConnectivityState connectivity.State + err error + members []string + spread uint8 + replicationFactor uint16 + } + tests := []struct { + name string + s []balancer.ClientConnState + expectedStates []balancerState + expectedConnState connectivity.State + wantErr bool + }{ + { + name: "no hashring", + expectedStates: []balancerState{}, + expectedConnState: connectivity.TransientFailure, + wantErr: true, + }, + { + name: "configures hashring, no addresses", + s: []balancer.ClientConnState{{ + ResolverState: resolver.State{}, + BalancerConfig: &ConsistentHashringBalancerConfig{ + ReplicationFactor: 100, + Spread: 1, + }, + }}, + expectedStates: []balancerState{ + { + ConnectivityState: connectivity.TransientFailure, + err: errors.Join(nil, fmt.Errorf("produced zero addresses")), + }, + }, + expectedConnState: connectivity.TransientFailure, + wantErr: true, + }, + { + name: "configures hashring, 3 addresses", + s: []balancer.ClientConnState{{ + ResolverState: resolver.State{ + Addresses: []resolver.Address{ + {ServerName: "t", Addr: "1"}, + {ServerName: "t", Addr: "2"}, + {ServerName: "t", Addr: "3"}, + }, + }, + BalancerConfig: &ConsistentHashringBalancerConfig{ + ReplicationFactor: 100, + Spread: 1, + }, + }}, + expectedStates: []balancerState{ + { + ConnectivityState: connectivity.Connecting, + members: []string{"t1", "t2", "t3"}, + replicationFactor: 100, + spread: 1, + }, + }, + expectedConnState: connectivity.Idle, + }, + { + name: "existing hashring with 3 nodes, 1 removed", + s: []balancer.ClientConnState{{ + ResolverState: resolver.State{ + Addresses: []resolver.Address{ + {ServerName: "t", Addr: "1"}, + {ServerName: "t", Addr: "2"}, + {ServerName: "t", Addr: "3"}, + }, + }, + BalancerConfig: &ConsistentHashringBalancerConfig{ + ReplicationFactor: 100, + Spread: 1, + }, + }, { + ResolverState: resolver.State{ + Addresses: []resolver.Address{ + {ServerName: "t", Addr: "1"}, + {ServerName: "t", Addr: "2"}, + }, + }, + }}, + expectedStates: []balancerState{ + { + ConnectivityState: connectivity.Connecting, + members: []string{"t1", "t2", "t3"}, + replicationFactor: 100, + spread: 1, + }, + { + ConnectivityState: connectivity.Connecting, + members: []string{"t1", "t2"}, + replicationFactor: 100, + spread: 1, + }, + }, + expectedConnState: connectivity.Idle, + }, + { + name: "existing hashring with 3 nodes, 1 added", + s: []balancer.ClientConnState{{ + ResolverState: resolver.State{ + Addresses: []resolver.Address{ + {ServerName: "t", Addr: "1"}, + {ServerName: "t", Addr: "2"}, + {ServerName: "t", Addr: "3"}, + }, + }, + BalancerConfig: &ConsistentHashringBalancerConfig{ + ReplicationFactor: 100, + Spread: 1, + }, + }, { + ResolverState: resolver.State{ + Addresses: []resolver.Address{ + {ServerName: "t", Addr: "1"}, + {ServerName: "t", Addr: "2"}, + {ServerName: "t", Addr: "3"}, + {ServerName: "t", Addr: "4"}, + }, + }, + }}, + expectedStates: []balancerState{ + { + ConnectivityState: connectivity.Connecting, + members: []string{"t1", "t2", "t3"}, + replicationFactor: 100, + spread: 1, + }, + { + ConnectivityState: connectivity.Connecting, + members: []string{"t1", "t2", "t3", "t4"}, + replicationFactor: 100, + spread: 1, + }, + }, + expectedConnState: connectivity.Idle, + }, + { + name: "existing hashring with 3 nodes, 1 replaced", + s: []balancer.ClientConnState{{ + ResolverState: resolver.State{ + Addresses: []resolver.Address{ + {ServerName: "t", Addr: "1"}, + {ServerName: "t", Addr: "2"}, + {ServerName: "t", Addr: "3"}, + }, + }, + BalancerConfig: &ConsistentHashringBalancerConfig{ + ReplicationFactor: 100, + Spread: 1, + }, + }, { + ResolverState: resolver.State{ + Addresses: []resolver.Address{ + {ServerName: "t", Addr: "1"}, + {ServerName: "t", Addr: "2"}, + {ServerName: "t", Addr: "4"}, + }, + }, + }}, + expectedStates: []balancerState{ + { + ConnectivityState: connectivity.Connecting, + members: []string{"t1", "t2", "t3"}, + replicationFactor: 100, + spread: 1, + }, + { + ConnectivityState: connectivity.Connecting, + members: []string{"t1", "t2", "t4"}, + replicationFactor: 100, + spread: 1, + }, + }, + expectedConnState: connectivity.Idle, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b := NewConsistentHashringBuilder(xxhash.Sum64) + cc := newFakeClientConn() + bb := b.Build(cc, balancer.BuildOptions{}) + cb := bb.(*ConsistentHashringBalancer) + + tt := tt + + done := make(chan struct{}) + go func() { + i := 0 + if len(tt.expectedStates) == 0 { + done <- struct{}{} + return + } + for { + s := <-cc.stateCh + expected := tt.expectedStates[i] + require.Equal(t, expected.ConnectivityState, s.ConnectivityState) + if expected.err != nil { + require.Equal(t, base.NewErrPicker(expected.err), s.Picker) + } else { + p := s.Picker.(*consistentHashringPicker) + require.Equal(t, expected.spread, p.spread) + require.ElementsMatch(t, expected.members, lo.Map(p.hashring.Members(), func(m consistent.Member, index int) string { + return m.Key() + })) + } + i++ + done <- struct{}{} + } + }() + + for _, state := range tt.s { + if err := cb.UpdateClientConnState(state); (err != nil) != tt.wantErr { + t.Errorf("UpdateClientConnState() error = %v, wantErr %v", err, tt.wantErr) + } + <-done + } + require.Equal(t, tt.expectedConnState, cb.csEvltr.CurrentState()) + }) + } +} + +type fakeClientConn struct { + balancer.ClientConn + + stateCh chan balancer.State + + mu sync.Mutex + subConns map[balancer.SubConn]resolver.Address +} + +func newFakeClientConn() *fakeClientConn { + return &fakeClientConn{ + subConns: make(map[balancer.SubConn]resolver.Address), + stateCh: make(chan balancer.State), + } +} + +func (c *fakeClientConn) NewSubConn(addrs []resolver.Address, _ balancer.NewSubConnOptions) (balancer.SubConn, error) { + sc := &fakeSubConn{} + c.mu.Lock() + defer c.mu.Unlock() + c.subConns[sc] = addrs[0] + return sc, nil +} + +func (c *fakeClientConn) RemoveSubConn(sc balancer.SubConn) { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.subConns, sc) +} + +func (c *fakeClientConn) UpdateState(s balancer.State) { + c.stateCh <- s +} diff --git a/pkg/cmd/server/server.go b/pkg/cmd/server/server.go index 232d9bc9c6..d21e16a41d 100644 --- a/pkg/cmd/server/server.go +++ b/pkg/cmd/server/server.go @@ -45,15 +45,8 @@ import ( "github.com/authzed/spicedb/pkg/datastore" ) -const ( - hashringReplicationFactor = 100 - backendsPerKey = 1 -) - -var ConsistentHashringPicker = balancer.NewConsistentHashringPickerBuilder( +var ConsistentHashringBuilder = balancer.NewConsistentHashringBuilder( xxhash.Sum64, - hashringReplicationFactor, - backendsPerKey, ) //go:generate go run github.com/ecordell/optgen -output zz_generated.options.go . Config @@ -246,14 +239,12 @@ func (c *Config) Complete(ctx context.Context) (RunnableServer, error) { specificConcurrencyLimits := c.DispatchConcurrencyLimits concurrencyLimits := specificConcurrencyLimits.WithOverallDefaultLimit(c.GlobalDispatchConcurrencyLimit) - // Set the hashring values to take effect the next time the replicas are updated - // Applies to ALL running servers. - if c.DispatchHashringReplicationFactor > 0 { - ConsistentHashringPicker.MustReplicationFactor(c.DispatchHashringReplicationFactor) - } - - if c.DispatchHashringSpread > 0 { - ConsistentHashringPicker.MustSpread(c.DispatchHashringSpread) + hashringConfigJSON, err := (&balancer.ConsistentHashringBalancerConfig{ + ReplicationFactor: c.DispatchHashringReplicationFactor, + Spread: c.DispatchHashringSpread, + }).ToServiceConfigJSON() + if err != nil { + return nil, fmt.Errorf("failed to create hashring balancer config: %w", err) } dispatcher, err = combineddispatch.NewDispatcher( @@ -262,7 +253,7 @@ func (c *Config) Complete(ctx context.Context) (RunnableServer, error) { combineddispatch.GrpcPresharedKey(dispatchPresharedKey), combineddispatch.GrpcDialOpts( grpc.WithUnaryInterceptor(otelgrpc.UnaryClientInterceptor()), - grpc.WithDefaultServiceConfig(balancer.BalancerServiceConfig), + grpc.WithDefaultServiceConfig(hashringConfigJSON), ), combineddispatch.MetricsEnabled(c.DispatchClientMetricsEnabled), combineddispatch.PrometheusSubsystem(c.DispatchClientMetricsPrefix), @@ -272,7 +263,7 @@ func (c *Config) Complete(ctx context.Context) (RunnableServer, error) { if err != nil { return nil, fmt.Errorf("failed to create dispatcher: %w", err) } - log.Ctx(ctx).Info().EmbedObject(concurrencyLimits).EmbedObject(ConsistentHashringPicker).Msg("configured dispatcher") + log.Ctx(ctx).Info().EmbedObject(concurrencyLimits).RawJSON("balancerconfig", []byte(hashringConfigJSON)).Msg("configured dispatcher") } closeables.AddWithError(dispatcher.Close)