Skip to content

Commit

Permalink
balancer: add StateListener to NewSubConnOptions for SubConn state up…
Browse files Browse the repository at this point in the history
…dates (#6481)
  • Loading branch information
dfawley authored Jul 31, 2023
1 parent 94df716 commit c635404
Show file tree
Hide file tree
Showing 19 changed files with 303 additions and 237 deletions.
8 changes: 8 additions & 0 deletions balancer/balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ type NewSubConnOptions struct {
// HealthCheckEnabled indicates whether health check service should be
// enabled on this SubConn
HealthCheckEnabled bool
// StateListener is called when the state of the subconn changes. If nil,
// Balancer.UpdateSubConnState will be called instead. Will never be
// invoked until after Connect() is called on the SubConn created with
// these options.
StateListener func(SubConnState)
}

// State contains the balancer's state relevant to the gRPC ClientConn.
Expand Down Expand Up @@ -349,6 +354,9 @@ type Balancer interface {
ResolverError(error)
// UpdateSubConnState is called by gRPC when the state of a SubConn
// changes.
//
// Deprecated: Use NewSubConnOptions.StateListener when creating the
// SubConn instead.
UpdateSubConnState(SubConn, SubConnState)
// Close closes the balancer. The balancer is not required to call
// ClientConn.RemoveSubConn for its existing SubConns.
Expand Down
182 changes: 91 additions & 91 deletions balancer/weightedtarget/weightedtarget_test.go

Large diffs are not rendered by default.

13 changes: 10 additions & 3 deletions balancer_conn_wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ func (ccb *ccBalancerWrapper) updateClientConnState(ccs *balancer.ClientConnStat
func (ccb *ccBalancerWrapper) updateSubConnState(sc balancer.SubConn, s connectivity.State, err error) {
ccb.mu.Lock()
ccb.serializer.Schedule(func(_ context.Context) {
ccb.balancer.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: s, ConnectionError: err})
// Even though it is optional for balancers, gracefulswitch ensures
// opts.StateListener is set, so this cannot ever be nil.
sc.(*acBalancerWrapper).stateListener(balancer.SubConnState{ConnectivityState: s, ConnectionError: err})
})
ccb.mu.Unlock()
}
Expand Down Expand Up @@ -300,7 +302,11 @@ func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer
channelz.Warningf(logger, ccb.cc.channelzID, "acBalancerWrapper: NewSubConn: failed to newAddrConn: %v", err)
return nil, err
}
acbw := &acBalancerWrapper{ac: ac, producers: make(map[balancer.ProducerBuilder]*refCountedProducer)}
acbw := &acBalancerWrapper{
ac: ac,
producers: make(map[balancer.ProducerBuilder]*refCountedProducer),
stateListener: opts.StateListener,
}
ac.acbw = acbw
return acbw, nil
}
Expand Down Expand Up @@ -366,7 +372,8 @@ func (ccb *ccBalancerWrapper) Target() string {
// acBalancerWrapper is a wrapper on top of ac for balancers.
// It implements balancer.SubConn interface.
type acBalancerWrapper struct {
ac *addrConn // read-only
ac *addrConn // read-only
stateListener func(balancer.SubConnState)

mu sync.Mutex
producers map[balancer.ProducerBuilder]*refCountedProducer
Expand Down
36 changes: 20 additions & 16 deletions internal/balancer/gracefulswitch/gracefulswitch.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ func (gsb *Balancer) ExitIdle() {
}
}

// UpdateSubConnState forwards the update to the appropriate child.
func (gsb *Balancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
// updateSubConnState forwards the update to the appropriate child.
func (gsb *Balancer) updateSubConnState(sc balancer.SubConn, state balancer.SubConnState, cb func(balancer.SubConnState)) {
gsb.currentMu.Lock()
defer gsb.currentMu.Unlock()
gsb.mu.Lock()
Expand All @@ -214,13 +214,26 @@ func (gsb *Balancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubC
} else if gsb.balancerPending != nil && gsb.balancerPending.subconns[sc] {
balToUpdate = gsb.balancerPending
}
gsb.mu.Unlock()
if balToUpdate == nil {
// SubConn belonged to a stale lb policy that has not yet fully closed,
// or the balancer was already closed.
gsb.mu.Unlock()
return
}
balToUpdate.UpdateSubConnState(sc, state)
if state.ConnectivityState == connectivity.Shutdown {
delete(balToUpdate.subconns, sc)
}
gsb.mu.Unlock()
if cb != nil {
cb(state)
} else {
balToUpdate.UpdateSubConnState(sc, state)
}
}

// UpdateSubConnState forwards the update to the appropriate child.
func (gsb *Balancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
gsb.updateSubConnState(sc, state, nil)
}

// Close closes any active child balancers.
Expand Down Expand Up @@ -254,18 +267,6 @@ type balancerWrapper struct {
subconns map[balancer.SubConn]bool // subconns created by this balancer
}

func (bw *balancerWrapper) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
if state.ConnectivityState == connectivity.Shutdown {
bw.gsb.mu.Lock()
delete(bw.subconns, sc)
bw.gsb.mu.Unlock()
}
// There is no need to protect this read with a mutex, as the write to the
// Balancer field happens in SwitchTo, which completes before this can be
// called.
bw.Balancer.UpdateSubConnState(sc, state)
}

// Close closes the underlying LB policy and removes the subconns it created. bw
// must not be referenced via balancerCurrent or balancerPending in gsb when
// called. gsb.mu must not be held. Does not panic with a nil receiver.
Expand Down Expand Up @@ -335,6 +336,9 @@ func (bw *balancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer.Ne
}
bw.gsb.mu.Unlock()

var sc balancer.SubConn
oldListener := opts.StateListener
opts.StateListener = func(state balancer.SubConnState) { bw.gsb.updateSubConnState(sc, state, oldListener) }
sc, err := bw.gsb.cc.NewSubConn(addrs, opts)
if err != nil {
return nil, err
Expand Down
98 changes: 31 additions & 67 deletions internal/balancer/gracefulswitch/gracefulswitch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,7 @@ func (s) TestCurrentLeavingReady(t *testing.T) {

// TestBalancerSubconns tests the SubConn functionality of the graceful switch
// load balancer. This tests the SubConn update flow in both directions, and
// make sure updates end up at the correct component. Also, it tests that on an
// UpdateSubConnState() call from the ClientConn, the graceful switch load
// balancer forwards it to the correct child balancer.
// make sure updates end up at the correct component.
func (s) TestBalancerSubconns(t *testing.T) {
tcc, gsb := setup(t)
gsb.SwitchTo(mockBalancerBuilder1{})
Expand All @@ -365,7 +363,7 @@ func (s) TestBalancerSubconns(t *testing.T) {
case <-ctx.Done():
t.Fatalf("timeout while waiting for an NewSubConn call on the ClientConn")
case sc := <-tcc.NewSubConnCh:
if !cmp.Equal(sc1, sc, cmp.AllowUnexported(testutils.TestSubConn{})) {
if sc != sc1 {
t.Fatalf("NewSubConn, want %v, got %v", sc1, sc)
}
}
Expand All @@ -380,47 +378,20 @@ func (s) TestBalancerSubconns(t *testing.T) {
case <-ctx.Done():
t.Fatalf("timeout while waiting for an NewSubConn call on the ClientConn")
case sc := <-tcc.NewSubConnCh:
if !cmp.Equal(sc2, sc, cmp.AllowUnexported(testutils.TestSubConn{})) {
if sc != sc2 {
t.Fatalf("NewSubConn, want %v, got %v", sc2, sc)
}
}
scState := balancer.SubConnState{ConnectivityState: connectivity.Ready}
// Updating the SubConnState for sc1 should cause the graceful switch
// balancer to forward the Update to balancerCurrent for sc1, as that is the
// balancer that created this SubConn.
gsb.UpdateSubConnState(sc1, scState)

// This update should get forwarded to balancerCurrent, as that is the LB
// that created this SubConn.
if err := gsb.balancerCurrent.Balancer.(*mockBalancer).waitForSubConnUpdate(ctx, subConnWithState{sc: sc1, state: scState}); err != nil {
t.Fatal(err)
}
// This update should not get forwarded to balancerPending, as that is not
// the LB that created this SubConn.
sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
defer sCancel()
if err := gsb.balancerPending.Balancer.(*mockBalancer).waitForSubConnUpdate(sCtx, subConnWithState{sc: sc1, state: scState}); err == nil {
t.Fatalf("balancerPending should not have received a subconn update for sc1")
}
sc1.(*testutils.TestSubConn).UpdateState(scState)

// Updating the SubConnState for sc2 should cause the graceful switch
// balancer to forward the Update to balancerPending for sc2, as that is the
// balancer that created this SubConn.
gsb.UpdateSubConnState(sc2, scState)

// This update should get forwarded to balancerPending, as that is the LB
// that created this SubConn.
if err := gsb.balancerPending.Balancer.(*mockBalancer).waitForSubConnUpdate(ctx, subConnWithState{sc: sc2, state: scState}); err != nil {
t.Fatal(err)
}

// This update should not get forwarded to balancerCurrent, as that is not
// the LB that created this SubConn.
sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout)
defer sCancel()
if err := gsb.balancerCurrent.Balancer.(*mockBalancer).waitForSubConnUpdate(sCtx, subConnWithState{sc: sc2, state: scState}); err == nil {
t.Fatalf("balancerCurrent should not have received a subconn update for sc2")
}
sc2.(*testutils.TestSubConn).UpdateState(scState)

// Updating the addresses for both SubConns and removing both SubConns
// should get forwarded to the ClientConn.
Expand Down Expand Up @@ -448,7 +419,7 @@ func (s) TestBalancerSubconns(t *testing.T) {
case <-ctx.Done():
t.Fatalf("timeout while waiting for an UpdateAddresses call on the ClientConn")
case sc := <-tcc.RemoveSubConnCh:
if !cmp.Equal(sc1, sc, cmp.AllowUnexported(testutils.TestSubConn{})) {
if sc != sc1 {
t.Fatalf("RemoveSubConn, want %v, got %v", sc1, sc)
}
}
Expand All @@ -458,7 +429,7 @@ func (s) TestBalancerSubconns(t *testing.T) {
case <-ctx.Done():
t.Fatalf("timeout while waiting for an UpdateAddresses call on the ClientConn")
case sc := <-tcc.RemoveSubConnCh:
if !cmp.Equal(sc2, sc, cmp.AllowUnexported(testutils.TestSubConn{})) {
if sc != sc2 {
t.Fatalf("RemoveSubConn, want %v, got %v", sc2, sc)
}
}
Expand All @@ -476,7 +447,8 @@ func (s) TestBalancerClose(t *testing.T) {
gsb.SwitchTo(mockBalancerBuilder1{})
gsb.SwitchTo(mockBalancerBuilder2{})

sc1, err := gsb.balancerCurrent.Balancer.(*mockBalancer).newSubConn([]resolver.Address{}, balancer.NewSubConnOptions{}) // Will eventually get back a SubConn with an identifying property id 1
sc1, err := gsb.balancerCurrent.Balancer.(*mockBalancer).newSubConn([]resolver.Address{}, balancer.NewSubConnOptions{})
// Will eventually get back a SubConn with an identifying property id 1
if err != nil {
t.Fatalf("error constructing newSubConn in gsb: %v", err)
}
Expand All @@ -488,7 +460,8 @@ func (s) TestBalancerClose(t *testing.T) {
case <-tcc.NewSubConnCh:
}

sc2, err := gsb.balancerPending.Balancer.(*mockBalancer).newSubConn([]resolver.Address{}, balancer.NewSubConnOptions{}) // Will eventually get back a SubConn with an identifying property id 2
sc2, err := gsb.balancerPending.Balancer.(*mockBalancer).newSubConn([]resolver.Address{}, balancer.NewSubConnOptions{})
// Will eventually get back a SubConn with an identifying property id 2
if err != nil {
t.Fatalf("error constructing newSubConn in gsb: %v", err)
}
Expand All @@ -512,10 +485,8 @@ func (s) TestBalancerClose(t *testing.T) {
case <-ctx.Done():
t.Fatalf("timeout while waiting for an UpdateAddresses call on the ClientConn")
case sc := <-tcc.RemoveSubConnCh:
if !cmp.Equal(sc1, sc, cmp.AllowUnexported(testutils.TestSubConn{})) {
if !cmp.Equal(sc2, sc, cmp.AllowUnexported(testutils.TestSubConn{})) {
t.Fatalf("RemoveSubConn, want either %v or %v, got %v", sc1, sc2, sc)
}
if sc != sc1 && sc != sc2 {
t.Fatalf("RemoveSubConn, want either %v or %v, got %v", sc1, sc2, sc)
}
}

Expand All @@ -525,10 +496,8 @@ func (s) TestBalancerClose(t *testing.T) {
case <-ctx.Done():
t.Fatalf("timeout while waiting for an UpdateAddresses call on the ClientConn")
case sc := <-tcc.RemoveSubConnCh:
if !cmp.Equal(sc1, sc, cmp.AllowUnexported(testutils.TestSubConn{})) {
if !cmp.Equal(sc2, sc, cmp.AllowUnexported(testutils.TestSubConn{})) {
t.Fatalf("RemoveSubConn, want either %v or %v, got %v", sc1, sc2, sc)
}
if sc != sc1 && sc != sc2 {
t.Fatalf("RemoveSubConn, want either %v or %v, got %v", sc1, sc2, sc)
}
}

Expand Down Expand Up @@ -654,7 +623,7 @@ func (s) TestPendingReplacedByAnotherPending(t *testing.T) {
case <-ctx.Done():
t.Fatalf("timeout while waiting for a RemoveSubConn call on the ClientConn")
case sc := <-tcc.RemoveSubConnCh:
if !cmp.Equal(sc1, sc, cmp.AllowUnexported(testutils.TestSubConn{})) {
if sc != sc1 {
t.Fatalf("RemoveSubConn, want %v, got %v", sc1, sc)
}
}
Expand Down Expand Up @@ -735,7 +704,7 @@ func (s) TestUpdateSubConnStateRace(t *testing.T) {
return
default:
}
gsb.UpdateSubConnState(sc, balancer.SubConnState{
sc.(*testutils.TestSubConn).UpdateState(balancer.SubConnState{
ConnectivityState: connectivity.Ready,
})
}
Expand Down Expand Up @@ -771,7 +740,7 @@ func (s) TestInlineCallbackInBuild(t *testing.T) {
}
select {
case <-ctx.Done():
t.Fatalf("timeout while waiting for an NewSubConn() call on the ClientConn")
t.Fatalf("timeout while waiting for a NewSubConn() call on the ClientConn")
case <-tcc.NewSubConnCh:
}
select {
Expand All @@ -796,7 +765,7 @@ func (s) TestInlineCallbackInBuild(t *testing.T) {
}
select {
case <-ctx.Done():
t.Fatalf("timeout while waiting for an NewSubConn() call on the ClientConn")
t.Fatalf("timeout while waiting for a NewSubConn() call on the ClientConn")
case <-tcc.NewSubConnCh:
}
select {
Expand Down Expand Up @@ -945,20 +914,6 @@ func (mb1 *mockBalancer) waitForClientConnUpdate(ctx context.Context, wantCCS ba
return nil
}

// waitForSubConnUpdate verifies if the mockBalancer receives the provided
// SubConn update before the context expires.
func (mb1 *mockBalancer) waitForSubConnUpdate(ctx context.Context, wantSCS subConnWithState) error {
scs, err := mb1.scStateCh.Receive(ctx)
if err != nil {
return fmt.Errorf("error waiting for SubConnUpdate: %v", err)
}
gotSCS := scs.(subConnWithState)
if !cmp.Equal(gotSCS, wantSCS, cmp.AllowUnexported(subConnWithState{}, testutils.TestSubConn{})) {
return fmt.Errorf("error in SubConnUpdate: received SubConnState: %+v, want %+v", gotSCS, wantSCS)
}
return nil
}

// waitForResolverError verifies if the mockBalancer receives the provided
// resolver error before the context expires.
func (mb1 *mockBalancer) waitForResolverError(ctx context.Context, wantErr error) error {
Expand Down Expand Up @@ -994,7 +949,10 @@ func (mb1 *mockBalancer) updateState(state balancer.State) {
mb1.cc.UpdateState(state)
}

func (mb1 *mockBalancer) newSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) {
func (mb1 *mockBalancer) newSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (sc balancer.SubConn, err error) {
if opts.StateListener == nil {
opts.StateListener = func(state balancer.SubConnState) { mb1.UpdateSubConnState(sc, state) }
}
return mb1.cc.NewSubConn(addrs, opts)
}

Expand Down Expand Up @@ -1061,7 +1019,10 @@ func (vb *verifyBalancer) Close() {
vb.closed.Fire()
}

func (vb *verifyBalancer) newSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) {
func (vb *verifyBalancer) newSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (sc balancer.SubConn, err error) {
if opts.StateListener == nil {
opts.StateListener = func(state balancer.SubConnState) { vb.UpdateSubConnState(sc, state) }
}
return vb.cc.NewSubConn(addrs, opts)
}

Expand Down Expand Up @@ -1111,7 +1072,10 @@ func (bcb *buildCallbackBal) updateState(state balancer.State) {
bcb.cc.UpdateState(state)
}

func (bcb *buildCallbackBal) newSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) {
func (bcb *buildCallbackBal) newSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (sc balancer.SubConn, err error) {
if opts.StateListener == nil {
opts.StateListener = func(state balancer.SubConnState) { bcb.UpdateSubConnState(sc, state) }
}
return bcb.cc.NewSubConn(addrs, opts)
}

Expand Down
21 changes: 17 additions & 4 deletions internal/balancergroup/balancergroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -449,9 +449,9 @@ func (bg *BalancerGroup) connect(sb *subBalancerWrapper) {

// Following are actions from the parent grpc.ClientConn, forward to sub-balancers.

// UpdateSubConnState handles the state for the subconn. It finds the
// corresponding balancer and forwards the update.
func (bg *BalancerGroup) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
// updateSubConnState handles the state for the subconn. It finds the
// corresponding balancer and forwards the update to cb.
func (bg *BalancerGroup) updateSubConnState(sc balancer.SubConn, state balancer.SubConnState, cb func(balancer.SubConnState)) {
bg.incomingMu.Lock()
config, ok := bg.scToSubBalancer[sc]
if !ok {
Expand All @@ -465,10 +465,20 @@ func (bg *BalancerGroup) UpdateSubConnState(sc balancer.SubConn, state balancer.
bg.incomingMu.Unlock()

bg.outgoingMu.Lock()
config.updateSubConnState(sc, state)
if cb != nil {
cb(state)
} else {
config.updateSubConnState(sc, state)
}
bg.outgoingMu.Unlock()
}

// UpdateSubConnState handles the state for the subconn. It finds the
// corresponding balancer and forwards the update.
func (bg *BalancerGroup) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
bg.updateSubConnState(sc, state, nil)
}

// UpdateClientConnState handles ClientState (including balancer config and
// addresses) from resolver. It finds the balancer and forwards the update.
func (bg *BalancerGroup) UpdateClientConnState(id string, s balancer.ClientConnState) error {
Expand Down Expand Up @@ -507,6 +517,9 @@ func (bg *BalancerGroup) newSubConn(config *subBalancerWrapper, addrs []resolver
bg.incomingMu.Unlock()
return nil, fmt.Errorf("NewSubConn is called after balancer group is closed")
}
var sc balancer.SubConn
oldListener := opts.StateListener
opts.StateListener = func(state balancer.SubConnState) { bg.updateSubConnState(sc, state, oldListener) }
sc, err := bg.cc.NewSubConn(addrs, opts)
if err != nil {
bg.incomingMu.Unlock()
Expand Down
Loading

0 comments on commit c635404

Please sign in to comment.