From ca4865d6dd6f3d8b77f1943ccfd6c9e78223912d Mon Sep 17 00:00:00 2001 From: Doug Fawley Date: Mon, 30 Sep 2024 08:42:42 -0700 Subject: [PATCH] balancer: automatically stop producers on subchannel state changes (#7663) --- balancer/balancer.go | 13 +- balancer/weightedroundrobin/balancer.go | 22 ++-- balancer_wrapper.go | 49 +++++--- clientconn.go | 53 ++------- interop/orcalb.go | 11 +- orca/producer.go | 19 +-- orca/producer_test.go | 35 +++--- producer_ext_test.go | 101 +++++++--------- test/balancer_test.go | 150 ------------------------ 9 files changed, 133 insertions(+), 320 deletions(-) diff --git a/balancer/balancer.go b/balancer/balancer.go index 8d125d2aa207..3a2092f1056e 100644 --- a/balancer/balancer.go +++ b/balancer/balancer.go @@ -142,8 +142,11 @@ type SubConn interface { Connect() // GetOrBuildProducer returns a reference to the existing Producer for this // ProducerBuilder in this SubConn, or, if one does not currently exist, - // creates a new one and returns it. Returns a close function which must - // be called when the Producer is no longer needed. + // creates a new one and returns it. Returns a close function which may be + // called when the Producer is no longer needed. Otherwise the producer + // will automatically be closed upon connection loss or subchannel close. + // Should only be called on a SubConn in state Ready. Otherwise the + // producer will be unable to create streams. GetOrBuildProducer(ProducerBuilder) (p Producer, close func()) // Shutdown shuts down the SubConn gracefully. Any started RPCs will be // allowed to complete. No future calls should be made on the SubConn. @@ -452,8 +455,10 @@ type ProducerBuilder interface { // Build creates a Producer. The first parameter is always a // grpc.ClientConnInterface (a type to allow creating RPCs/streams on the // associated SubConn), but is declared as `any` to avoid a dependency - // cycle. Should also return a close function that will be called when all - // references to the Producer have been given up. + // cycle. Build also returns a close function that will be called when all + // references to the Producer have been given up for a SubConn, or when a + // connectivity state change occurs on the SubConn. The close function + // should always block until all asynchronous cleanup work is completed. Build(grpcClientConnInterface any) (p Producer, close func()) } diff --git a/balancer/weightedroundrobin/balancer.go b/balancer/weightedroundrobin/balancer.go index 88bf64ec4ec4..1ea9eba4c894 100644 --- a/balancer/weightedroundrobin/balancer.go +++ b/balancer/weightedroundrobin/balancer.go @@ -526,17 +526,21 @@ func (w *weightedSubConn) updateConfig(cfg *lbConfig) { w.cfg = cfg w.mu.Unlock() - newPeriod := cfg.OOBReportingPeriod if cfg.EnableOOBLoadReport == oldCfg.EnableOOBLoadReport && - newPeriod == oldCfg.OOBReportingPeriod { + cfg.OOBReportingPeriod == oldCfg.OOBReportingPeriod { // Load reporting wasn't enabled before or after, or load reporting was // enabled before and after, and had the same period. (Note that with // load reporting disabled, OOBReportingPeriod is always 0.) return } - // (Optionally stop and) start the listener to use the new config's - // settings for OOB reporting. + if w.connectivityState == connectivity.Ready { + // (Re)start the listener to use the new config's settings for OOB + // reporting. + w.updateORCAListener(cfg) + } +} +func (w *weightedSubConn) updateORCAListener(cfg *lbConfig) { if w.stopORCAListener != nil { w.stopORCAListener() } @@ -545,9 +549,9 @@ func (w *weightedSubConn) updateConfig(cfg *lbConfig) { return } if w.logger.V(2) { - w.logger.Infof("Registering ORCA listener for %v with interval %v", w.SubConn, newPeriod) + w.logger.Infof("Registering ORCA listener for %v with interval %v", w.SubConn, cfg.OOBReportingPeriod) } - opts := orca.OOBListenerOptions{ReportInterval: time.Duration(newPeriod)} + opts := orca.OOBListenerOptions{ReportInterval: time.Duration(cfg.OOBReportingPeriod)} w.stopORCAListener = orca.RegisterOOBListener(w.SubConn, w, opts) } @@ -569,11 +573,9 @@ func (w *weightedSubConn) updateConnectivityState(cs connectivity.State) connect w.mu.Lock() w.nonEmptySince = time.Time{} w.lastUpdated = time.Time{} + cfg := w.cfg w.mu.Unlock() - case connectivity.Shutdown: - if w.stopORCAListener != nil { - w.stopORCAListener() - } + w.updateORCAListener(cfg) } oldCS := w.connectivityState diff --git a/balancer_wrapper.go b/balancer_wrapper.go index efdbe7cf4fae..2a4f2878aef4 100644 --- a/balancer_wrapper.go +++ b/balancer_wrapper.go @@ -24,12 +24,14 @@ import ( "sync" "google.golang.org/grpc/balancer" + "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/balancer/gracefulswitch" "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/resolver" + "google.golang.org/grpc/status" ) var setConnectedAddress = internal.SetConnectedAddress.(func(*balancer.SubConnState, resolver.Address)) @@ -256,17 +258,20 @@ type acBalancerWrapper struct { ccb *ccBalancerWrapper // read-only stateListener func(balancer.SubConnState) - mu sync.Mutex - producers map[balancer.ProducerBuilder]*refCountedProducer + producersMu sync.Mutex + producers map[balancer.ProducerBuilder]*refCountedProducer } // updateState is invoked by grpc to push a subConn state update to the // underlying balancer. -func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolver.Address, err error, readyChan chan struct{}) { +func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolver.Address, err error) { acbw.ccb.serializer.TrySchedule(func(ctx context.Context) { if ctx.Err() != nil || acbw.ccb.balancer == nil { return } + // Invalidate all producers on any state change. + acbw.closeProducers() + // Even though it is optional for balancers, gracefulswitch ensures // opts.StateListener is set, so this cannot ever be nil. // TODO: delete this comment when UpdateSubConnState is removed. @@ -275,15 +280,6 @@ func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolve setConnectedAddress(&scs, curAddr) } acbw.stateListener(scs) - acbw.ac.mu.Lock() - defer acbw.ac.mu.Unlock() - if s == connectivity.Ready { - // When changing states to READY, close stateReadyChan. Wait until - // after we notify the LB policy's listener(s) in order to prevent - // ac.getTransport() from unblocking before the LB policy starts - // tracking the subchannel as READY. - close(readyChan) - } }) } @@ -300,6 +296,7 @@ func (acbw *acBalancerWrapper) Connect() { } func (acbw *acBalancerWrapper) Shutdown() { + acbw.closeProducers() acbw.ccb.cc.removeAddrConn(acbw.ac, errConnDrain) } @@ -307,9 +304,10 @@ func (acbw *acBalancerWrapper) Shutdown() { // ready, blocks until it is or ctx expires. Returns an error when the context // expires or the addrConn is shut down. func (acbw *acBalancerWrapper) NewStream(ctx context.Context, desc *StreamDesc, method string, opts ...CallOption) (ClientStream, error) { - transport, err := acbw.ac.getTransport(ctx) - if err != nil { - return nil, err + transport := acbw.ac.getReadyTransport() + if transport == nil { + return nil, status.Errorf(codes.Unavailable, "SubConn state is not Ready") + } return newNonRetryClientStream(ctx, desc, method, transport, acbw.ac, opts...) } @@ -334,8 +332,8 @@ type refCountedProducer struct { } func (acbw *acBalancerWrapper) GetOrBuildProducer(pb balancer.ProducerBuilder) (balancer.Producer, func()) { - acbw.mu.Lock() - defer acbw.mu.Unlock() + acbw.producersMu.Lock() + defer acbw.producersMu.Unlock() // Look up existing producer from this builder. pData := acbw.producers[pb] @@ -352,13 +350,26 @@ func (acbw *acBalancerWrapper) GetOrBuildProducer(pb balancer.ProducerBuilder) ( // and delete the refCountedProducer from the map if the total reference // count goes to zero. unref := func() { - acbw.mu.Lock() + acbw.producersMu.Lock() + // If closeProducers has already closed this producer instance, refs is + // set to 0, so the check after decrementing will never pass, and the + // producer will not be double-closed. pData.refs-- if pData.refs == 0 { defer pData.close() // Run outside the acbw mutex delete(acbw.producers, pb) } - acbw.mu.Unlock() + acbw.producersMu.Unlock() } return pData.producer, grpcsync.OnceFunc(unref) } + +func (acbw *acBalancerWrapper) closeProducers() { + acbw.producersMu.Lock() + defer acbw.producersMu.Unlock() + for pb, pData := range acbw.producers { + pData.refs = 0 + pData.close() + delete(acbw.producers, pb) + } +} diff --git a/clientconn.go b/clientconn.go index a680fefc1385..b47efb33c0e9 100644 --- a/clientconn.go +++ b/clientconn.go @@ -825,14 +825,13 @@ func (cc *ClientConn) newAddrConnLocked(addrs []resolver.Address, opts balancer. } ac := &addrConn{ - state: connectivity.Idle, - cc: cc, - addrs: copyAddresses(addrs), - scopts: opts, - dopts: cc.dopts, - channelz: channelz.RegisterSubChannel(cc.channelz, ""), - resetBackoff: make(chan struct{}), - stateReadyChan: make(chan struct{}), + state: connectivity.Idle, + cc: cc, + addrs: copyAddresses(addrs), + scopts: opts, + dopts: cc.dopts, + channelz: channelz.RegisterSubChannel(cc.channelz, ""), + resetBackoff: make(chan struct{}), } ac.ctx, ac.cancel = context.WithCancel(cc.ctx) // Start with our address set to the first address; this may be updated if @@ -1179,8 +1178,7 @@ type addrConn struct { addrs []resolver.Address // All addresses that the resolver resolved to. // Use updateConnectivityState for updating addrConn's connectivity state. - state connectivity.State - stateReadyChan chan struct{} // closed and recreated on every READY state change. + state connectivity.State backoffIdx int // Needs to be stateful for resetConnectBackoff. resetBackoff chan struct{} @@ -1193,14 +1191,6 @@ func (ac *addrConn) updateConnectivityState(s connectivity.State, lastErr error) if ac.state == s { return } - if ac.state == connectivity.Ready { - // When leaving ready, re-create the ready channel. - ac.stateReadyChan = make(chan struct{}) - } - if s == connectivity.Shutdown { - // Wake any producer waiting to create a stream on the transport. - close(ac.stateReadyChan) - } ac.state = s ac.channelz.ChannelMetrics.State.Store(&s) if lastErr == nil { @@ -1208,7 +1198,7 @@ func (ac *addrConn) updateConnectivityState(s connectivity.State, lastErr error) } else { channelz.Infof(logger, ac.channelz, "Subchannel Connectivity change to %v, last error: %s", s, lastErr) } - ac.acbw.updateState(s, ac.curAddr, lastErr, ac.stateReadyChan) + ac.acbw.updateState(s, ac.curAddr, lastErr) } // adjustParams updates parameters used to create transports upon @@ -1512,31 +1502,6 @@ func (ac *addrConn) getReadyTransport() transport.ClientTransport { return nil } -// getTransport waits until the addrconn is ready and returns the transport. -// If the context expires first, returns an appropriate status. If the -// addrConn is stopped first, returns an Unavailable status error. -func (ac *addrConn) getTransport(ctx context.Context) (transport.ClientTransport, error) { - for ctx.Err() == nil { - ac.mu.Lock() - t, state, readyChan := ac.transport, ac.state, ac.stateReadyChan - ac.mu.Unlock() - if state == connectivity.Shutdown { - // Return an error immediately in only this case since a connection - // will never occur. - return nil, status.Errorf(codes.Unavailable, "SubConn shutting down") - } - - select { - case <-ctx.Done(): - case <-readyChan: - if state == connectivity.Ready { - return t, nil - } - } - } - return nil, status.FromContextError(ctx.Err()).Err() -} - // tearDown starts to tear down the addrConn. // // Note that tearDown doesn't remove ac from ac.cc.conns, so the addrConn struct diff --git a/interop/orcalb.go b/interop/orcalb.go index 5ff1dc64d973..572a7dfcd5cb 100644 --- a/interop/orcalb.go +++ b/interop/orcalb.go @@ -46,9 +46,8 @@ func (orcabb) Name() string { } type orcab struct { - cc balancer.ClientConn - sc balancer.SubConn - cancelWatch func() + cc balancer.ClientConn + sc balancer.SubConn reportMu sync.Mutex report *v3orcapb.OrcaLoadReport @@ -70,7 +69,6 @@ func (o *orcab) UpdateClientConnState(s balancer.ClientConnState) error { o.cc.UpdateState(balancer.State{ConnectivityState: connectivity.TransientFailure, Picker: base.NewErrPicker(fmt.Errorf("error creating subconn: %v", err))}) return nil } - o.cancelWatch = orca.RegisterOOBListener(o.sc, o, orca.OOBListenerOptions{ReportInterval: time.Second}) o.sc.Connect() o.cc.UpdateState(balancer.State{ConnectivityState: connectivity.Connecting, Picker: base.NewErrPicker(balancer.ErrNoSubConnAvailable)}) return nil @@ -89,6 +87,7 @@ func (o *orcab) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnSt func (o *orcab) updateSubConnState(state balancer.SubConnState) { switch state.ConnectivityState { case connectivity.Ready: + orca.RegisterOOBListener(o.sc, o, orca.OOBListenerOptions{ReportInterval: time.Second}) o.cc.UpdateState(balancer.State{ConnectivityState: connectivity.Ready, Picker: &orcaPicker{o: o}}) case connectivity.TransientFailure: o.cc.UpdateState(balancer.State{ConnectivityState: connectivity.TransientFailure, Picker: base.NewErrPicker(fmt.Errorf("all subchannels in transient failure: %v", state.ConnectionError))}) @@ -102,9 +101,7 @@ func (o *orcab) updateSubConnState(state balancer.SubConnState) { } } -func (o *orcab) Close() { - o.cancelWatch() -} +func (o *orcab) Close() {} func (o *orcab) OnLoadReport(r *v3orcapb.OrcaLoadReport) { o.reportMu.Lock() diff --git a/orca/producer.go b/orca/producer.go index 6e7c4c9f301a..4d370310a0dd 100644 --- a/orca/producer.go +++ b/orca/producer.go @@ -46,6 +46,12 @@ func (*producerBuilder) Build(cci any) (balancer.Producer, func()) { backoff: internal.DefaultBackoffFunc, } return p, func() { + p.mu.Lock() + if p.stop != nil { + p.stop() + p.stop = nil + } + p.mu.Unlock() <-p.stopped } } @@ -67,9 +73,9 @@ type OOBListenerOptions struct { ReportInterval time.Duration } -// RegisterOOBListener registers an out-of-band load report listener on sc. -// Any OOBListener may only be registered once per subchannel at a time. The -// returned stop function must be called when no longer needed. Do not +// RegisterOOBListener registers an out-of-band load report listener on a Ready +// sc. Any OOBListener may only be registered once per subchannel at a time. +// The returned stop function must be called when no longer needed. Do not // register a single OOBListener more than once per SubConn. func RegisterOOBListener(sc balancer.SubConn, l OOBListener, opts OOBListenerOptions) (stop func()) { pr, closeFn := sc.GetOrBuildProducer(producerBuilderSingleton) @@ -77,9 +83,6 @@ func RegisterOOBListener(sc balancer.SubConn, l OOBListener, opts OOBListenerOpt p.registerListener(l, opts.ReportInterval) - // TODO: When we can register for SubConn state updates, automatically call - // stop() on SHUTDOWN. - // If stop is called multiple times, prevent it from having any effect on // subsequent calls. return grpcsync.OnceFunc(func() { @@ -96,13 +99,13 @@ type producer struct { // is incremented when stream errors occur and is reset when the stream // reports a result. backoff func(int) time.Duration + stopped chan struct{} // closed when the run goroutine exits mu sync.Mutex intervals map[time.Duration]int // map from interval time to count of listeners requesting that time listeners map[OOBListener]struct{} // set of registered listeners minInterval time.Duration - stop func() // stops the current run goroutine - stopped chan struct{} // closed when the run goroutine exits + stop func() // stops the current run goroutine } // registerListener adds the listener and its requested report interval to the diff --git a/orca/producer_test.go b/orca/producer_test.go index ece8a8db7145..9df18bf574c9 100644 --- a/orca/producer_test.go +++ b/orca/producer_test.go @@ -27,6 +27,7 @@ import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/roundrobin" "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" @@ -64,13 +65,19 @@ func (w *ccWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubCon if len(addrs) != 1 { panic(fmt.Sprintf("got addrs=%v; want len(addrs) == 1", addrs)) } + var sc balancer.SubConn + opts.StateListener = func(scs balancer.SubConnState) { + if scs.ConnectivityState != connectivity.Ready { + return + } + l := getListenerInfo(addrs[0]) + l.listener.cleanup = orca.RegisterOOBListener(sc, l.listener, l.opts) + l.scChan <- sc + } sc, err := w.ClientConn.NewSubConn(addrs, opts) if err != nil { return sc, err } - l := getListenerInfo(addrs[0]) - l.listener.cleanup = orca.RegisterOOBListener(sc, l.listener, l.opts) - l.sc = sc return sc, nil } @@ -79,7 +86,7 @@ func (w *ccWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubCon type listenerInfo struct { listener *testOOBListener opts orca.OOBListenerOptions - sc balancer.SubConn // Set by the LB policy + scChan chan balancer.SubConn // Pushed on by the LB policy } type listenerInfoKey struct{} @@ -143,7 +150,7 @@ func (s) TestProducer(t *testing.T) { oobLis := newTestOOBListener() lisOpts := orca.OOBListenerOptions{ReportInterval: 50 * time.Millisecond} - li := &listenerInfo{listener: oobLis, opts: lisOpts} + li := &listenerInfo{scChan: make(chan balancer.SubConn, 1), listener: oobLis, opts: lisOpts} addr := setListenerInfo(resolver.Address{Addr: lis.Addr().String()}, li) r.InitialState(resolver.State{Addresses: []resolver.Address{addr}}) cc, err := grpc.Dial("whatever:///whatever", grpc.WithDefaultServiceConfig(`{"loadBalancingConfig": [{"customLB":{}}]}`), grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -152,10 +159,6 @@ func (s) TestProducer(t *testing.T) { } defer cc.Close() - // Ensure the OOB listener is stopped before the client is closed to avoid - // a potential irrelevant error in the logs. - defer oobLis.Stop() - // Set a few metrics and wait for them on the client side. smr.SetCPUUtilization(10) smr.SetMemoryUtilization(0.1) @@ -202,6 +205,7 @@ testReport: t.Fatalf("timed out waiting for load report: %v", loadReportWant) } } + } // fakeORCAService is a simple implementation of an ORCA service that pushes @@ -313,7 +317,7 @@ func (s) TestProducerBackoff(t *testing.T) { oobLis := newTestOOBListener() lisOpts := orca.OOBListenerOptions{ReportInterval: reportInterval} - li := &listenerInfo{listener: oobLis, opts: lisOpts} + li := &listenerInfo{scChan: make(chan balancer.SubConn, 1), listener: oobLis, opts: lisOpts} r.InitialState(resolver.State{Addresses: []resolver.Address{setListenerInfo(resolver.Address{Addr: lis.Addr().String()}, li)}}) cc, err := grpc.Dial("whatever:///whatever", grpc.WithDefaultServiceConfig(`{"loadBalancingConfig": [{"customLB":{}}]}`), grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { @@ -321,10 +325,6 @@ func (s) TestProducerBackoff(t *testing.T) { } defer cc.Close() - // Ensure the OOB listener is stopped before the client is closed to avoid - // a potential irrelevant error in the logs. - defer oobLis.Stop() - // Define a load report to send and expect the client to see. loadReportWant := &v3orcapb.OrcaLoadReport{ CpuUtilization: 10, @@ -429,7 +429,7 @@ func (s) TestProducerMultipleListeners(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") oobLis1 := newTestOOBListener() lisOpts1 := orca.OOBListenerOptions{ReportInterval: reportInterval1} - li := &listenerInfo{listener: oobLis1, opts: lisOpts1} + li := &listenerInfo{scChan: make(chan balancer.SubConn, 1), listener: oobLis1, opts: lisOpts1} r.InitialState(resolver.State{Addresses: []resolver.Address{setListenerInfo(resolver.Address{Addr: lis.Addr().String()}, li)}}) cc, err := grpc.Dial("whatever:///whatever", grpc.WithDefaultServiceConfig(`{"loadBalancingConfig": [{"customLB":{}}]}`), grpc.WithResolvers(r), grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { @@ -518,15 +518,16 @@ func (s) TestProducerMultipleListeners(t *testing.T) { fake.respCh <- loadReportWant checkReports(1, 0, 0) + sc := <-li.scChan // Register listener 2 with a less frequent interval; no need to recreate // stream. Report should go to both listeners. - oobLis2.cleanup = orca.RegisterOOBListener(li.sc, oobLis2, lisOpts2) + oobLis2.cleanup = orca.RegisterOOBListener(sc, oobLis2, lisOpts2) fake.respCh <- loadReportWant checkReports(2, 1, 0) // Register listener 3 with a more frequent interval; stream is recreated // with this interval. The next report will go to all three listeners. - oobLis3.cleanup = orca.RegisterOOBListener(li.sc, oobLis3, lisOpts3) + oobLis3.cleanup = orca.RegisterOOBListener(sc, oobLis3, lisOpts3) awaitRequest(reportInterval3) fake.respCh <- loadReportWant checkReports(3, 2, 1) diff --git a/producer_ext_test.go b/producer_ext_test.go index 628a11851f87..a7ee89869375 100644 --- a/producer_ext_test.go +++ b/producer_ext_test.go @@ -21,8 +21,8 @@ package grpc_test import ( "context" "strings" + "sync/atomic" "testing" - "time" "google.golang.org/grpc" "google.golang.org/grpc/balancer" @@ -30,72 +30,41 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal/balancer/stub" "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/testutils" testgrpc "google.golang.org/grpc/interop/grpc_testing" ) -type producerBuilder struct{} - -type producer struct { - client testgrpc.TestServiceClient - stopped chan struct{} -} - -// Build constructs and returns a producer and its cleanup function -func (*producerBuilder) Build(cci any) (balancer.Producer, func()) { - p := &producer{ - client: testgrpc.NewTestServiceClient(cci.(grpc.ClientConnInterface)), - stopped: make(chan struct{}), - } - return p, func() { - <-p.stopped - } -} - -func (p *producer) testStreamStart(t *testing.T, streamStarted chan<- struct{}) { - go func() { - defer close(p.stopped) - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - if _, err := p.client.FullDuplexCall(ctx); err != nil { - t.Errorf("Unexpected error starting stream: %v", err) - } - close(streamStarted) - }() -} - -var producerBuilderSingleton = &producerBuilder{} - -// TestProducerStreamStartsAfterReady ensures producer streams only start after -// the subchannel reports as READY to the LB policy. -func (s) TestProducerStreamStartsAfterReady(t *testing.T) { +// TestProducerStopsBeforeStateChange confirms that producers are stopped before +// any state change notification is delivered to the LB policy. +func (s) TestProducerStopsBeforeStateChange(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() + name := strings.ReplaceAll(strings.ToLower(t.Name()), "/", "") - producerCh := make(chan balancer.Producer) - var producerClose func() - streamStarted := make(chan struct{}) - done := make(chan struct{}) + var lastProducer *testProducer bf := stub.BalancerFuncs{ UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error { + var sc balancer.SubConn sc, err := bd.ClientConn.NewSubConn(ccs.ResolverState.Addresses, balancer.NewSubConnOptions{ StateListener: func(scs balancer.SubConnState) { - if scs.ConnectivityState == connectivity.Ready { - timer := time.NewTimer(5 * time.Millisecond) - select { - case <-streamStarted: - t.Errorf("Producer stream started before Ready listener returned") - case <-timer.C: - } - close(done) + bd.ClientConn.UpdateState(balancer.State{ + ConnectivityState: scs.ConnectivityState, + // We do not pass a picker, but since we don't perform + // RPCs, that's okay. + }) + if !lastProducer.stopped.Load() { + t.Errorf("lastProducer not stopped before state change notification") } + t.Logf("State is now %v; recreating producer", scs.ConnectivityState) + p, _ := sc.GetOrBuildProducer(producerBuilderSingleton) + lastProducer = p.(*testProducer) }, }) if err != nil { return err } - var producer balancer.Producer - producer, producerClose = sc.GetOrBuildProducer(producerBuilderSingleton) - producerCh <- producer + p, _ := sc.GetOrBuildProducer(producerBuilderSingleton) + lastProducer = p.(*testProducer) sc.Connect() return nil }, @@ -122,16 +91,26 @@ func (s) TestProducerStreamStartsAfterReady(t *testing.T) { defer cc.Close() go cc.Connect() - p := <-producerCh - p.(*producer).testStreamStart(t, streamStarted) + testutils.AwaitState(ctx, t, cc, connectivity.Ready) + + cc.Close() + testutils.AwaitState(ctx, t, cc, connectivity.Shutdown) +} - select { - case <-done: - // Wait for the stream to start before exiting; otherwise the ClientConn - // will close and cause stream creation to fail. - <-streamStarted - producerClose() - case <-ctx.Done(): - t.Error("Timed out waiting for test to complete") +type producerBuilder struct{} + +type testProducer struct { + // There should be no race accessing this field, but use an atomic since + // the race checker probably can't detect that. + stopped atomic.Bool +} + +// Build constructs and returns a producer and its cleanup function +func (*producerBuilder) Build(cci any) (balancer.Producer, func()) { + p := &testProducer{} + return p, func() { + p.stopped.Store(true) } } + +var producerBuilderSingleton = &producerBuilder{} diff --git a/test/balancer_test.go b/test/balancer_test.go index 36d347ca6935..f27ec4d3fe90 100644 --- a/test/balancer_test.go +++ b/test/balancer_test.go @@ -904,156 +904,6 @@ func (s) TestMetadataInPickResult(t *testing.T) { } } -// producerTestBalancerBuilder and producerTestBalancer start a producer which -// makes an RPC before the subconn is READY, then connects the subconn, and -// pushes the resulting error (expected to be nil) to rpcErrChan. -type producerTestBalancerBuilder struct { - rpcErrChan chan error - ctxChan chan context.Context - connect bool -} - -func (bb *producerTestBalancerBuilder) Build(cc balancer.ClientConn, _ balancer.BuildOptions) balancer.Balancer { - return &producerTestBalancer{cc: cc, rpcErrChan: bb.rpcErrChan, ctxChan: bb.ctxChan, connect: bb.connect} -} - -const producerTestBalancerName = "producer_test_balancer" - -func (bb *producerTestBalancerBuilder) Name() string { return producerTestBalancerName } - -type producerTestBalancer struct { - cc balancer.ClientConn - rpcErrChan chan error - ctxChan chan context.Context - connect bool -} - -func (b *producerTestBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error { - // Create the subconn, but don't connect it. - sc, err := b.cc.NewSubConn(ccs.ResolverState.Addresses, balancer.NewSubConnOptions{}) - if err != nil { - return fmt.Errorf("error creating subconn: %v", err) - } - - // Create the producer. This will call the producer builder's Build - // method, which will try to start an RPC in a goroutine. - p := &testProducerBuilder{start: grpcsync.NewEvent(), rpcErrChan: b.rpcErrChan, ctxChan: b.ctxChan} - sc.GetOrBuildProducer(p) - - // Wait here until the producer is about to perform the RPC, which should - // block until connected. - <-p.start.Done() - - // Ensure the error chan doesn't get anything on it before we connect the - // subconn. - select { - case err := <-b.rpcErrChan: - go func() { b.rpcErrChan <- fmt.Errorf("Got unexpected data on rpcErrChan: %v", err) }() - default: - } - - if b.connect { - // Now we can connect, which will unblock the RPC above. - sc.Connect() - } - - // The stub server requires a READY picker to be reported, to unblock its - // Start method. We won't make RPCs in our test, so a nil picker is okay. - b.cc.UpdateState(balancer.State{ConnectivityState: connectivity.Ready, Picker: nil}) - return nil -} - -func (b *producerTestBalancer) ResolverError(err error) { - panic(fmt.Sprintf("Unexpected resolver error: %v", err)) -} - -func (b *producerTestBalancer) UpdateSubConnState(balancer.SubConn, balancer.SubConnState) {} -func (b *producerTestBalancer) Close() {} - -type testProducerBuilder struct { - start *grpcsync.Event - rpcErrChan chan error - ctxChan chan context.Context -} - -func (b *testProducerBuilder) Build(cci any) (balancer.Producer, func()) { - c := testgrpc.NewTestServiceClient(cci.(grpc.ClientConnInterface)) - // Perform the RPC in a goroutine instead of during build because the - // subchannel's mutex is held here. - go func() { - ctx := <-b.ctxChan - b.start.Fire() - _, err := c.EmptyCall(ctx, &testpb.Empty{}) - b.rpcErrChan <- err - }() - return nil, func() {} -} - -// TestBalancerProducerBlockUntilReady tests that we get no RPC errors from -// producers when subchannels aren't ready. -func (s) TestBalancerProducerBlockUntilReady(t *testing.T) { - // rpcErrChan is given to the LB policy to report the status of the - // producer's one RPC. - ctxChan := make(chan context.Context, 1) - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - ctxChan <- ctx - - rpcErrChan := make(chan error) - balancer.Register(&producerTestBalancerBuilder{rpcErrChan: rpcErrChan, ctxChan: ctxChan, connect: true}) - - ss := &stubserver.StubServer{ - EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { - return &testpb.Empty{}, nil - }, - } - - // Start the server & client with the test producer LB policy. - svcCfg := fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, producerTestBalancerName) - if err := ss.Start(nil, grpc.WithDefaultServiceConfig(svcCfg)); err != nil { - t.Fatalf("Error starting testing server: %v", err) - } - defer ss.Stop() - - // Receive the error from the producer's RPC, which should be nil. - if err := <-rpcErrChan; err != nil { - t.Fatalf("Received unexpected error from producer RPC: %v", err) - } -} - -// TestBalancerProducerHonorsContext tests that producers that perform RPC get -// context errors correctly. -func (s) TestBalancerProducerHonorsContext(t *testing.T) { - // rpcErrChan is given to the LB policy to report the status of the - // producer's one RPC. - ctxChan := make(chan context.Context, 1) - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - ctxChan <- ctx - - rpcErrChan := make(chan error) - balancer.Register(&producerTestBalancerBuilder{rpcErrChan: rpcErrChan, ctxChan: ctxChan, connect: false}) - - ss := &stubserver.StubServer{ - EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { - return &testpb.Empty{}, nil - }, - } - - // Start the server & client with the test producer LB policy. - svcCfg := fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, producerTestBalancerName) - if err := ss.Start(nil, grpc.WithDefaultServiceConfig(svcCfg)); err != nil { - t.Fatalf("Error starting testing server: %v", err) - } - defer ss.Stop() - - cancel() - - // Receive the error from the producer's RPC, which should be canceled. - if err := <-rpcErrChan; status.Code(err) != codes.Canceled { - t.Fatalf("RPC error: %v; want status.Code(err)=%v", err, codes.Canceled) - } -} - // TestSubConnShutdown confirms that the Shutdown method on subconns and // RemoveSubConn method on ClientConn properly initiates subconn shutdown. func (s) TestSubConnShutdown(t *testing.T) {