From 87f0254f112632b1745897553e48aff5884c9217 Mon Sep 17 00:00:00 2001 From: Purnesh Dixit Date: Fri, 22 Nov 2024 00:45:02 +0530 Subject: [PATCH 1/9] xdsclient: fix new watcher hang when registering for removed resource (#7853) --- xds/internal/xdsclient/authority.go | 5 ++ .../xdsclient/tests/lds_watchers_test.go | 81 +++++++++++++++++++ 2 files changed, 86 insertions(+) diff --git a/xds/internal/xdsclient/authority.go b/xds/internal/xdsclient/authority.go index 27abb64ef6d5..04bd278d2c47 100644 --- a/xds/internal/xdsclient/authority.go +++ b/xds/internal/xdsclient/authority.go @@ -641,6 +641,11 @@ func (a *authority) watchResource(rType xdsresource.Type, resourceName string, w resource := state.cache a.watcherCallbackSerializer.TrySchedule(func(context.Context) { watcher.OnUpdate(resource, func() {}) }) } + // If the metadata field is updated to indicate that the management + // server does not have this resource, notify the new watcher. + if state.md.Status == xdsresource.ServiceStatusNotExist { + a.watcherCallbackSerializer.TrySchedule(func(context.Context) { watcher.OnResourceDoesNotExist(func() {}) }) + } cleanup = a.unwatchResource(rType, resourceName, watcher) }, func() { if a.logger.V(2) { diff --git a/xds/internal/xdsclient/tests/lds_watchers_test.go b/xds/internal/xdsclient/tests/lds_watchers_test.go index 2ea2c50ce18b..38e1f1760383 100644 --- a/xds/internal/xdsclient/tests/lds_watchers_test.go +++ b/xds/internal/xdsclient/tests/lds_watchers_test.go @@ -871,6 +871,87 @@ func (s) TestLDSWatch_ResourceRemoved(t *testing.T) { } } +// TestLDSWatch_NewWatcherForRemovedResource covers the case where a new +// watcher registers for a resource that has been removed. The test verifies +// the following scenarios: +// 1. When a resource is deleted by the management server, any active +// watchers of that resource should be notified with a "resource removed" +// error through their watch callback. +// 2. If a new watcher attempts to register for a resource that has already +// been deleted, its watch callback should be immediately invoked with a +// "resource removed" error. +func (s) TestLDSWatch_NewWatcherForRemovedResource(t *testing.T) { + mgmtServer := e2e.StartManagementServer(t, e2e.ManagementServerOptions{}) + + nodeID := uuid.New().String() + bc := e2e.DefaultBootstrapContents(t, nodeID, mgmtServer.Address) + + // Create an xDS client with the above bootstrap contents. + client, close, err := xdsclient.NewForTesting(xdsclient.OptionsForTesting{ + Name: t.Name(), + Contents: bc, + }) + if err != nil { + t.Fatalf("Failed to create xDS client: %v", err) + } + defer close() + + // Register watch for the listener resource and have the + // callbacks push the received updates on to a channel. + lw1 := newListenerWatcher() + ldsCancel1 := xdsresource.WatchListener(client, ldsName, lw1) + defer ldsCancel1() + + // Configure the management server to return listener resource, + // corresponding to the registered watch. + resource := e2e.UpdateOptions{ + NodeID: nodeID, + Listeners: []*v3listenerpb.Listener{e2e.DefaultClientListener(ldsName, rdsName)}, + SkipValidation: true, + } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := mgmtServer.Update(ctx, resource); err != nil { + t.Fatalf("Failed to update management server with resource: %v, err: %v", resource, err) + } + + // Verify the contents of the received update for existing watch. + wantUpdate := listenerUpdateErrTuple{ + update: xdsresource.ListenerUpdate{ + RouteConfigName: rdsName, + HTTPFilters: []xdsresource.HTTPFilter{{Name: "router"}}, + }, + } + if err := verifyListenerUpdate(ctx, lw1.updateCh, wantUpdate); err != nil { + t.Fatal(err) + } + + // Remove the listener resource on the management server. + resource = e2e.UpdateOptions{ + NodeID: nodeID, + Listeners: []*v3listenerpb.Listener{}, + SkipValidation: true, + } + if err := mgmtServer.Update(ctx, resource); err != nil { + t.Fatalf("Failed to update management server with resource: %v, err: %v", resource, err) + } + + // The existing watcher should receive a resource removed error. + updateError := listenerUpdateErrTuple{err: xdsresource.NewErrorf(xdsresource.ErrorTypeResourceNotFound, "")} + if err := verifyListenerUpdate(ctx, lw1.updateCh, updateError); err != nil { + t.Fatal(err) + } + + // New watchers attempting to register for a deleted resource should also + // receive a "resource removed" error. + lw2 := newListenerWatcher() + ldsCancel2 := xdsresource.WatchListener(client, ldsName, lw2) + defer ldsCancel2() + if err := verifyListenerUpdate(ctx, lw2.updateCh, updateError); err != nil { + t.Fatal(err) + } +} + // TestLDSWatch_NACKError covers the case where an update from the management // server is NACK'ed by the xdsclient. The test verifies that the error is // propagated to the watcher. From 44a5eb9231ec6e753dab5ded7cda6f81788fc3f7 Mon Sep 17 00:00:00 2001 From: Purnesh Dixit Date: Fri, 22 Nov 2024 01:02:44 +0530 Subject: [PATCH 2/9] xdsclient: fix new watcher to get both old good update and nack error (if exist) from the cache (#7851) --- xds/internal/xdsclient/authority.go | 11 +- .../xdsclient/tests/lds_watchers_test.go | 198 ++++++++++++++++-- 2 files changed, 187 insertions(+), 22 deletions(-) diff --git a/xds/internal/xdsclient/authority.go b/xds/internal/xdsclient/authority.go index 04bd278d2c47..24673a8d9077 100644 --- a/xds/internal/xdsclient/authority.go +++ b/xds/internal/xdsclient/authority.go @@ -633,7 +633,8 @@ func (a *authority) watchResource(rType xdsresource.Type, resourceName string, w // Always add the new watcher to the set of watchers. state.watchers[watcher] = true - // If we have a cached copy of the resource, notify the new watcher. + // If we have a cached copy of the resource, notify the new watcher + // immediately. if state.cache != nil { if a.logger.V(2) { a.logger.Infof("Resource type %q with resource name %q found in cache: %s", rType.TypeName(), resourceName, state.cache.ToJSON()) @@ -641,6 +642,14 @@ func (a *authority) watchResource(rType xdsresource.Type, resourceName string, w resource := state.cache a.watcherCallbackSerializer.TrySchedule(func(context.Context) { watcher.OnUpdate(resource, func() {}) }) } + // If last update was NACK'd, notify the new watcher of error + // immediately as well. + if state.md.Status == xdsresource.ServiceStatusNACKed { + if a.logger.V(2) { + a.logger.Infof("Resource type %q with resource name %q was NACKed: %s", rType.TypeName(), resourceName, state.cache.ToJSON()) + } + a.watcherCallbackSerializer.TrySchedule(func(context.Context) { watcher.OnError(state.md.ErrState.Err, func() {}) }) + } // If the metadata field is updated to indicate that the management // server does not have this resource, notify the new watcher. if state.md.Status == xdsresource.ServiceStatusNotExist { diff --git a/xds/internal/xdsclient/tests/lds_watchers_test.go b/xds/internal/xdsclient/tests/lds_watchers_test.go index 38e1f1760383..7b49b9b17b74 100644 --- a/xds/internal/xdsclient/tests/lds_watchers_test.go +++ b/xds/internal/xdsclient/tests/lds_watchers_test.go @@ -71,22 +71,47 @@ func newListenerWatcher() *listenerWatcher { return &listenerWatcher{updateCh: testutils.NewChannel()} } -func (cw *listenerWatcher) OnUpdate(update *xdsresource.ListenerResourceData, onDone xdsresource.OnDoneFunc) { - cw.updateCh.Send(listenerUpdateErrTuple{update: update.Resource}) +func (lw *listenerWatcher) OnUpdate(update *xdsresource.ListenerResourceData, onDone xdsresource.OnDoneFunc) { + lw.updateCh.Send(listenerUpdateErrTuple{update: update.Resource}) onDone() } -func (cw *listenerWatcher) OnError(err error, onDone xdsresource.OnDoneFunc) { +func (lw *listenerWatcher) OnError(err error, onDone xdsresource.OnDoneFunc) { // When used with a go-control-plane management server that continuously // resends resources which are NACKed by the xDS client, using a `Replace()` // here and in OnResourceDoesNotExist() simplifies tests which will have // access to the most recently received error. - cw.updateCh.Replace(listenerUpdateErrTuple{err: err}) + lw.updateCh.Replace(listenerUpdateErrTuple{err: err}) onDone() } -func (cw *listenerWatcher) OnResourceDoesNotExist(onDone xdsresource.OnDoneFunc) { - cw.updateCh.Replace(listenerUpdateErrTuple{err: xdsresource.NewErrorf(xdsresource.ErrorTypeResourceNotFound, "Listener not found in received response")}) +func (lw *listenerWatcher) OnResourceDoesNotExist(onDone xdsresource.OnDoneFunc) { + lw.updateCh.Replace(listenerUpdateErrTuple{err: xdsresource.NewErrorf(xdsresource.ErrorTypeResourceNotFound, "Listener not found in received response")}) + onDone() +} + +type listenerWatcherMultiple struct { + updateCh *testutils.Channel +} + +// TODO: delete this once `newListenerWatcher` is modified to handle multiple +// updates (https://github.com/grpc/grpc-go/issues/7864). +func newListenerWatcherMultiple(size int) *listenerWatcherMultiple { + return &listenerWatcherMultiple{updateCh: testutils.NewChannelWithSize(size)} +} + +func (lw *listenerWatcherMultiple) OnUpdate(update *xdsresource.ListenerResourceData, onDone xdsresource.OnDoneFunc) { + lw.updateCh.Send(listenerUpdateErrTuple{update: update.Resource}) + onDone() +} + +func (lw *listenerWatcherMultiple) OnError(err error, onDone xdsresource.OnDoneFunc) { + lw.updateCh.Send(listenerUpdateErrTuple{err: err}) + onDone() +} + +func (lw *listenerWatcherMultiple) OnResourceDoesNotExist(onDone xdsresource.OnDoneFunc) { + lw.updateCh.Send(listenerUpdateErrTuple{err: xdsresource.NewErrorf(xdsresource.ErrorTypeResourceNotFound, "Listener not found in received response")}) onDone() } @@ -155,6 +180,18 @@ func verifyListenerUpdate(ctx context.Context, updateCh *testutils.Channel, want return nil } +func verifyUnknownListenerError(ctx context.Context, updateCh *testutils.Channel, wantErr string) error { + u, err := updateCh.Receive(ctx) + if err != nil { + return fmt.Errorf("timeout when waiting for a listener error from the management server: %v", err) + } + gotErr := u.(listenerUpdateErrTuple).err + if gotErr == nil || !strings.Contains(gotErr.Error(), wantErr) { + return fmt.Errorf("update received with error: %v, want %q", gotErr, wantErr) + } + return nil +} + // TestLDSWatch covers the case where a single watcher exists for a single // listener resource. The test verifies the following scenarios: // 1. An update from the management server containing the resource being @@ -953,8 +990,9 @@ func (s) TestLDSWatch_NewWatcherForRemovedResource(t *testing.T) { } // TestLDSWatch_NACKError covers the case where an update from the management -// server is NACK'ed by the xdsclient. The test verifies that the error is -// propagated to the watcher. +// server is NACKed by the xdsclient. The test verifies that the error is +// propagated to the existing watcher. After NACK, if a new watcher registers +// for the resource, error is propagated to the new watcher as well. func (s) TestLDSWatch_NACKError(t *testing.T) { mgmtServer := e2e.StartManagementServer(t, e2e.ManagementServerOptions{}) @@ -992,19 +1030,141 @@ func (s) TestLDSWatch_NACKError(t *testing.T) { } // Verify that the expected error is propagated to the watcher. - u, err := lw.updateCh.Receive(ctx) + // Verify that the expected error is propagated to the existing watcher. + if err := verifyUnknownListenerError(ctx, lw.updateCh, wantListenerNACKErr); err != nil { + t.Fatal(err) + } + + // Verify that the expected error is propagated to the new watcher as well. + lw2 := newListenerWatcher() + ldsCancel2 := xdsresource.WatchListener(client, ldsName, lw2) + defer ldsCancel2() + // Verify that the expected error is propagated to the existing watcher. + if err := verifyUnknownListenerError(ctx, lw2.updateCh, wantListenerNACKErr); err != nil { + t.Fatal(err) + } +} + +// TestLDSWatch_ResourceCaching_WithNACKError covers the case where a watch is +// registered for a resource which is already present in the cache with an old +// good update as well as latest NACK error. The test verifies that new watcher +// receives both good update and error without a new resource request being +// sent to the management server. +func TestLDSWatch_ResourceCaching_NACKError(t *testing.T) { + firstRequestReceived := false + firstAckReceived := grpcsync.NewEvent() + secondAckReceived := grpcsync.NewEvent() + secondRequestReceived := grpcsync.NewEvent() + + mgmtServer := e2e.StartManagementServer(t, e2e.ManagementServerOptions{ + OnStreamRequest: func(id int64, req *v3discoverypb.DiscoveryRequest) error { + // The first request has an empty version string. + if !firstRequestReceived && req.GetVersionInfo() == "" { + firstRequestReceived = true + return nil + } + // The first ack has a non-empty version string. + if !firstAckReceived.HasFired() && req.GetVersionInfo() != "" { + firstAckReceived.Fire() + return nil + } + // The second ack has a non-empty version string. + if !secondAckReceived.HasFired() && req.GetVersionInfo() != "" { + secondAckReceived.Fire() + return nil + } + // Any requests after the first request and two acks, are not expected. + secondRequestReceived.Fire() + return nil + }, + }) + + nodeID := uuid.New().String() + bc := e2e.DefaultBootstrapContents(t, nodeID, mgmtServer.Address) + + // Create an xDS client with the above bootstrap contents. + client, close, err := xdsclient.NewForTesting(xdsclient.OptionsForTesting{ + Name: t.Name(), + Contents: bc, + }) if err != nil { - t.Fatalf("timeout when waiting for a listener resource from the management server: %v", err) + t.Fatalf("Failed to create xDS client: %v", err) } - gotErr := u.(listenerUpdateErrTuple).err - if gotErr == nil || !strings.Contains(gotErr.Error(), wantListenerNACKErr) { - t.Fatalf("update received with error: %v, want %q", gotErr, wantListenerNACKErr) + defer close() + + // Register a watch for a listener resource and have the watch + // callback push the received update on to a channel. + lw1 := newListenerWatcher() + ldsCancel1 := xdsresource.WatchListener(client, ldsName, lw1) + defer ldsCancel1() + + // Configure the management server to return a single listener + // resource, corresponding to the one we registered a watch for. + resources := e2e.UpdateOptions{ + NodeID: nodeID, + Listeners: []*v3listenerpb.Listener{e2e.DefaultClientListener(ldsName, rdsName)}, + SkipValidation: true, + } + ctx, cancel := context.WithTimeout(context.Background(), 1000*defaultTestTimeout) + defer cancel() + if err := mgmtServer.Update(ctx, resources); err != nil { + t.Fatalf("Failed to update management server with resources: %v, err: %v", resources, err) + } + + // Verify the contents of the received update. + wantUpdate := listenerUpdateErrTuple{ + update: xdsresource.ListenerUpdate{ + RouteConfigName: rdsName, + HTTPFilters: []xdsresource.HTTPFilter{{Name: "router"}}, + }, + } + if err := verifyListenerUpdate(ctx, lw1.updateCh, wantUpdate); err != nil { + t.Fatal(err) + } + + // Configure the management server to return a single listener resource + // which is expected to be NACKed by the client. + resources = e2e.UpdateOptions{ + NodeID: nodeID, + Listeners: []*v3listenerpb.Listener{badListenerResource(t, ldsName)}, + SkipValidation: true, + } + if err := mgmtServer.Update(ctx, resources); err != nil { + t.Fatalf("Failed to update management server with resources: %v, err: %v", resources, err) + } + + // Verify that the expected error is propagated to the existing watcher. + if err := verifyUnknownListenerError(ctx, lw1.updateCh, wantListenerNACKErr); err != nil { + t.Fatal(err) + } + + // Register another watch for the same resource. This should get the update + // and error from the cache. + lw2 := newListenerWatcherMultiple(2) + ldsCancel2 := xdsresource.WatchListener(client, ldsName, lw2) + defer ldsCancel2() + if err := verifyListenerUpdate(ctx, lw2.updateCh, wantUpdate); err != nil { + t.Fatal(err) + } + // Verify that the expected error is propagated to the existing watcher. + if err := verifyUnknownListenerError(ctx, lw2.updateCh, wantListenerNACKErr); err != nil { + t.Fatal(err) + } + + // No request should get sent out as part of this watch. + sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout) + defer sCancel() + select { + case <-sCtx.Done(): + case <-secondRequestReceived.Done(): + t.Fatal("xdsClient sent out request instead of using update from cache") + default: } } // TestLDSWatch_PartialValid covers the case where a response from the // management server contains both valid and invalid resources and is expected -// to be NACK'ed by the xdsclient. The test verifies that watchers corresponding +// to be NACKed by the xdsclient. The test verifies that watchers corresponding // to the valid resource receive the update, while watchers corresponding to the // invalid resource receive an error. func (s) TestLDSWatch_PartialValid(t *testing.T) { @@ -1071,13 +1231,9 @@ func (s) TestLDSWatch_PartialValid(t *testing.T) { // Verify that the expected error is propagated to the watcher which // requested for the bad resource. - u, err := lw1.updateCh.Receive(ctx) - if err != nil { - t.Fatalf("timeout when waiting for a listener resource from the management server: %v", err) - } - gotErr := u.(listenerUpdateErrTuple).err - if gotErr == nil || !strings.Contains(gotErr.Error(), wantListenerNACKErr) { - t.Fatalf("update received with error: %v, want %q", gotErr, wantListenerNACKErr) + // Verify that the expected error is propagated to the existing watcher. + if err := verifyUnknownListenerError(ctx, lw1.updateCh, wantListenerNACKErr); err != nil { + t.Fatal(err) } // Verify that the watcher watching the good resource receives a good From 93f1cc163b21b863fff761f0122db733d03aa657 Mon Sep 17 00:00:00 2001 From: Brad Town Date: Fri, 22 Nov 2024 10:46:40 -0800 Subject: [PATCH 3/9] credentials/alts: avoid SRV and TXT lookups for handshaker service (#7861) --- credentials/alts/internal/handshaker/service/service.go | 4 +++- internal/resolver/dns/dns_resolver.go | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/credentials/alts/internal/handshaker/service/service.go b/credentials/alts/internal/handshaker/service/service.go index b3af03590729..fbfde5d047fe 100644 --- a/credentials/alts/internal/handshaker/service/service.go +++ b/credentials/alts/internal/handshaker/service/service.go @@ -47,8 +47,10 @@ func Dial(hsAddress string) (*grpc.ClientConn, error) { if !ok { // Create a new connection to the handshaker service. Note that // this connection stays open until the application is closed. + // Disable the service config to avoid unnecessary TXT record lookups that + // cause timeouts with some versions of systemd-resolved. var err error - hsConn, err = grpc.Dial(hsAddress, grpc.WithTransportCredentials(insecure.NewCredentials())) + hsConn, err = grpc.Dial(hsAddress, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDisableServiceConfig()) if err != nil { return nil, err } diff --git a/internal/resolver/dns/dns_resolver.go b/internal/resolver/dns/dns_resolver.go index cc5d5e05c010..b080ae30bc1b 100644 --- a/internal/resolver/dns/dns_resolver.go +++ b/internal/resolver/dns/dns_resolver.go @@ -237,7 +237,9 @@ func (d *dnsResolver) watcher() { } func (d *dnsResolver) lookupSRV(ctx context.Context) ([]resolver.Address, error) { - if !EnableSRVLookups { + // Skip this particular host to avoid timeouts with some versions of + // systemd-resolved. + if !EnableSRVLookups || d.host == "metadata.google.internal." { return nil, nil } var newAddrs []resolver.Address From 13d5a168d98ac33cb1c28083d29b76d92d24085d Mon Sep 17 00:00:00 2001 From: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Fri, 22 Nov 2024 19:20:03 -0500 Subject: [PATCH 4/9] balancer/weightedroundrobin: Switch Weighted Round Robin to use pick first instead of SubConns (#7826) --- balancer/endpointsharding/endpointsharding.go | 27 +- balancer/weightedroundrobin/balancer.go | 569 +++++++++--------- balancer/weightedroundrobin/balancer_test.go | 59 ++ balancer/weightedroundrobin/metrics_test.go | 10 +- balancer/weightedroundrobin/scheduler.go | 14 +- 5 files changed, 374 insertions(+), 305 deletions(-) diff --git a/balancer/endpointsharding/endpointsharding.go b/balancer/endpointsharding/endpointsharding.go index b5b92143194b..263c024a84c7 100644 --- a/balancer/endpointsharding/endpointsharding.go +++ b/balancer/endpointsharding/endpointsharding.go @@ -28,19 +28,33 @@ package endpointsharding import ( "encoding/json" "errors" + "fmt" + rand "math/rand/v2" "sync" "sync/atomic" - rand "math/rand/v2" - "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/balancer/pickfirst" + "google.golang.org/grpc/balancer/pickfirst/pickfirstleaf" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/internal/balancer/gracefulswitch" + "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" ) +// PickFirstConfig is a pick first config without shuffling enabled. +var PickFirstConfig string + +func init() { + name := pickfirst.Name + if !envconfig.NewPickFirstEnabled { + name = pickfirstleaf.Name + } + PickFirstConfig = fmt.Sprintf("[{%q: {}}]", name) +} + // ChildState is the balancer state of a child along with the endpoint which // identifies the child balancer. type ChildState struct { @@ -100,9 +114,6 @@ func (es *endpointSharding) UpdateClientConnState(state balancer.ClientConnState // Update/Create new children. for _, endpoint := range state.ResolverState.Endpoints { - if len(endpoint.Addresses) == 0 { - continue - } if _, ok := newChildren.Get(endpoint); ok { // Endpoint child was already created, continue to avoid duplicate // update. @@ -143,6 +154,9 @@ func (es *endpointSharding) UpdateClientConnState(state balancer.ClientConnState } } es.children.Store(newChildren) + if newChildren.Len() == 0 { + return balancer.ErrBadResolverState + } return ret } @@ -306,6 +320,3 @@ func (bw *balancerWrapper) UpdateState(state balancer.State) { func ParseConfig(cfg json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { return gracefulswitch.ParseConfig(cfg) } - -// PickFirstConfig is a pick first config without shuffling enabled. -const PickFirstConfig = "[{\"pick_first\": {}}]" diff --git a/balancer/weightedroundrobin/balancer.go b/balancer/weightedroundrobin/balancer.go index a0511772d2fa..c9c5b576bb0c 100644 --- a/balancer/weightedroundrobin/balancer.go +++ b/balancer/weightedroundrobin/balancer.go @@ -19,9 +19,7 @@ package weightedroundrobin import ( - "context" "encoding/json" - "errors" "fmt" rand "math/rand/v2" "sync" @@ -30,12 +28,13 @@ import ( "unsafe" "google.golang.org/grpc/balancer" - "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/balancer/endpointsharding" "google.golang.org/grpc/balancer/weightedroundrobin/internal" "google.golang.org/grpc/balancer/weightedtarget" "google.golang.org/grpc/connectivity" estats "google.golang.org/grpc/experimental/stats" "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/internal/grpcsync" iserviceconfig "google.golang.org/grpc/internal/serviceconfig" "google.golang.org/grpc/orca" "google.golang.org/grpc/resolver" @@ -84,23 +83,31 @@ var ( }) ) +// endpointSharding which specifies pick first children. +var endpointShardingLBConfig serviceconfig.LoadBalancingConfig + func init() { balancer.Register(bb{}) + var err error + endpointShardingLBConfig, err = endpointsharding.ParseConfig(json.RawMessage(endpointsharding.PickFirstConfig)) + if err != nil { + logger.Fatal(err) + } } type bb struct{} func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { b := &wrrBalancer{ - cc: cc, - subConns: resolver.NewAddressMap(), - csEvltr: &balancer.ConnectivityStateEvaluator{}, - scMap: make(map[balancer.SubConn]*weightedSubConn), - connectivityState: connectivity.Connecting, - target: bOpts.Target.String(), - metricsRecorder: bOpts.MetricsRecorder, + ClientConn: cc, + target: bOpts.Target.String(), + metricsRecorder: bOpts.MetricsRecorder, + addressWeights: resolver.NewAddressMap(), + endpointToWeight: resolver.NewEndpointMap(), + scToWeight: make(map[balancer.SubConn]*endpointWeight), } + b.child = endpointsharding.NewBalancer(b, bOpts) b.logger = prefixLogger(b) b.logger.Infof("Created") return b @@ -141,123 +148,189 @@ func (bb) Name() string { return Name } +// updateEndpointsLocked updates endpoint weight state based off new update, by +// starting and clearing any endpoint weights needed. +// +// Caller must hold b.mu. +func (b *wrrBalancer) updateEndpointsLocked(endpoints []resolver.Endpoint) { + endpointSet := resolver.NewEndpointMap() + addressSet := resolver.NewAddressMap() + for _, endpoint := range endpoints { + endpointSet.Set(endpoint, nil) + for _, addr := range endpoint.Addresses { + addressSet.Set(addr, nil) + } + var ew *endpointWeight + if ewi, ok := b.endpointToWeight.Get(endpoint); ok { + ew = ewi.(*endpointWeight) + } else { + ew = &endpointWeight{ + logger: b.logger, + connectivityState: connectivity.Connecting, + // Initially, we set load reports to off, because they are not + // running upon initial endpointWeight creation. + cfg: &lbConfig{EnableOOBLoadReport: false}, + metricsRecorder: b.metricsRecorder, + target: b.target, + locality: b.locality, + } + for _, addr := range endpoint.Addresses { + b.addressWeights.Set(addr, ew) + } + b.endpointToWeight.Set(endpoint, ew) + } + ew.updateConfig(b.cfg) + } + + for _, endpoint := range b.endpointToWeight.Keys() { + if _, ok := endpointSet.Get(endpoint); ok { + // Existing endpoint also in new endpoint list; skip. + continue + } + b.endpointToWeight.Delete(endpoint) + for _, addr := range endpoint.Addresses { + if _, ok := addressSet.Get(addr); !ok { // old endpoints to be deleted can share addresses with new endpoints, so only delete if necessary + b.addressWeights.Delete(addr) + } + } + // SubConn map will get handled in updateSubConnState + // when receives SHUTDOWN signal. + } +} + // wrrBalancer implements the weighted round robin LB policy. type wrrBalancer struct { - // The following fields are immutable. - cc balancer.ClientConn - logger *grpclog.PrefixLogger - target string - metricsRecorder estats.MetricsRecorder - - // The following fields are only accessed on calls into the LB policy, and - // do not need a mutex. - cfg *lbConfig // active config - subConns *resolver.AddressMap // active weightedSubConns mapped by address - scMap map[balancer.SubConn]*weightedSubConn - connectivityState connectivity.State // aggregate state - csEvltr *balancer.ConnectivityStateEvaluator - resolverErr error // the last error reported by the resolver; cleared on successful resolution - connErr error // the last connection error; cleared upon leaving TransientFailure - stopPicker func() - locality string + // The following fields are set at initialization time and read only after that, + // so they do not need to be protected by a mutex. + child balancer.Balancer + balancer.ClientConn // Embed to intercept NewSubConn operation + logger *grpclog.PrefixLogger + target string + metricsRecorder estats.MetricsRecorder + + mu sync.Mutex + cfg *lbConfig // active config + locality string + stopPicker *grpcsync.Event + addressWeights *resolver.AddressMap // addr -> endpointWeight + endpointToWeight *resolver.EndpointMap // endpoint -> endpointWeight + scToWeight map[balancer.SubConn]*endpointWeight } func (b *wrrBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error { b.logger.Infof("UpdateCCS: %v", ccs) - b.resolverErr = nil cfg, ok := ccs.BalancerConfig.(*lbConfig) if !ok { return fmt.Errorf("wrr: received nil or illegal BalancerConfig (type %T): %v", ccs.BalancerConfig, ccs.BalancerConfig) } + // Note: empty endpoints and duplicate addresses across endpoints won't + // explicitly error but will have undefined behavior. + b.mu.Lock() b.cfg = cfg b.locality = weightedtarget.LocalityFromResolverState(ccs.ResolverState) - b.updateAddresses(ccs.ResolverState.Addresses) - - if len(ccs.ResolverState.Addresses) == 0 { - b.ResolverError(errors.New("resolver produced zero addresses")) // will call regeneratePicker - return balancer.ErrBadResolverState - } + b.updateEndpointsLocked(ccs.ResolverState.Endpoints) + b.mu.Unlock() + + // This causes child to update picker inline and will thus cause inline + // picker update. + return b.child.UpdateClientConnState(balancer.ClientConnState{ + BalancerConfig: endpointShardingLBConfig, + ResolverState: ccs.ResolverState, + }) +} - b.regeneratePicker() +func (b *wrrBalancer) UpdateState(state balancer.State) { + b.mu.Lock() + defer b.mu.Unlock() - return nil -} + if b.stopPicker != nil { + b.stopPicker.Fire() + b.stopPicker = nil + } -func (b *wrrBalancer) updateAddresses(addrs []resolver.Address) { - addrsSet := resolver.NewAddressMap() + childStates := endpointsharding.ChildStatesFromPicker(state.Picker) - // Loop through new address list and create subconns for any new addresses. - for _, addr := range addrs { - if _, ok := addrsSet.Get(addr); ok { - // Redundant address; skip. - continue - } - addrsSet.Set(addr, nil) + var readyPickersWeight []pickerWeightedEndpoint - var wsc *weightedSubConn - wsci, ok := b.subConns.Get(addr) - if ok { - wsc = wsci.(*weightedSubConn) - } else { - // addr is a new address (not existing in b.subConns). - var sc balancer.SubConn - sc, err := b.cc.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{ - StateListener: func(state balancer.SubConnState) { - b.updateSubConnState(sc, state) - }, - }) - if err != nil { - b.logger.Warningf("Failed to create new SubConn for address %v: %v", addr, err) + for _, childState := range childStates { + if childState.State.ConnectivityState == connectivity.Ready { + ewv, ok := b.endpointToWeight.Get(childState.Endpoint) + if !ok { + // Should never happen, simply continue and ignore this endpoint + // for READY pickers. continue } - wsc = &weightedSubConn{ - SubConn: sc, - logger: b.logger, - connectivityState: connectivity.Idle, - // Initially, we set load reports to off, because they are not - // running upon initial weightedSubConn creation. - cfg: &lbConfig{EnableOOBLoadReport: false}, - - metricsRecorder: b.metricsRecorder, - target: b.target, - locality: b.locality, - } - b.subConns.Set(addr, wsc) - b.scMap[sc] = wsc - b.csEvltr.RecordTransition(connectivity.Shutdown, connectivity.Idle) - sc.Connect() + ew := ewv.(*endpointWeight) + readyPickersWeight = append(readyPickersWeight, pickerWeightedEndpoint{ + picker: childState.State.Picker, + weightedEndpoint: ew, + }) } - // Update config for existing weightedSubConn or send update for first - // time to new one. Ensures an OOB listener is running if needed - // (and stops the existing one if applicable). - wsc.updateConfig(b.cfg) + } + // If no ready pickers are present, simply defer to the round robin picker + // from endpoint sharding, which will round robin across the most relevant + // pick first children in the highest precedence connectivity state. + if len(readyPickersWeight) == 0 { + b.ClientConn.UpdateState(balancer.State{ + ConnectivityState: state.ConnectivityState, + Picker: state.Picker, + }) + return } - // Loop through existing subconns and remove ones that are not in addrs. - for _, addr := range b.subConns.Keys() { - if _, ok := addrsSet.Get(addr); ok { - // Existing address also in new address list; skip. - continue - } - // addr was removed by resolver. Remove. - wsci, _ := b.subConns.Get(addr) - wsc := wsci.(*weightedSubConn) - wsc.SubConn.Shutdown() - b.subConns.Delete(addr) + p := &picker{ + v: rand.Uint32(), // start the scheduler at a random point + cfg: b.cfg, + weightedPickers: readyPickersWeight, + metricsRecorder: b.metricsRecorder, + locality: b.locality, + target: b.target, } + + b.stopPicker = grpcsync.NewEvent() + p.start(b.stopPicker) + + b.ClientConn.UpdateState(balancer.State{ + ConnectivityState: state.ConnectivityState, + Picker: p, + }) } -func (b *wrrBalancer) ResolverError(err error) { - b.resolverErr = err - if b.subConns.Len() == 0 { - b.connectivityState = connectivity.TransientFailure +type pickerWeightedEndpoint struct { + picker balancer.Picker + weightedEndpoint *endpointWeight +} + +func (b *wrrBalancer) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { + addr := addrs[0] // The new pick first policy for DualStack will only ever create a SubConn with one address. + var sc balancer.SubConn + + oldListener := opts.StateListener + opts.StateListener = func(state balancer.SubConnState) { + b.updateSubConnState(sc, state) + oldListener(state) } - if b.connectivityState != connectivity.TransientFailure { - // No need to update the picker since no error is being returned. - return + + b.mu.Lock() + defer b.mu.Unlock() + ewi, ok := b.addressWeights.Get(addr) + if !ok { + // SubConn state updates can come in for a no longer relevant endpoint + // weight (from the old system after a new config update is applied). + return nil, fmt.Errorf("balancer is being closed; no new SubConns allowed") + } + sc, err := b.ClientConn.NewSubConn([]resolver.Address{addr}, opts) + if err != nil { + return nil, err } - b.regeneratePicker() + b.scToWeight[sc] = ewi.(*endpointWeight) + return sc, nil +} + +func (b *wrrBalancer) ResolverError(err error) { + // Will cause inline picker update from endpoint sharding. + b.child.ResolverError(err) } func (b *wrrBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { @@ -265,134 +338,84 @@ func (b *wrrBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.Sub } func (b *wrrBalancer) updateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { - wsc := b.scMap[sc] - if wsc == nil { - b.logger.Errorf("UpdateSubConnState called with an unknown SubConn: %p, %v", sc, state) + b.mu.Lock() + ew := b.scToWeight[sc] + // updates from a no longer relevant SubConn update, nothing to do here but + // forward state to state listener, which happens in wrapped listener. Will + // eventually get cleared from scMap once receives Shutdown signal. + if ew == nil { + b.mu.Unlock() return } - if b.logger.V(2) { - logger.Infof("UpdateSubConnState(%+v, %+v)", sc, state) - } - - cs := state.ConnectivityState - - if cs == connectivity.TransientFailure { - // Save error to be reported via picker. - b.connErr = state.ConnectionError - } - - if cs == connectivity.Shutdown { - delete(b.scMap, sc) - // The subconn was removed from b.subConns when the address was removed - // in updateAddresses. + if state.ConnectivityState == connectivity.Shutdown { + delete(b.scToWeight, sc) + } + b.mu.Unlock() + + // On the first READY SubConn/Transition for an endpoint, set pickedSC, + // clear endpoint tracking weight state, and potentially start an OOB watch. + if state.ConnectivityState == connectivity.Ready && ew.pickedSC == nil { + ew.pickedSC = sc + ew.mu.Lock() + ew.nonEmptySince = time.Time{} + ew.lastUpdated = time.Time{} + cfg := ew.cfg + ew.mu.Unlock() + ew.updateORCAListener(cfg) + return } - oldCS := wsc.updateConnectivityState(cs) - b.connectivityState = b.csEvltr.RecordTransition(oldCS, cs) - - // Regenerate picker when one of the following happens: - // - this sc entered or left ready - // - the aggregated state of balancer is TransientFailure - // (may need to update error message) - if (cs == connectivity.Ready) != (oldCS == connectivity.Ready) || - b.connectivityState == connectivity.TransientFailure { - b.regeneratePicker() + // If the pickedSC (the one pick first uses for an endpoint) transitions out + // of READY, stop OOB listener if needed and clear pickedSC so the next + // created SubConn for the endpoint that goes READY will be chosen for + // endpoint as the active SubConn. + if state.ConnectivityState != connectivity.Ready && ew.pickedSC == sc { + // The first SubConn that goes READY for an endpoint is what pick first + // will pick. Only once that SubConn goes not ready will pick first + // restart this cycle of creating SubConns and using the first READY + // one. The lower level endpoint sharding will ping the Pick First once + // this occurs to ExitIdle which will trigger a connection attempt. + if ew.stopORCAListener != nil { + ew.stopORCAListener() + } + ew.pickedSC = nil } } // Close stops the balancer. It cancels any ongoing scheduler updates and // stops any ORCA listeners. func (b *wrrBalancer) Close() { + b.mu.Lock() if b.stopPicker != nil { - b.stopPicker() + b.stopPicker.Fire() b.stopPicker = nil } - for _, wsc := range b.scMap { - // Ensure any lingering OOB watchers are stopped. - wsc.updateConnectivityState(connectivity.Shutdown) - } -} - -// ExitIdle is ignored; we always connect to all backends. -func (b *wrrBalancer) ExitIdle() {} + b.mu.Unlock() -func (b *wrrBalancer) readySubConns() []*weightedSubConn { - var ret []*weightedSubConn - for _, v := range b.subConns.Values() { - wsc := v.(*weightedSubConn) - if wsc.connectivityState == connectivity.Ready { - ret = append(ret, wsc) + // Ensure any lingering OOB watchers are stopped. + for _, ewv := range b.endpointToWeight.Values() { + ew := ewv.(*endpointWeight) + if ew.stopORCAListener != nil { + ew.stopORCAListener() } } - return ret } -// mergeErrors builds an error from the last connection error and the last -// resolver error. Must only be called if b.connectivityState is -// TransientFailure. -func (b *wrrBalancer) mergeErrors() error { - // connErr must always be non-nil unless there are no SubConns, in which - // case resolverErr must be non-nil. - if b.connErr == nil { - return fmt.Errorf("last resolver error: %v", b.resolverErr) +func (b *wrrBalancer) ExitIdle() { + if ei, ok := b.child.(balancer.ExitIdler); ok { // Should always be ok, as child is endpoint sharding. + ei.ExitIdle() } - if b.resolverErr == nil { - return fmt.Errorf("last connection error: %v", b.connErr) - } - return fmt.Errorf("last connection error: %v; last resolver error: %v", b.connErr, b.resolverErr) -} - -func (b *wrrBalancer) regeneratePicker() { - if b.stopPicker != nil { - b.stopPicker() - b.stopPicker = nil - } - - switch b.connectivityState { - case connectivity.TransientFailure: - b.cc.UpdateState(balancer.State{ - ConnectivityState: connectivity.TransientFailure, - Picker: base.NewErrPicker(b.mergeErrors()), - }) - return - case connectivity.Connecting, connectivity.Idle: - // Idle could happen very briefly if all subconns are Idle and we've - // asked them to connect but they haven't reported Connecting yet. - // Report the same as Connecting since this is temporary. - b.cc.UpdateState(balancer.State{ - ConnectivityState: connectivity.Connecting, - Picker: base.NewErrPicker(balancer.ErrNoSubConnAvailable), - }) - return - case connectivity.Ready: - b.connErr = nil - } - - p := &picker{ - v: rand.Uint32(), // start the scheduler at a random point - cfg: b.cfg, - subConns: b.readySubConns(), - metricsRecorder: b.metricsRecorder, - locality: b.locality, - target: b.target, - } - var ctx context.Context - ctx, b.stopPicker = context.WithCancel(context.Background()) - p.start(ctx) - b.cc.UpdateState(balancer.State{ - ConnectivityState: b.connectivityState, - Picker: p, - }) } // picker is the WRR policy's picker. It uses live-updating backend weights to // update the scheduler periodically and ensure picks are routed proportional // to those weights. type picker struct { - scheduler unsafe.Pointer // *scheduler; accessed atomically - v uint32 // incrementing value used by the scheduler; accessed atomically - cfg *lbConfig // active config when picker created - subConns []*weightedSubConn // all READY subconns + scheduler unsafe.Pointer // *scheduler; accessed atomically + v uint32 // incrementing value used by the scheduler; accessed atomically + cfg *lbConfig // active config when picker created + + weightedPickers []pickerWeightedEndpoint // all READY pickers // The following fields are immutable. target string @@ -400,14 +423,39 @@ type picker struct { metricsRecorder estats.MetricsRecorder } -func (p *picker) scWeights(recordMetrics bool) []float64 { - ws := make([]float64, len(p.subConns)) +func (p *picker) endpointWeights(recordMetrics bool) []float64 { + wp := make([]float64, len(p.weightedPickers)) now := internal.TimeNow() - for i, wsc := range p.subConns { - ws[i] = wsc.weight(now, time.Duration(p.cfg.WeightExpirationPeriod), time.Duration(p.cfg.BlackoutPeriod), recordMetrics) + for i, wpi := range p.weightedPickers { + wp[i] = wpi.weightedEndpoint.weight(now, time.Duration(p.cfg.WeightExpirationPeriod), time.Duration(p.cfg.BlackoutPeriod), recordMetrics) } + return wp +} + +func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { + // Read the scheduler atomically. All scheduler operations are threadsafe, + // and if the scheduler is replaced during this usage, we want to use the + // scheduler that was live when the pick started. + sched := *(*scheduler)(atomic.LoadPointer(&p.scheduler)) - return ws + pickedPicker := p.weightedPickers[sched.nextIndex()] + pr, err := pickedPicker.picker.Pick(info) + if err != nil { + logger.Errorf("ready picker returned error: %v", err) + return balancer.PickResult{}, err + } + if !p.cfg.EnableOOBLoadReport { + oldDone := pr.Done + pr.Done = func(info balancer.DoneInfo) { + if load, ok := info.ServerLoad.(*v3orcapb.OrcaLoadReport); ok && load != nil { + pickedPicker.weightedEndpoint.OnLoadReport(load) + } + if oldDone != nil { + oldDone(info) + } + } + } + return pr, nil } func (p *picker) inc() uint32 { @@ -419,9 +467,9 @@ func (p *picker) regenerateScheduler() { atomic.StorePointer(&p.scheduler, unsafe.Pointer(&s)) } -func (p *picker) start(ctx context.Context) { +func (p *picker) start(stopPicker *grpcsync.Event) { p.regenerateScheduler() - if len(p.subConns) == 1 { + if len(p.weightedPickers) == 1 { // No need to regenerate weights with only one backend. return } @@ -431,7 +479,7 @@ func (p *picker) start(ctx context.Context) { defer ticker.Stop() for { select { - case <-ctx.Done(): + case <-stopPicker.Done(): return case <-ticker.C: p.regenerateScheduler() @@ -440,29 +488,12 @@ func (p *picker) start(ctx context.Context) { }() } -func (p *picker) Pick(balancer.PickInfo) (balancer.PickResult, error) { - // Read the scheduler atomically. All scheduler operations are threadsafe, - // and if the scheduler is replaced during this usage, we want to use the - // scheduler that was live when the pick started. - sched := *(*scheduler)(atomic.LoadPointer(&p.scheduler)) - - pickedSC := p.subConns[sched.nextIndex()] - pr := balancer.PickResult{SubConn: pickedSC.SubConn} - if !p.cfg.EnableOOBLoadReport { - pr.Done = func(info balancer.DoneInfo) { - if load, ok := info.ServerLoad.(*v3orcapb.OrcaLoadReport); ok && load != nil { - pickedSC.OnLoadReport(load) - } - } - } - return pr, nil -} - -// weightedSubConn is the wrapper of a subconn that holds the subconn and its -// weight (and other parameters relevant to computing the effective weight). -// When needed, it also tracks connectivity state, listens for metrics updates -// by implementing the orca.OOBListener interface and manages that listener. -type weightedSubConn struct { +// endpointWeight is the weight for an endpoint. It tracks the SubConn that will +// be picked for the endpoint, and other parameters relevant to computing the +// effective weight. When needed, it also tracks connectivity state, listens for +// metrics updates by implementing the orca.OOBListener interface and manages +// that listener. +type endpointWeight struct { // The following fields are immutable. balancer.SubConn logger *grpclog.PrefixLogger @@ -474,6 +505,11 @@ type weightedSubConn struct { // do not need a mutex. connectivityState connectivity.State stopORCAListener func() + // The first SubConn for the endpoint that goes READY when endpoint has no + // READY SubConns yet, cleared on that sc disconnecting (i.e. going out of + // READY). Represents what pick first will use as it's picked SubConn for + // this endpoint. + pickedSC balancer.SubConn // The following fields are accessed asynchronously and are protected by // mu. Note that mu may not be held when calling into the stopORCAListener @@ -487,11 +523,11 @@ type weightedSubConn struct { cfg *lbConfig } -func (w *weightedSubConn) OnLoadReport(load *v3orcapb.OrcaLoadReport) { +func (w *endpointWeight) OnLoadReport(load *v3orcapb.OrcaLoadReport) { if w.logger.V(2) { w.logger.Infof("Received load report for subchannel %v: %v", w.SubConn, load) } - // Update weights of this subchannel according to the reported load + // Update weights of this endpoint according to the reported load. utilization := load.ApplicationUtilization if utilization == 0 { utilization = load.CpuUtilization @@ -520,7 +556,7 @@ func (w *weightedSubConn) OnLoadReport(load *v3orcapb.OrcaLoadReport) { // updateConfig updates the parameters of the WRR policy and // stops/starts/restarts the ORCA OOB listener. -func (w *weightedSubConn) updateConfig(cfg *lbConfig) { +func (w *endpointWeight) updateConfig(cfg *lbConfig) { w.mu.Lock() oldCfg := w.cfg w.cfg = cfg @@ -533,14 +569,12 @@ func (w *weightedSubConn) updateConfig(cfg *lbConfig) { // load reporting disabled, OOBReportingPeriod is always 0.) return } - if w.connectivityState == connectivity.Ready { - // (Re)start the listener to use the new config's settings for OOB - // reporting. - w.updateORCAListener(cfg) - } + // (Re)start the listener to use the new config's settings for OOB + // reporting. + w.updateORCAListener(cfg) } -func (w *weightedSubConn) updateORCAListener(cfg *lbConfig) { +func (w *endpointWeight) updateORCAListener(cfg *lbConfig) { if w.stopORCAListener != nil { w.stopORCAListener() } @@ -548,57 +582,22 @@ func (w *weightedSubConn) updateORCAListener(cfg *lbConfig) { w.stopORCAListener = nil return } + if w.pickedSC == nil { // No picked SC for this endpoint yet, nothing to listen on. + return + } if w.logger.V(2) { - w.logger.Infof("Registering ORCA listener for %v with interval %v", w.SubConn, cfg.OOBReportingPeriod) + w.logger.Infof("Registering ORCA listener for %v with interval %v", w.pickedSC, cfg.OOBReportingPeriod) } opts := orca.OOBListenerOptions{ReportInterval: time.Duration(cfg.OOBReportingPeriod)} - w.stopORCAListener = orca.RegisterOOBListener(w.SubConn, w, opts) -} - -func (w *weightedSubConn) updateConnectivityState(cs connectivity.State) connectivity.State { - switch cs { - case connectivity.Idle: - // Always reconnect when idle. - w.SubConn.Connect() - case connectivity.Ready: - // If we transition back to READY state, reset nonEmptySince so that we - // apply the blackout period after we start receiving load data. Also - // reset lastUpdated to trigger endpoint weight not yet usable in the - // case endpoint gets asked what weight it is before receiving a new - // load report. Note that we cannot guarantee that we will never receive - // lingering callbacks for backend metric reports from the previous - // connection after the new connection has been established, but they - // should be masked by new backend metric reports from the new - // connection by the time the blackout period ends. - w.mu.Lock() - w.nonEmptySince = time.Time{} - w.lastUpdated = time.Time{} - cfg := w.cfg - w.mu.Unlock() - w.updateORCAListener(cfg) - } - - oldCS := w.connectivityState - - if oldCS == connectivity.TransientFailure && - (cs == connectivity.Connecting || cs == connectivity.Idle) { - // Once a subconn enters TRANSIENT_FAILURE, ignore subsequent IDLE or - // CONNECTING transitions to prevent the aggregated state from being - // always CONNECTING when many backends exist but are all down. - return oldCS - } - - w.connectivityState = cs - - return oldCS + w.stopORCAListener = orca.RegisterOOBListener(w.pickedSC, w, opts) } -// weight returns the current effective weight of the subconn, taking into +// weight returns the current effective weight of the endpoint, taking into // account the parameters. Returns 0 for blacked out or expired data, which // will cause the backend weight to be treated as the mean of the weights of the // other backends. If forScheduler is set to true, this function will emit // metrics through the metrics registry. -func (w *weightedSubConn) weight(now time.Time, weightExpirationPeriod, blackoutPeriod time.Duration, recordMetrics bool) (weight float64) { +func (w *endpointWeight) weight(now time.Time, weightExpirationPeriod, blackoutPeriod time.Duration, recordMetrics bool) (weight float64) { w.mu.Lock() defer w.mu.Unlock() @@ -608,7 +607,7 @@ func (w *weightedSubConn) weight(now time.Time, weightExpirationPeriod, blackout }() } - // The SubConn has not received a load report (i.e. just turned READY with + // The endpoint has not received a load report (i.e. just turned READY with // no load report). if w.lastUpdated.Equal(time.Time{}) { endpointWeightNotYetUsableMetric.Record(w.metricsRecorder, 1, w.target, w.locality) diff --git a/balancer/weightedroundrobin/balancer_test.go b/balancer/weightedroundrobin/balancer_test.go index 68d2d5a5c5c8..5e369780764e 100644 --- a/balancer/weightedroundrobin/balancer_test.go +++ b/balancer/weightedroundrobin/balancer_test.go @@ -460,6 +460,65 @@ func (s) TestBalancer_TwoAddresses_OOBThenPerCall(t *testing.T) { checkWeights(ctx, t, srvWeight{srv1, 10}, srvWeight{srv2, 1}) } +// TestEndpoints_SharedAddress tests the case where two endpoints have the same +// address. The expected behavior is undefined, however the program should not +// crash. +func (s) TestEndpoints_SharedAddress(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + srv := startServer(t, reportCall) + sc := svcConfig(t, perCallConfig) + if err := srv.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil { + t.Fatalf("Error starting client: %v", err) + } + + endpointsSharedAddress := []resolver.Endpoint{{Addresses: []resolver.Address{{Addr: srv.Address}}}, {Addresses: []resolver.Address{{Addr: srv.Address}}}} + srv.R.UpdateState(resolver.State{Endpoints: endpointsSharedAddress}) + + // Make some RPC's and make sure doesn't crash. It should go to one of the + // endpoints addresses, it's undefined which one it will choose and the load + // reporting might not work, but it should be able to make an RPC. + for i := 0; i < 10; i++ { + if _, err := srv.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("EmptyCall failed with err: %v", err) + } + } +} + +// TestEndpoints_MultipleAddresses tests WRR on endpoints with numerous +// addresses. It configures WRR with two endpoints with one bad address followed +// by a good address. It configures two backends that each report per call +// metrics, each corresponding to the two endpoints good address. It then +// asserts load is distributed as expected corresponding to the call metrics +// received. +func (s) TestEndpoints_MultipleAddresses(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + srv1 := startServer(t, reportCall) + srv2 := startServer(t, reportCall) + + srv1.callMetrics.SetQPS(10.0) + srv1.callMetrics.SetApplicationUtilization(.1) + + srv2.callMetrics.SetQPS(10.0) + srv2.callMetrics.SetApplicationUtilization(1.0) + + sc := svcConfig(t, perCallConfig) + if err := srv1.StartClient(grpc.WithDefaultServiceConfig(sc)); err != nil { + t.Fatalf("Error starting client: %v", err) + } + + twoEndpoints := []resolver.Endpoint{{Addresses: []resolver.Address{{Addr: "bad-address-1"}, {Addr: srv1.Address}}}, {Addresses: []resolver.Address{{Addr: "bad-address-2"}, {Addr: srv2.Address}}}} + srv1.R.UpdateState(resolver.State{Endpoints: twoEndpoints}) + + // Call each backend once to ensure the weights have been received. + ensureReached(ctx, t, srv1.Client, 2) + // Wait for the weight update period to allow the new weights to be processed. + time.Sleep(weightUpdatePeriod) + checkWeights(ctx, t, srvWeight{srv1, 10}, srvWeight{srv2, 1}) +} + // Tests two addresses with OOB ORCA reporting enabled and a non-zero error // penalty applied. func (s) TestBalancer_TwoAddresses_ErrorPenalty(t *testing.T) { diff --git a/balancer/weightedroundrobin/metrics_test.go b/balancer/weightedroundrobin/metrics_test.go index 9794a65e044f..79e4d0a145a0 100644 --- a/balancer/weightedroundrobin/metrics_test.go +++ b/balancer/weightedroundrobin/metrics_test.go @@ -109,7 +109,7 @@ func (s) TestWRR_Metrics_SubConnWeight(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { tmr := stats.NewTestMetricsRecorder() - wsc := &weightedSubConn{ + wsc := &endpointWeight{ metricsRecorder: tmr, weightVal: 3, lastUpdated: test.lastUpdated, @@ -137,7 +137,7 @@ func (s) TestWRR_Metrics_SubConnWeight(t *testing.T) { // fallback. func (s) TestWRR_Metrics_Scheduler_RR_Fallback(t *testing.T) { tmr := stats.NewTestMetricsRecorder() - wsc := &weightedSubConn{ + ew := &endpointWeight{ metricsRecorder: tmr, weightVal: 0, } @@ -147,7 +147,7 @@ func (s) TestWRR_Metrics_Scheduler_RR_Fallback(t *testing.T) { BlackoutPeriod: iserviceconfig.Duration(10 * time.Second), WeightExpirationPeriod: iserviceconfig.Duration(3 * time.Minute), }, - subConns: []*weightedSubConn{wsc}, + weightedPickers: []pickerWeightedEndpoint{{weightedEndpoint: ew}}, metricsRecorder: tmr, } // There is only one SubConn, so no matter if the SubConn has a weight or @@ -160,12 +160,12 @@ func (s) TestWRR_Metrics_Scheduler_RR_Fallback(t *testing.T) { // With two SubConns, if neither of them have weights, it will also fallback // to round robin. - wsc2 := &weightedSubConn{ + ew2 := &endpointWeight{ target: "target", metricsRecorder: tmr, weightVal: 0, } - p.subConns = append(p.subConns, wsc2) + p.weightedPickers = append(p.weightedPickers, pickerWeightedEndpoint{weightedEndpoint: ew2}) p.regenerateScheduler() if got, _ := tmr.Metric("grpc.lb.wrr.rr_fallback"); got != 1 { t.Fatalf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.wrr.rr_fallback", got, 1) diff --git a/balancer/weightedroundrobin/scheduler.go b/balancer/weightedroundrobin/scheduler.go index 56aa15da10d2..7d3d6815eb7a 100644 --- a/balancer/weightedroundrobin/scheduler.go +++ b/balancer/weightedroundrobin/scheduler.go @@ -26,14 +26,14 @@ type scheduler interface { nextIndex() int } -// newScheduler uses scWeights to create a new scheduler for selecting subconns +// newScheduler uses scWeights to create a new scheduler for selecting endpoints // in a picker. It will return a round robin implementation if at least -// len(scWeights)-1 are zero or there is only a single subconn, otherwise it +// len(scWeights)-1 are zero or there is only a single endpoint, otherwise it // will return an Earliest Deadline First (EDF) scheduler implementation that -// selects the subchannels according to their weights. +// selects the endpoints according to their weights. func (p *picker) newScheduler(recordMetrics bool) scheduler { - scWeights := p.scWeights(recordMetrics) - n := len(scWeights) + epWeights := p.endpointWeights(recordMetrics) + n := len(epWeights) if n == 0 { return nil } @@ -46,7 +46,7 @@ func (p *picker) newScheduler(recordMetrics bool) scheduler { sum := float64(0) numZero := 0 max := float64(0) - for _, w := range scWeights { + for _, w := range epWeights { sum += w if w > max { max = w @@ -68,7 +68,7 @@ func (p *picker) newScheduler(recordMetrics bool) scheduler { weights := make([]uint16, n) allEqual := true - for i, w := range scWeights { + for i, w := range epWeights { if w == 0 { // Backends with weight = 0 use the mean. weights[i] = mean From 8b70aeb896f52f99d43aa00bd196cffaf9e1db5e Mon Sep 17 00:00:00 2001 From: Purnesh Dixit Date: Mon, 25 Nov 2024 11:03:13 +0530 Subject: [PATCH 5/9] stats/opentelemetry: introduce tracing propagator and carrier (#7677) --- go.mod | 2 +- .../grpc_trace_bin_propagator.go | 119 ++++++++++ .../grpc_trace_bin_propagator_test.go | 219 ++++++++++++++++++ .../opentelemetry/internal/tracing/carrier.go | 131 +++++++++++ .../internal/tracing/carrier_test.go | 190 +++++++++++++++ 5 files changed, 660 insertions(+), 1 deletion(-) create mode 100644 stats/opentelemetry/grpc_trace_bin_propagator.go create mode 100644 stats/opentelemetry/grpc_trace_bin_propagator_test.go create mode 100644 stats/opentelemetry/internal/tracing/carrier.go create mode 100644 stats/opentelemetry/internal/tracing/carrier_test.go diff --git a/go.mod b/go.mod index 1bbd024d22c1..9b3d296cc882 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( go.opentelemetry.io/otel/metric v1.31.0 go.opentelemetry.io/otel/sdk v1.31.0 go.opentelemetry.io/otel/sdk/metric v1.31.0 + go.opentelemetry.io/otel/trace v1.31.0 golang.org/x/net v0.30.0 golang.org/x/oauth2 v0.23.0 golang.org/x/sync v0.8.0 @@ -32,7 +33,6 @@ require ( github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect - go.opentelemetry.io/otel/trace v1.31.0 // indirect golang.org/x/text v0.19.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20241015192408-796eee8c2d53 // indirect ) diff --git a/stats/opentelemetry/grpc_trace_bin_propagator.go b/stats/opentelemetry/grpc_trace_bin_propagator.go new file mode 100644 index 000000000000..e8a3986d4f4a --- /dev/null +++ b/stats/opentelemetry/grpc_trace_bin_propagator.go @@ -0,0 +1,119 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package opentelemetry + +import ( + "context" + + otelpropagation "go.opentelemetry.io/otel/propagation" + oteltrace "go.opentelemetry.io/otel/trace" +) + +// gRPCTraceBinHeaderKey is the gRPC metadata header key `grpc-trace-bin` used +// to propagate trace context in binary format. +const grpcTraceBinHeaderKey = "grpc-trace-bin" + +// GRPCTraceBinPropagator is an OpenTelemetry TextMapPropagator which is used +// to extract and inject trace context data from and into headers exchanged by +// gRPC applications. It propagates trace data in binary format using the +// `grpc-trace-bin` header. +type GRPCTraceBinPropagator struct{} + +// Inject sets OpenTelemetry span context from the Context into the carrier as +// a `grpc-trace-bin` header if span context is valid. +// +// If span context is not valid, it returns without setting `grpc-trace-bin` +// header. +func (GRPCTraceBinPropagator) Inject(ctx context.Context, carrier otelpropagation.TextMapCarrier) { + sc := oteltrace.SpanFromContext(ctx) + if !sc.SpanContext().IsValid() { + return + } + + bd := toBinary(sc.SpanContext()) + carrier.Set(grpcTraceBinHeaderKey, string(bd)) +} + +// Extract reads OpenTelemetry span context from the `grpc-trace-bin` header of +// carrier into the provided context, if present. +// +// If a valid span context is retrieved from `grpc-trace-bin`, it returns a new +// context containing the extracted OpenTelemetry span context marked as +// remote. +// +// If `grpc-trace-bin` header is not present, it returns the context as is. +func (GRPCTraceBinPropagator) Extract(ctx context.Context, carrier otelpropagation.TextMapCarrier) context.Context { + h := carrier.Get(grpcTraceBinHeaderKey) + if h == "" { + return ctx + } + + sc, ok := fromBinary([]byte(h)) + if !ok { + return ctx + } + return oteltrace.ContextWithRemoteSpanContext(ctx, sc) +} + +// Fields returns the keys whose values are set with Inject. +// +// GRPCTraceBinPropagator always returns a slice containing only +// `grpc-trace-bin` key because it only sets the `grpc-trace-bin` header for +// propagating trace context. +func (GRPCTraceBinPropagator) Fields() []string { + return []string{grpcTraceBinHeaderKey} +} + +// toBinary returns the binary format representation of a SpanContext. +// +// If sc is the zero value, returns nil. +func toBinary(sc oteltrace.SpanContext) []byte { + if sc.Equal(oteltrace.SpanContext{}) { + return nil + } + var b [29]byte + traceID := oteltrace.TraceID(sc.TraceID()) + copy(b[2:18], traceID[:]) + b[18] = 1 + spanID := oteltrace.SpanID(sc.SpanID()) + copy(b[19:27], spanID[:]) + b[27] = 2 + b[28] = byte(oteltrace.TraceFlags(sc.TraceFlags())) + return b[:] +} + +// fromBinary returns the SpanContext represented by b with Remote set to true. +// +// It returns with zero value SpanContext and false, if any of the +// below condition is not satisfied: +// - Valid header: len(b) = 29 +// - Valid version: b[0] = 0 +// - Valid traceID prefixed with 0: b[1] = 0 +// - Valid spanID prefixed with 1: b[18] = 1 +// - Valid traceFlags prefixed with 2: b[27] = 2 +func fromBinary(b []byte) (oteltrace.SpanContext, bool) { + if len(b) != 29 || b[0] != 0 || b[1] != 0 || b[18] != 1 || b[27] != 2 { + return oteltrace.SpanContext{}, false + } + + return oteltrace.SpanContext{}.WithTraceID( + oteltrace.TraceID(b[2:18])).WithSpanID( + oteltrace.SpanID(b[19:27])).WithTraceFlags( + oteltrace.TraceFlags(b[28])).WithRemote(true), true +} diff --git a/stats/opentelemetry/grpc_trace_bin_propagator_test.go b/stats/opentelemetry/grpc_trace_bin_propagator_test.go new file mode 100644 index 000000000000..2d575af4a581 --- /dev/null +++ b/stats/opentelemetry/grpc_trace_bin_propagator_test.go @@ -0,0 +1,219 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package opentelemetry + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + oteltrace "go.opentelemetry.io/otel/trace" + "google.golang.org/grpc/metadata" + itracing "google.golang.org/grpc/stats/opentelemetry/internal/tracing" +) + +var validSpanContext = oteltrace.SpanContext{}.WithTraceID( + oteltrace.TraceID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}).WithSpanID( + oteltrace.SpanID{17, 18, 19, 20, 21, 22, 23, 24}).WithTraceFlags( + oteltrace.TraceFlags(1)) + +// TestInject_ValidSpanContext verifies that the GRPCTraceBinPropagator +// correctly injects a valid OpenTelemetry span context as `grpc-trace-bin` +// header in the provided carrier's context metadata. +// +// It verifies that if a valid span context is injected, same span context can +// can be retreived from the carrier's context metadata. +func (s) TestInject_ValidSpanContext(t *testing.T) { + p := GRPCTraceBinPropagator{} + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + c := itracing.NewOutgoingCarrier(ctx) + ctx = oteltrace.ContextWithSpanContext(ctx, validSpanContext) + + p.Inject(ctx, c) + + md, _ := metadata.FromOutgoingContext(c.Context()) + gotH := md.Get(grpcTraceBinHeaderKey) + if gotH[len(gotH)-1] == "" { + t.Fatalf("got empty value from Carrier's context metadata grpc-trace-bin header, want valid span context: %v", validSpanContext) + } + gotSC, ok := fromBinary([]byte(gotH[len(gotH)-1])) + if !ok { + t.Fatalf("got invalid span context %v from Carrier's context metadata grpc-trace-bin header, want valid span context: %v", gotSC, validSpanContext) + } + if cmp.Equal(validSpanContext, gotSC) { + t.Fatalf("got span context = %v, want span contexts %v", gotSC, validSpanContext) + } +} + +// TestInject_InvalidSpanContext verifies that the GRPCTraceBinPropagator does +// not inject an invalid OpenTelemetry span context as `grpc-trace-bin` header +// in the provided carrier's context metadata. +// +// If an invalid span context is injected, it verifies that `grpc-trace-bin` +// header is not set in the carrier's context metadata. +func (s) TestInject_InvalidSpanContext(t *testing.T) { + p := GRPCTraceBinPropagator{} + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + c := itracing.NewOutgoingCarrier(ctx) + ctx = oteltrace.ContextWithSpanContext(ctx, oteltrace.SpanContext{}) + + p.Inject(ctx, c) + + md, _ := metadata.FromOutgoingContext(c.Context()) + if gotH := md.Get(grpcTraceBinHeaderKey); len(gotH) > 0 { + t.Fatalf("got %v value from Carrier's context metadata grpc-trace-bin header, want empty", gotH) + } +} + +// TestExtract verifies that the GRPCTraceBinPropagator correctly extracts +// OpenTelemetry span context data from the provided context using carrier. +// +// If a valid span context was injected, it verifies same trace span context +// is extracted from carrier's metadata for `grpc-trace-bin` header key. +// +// If invalid span context was injected, it verifies that valid trace span +// context is not extracted. +func (s) TestExtract(t *testing.T) { + tests := []struct { + name string + wantSC oteltrace.SpanContext // expected span context from carrier + }{ + { + name: "valid OpenTelemetry span context", + wantSC: validSpanContext.WithRemote(true), + }, + { + name: "invalid OpenTelemetry span context", + wantSC: oteltrace.SpanContext{}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + p := GRPCTraceBinPropagator{} + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ctx = metadata.NewIncomingContext(ctx, metadata.MD{grpcTraceBinHeaderKey: []string{string(toBinary(test.wantSC))}}) + + c := itracing.NewIncomingCarrier(ctx) + + tCtx := p.Extract(ctx, c) + got := oteltrace.SpanContextFromContext(tCtx) + if !got.Equal(test.wantSC) { + t.Fatalf("got span context: %v, want span context: %v", got, test.wantSC) + } + }) + } +} + +// TestBinary verifies that the toBinary() function correctly serializes a valid +// OpenTelemetry span context into its binary format representation. If span +// context is invalid, it verifies that serialization is nil. +func (s) TestToBinary(t *testing.T) { + tests := []struct { + name string + sc oteltrace.SpanContext + want []byte + }{ + { + name: "valid context", + sc: validSpanContext, + want: toBinary(validSpanContext), + }, + { + name: "zero value context", + sc: oteltrace.SpanContext{}, + want: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := toBinary(test.sc); !cmp.Equal(got, test.want) { + t.Fatalf("binary() = %v, want %v", got, test.want) + } + }) + } +} + +// TestFromBinary verifies that the fromBinary() function correctly +// deserializes a binary format representation of a valid OpenTelemetry span +// context into its corresponding span context format. If span context's binary +// representation is invalid, it verifies that deserialization is zero value +// span context. +func (s) TestFromBinary(t *testing.T) { + tests := []struct { + name string + b []byte + want oteltrace.SpanContext + ok bool + }{ + { + name: "valid", + b: []byte{0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 17, 18, 19, 20, 21, 22, 23, 24, 2, 1}, + want: validSpanContext.WithRemote(true), + ok: true, + }, + { + name: "invalid length", + b: []byte{0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 17, 18, 19, 20, 21, 22, 23, 24, 2}, + want: oteltrace.SpanContext{}, + ok: false, + }, + { + name: "invalid version", + b: []byte{1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 17, 18, 19, 20, 21, 22, 23, 24, 2, 1}, + want: oteltrace.SpanContext{}, + ok: false, + }, + { + name: "invalid traceID field ID", + b: []byte{0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 17, 18, 19, 20, 21, 22, 23, 24, 2, 1}, + want: oteltrace.SpanContext{}, + ok: false, + }, + { + name: "invalid spanID field ID", + b: []byte{0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 0, 17, 18, 19, 20, 21, 22, 23, 24, 2, 1}, + want: oteltrace.SpanContext{}, + ok: false, + }, + { + name: "invalid traceFlags field ID", + b: []byte{0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 17, 18, 19, 20, 21, 22, 23, 24, 1, 1}, + want: oteltrace.SpanContext{}, + ok: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got, ok := fromBinary(test.b) + if ok != test.ok { + t.Fatalf("fromBinary() ok = %v, want %v", ok, test.ok) + return + } + if !got.Equal(test.want) { + t.Fatalf("fromBinary() got = %v, want %v", got, test.want) + } + }) + } +} diff --git a/stats/opentelemetry/internal/tracing/carrier.go b/stats/opentelemetry/internal/tracing/carrier.go new file mode 100644 index 000000000000..214102aaf97a --- /dev/null +++ b/stats/opentelemetry/internal/tracing/carrier.go @@ -0,0 +1,131 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package tracing implements the OpenTelemetry carrier for context propagation +// in gRPC tracing. +package tracing + +import ( + "context" + + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/metadata" +) + +var logger = grpclog.Component("otel-plugin") + +// IncomingCarrier is a TextMapCarrier that uses incoming `context.Context` to +// retrieve any propagated key-value pairs in text format. +type IncomingCarrier struct { + ctx context.Context +} + +// NewIncomingCarrier creates a new `IncomingCarrier` with the given context. +// The incoming carrier should be used with propagator's `Extract()` method in +// the incoming rpc path. +func NewIncomingCarrier(ctx context.Context) *IncomingCarrier { + return &IncomingCarrier{ctx: ctx} +} + +// Get returns the string value associated with the passed key from the +// carrier's incoming context metadata. +// +// It returns an empty string if the key is not present in the carrier's +// context or if the value associated with the key is empty. +// +// If multiple values are present for a key, it returns the last one. +func (c *IncomingCarrier) Get(key string) string { + values := metadata.ValueFromIncomingContext(c.ctx, key) + if len(values) == 0 { + return "" + } + return values[len(values)-1] +} + +// Set just logs an error. It implements the `TextMapCarrier` interface but +// should not be used with `IncomingCarrier`. +func (c *IncomingCarrier) Set(string, string) { + logger.Error("Set() should not be used with IncomingCarrier.") +} + +// Keys returns the keys stored in the carrier's context metadata. It returns +// keys from incoming context metadata. +func (c *IncomingCarrier) Keys() []string { + md, ok := metadata.FromIncomingContext(c.ctx) + if !ok { + return nil + } + keys := make([]string, 0, len(md)) + for key := range md { + keys = append(keys, key) + } + return keys +} + +// Context returns the underlying context associated with the +// `IncomingCarrier“. +func (c *IncomingCarrier) Context() context.Context { + return c.ctx +} + +// OutgoingCarrier is a TextMapCarrier that uses outgoing `context.Context` to +// store any propagated key-value pairs in text format. +type OutgoingCarrier struct { + ctx context.Context +} + +// NewOutgoingCarrier creates a new Carrier with the given context. The +// outgoing carrier should be used with propagator's `Inject()` method in the +// outgoing rpc path. +func NewOutgoingCarrier(ctx context.Context) *OutgoingCarrier { + return &OutgoingCarrier{ctx: ctx} +} + +// Get just logs an error and returns an empty string. It implements the +// `TextMapCarrier` interface but should not be used with `OutgoingCarrier`. +func (c *OutgoingCarrier) Get(string) string { + logger.Error("Get() should not be used with `OutgoingCarrier`") + return "" +} + +// Set stores the key-value pair in the carrier's outgoing context metadata. +// +// If the key already exists, given value is appended to the last. +func (c *OutgoingCarrier) Set(key, value string) { + c.ctx = metadata.AppendToOutgoingContext(c.ctx, key, value) +} + +// Keys returns the keys stored in the carrier's context metadata. It returns +// keys from outgoing context metadata. +func (c *OutgoingCarrier) Keys() []string { + md, ok := metadata.FromOutgoingContext(c.ctx) + if !ok { + return nil + } + keys := make([]string, 0, len(md)) + for key := range md { + keys = append(keys, key) + } + return keys +} + +// Context returns the underlying context associated with the +// `OutgoingCarrier“. +func (c *OutgoingCarrier) Context() context.Context { + return c.ctx +} diff --git a/stats/opentelemetry/internal/tracing/carrier_test.go b/stats/opentelemetry/internal/tracing/carrier_test.go new file mode 100644 index 000000000000..a2e22beb08ac --- /dev/null +++ b/stats/opentelemetry/internal/tracing/carrier_test.go @@ -0,0 +1,190 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * htestp://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package tracing + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/metadata" +) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +// TestIncomingCarrier verifies that `IncomingCarrier.Get()` returns correct +// value for the corresponding key in the carrier's context metadata, if key is +// present. If key is not present, it verifies that empty string is returned. +// +// If multiple values are present for a key, it verifies that last value is +// returned. +// +// If key ends with `-bin`, it verifies that a correct binary value is returned +// in the string format for the binary header. +func (s) TestIncomingCarrier(t *testing.T) { + tests := []struct { + name string + md metadata.MD + key string + want string + wantKeys []string + }{ + { + name: "existing key", + md: metadata.Pairs("key1", "value1"), + key: "key1", + want: "value1", + wantKeys: []string{"key1"}, + }, + { + name: "non-existing key", + md: metadata.Pairs("key1", "value1"), + key: "key2", + want: "", + wantKeys: []string{"key1"}, + }, + { + name: "empty key", + md: metadata.MD{}, + key: "key1", + want: "", + wantKeys: []string{}, + }, + { + name: "more than one key/value pair", + md: metadata.MD{"key1": []string{"value1"}, "key2": []string{"value2"}}, + key: "key2", + want: "value2", + wantKeys: []string{"key1", "key2"}, + }, + { + name: "more than one value for a key", + md: metadata.MD{"key1": []string{"value1", "value2"}}, + key: "key1", + want: "value2", + wantKeys: []string{"key1"}, + }, + { + name: "grpc-trace-bin key", + md: metadata.Pairs("grpc-trace-bin", string([]byte{0x01, 0x02, 0x03})), + key: "grpc-trace-bin", + want: string([]byte{0x01, 0x02, 0x03}), + wantKeys: []string{"grpc-trace-bin"}, + }, + { + name: "grpc-trace-bin key with another string key", + md: metadata.MD{"key1": []string{"value1"}, "grpc-trace-bin": []string{string([]byte{0x01, 0x02, 0x03})}}, + key: "grpc-trace-bin", + want: string([]byte{0x01, 0x02, 0x03}), + wantKeys: []string{"key1", "grpc-trace-bin"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + c := NewIncomingCarrier(metadata.NewIncomingContext(ctx, test.md)) + got := c.Get(test.key) + if got != test.want { + t.Fatalf("c.Get() = %s, want %s", got, test.want) + } + if gotKeys := c.Keys(); !cmp.Equal(test.wantKeys, gotKeys, cmpopts.SortSlices(func(a, b string) bool { return a < b })) { + t.Fatalf("c.Keys() = keys %v, want %v", gotKeys, test.wantKeys) + } + }) + } +} + +// TestOutgoingCarrier verifies that a key-value pair is set in carrier's +// context metadata using `OutgoingCarrier.Set()`. If key is not present, it +// verifies that key-value pair is insterted. If key is already present, it +// verifies that new value is appended at the end of list for the existing key. +// +// If key ends with `-bin`, it verifies that a binary value is set for +// `-bin` header in string format. +// +// It also verifies that both existing and newly inserted keys are present in +// the carrier's context using `Carrier.Keys()`. +func (s) TestOutgoingCarrier(t *testing.T) { + tests := []struct { + name string + initialMD metadata.MD + setKey string + setValue string + wantValue string // expected value of the set key + wantKeys []string + }{ + { + name: "new key", + initialMD: metadata.MD{}, + setKey: "key1", + setValue: "value1", + wantValue: "value1", + wantKeys: []string{"key1"}, + }, + { + name: "add to existing key", + initialMD: metadata.MD{"key1": []string{"oldvalue"}}, + setKey: "key1", + setValue: "newvalue", + wantValue: "newvalue", + wantKeys: []string{"key1"}, + }, + { + name: "new key with different existing key", + initialMD: metadata.MD{"key2": []string{"value2"}}, + setKey: "key1", + setValue: "value1", + wantValue: "value1", + wantKeys: []string{"key2", "key1"}, + }, + { + name: "grpc-trace-bin binary key", + initialMD: metadata.MD{"key1": []string{"value1"}}, + setKey: "grpc-trace-bin", + setValue: string([]byte{0x01, 0x02, 0x03}), + wantValue: string([]byte{0x01, 0x02, 0x03}), + wantKeys: []string{"key1", "grpc-trace-bin"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + c := NewOutgoingCarrier(metadata.NewOutgoingContext(ctx, test.initialMD)) + c.Set(test.setKey, test.setValue) + if gotKeys := c.Keys(); !cmp.Equal(test.wantKeys, gotKeys, cmpopts.SortSlices(func(a, b string) bool { return a < b })) { + t.Fatalf("c.Keys() = keys %v, want %v", gotKeys, test.wantKeys) + } + if md, ok := metadata.FromOutgoingContext(c.Context()); ok && md.Get(test.setKey)[len(md.Get(test.setKey))-1] != test.wantValue { + t.Fatalf("got value %s, want %s, for key %s", md.Get(test.setKey)[len(md.Get(test.setKey))-1], test.wantValue, test.setKey) + } + }) + } +} From dcba136b362e8a8096b3986e700047f8ca6302ac Mon Sep 17 00:00:00 2001 From: janardhanvissa <47281167+janardhanvissa@users.noreply.github.com> Date: Mon, 25 Nov 2024 12:57:01 +0530 Subject: [PATCH 6/9] test/xds: remove redundant server when using stubserver in tests (#7846) --- ...ds_client_ignore_resource_deletion_test.go | 12 +--- .../xds_server_certificate_providers_test.go | 52 ++++----------- test/xds/xds_server_integration_test.go | 8 +-- test/xds/xds_server_serving_mode_test.go | 66 +++++++------------ test/xds/xds_server_test.go | 60 +++++------------ 5 files changed, 58 insertions(+), 140 deletions(-) diff --git a/test/xds/xds_client_ignore_resource_deletion_test.go b/test/xds/xds_client_ignore_resource_deletion_test.go index a8078cd206fb..b85ec16aef44 100644 --- a/test/xds/xds_client_ignore_resource_deletion_test.go +++ b/test/xds/xds_client_ignore_resource_deletion_test.go @@ -310,6 +310,7 @@ func setupGRPCServerWithModeChangeChannelAndServe(t *testing.T, bootstrapContent updateCh <- args.Mode }) stub := &stubserver.StubServer{ + Listener: lis, EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil }, @@ -321,17 +322,10 @@ func setupGRPCServerWithModeChangeChannelAndServe(t *testing.T, bootstrapContent if err != nil { t.Fatalf("Failed to create an xDS enabled gRPC server: %v", err) } - t.Cleanup(server.Stop) - stub.S = server - stubserver.StartTestService(t, stub) + t.Cleanup(stub.S.Stop) - // Serve. - go func() { - if err := server.Serve(lis); err != nil { - t.Errorf("Serve() failed: %v", err) - } - }() + stubserver.StartTestService(t, stub) return updateCh } diff --git a/test/xds/xds_server_certificate_providers_test.go b/test/xds/xds_server_certificate_providers_test.go index f277db1376ed..9fcf6f49cf99 100644 --- a/test/xds/xds_server_certificate_providers_test.go +++ b/test/xds/xds_server_certificate_providers_test.go @@ -158,27 +158,21 @@ func (s) TestServerSideXDS_WithNoCertificateProvidersInBootstrap_Failure(t *test close(servingModeCh) } }) - server, err := xds.NewGRPCServer(grpc.Creds(creds), modeChangeOpt, xds.BootstrapContentsForTesting(bs)) - if err != nil { - t.Fatalf("Failed to create an xDS enabled gRPC server: %v", err) - } - defer server.Stop() - stub := &stubserver.StubServer{} - stub.S = server - stubserver.StartTestService(t, stub) - - // Create a local listener and pass it to Serve(). + // Create a local listener and assign it to the stub server. lis, err := testutils.LocalTCPListener() if err != nil { t.Fatalf("testutils.LocalTCPListener() failed: %v", err) } - go func() { - if err := server.Serve(lis); err != nil { - t.Errorf("Serve() failed: %v", err) - } - }() + stub := &stubserver.StubServer{ + Listener: lis, + } + if stub.S, err = xds.NewGRPCServer(grpc.Creds(creds), modeChangeOpt, xds.BootstrapContentsForTesting(bs)); err != nil { + t.Fatalf("Failed to create an xDS enabled gRPC server: %v", err) + } + defer stub.S.Stop() + stubserver.StartTestService(t, stub) // Create an inbound xDS listener resource for the server side that contains // mTLS security configuration. Since the received certificate provider @@ -288,30 +282,10 @@ func (s) TestServerSideXDS_WithValidAndInvalidSecurityConfiguration(t *testing.T } }) - stub := &stubserver.StubServer{ - EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { - return &testpb.Empty{}, nil - }, - } - server, err := xds.NewGRPCServer(grpc.Creds(creds), modeChangeOpt, xds.BootstrapContentsForTesting(bootstrapContents)) - if err != nil { - t.Fatalf("Failed to create an xDS enabled gRPC server: %v", err) - } - defer server.Stop() - - stub.S = server - stubserver.StartTestService(t, stub) - - go func() { - if err := server.Serve(lis1); err != nil { - t.Errorf("Serve() failed: %v", err) - } - }() - go func() { - if err := server.Serve(lis2); err != nil { - t.Errorf("Serve() failed: %v", err) - } - }() + stub1 := createStubServer(t, lis1, creds, modeChangeOpt, bootstrapContents) + defer stub1.S.Stop() + stub2 := createStubServer(t, lis2, creds, modeChangeOpt, bootstrapContents) + defer stub2.S.Stop() // Create inbound xDS listener resources for the server side that contains // mTLS security configuration. diff --git a/test/xds/xds_server_integration_test.go b/test/xds/xds_server_integration_test.go index eacc6463c395..054e1fc7b0f0 100644 --- a/test/xds/xds_server_integration_test.go +++ b/test/xds/xds_server_integration_test.go @@ -111,12 +111,10 @@ func setupGRPCServer(t *testing.T, bootstrapContents []byte) (net.Listener, func }, } - server, err := xds.NewGRPCServer(grpc.Creds(creds), testModeChangeServerOption(t), xds.BootstrapContentsForTesting(bootstrapContents)) - if err != nil { + if stub.S, err = xds.NewGRPCServer(grpc.Creds(creds), testModeChangeServerOption(t), xds.BootstrapContentsForTesting(bootstrapContents)); err != nil { t.Fatalf("Failed to create an xDS enabled gRPC server: %v", err) } - stub.S = server stubserver.StartTestService(t, stub) // Create a local listener and pass it to Serve(). @@ -131,7 +129,7 @@ func setupGRPCServer(t *testing.T, bootstrapContents []byte) (net.Listener, func } go func() { - if err := server.Serve(readyLis); err != nil { + if err := stub.S.Serve(readyLis); err != nil { t.Errorf("Serve() failed: %v", err) } }() @@ -144,7 +142,7 @@ func setupGRPCServer(t *testing.T, bootstrapContents []byte) (net.Listener, func } return lis, func() { - server.Stop() + stub.S.Stop() } } diff --git a/test/xds/xds_server_serving_mode_test.go b/test/xds/xds_server_serving_mode_test.go index 3ed6750a6353..0299d6954ddb 100644 --- a/test/xds/xds_server_serving_mode_test.go +++ b/test/xds/xds_server_serving_mode_test.go @@ -27,6 +27,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" xdscreds "google.golang.org/grpc/credentials/xds" "google.golang.org/grpc/internal/stubserver" @@ -65,19 +66,8 @@ func (s) TestServerSideXDS_RedundantUpdateSuppression(t *testing.T) { // Initialize a test gRPC server, assign it to the stub server, and start // the test service. - stub := &stubserver.StubServer{ - EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { - return &testpb.Empty{}, nil - }, - } - server, err := xds.NewGRPCServer(grpc.Creds(creds), modeChangeOpt, xds.BootstrapContentsForTesting(bootstrapContents)) - if err != nil { - t.Fatalf("Failed to create an xDS enabled gRPC server: %v", err) - } - defer server.Stop() - - stub.S = server - stubserver.StartTestService(t, stub) + stub := createStubServer(t, lis, creds, modeChangeOpt, bootstrapContents) + defer stub.S.Stop() // Setup the management server to respond with the listener resources. host, port, err := hostPortFromListener(lis) @@ -95,12 +85,6 @@ func (s) TestServerSideXDS_RedundantUpdateSuppression(t *testing.T) { t.Fatal(err) } - go func() { - if err := server.Serve(lis); err != nil { - t.Errorf("Serve() failed: %v", err) - } - }() - // Wait for the listener to move to "serving" mode. select { case <-ctx.Done(): @@ -217,19 +201,10 @@ func (s) TestServerSideXDS_ServingModeChanges(t *testing.T) { // Initialize a test gRPC server, assign it to the stub server, and start // the test service. - stub := &stubserver.StubServer{ - EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { - return &testpb.Empty{}, nil - }, - } - server, err := xds.NewGRPCServer(grpc.Creds(creds), modeChangeOpt, xds.BootstrapContentsForTesting(bootstrapContents)) - if err != nil { - t.Fatalf("Failed to create an xDS enabled gRPC server: %v", err) - } - defer server.Stop() - - stub.S = server - stubserver.StartTestService(t, stub) + stub1 := createStubServer(t, lis1, creds, modeChangeOpt, bootstrapContents) + defer stub1.S.Stop() + stub2 := createStubServer(t, lis2, creds, modeChangeOpt, bootstrapContents) + defer stub2.S.Stop() // Setup the management server to respond with server-side Listener // resources for both listeners. @@ -251,17 +226,6 @@ func (s) TestServerSideXDS_ServingModeChanges(t *testing.T) { t.Fatal(err) } - go func() { - if err := server.Serve(lis1); err != nil { - t.Errorf("Serve() failed: %v", err) - } - }() - go func() { - if err := server.Serve(lis2); err != nil { - t.Errorf("Serve() failed: %v", err) - } - }() - // Wait for both listeners to move to "serving" mode. select { case <-ctx.Done(): @@ -384,6 +348,22 @@ func (s) TestServerSideXDS_ServingModeChanges(t *testing.T) { waitForSuccessfulRPC(ctx, t, cc2) } +func createStubServer(t *testing.T, lis net.Listener, creds credentials.TransportCredentials, modeChangeOpt grpc.ServerOption, bootstrapContents []byte) *stubserver.StubServer { + stub := &stubserver.StubServer{ + Listener: lis, + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + server, err := xds.NewGRPCServer(grpc.Creds(creds), modeChangeOpt, xds.BootstrapContentsForTesting(bootstrapContents)) + if err != nil { + t.Fatalf("Failed to create an xDS enabled gRPC server: %v", err) + } + stub.S = server + stubserver.StartTestService(t, stub) + return stub +} + func waitForSuccessfulRPC(ctx context.Context, t *testing.T, cc *grpc.ClientConn) { t.Helper() diff --git a/test/xds/xds_server_test.go b/test/xds/xds_server_test.go index 6912757e5e13..bee9d401423b 100644 --- a/test/xds/xds_server_test.go +++ b/test/xds/xds_server_test.go @@ -30,6 +30,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials/insecure" + xdscreds "google.golang.org/grpc/credentials/xds" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/internal/testutils" @@ -93,26 +94,15 @@ func (s) TestServeLDSRDS(t *testing.T) { serving.Fire() } }) - - stub := &stubserver.StubServer{ - EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { - return &testpb.Empty{}, nil - }, - } - server, err := xds.NewGRPCServer(grpc.Creds(insecure.NewCredentials()), modeChangeOpt, xds.BootstrapContentsForTesting(bootstrapContents)) + // Configure xDS credentials with an insecure fallback to be used on the + // server-side. + creds, err := xdscreds.NewServerCredentials(xdscreds.ServerOptions{FallbackCreds: insecure.NewCredentials()}) if err != nil { - t.Fatalf("Failed to create an xDS enabled gRPC server: %v", err) + t.Fatalf("failed to create server credentials: %v", err) } - defer server.Stop() - - stub.S = server - stubserver.StartTestService(t, stub) + stub := createStubServer(t, lis, creds, modeChangeOpt, bootstrapContents) + defer stub.S.Stop() - go func() { - if err := server.Serve(lis); err != nil { - t.Errorf("Serve() failed: %v", err) - } - }() select { case <-ctx.Done(): t.Fatal("timeout waiting for the xDS Server to go Serving") @@ -210,25 +200,15 @@ func (s) TestRDSNack(t *testing.T) { } }) - stub := &stubserver.StubServer{ - EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { - return &testpb.Empty{}, nil - }, - } - server, err := xds.NewGRPCServer(grpc.Creds(insecure.NewCredentials()), modeChangeOpt, xds.BootstrapContentsForTesting(bootstrapContents)) + // Configure xDS credentials with an insecure fallback to be used on the + // server-side. + creds, err := xdscreds.NewServerCredentials(xdscreds.ServerOptions{FallbackCreds: insecure.NewCredentials()}) if err != nil { - t.Fatalf("Failed to create an xDS enabled gRPC server: %v", err) + t.Fatalf("failed to create server credentials: %v", err) } - defer server.Stop() - stub.S = server - stubserver.StartTestService(t, stub) - - go func() { - if err := server.Serve(lis); err != nil { - t.Errorf("Serve() failed: %v", err) - } - }() + stub := createStubServer(t, lis, creds, modeChangeOpt, bootstrapContents) + defer stub.S.Stop() cc, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { @@ -278,6 +258,7 @@ func (s) TestMultipleUpdatesImmediatelySwitch(t *testing.T) { } stub := &stubserver.StubServer{ + Listener: lis, EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil }, @@ -291,21 +272,12 @@ func (s) TestMultipleUpdatesImmediatelySwitch(t *testing.T) { }, } - server, err := xds.NewGRPCServer(grpc.Creds(insecure.NewCredentials()), testModeChangeServerOption(t), xds.BootstrapContentsForTesting(bootstrapContents)) - if err != nil { + if stub.S, err = xds.NewGRPCServer(grpc.Creds(insecure.NewCredentials()), testModeChangeServerOption(t), xds.BootstrapContentsForTesting(bootstrapContents)); err != nil { t.Fatalf("Failed to create an xDS enabled gRPC server: %v", err) } - defer server.Stop() - - stub.S = server + defer stub.S.Stop() stubserver.StartTestService(t, stub) - go func() { - if err := server.Serve(lis); err != nil { - t.Errorf("Serve() failed: %v", err) - } - }() - cc, err := grpc.NewClient(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { t.Fatalf("failed to dial local test server: %v", err) From bb7ae0a2bf804286b3012e300020ad04b2995719 Mon Sep 17 00:00:00 2001 From: Robert O Butts Date: Tue, 26 Nov 2024 18:37:38 +0000 Subject: [PATCH 7/9] Change logger to avoid Printf when disabled (#7471) --- grpclog/internal/loggerv2.go | 107 ++++++++--- grpclog/internal/loggerv2_test.go | 305 +++++++++++++++++++++++++++++- 2 files changed, 381 insertions(+), 31 deletions(-) diff --git a/grpclog/internal/loggerv2.go b/grpclog/internal/loggerv2.go index 07df71e98a87..ed90060c3cba 100644 --- a/grpclog/internal/loggerv2.go +++ b/grpclog/internal/loggerv2.go @@ -101,6 +101,22 @@ var severityName = []string{ fatalLog: "FATAL", } +// sprintf is fmt.Sprintf. +// These vars exist to make it possible to test that expensive format calls aren't made unnecessarily. +var sprintf = fmt.Sprintf + +// sprint is fmt.Sprint. +// These vars exist to make it possible to test that expensive format calls aren't made unnecessarily. +var sprint = fmt.Sprint + +// sprintln is fmt.Sprintln. +// These vars exist to make it possible to test that expensive format calls aren't made unnecessarily. +var sprintln = fmt.Sprintln + +// exit is os.Exit. +// This var exists to make it possible to test functions calling os.Exit. +var exit = os.Exit + // loggerT is the default logger used by grpclog. type loggerT struct { m []*log.Logger @@ -111,7 +127,7 @@ type loggerT struct { func (g *loggerT) output(severity int, s string) { sevStr := severityName[severity] if !g.jsonFormat { - g.m[severity].Output(2, fmt.Sprintf("%v: %v", sevStr, s)) + g.m[severity].Output(2, sevStr+": "+s) return } // TODO: we can also include the logging component, but that needs more @@ -123,55 +139,79 @@ func (g *loggerT) output(severity int, s string) { g.m[severity].Output(2, string(b)) } +func (g *loggerT) printf(severity int, format string, args ...any) { + // Note the discard check is duplicated in each print func, rather than in + // output, to avoid the expensive Sprint calls. + // De-duplicating this by moving to output would be a significant performance regression! + if lg := g.m[severity]; lg.Writer() == io.Discard { + return + } + g.output(severity, sprintf(format, args...)) +} + +func (g *loggerT) print(severity int, v ...any) { + if lg := g.m[severity]; lg.Writer() == io.Discard { + return + } + g.output(severity, sprint(v...)) +} + +func (g *loggerT) println(severity int, v ...any) { + if lg := g.m[severity]; lg.Writer() == io.Discard { + return + } + g.output(severity, sprintln(v...)) +} + func (g *loggerT) Info(args ...any) { - g.output(infoLog, fmt.Sprint(args...)) + g.print(infoLog, args...) } func (g *loggerT) Infoln(args ...any) { - g.output(infoLog, fmt.Sprintln(args...)) + g.println(infoLog, args...) } func (g *loggerT) Infof(format string, args ...any) { - g.output(infoLog, fmt.Sprintf(format, args...)) + g.printf(infoLog, format, args...) } func (g *loggerT) Warning(args ...any) { - g.output(warningLog, fmt.Sprint(args...)) + g.print(warningLog, args...) } func (g *loggerT) Warningln(args ...any) { - g.output(warningLog, fmt.Sprintln(args...)) + g.println(warningLog, args...) } func (g *loggerT) Warningf(format string, args ...any) { - g.output(warningLog, fmt.Sprintf(format, args...)) + g.printf(warningLog, format, args...) } func (g *loggerT) Error(args ...any) { - g.output(errorLog, fmt.Sprint(args...)) + g.print(errorLog, args...) } func (g *loggerT) Errorln(args ...any) { - g.output(errorLog, fmt.Sprintln(args...)) + g.println(errorLog, args...) } func (g *loggerT) Errorf(format string, args ...any) { - g.output(errorLog, fmt.Sprintf(format, args...)) + g.printf(errorLog, format, args...) } func (g *loggerT) Fatal(args ...any) { - g.output(fatalLog, fmt.Sprint(args...)) - os.Exit(1) + g.print(fatalLog, args...) + exit(1) } func (g *loggerT) Fatalln(args ...any) { - g.output(fatalLog, fmt.Sprintln(args...)) - os.Exit(1) + g.println(fatalLog, args...) + exit(1) } func (g *loggerT) Fatalf(format string, args ...any) { - g.output(fatalLog, fmt.Sprintf(format, args...)) - os.Exit(1) + g.printf(fatalLog, format, args...) + exit(1) } func (g *loggerT) V(l int) bool { @@ -186,19 +226,42 @@ type LoggerV2Config struct { FormatJSON bool } +// combineLoggers returns a combined logger for both higher & lower severity logs, +// or only one if the other is io.Discard. +// +// This uses io.Discard instead of io.MultiWriter when all loggers +// are set to io.Discard. Both this package and the standard log package have +// significant optimizations for io.Discard, which io.MultiWriter lacks (as of +// this writing). +func combineLoggers(lower, higher io.Writer) io.Writer { + if lower == io.Discard { + return higher + } + if higher == io.Discard { + return lower + } + return io.MultiWriter(lower, higher) +} + // NewLoggerV2 creates a new LoggerV2 instance with the provided configuration. // The infoW, warningW, and errorW writers are used to write log messages of // different severity levels. func NewLoggerV2(infoW, warningW, errorW io.Writer, c LoggerV2Config) LoggerV2 { - var m []*log.Logger flag := log.LstdFlags if c.FormatJSON { flag = 0 } - m = append(m, log.New(infoW, "", flag)) - m = append(m, log.New(io.MultiWriter(infoW, warningW), "", flag)) - ew := io.MultiWriter(infoW, warningW, errorW) // ew will be used for error and fatal. - m = append(m, log.New(ew, "", flag)) - m = append(m, log.New(ew, "", flag)) + + warningW = combineLoggers(infoW, warningW) + errorW = combineLoggers(errorW, warningW) + + fatalW := errorW + + m := []*log.Logger{ + log.New(infoW, "", flag), + log.New(warningW, "", flag), + log.New(errorW, "", flag), + log.New(fatalW, "", flag), + } return &loggerT{m: m, v: c.Verbosity, jsonFormat: c.FormatJSON} } diff --git a/grpclog/internal/loggerv2_test.go b/grpclog/internal/loggerv2_test.go index b22ecbde82a0..3369448bef5f 100644 --- a/grpclog/internal/loggerv2_test.go +++ b/grpclog/internal/loggerv2_test.go @@ -20,11 +20,86 @@ package internal import ( "bytes" + "encoding/json" "fmt" + "io" + "os" + "reflect" "regexp" + "strings" "testing" ) +// logFuncStr is a string used via testCheckLogContainsFuncStr to test the +// logger output. +const logFuncStr = "called-func" + +func makeSprintfErr(t *testing.T) func(format string, a ...any) string { + return func(string, ...any) string { + t.Errorf("got: sprintf called on io.Discard logger, want: expensive sprintf to not be called for io.Discard") + return "" + } +} + +func makeSprintErr(t *testing.T) func(a ...any) string { + return func(...any) string { + t.Errorf("got: sprint called on io.Discard logger, want: expensive sprint to not be called for io.Discard") + return "" + } +} + +// checkLogContainsFuncStr checks that the logger buffer logBuf contains +// logFuncStr. +func checkLogContainsFuncStr(t *testing.T, logBuf []byte) { + if !bytes.Contains(logBuf, []byte(logFuncStr)) { + t.Errorf("got '%v', want logger func to be called and print '%v'", string(logBuf), logFuncStr) + } +} + +// checkBufferWasWrittenAsExpected checks that the log buffer buf was written as expected, +// per the discard, logTYpe, msg, and isJSON arguments. +func checkBufferWasWrittenAsExpected(t *testing.T, buf *bytes.Buffer, discard bool, logType string, msg string, isJSON bool) { + bts, err := buf.ReadBytes('\n') + if discard { + if err == nil { + t.Fatalf("got '%v', want discard %v to not write", string(bts), logType) + } else if err != io.EOF { + t.Fatalf("got '%v', want discard %v buffer to be EOF", err, logType) + } + } else { + if err != nil { + t.Fatalf("got '%v', want non-discard %v to not error", err, logType) + } else if !bytes.Contains(bts, []byte(msg)) { + t.Fatalf("got '%v', want non-discard %v buffer contain message '%v'", string(bts), logType, msg) + } + if isJSON { + obj := map[string]string{} + if err := json.Unmarshal(bts, &obj); err != nil { + t.Fatalf("got '%v', want non-discard json %v to unmarshal", err, logType) + } else if _, ok := obj["severity"]; !ok { + t.Fatalf("got '%v', want non-discard json %v to have severity field", "missing severity", logType) + + } else if jsonMsg, ok := obj["message"]; !ok { + t.Fatalf("got '%v', want non-discard json %v to have message field", "missing message", logType) + + } else if !strings.Contains(jsonMsg, msg) { + t.Fatalf("got '%v', want non-discard json %v buffer contain message '%v'", string(bts), logType, msg) + } + } + } +} + +// check if b is in the format of: +// +// 2017/04/07 14:55:42 WARNING: WARNING +func checkLogForSeverity(s int, b []byte) error { + expected := regexp.MustCompile(fmt.Sprintf(`^[0-9]{4}/[0-9]{2}/[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2} %s: %s\n$`, severityName[s], severityName[s])) + if m := expected.Match(b); !m { + return fmt.Errorf("got: %v, want string in format of: %v", string(b), severityName[s]+": 2016/10/05 17:09:26 "+severityName[s]) + } + return nil +} + func TestLoggerV2Severity(t *testing.T) { buffers := []*bytes.Buffer{new(bytes.Buffer), new(bytes.Buffer), new(bytes.Buffer)} l := NewLoggerV2(buffers[infoLog], buffers[warningLog], buffers[errorLog], LoggerV2Config{}) @@ -42,7 +117,7 @@ func TestLoggerV2Severity(t *testing.T) { for j := i; j < fatalLog; j++ { b, err := buf.ReadBytes('\n') if err != nil { - t.Fatal(err) + t.Fatalf("level %d: %v", j, err) } if err := checkLogForSeverity(j, b); err != nil { t.Fatal(err) @@ -51,13 +126,225 @@ func TestLoggerV2Severity(t *testing.T) { } } -// check if b is in the format of: -// -// 2017/04/07 14:55:42 WARNING: WARNING -func checkLogForSeverity(s int, b []byte) error { - expected := regexp.MustCompile(fmt.Sprintf(`^[0-9]{4}/[0-9]{2}/[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2} %s: %s\n$`, severityName[s], severityName[s])) - if m := expected.Match(b); !m { - return fmt.Errorf("got: %v, want string in format of: %v", string(b), severityName[s]+": 2016/10/05 17:09:26 "+severityName[s]) +// TestLoggerV2PrintFuncDiscardOnlyInfo ensures that logs at the INFO level are +// discarded when set to io.Discard, while logs at other levels (WARN, ERROR) +// are still printed. It does this by using a custom error function that raises +// an error if the logger attempts to print at the INFO level, ensuring early +// return when io.Discard is used. +func TestLoggerV2PrintFuncDiscardOnlyInfo(t *testing.T) { + buffers := []*bytes.Buffer{nil, new(bytes.Buffer), new(bytes.Buffer)} + logger := NewLoggerV2(io.Discard, buffers[warningLog], buffers[errorLog], LoggerV2Config{}) + loggerTp := logger.(*loggerT) + + // test that output doesn't call expensive printf funcs on an io.Discard logger + sprintf = makeSprintfErr(t) + sprint = makeSprintErr(t) + sprintln = makeSprintErr(t) + + loggerTp.output(infoLog, "something") + + sprintf = fmt.Sprintf + sprint = fmt.Sprint + sprintln = fmt.Sprintln + + loggerTp.output(errorLog, logFuncStr) + warnB, err := buffers[warningLog].ReadBytes('\n') + if err != nil { + t.Fatalf("level %v: %v", warningLog, err) + } + checkLogContainsFuncStr(t, warnB) + + errB, err := buffers[errorLog].ReadBytes('\n') + if err != nil { + t.Fatalf("level %v: %v", errorLog, err) + } + checkLogContainsFuncStr(t, errB) +} + +func TestLoggerV2PrintFuncNoDiscard(t *testing.T) { + buffers := []*bytes.Buffer{new(bytes.Buffer), new(bytes.Buffer), new(bytes.Buffer)} + logger := NewLoggerV2(buffers[infoLog], buffers[warningLog], buffers[errorLog], LoggerV2Config{}) + loggerTp := logger.(*loggerT) + + loggerTp.output(errorLog, logFuncStr) + + infoB, err := buffers[infoLog].ReadBytes('\n') + if err != nil { + t.Fatalf("level %v: %v", infoLog, err) + } + checkLogContainsFuncStr(t, infoB) + + warnB, err := buffers[warningLog].ReadBytes('\n') + if err != nil { + t.Fatalf("level %v: %v", warningLog, err) + } + checkLogContainsFuncStr(t, warnB) + + errB, err := buffers[errorLog].ReadBytes('\n') + if err != nil { + t.Fatalf("level %v: %v", errorLog, err) + } + checkLogContainsFuncStr(t, errB) +} + +// TestLoggerV2PrintFuncAllDiscard tests that discard loggers don't log. +func TestLoggerV2PrintFuncAllDiscard(t *testing.T) { + logger := NewLoggerV2(io.Discard, io.Discard, io.Discard, LoggerV2Config{}) + loggerTp := logger.(*loggerT) + + sprintf = makeSprintfErr(t) + sprint = makeSprintErr(t) + sprintln = makeSprintErr(t) + + // test that printFunc doesn't call the log func on discard loggers + // makeLogFuncErr will fail the test if it's called + loggerTp.output(infoLog, logFuncStr) + loggerTp.output(warningLog, logFuncStr) + loggerTp.output(errorLog, logFuncStr) + + sprintf = fmt.Sprintf + sprint = fmt.Sprint + sprintln = fmt.Sprintln +} + +func TestLoggerV2PrintFuncAllCombinations(t *testing.T) { + const ( + print int = iota + printf + println + ) + + type testDiscard struct { + discardInf bool + discardWarn bool + discardErr bool + + printType int + formatJSON bool + } + + discardName := func(td testDiscard) string { + strs := []string{} + if td.discardInf { + strs = append(strs, "discardInfo") + } + if td.discardWarn { + strs = append(strs, "discardWarn") + } + if td.discardErr { + strs = append(strs, "discardErr") + } + if len(strs) == 0 { + strs = append(strs, "noDiscard") + } + return strings.Join(strs, " ") + } + var printName = []string{ + print: "print", + printf: "printf", + println: "println", + } + var jsonName = map[bool]string{ + true: "json", + false: "noJson", + } + + discardTests := []testDiscard{} + for _, di := range []bool{true, false} { + for _, dw := range []bool{true, false} { + for _, de := range []bool{true, false} { + for _, pt := range []int{print, printf, println} { + for _, fj := range []bool{true, false} { + discardTests = append(discardTests, testDiscard{discardInf: di, discardWarn: dw, discardErr: de, printType: pt, formatJSON: fj}) + } + } + } + } + } + + for _, discardTest := range discardTests { + testName := fmt.Sprintf("%v %v %v", jsonName[discardTest.formatJSON], printName[discardTest.printType], discardName(discardTest)) + t.Run(testName, func(t *testing.T) { + cfg := LoggerV2Config{FormatJSON: discardTest.formatJSON} + + buffers := []*bytes.Buffer{new(bytes.Buffer), new(bytes.Buffer), new(bytes.Buffer)} + writers := []io.Writer{buffers[infoLog], buffers[warningLog], buffers[errorLog]} + if discardTest.discardInf { + writers[infoLog] = io.Discard + } + if discardTest.discardWarn { + writers[warningLog] = io.Discard + } + if discardTest.discardErr { + writers[errorLog] = io.Discard + } + logger := NewLoggerV2(writers[infoLog], writers[warningLog], writers[errorLog], cfg) + + msgInf := "someinfo" + msgWarn := "somewarn" + msgErr := "someerr" + if discardTest.printType == print { + logger.Info(msgInf) + logger.Warning(msgWarn) + logger.Error(msgErr) + } else if discardTest.printType == printf { + logger.Infof("%v", msgInf) + logger.Warningf("%v", msgWarn) + logger.Errorf("%v", msgErr) + } else if discardTest.printType == println { + logger.Infoln(msgInf) + logger.Warningln(msgWarn) + logger.Errorln(msgErr) + } + + // verify the test Discard, log type, message, and json arguments were logged as-expected + + checkBufferWasWrittenAsExpected(t, buffers[infoLog], discardTest.discardInf, "info", msgInf, cfg.FormatJSON) + checkBufferWasWrittenAsExpected(t, buffers[warningLog], discardTest.discardWarn, "warning", msgWarn, cfg.FormatJSON) + checkBufferWasWrittenAsExpected(t, buffers[errorLog], discardTest.discardErr, "error", msgErr, cfg.FormatJSON) + }) + } +} + +func TestLoggerV2Fatal(t *testing.T) { + const ( + print = "print" + println = "println" + printf = "printf" + ) + printFuncs := []string{print, println, printf} + + exitCalls := []int{} + + if reflect.ValueOf(exit).Pointer() != reflect.ValueOf(os.Exit).Pointer() { + t.Error("got: exit isn't os.Exit, want exit var to be os.Exit") + } + + exit = func(code int) { + exitCalls = append(exitCalls, code) + } + defer func() { + exit = os.Exit + }() + + for _, printFunc := range printFuncs { + buffers := []*bytes.Buffer{new(bytes.Buffer), new(bytes.Buffer), new(bytes.Buffer)} + logger := NewLoggerV2(buffers[infoLog], buffers[warningLog], buffers[errorLog], LoggerV2Config{}) + switch printFunc { + case print: + logger.Fatal("fatal") + case println: + logger.Fatalln("fatalln") + case printf: + logger.Fatalf("fatalf %d", 42) + default: + t.Errorf("unknown print type '%v'", printFunc) + } + + checkBufferWasWrittenAsExpected(t, buffers[errorLog], false, "fatal", "fatal", false) + if len(exitCalls) == 0 { + t.Error("got: no exit call, want fatal log to call exit") + } + exitCalls = []int{} } - return nil } From 967ba461405304ca9acc47c40eb55a4e87abb514 Mon Sep 17 00:00:00 2001 From: Zach Reyes <39203661+zasweq@users.noreply.github.com> Date: Tue, 26 Nov 2024 13:56:48 -0500 Subject: [PATCH 8/9] balancer/pickfirst: Add pick first metrics (#7839) --- balancer/pickfirst/pickfirst_test.go | 5 +- .../pickfirst/pickfirstleaf/metrics_test.go | 273 ++++++++++++++++++ .../pickfirst/pickfirstleaf/pickfirstleaf.go | 57 +++- .../pickfirstleaf/pickfirstleaf_ext_test.go | 48 ++- .../pickfirstleaf/pickfirstleaf_test.go | 3 +- gcp/observability/go.sum | 1 + internal/balancergroup/balancergroup_test.go | 2 + .../testutils/stats/test_metrics_recorder.go | 17 ++ interop/observability/go.sum | 1 + security/advancedtls/examples/go.sum | 16 + security/advancedtls/go.sum | 16 + stats/opencensus/go.sum | 8 + .../clustermanager/clustermanager_test.go | 2 + 13 files changed, 434 insertions(+), 15 deletions(-) create mode 100644 balancer/pickfirst/pickfirstleaf/metrics_test.go diff --git a/balancer/pickfirst/pickfirst_test.go b/balancer/pickfirst/pickfirst_test.go index 0b360c3b31ed..1da680fb4cf7 100644 --- a/balancer/pickfirst/pickfirst_test.go +++ b/balancer/pickfirst/pickfirst_test.go @@ -29,6 +29,7 @@ import ( "google.golang.org/grpc/connectivity" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/testutils/stats" "google.golang.org/grpc/resolver" ) @@ -55,7 +56,7 @@ func (s) TestPickFirst_InitialResolverError(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() cc := testutils.NewBalancerClientConn(t) - bal := balancer.Get(Name).Build(cc, balancer.BuildOptions{}) + bal := balancer.Get(Name).Build(cc, balancer.BuildOptions{MetricsRecorder: &stats.NoopMetricsRecorder{}}) defer bal.Close() bal.ResolverError(errors.New("resolution failed: test error")) @@ -88,7 +89,7 @@ func (s) TestPickFirst_ResolverErrorinTF(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() cc := testutils.NewBalancerClientConn(t) - bal := balancer.Get(Name).Build(cc, balancer.BuildOptions{}) + bal := balancer.Get(Name).Build(cc, balancer.BuildOptions{MetricsRecorder: &stats.NoopMetricsRecorder{}}) defer bal.Close() // After sending a valid update, the LB policy should report CONNECTING. diff --git a/balancer/pickfirst/pickfirstleaf/metrics_test.go b/balancer/pickfirst/pickfirstleaf/metrics_test.go new file mode 100644 index 000000000000..90beca6adc42 --- /dev/null +++ b/balancer/pickfirst/pickfirstleaf/metrics_test.go @@ -0,0 +1,273 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package pickfirstleaf_test + +import ( + "context" + "fmt" + "testing" + + "google.golang.org/grpc" + "google.golang.org/grpc/balancer/pickfirst/pickfirstleaf" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/testutils/stats" + testgrpc "google.golang.org/grpc/interop/grpc_testing" + testpb "google.golang.org/grpc/interop/grpc_testing" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/resolver/manual" + "google.golang.org/grpc/serviceconfig" + "google.golang.org/grpc/stats/opentelemetry" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/sdk/metric" + "go.opentelemetry.io/otel/sdk/metric/metricdata" + "go.opentelemetry.io/otel/sdk/metric/metricdata/metricdatatest" +) + +var pfConfig string + +func init() { + pfConfig = fmt.Sprintf(`{ + "loadBalancingConfig": [ + { + %q: { + } + } + ] + }`, pickfirstleaf.Name) +} + +// TestPickFirstMetrics tests pick first metrics. It configures a pick first +// balancer, causes it to connect and then disconnect, and expects the +// subsequent metrics to emit from that. +func (s) TestPickFirstMetrics(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + ss.StartServer() + defer ss.Stop() + + sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(pfConfig) + + r := manual.NewBuilderWithScheme("whatever") + r.InitialState(resolver.State{ + ServiceConfig: sc, + Addresses: []resolver.Address{{Addr: ss.Address}}}, + ) + + tmr := stats.NewTestMetricsRecorder() + cc, err := grpc.NewClient(r.Scheme()+":///", grpc.WithStatsHandler(tmr), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(r)) + if err != nil { + t.Fatalf("NewClient() failed with error: %v", err) + } + defer cc.Close() + + tsc := testgrpc.NewTestServiceClient(cc) + if _, err := tsc.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("EmptyCall() failed: %v", err) + } + + if got, _ := tmr.Metric("grpc.lb.pick_first.connection_attempts_succeeded"); got != 1 { + t.Errorf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.pick_first.connection_attempts_succeeded", got, 1) + } + if got, _ := tmr.Metric("grpc.lb.pick_first.connection_attempts_failed"); got != 0 { + t.Errorf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.pick_first.connection_attempts_failed", got, 0) + } + if got, _ := tmr.Metric("grpc.lb.pick_first.disconnections"); got != 0 { + t.Errorf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.pick_first.disconnections", got, 0) + } + + ss.Stop() + testutils.AwaitState(ctx, t, cc, connectivity.Idle) + if got, _ := tmr.Metric("grpc.lb.pick_first.disconnections"); got != 1 { + t.Errorf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.pick_first.disconnections", got, 1) + } +} + +// TestPickFirstMetricsFailure tests the connection attempts failed metric. It +// configures a channel and scenario that causes a pick first connection attempt +// to fail, and then expects that metric to emit. +func (s) TestPickFirstMetricsFailure(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(pfConfig) + + r := manual.NewBuilderWithScheme("whatever") + r.InitialState(resolver.State{ + ServiceConfig: sc, + Addresses: []resolver.Address{{Addr: "bad address"}}}, + ) + grpcTarget := r.Scheme() + ":///" + tmr := stats.NewTestMetricsRecorder() + cc, err := grpc.NewClient(grpcTarget, grpc.WithStatsHandler(tmr), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(r)) + if err != nil { + t.Fatalf("NewClient() failed with error: %v", err) + } + defer cc.Close() + + tsc := testgrpc.NewTestServiceClient(cc) + if _, err := tsc.EmptyCall(ctx, &testpb.Empty{}); err == nil { + t.Fatalf("EmptyCall() passed when expected to fail") + } + + if got, _ := tmr.Metric("grpc.lb.pick_first.connection_attempts_succeeded"); got != 0 { + t.Errorf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.pick_first.connection_attempts_succeeded", got, 0) + } + if got, _ := tmr.Metric("grpc.lb.pick_first.connection_attempts_failed"); got != 1 { + t.Errorf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.pick_first.connection_attempts_failed", got, 1) + } + if got, _ := tmr.Metric("grpc.lb.pick_first.disconnections"); got != 0 { + t.Errorf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.pick_first.disconnections", got, 0) + } +} + +// TestPickFirstMetricsE2E tests the pick first metrics end to end. It +// configures a channel with an OpenTelemetry plugin, induces all 3 pick first +// metrics to emit, and makes sure the correct OpenTelemetry metrics atoms emit. +func (s) TestPickFirstMetricsE2E(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + ss := &stubserver.StubServer{ + EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + ss.StartServer() + defer ss.Stop() + + sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(pfConfig) + r := manual.NewBuilderWithScheme("whatever") + r.InitialState(resolver.State{ + ServiceConfig: sc, + Addresses: []resolver.Address{{Addr: "bad address"}}}, + ) // Will trigger connection failed. + + grpcTarget := r.Scheme() + ":///" + reader := metric.NewManualReader() + provider := metric.NewMeterProvider(metric.WithReader(reader)) + mo := opentelemetry.MetricsOptions{ + MeterProvider: provider, + Metrics: opentelemetry.DefaultMetrics().Add("grpc.lb.pick_first.disconnections", "grpc.lb.pick_first.connection_attempts_succeeded", "grpc.lb.pick_first.connection_attempts_failed"), + } + + cc, err := grpc.NewClient(grpcTarget, opentelemetry.DialOption(opentelemetry.Options{MetricsOptions: mo}), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(r)) + if err != nil { + t.Fatalf("NewClient() failed with error: %v", err) + } + defer cc.Close() + + tsc := testgrpc.NewTestServiceClient(cc) + if _, err := tsc.EmptyCall(ctx, &testpb.Empty{}); err == nil { + t.Fatalf("EmptyCall() passed when expected to fail") + } + + r.UpdateState(resolver.State{ + ServiceConfig: sc, + Addresses: []resolver.Address{{Addr: ss.Address}}, + }) // Will trigger successful connection metric. + if _, err := tsc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { + t.Fatalf("EmptyCall() failed: %v", err) + } + + // Stop the server, that should send signal to disconnect, which will + // eventually emit disconnection metric before ClientConn goes IDLE. + ss.Stop() + testutils.AwaitState(ctx, t, cc, connectivity.Idle) + wantMetrics := []metricdata.Metrics{ + { + Name: "grpc.lb.pick_first.connection_attempts_succeeded", + Description: "EXPERIMENTAL. Number of successful connection attempts.", + Unit: "attempt", + Data: metricdata.Sum[int64]{ + DataPoints: []metricdata.DataPoint[int64]{ + { + Attributes: attribute.NewSet(attribute.String("grpc.target", grpcTarget)), + Value: 1, + }, + }, + Temporality: metricdata.CumulativeTemporality, + IsMonotonic: true, + }, + }, + { + Name: "grpc.lb.pick_first.connection_attempts_failed", + Description: "EXPERIMENTAL. Number of failed connection attempts.", + Unit: "attempt", + Data: metricdata.Sum[int64]{ + DataPoints: []metricdata.DataPoint[int64]{ + { + Attributes: attribute.NewSet(attribute.String("grpc.target", grpcTarget)), + Value: 1, + }, + }, + Temporality: metricdata.CumulativeTemporality, + IsMonotonic: true, + }, + }, + { + Name: "grpc.lb.pick_first.disconnections", + Description: "EXPERIMENTAL. Number of times the selected subchannel becomes disconnected.", + Unit: "disconnection", + Data: metricdata.Sum[int64]{ + DataPoints: []metricdata.DataPoint[int64]{ + { + Attributes: attribute.NewSet(attribute.String("grpc.target", grpcTarget)), + Value: 1, + }, + }, + Temporality: metricdata.CumulativeTemporality, + IsMonotonic: true, + }, + }, + } + + gotMetrics := metricsDataFromReader(ctx, reader) + for _, metric := range wantMetrics { + val, ok := gotMetrics[metric.Name] + if !ok { + t.Fatalf("Metric %v not present in recorded metrics", metric.Name) + } + if !metricdatatest.AssertEqual(t, metric, val, metricdatatest.IgnoreTimestamp(), metricdatatest.IgnoreExemplars()) { + t.Fatalf("Metrics data type not equal for metric: %v", metric.Name) + } + } +} + +func metricsDataFromReader(ctx context.Context, reader *metric.ManualReader) map[string]metricdata.Metrics { + rm := &metricdata.ResourceMetrics{} + reader.Collect(ctx, rm) + gotMetrics := map[string]metricdata.Metrics{} + for _, sm := range rm.ScopeMetrics { + for _, m := range sm.Metrics { + gotMetrics[m.Name] = m + } + } + return gotMetrics +} diff --git a/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go b/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go index aaec87497fd4..1ebf7cea5e94 100644 --- a/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go +++ b/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go @@ -36,6 +36,7 @@ import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/pickfirst/internal" "google.golang.org/grpc/connectivity" + expstats "google.golang.org/grpc/experimental/stats" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/internal/envconfig" internalgrpclog "google.golang.org/grpc/internal/grpclog" @@ -57,7 +58,28 @@ var ( // Name is the name of the pick_first_leaf balancer. // It is changed to "pick_first" in init() if this balancer is to be // registered as the default pickfirst. - Name = "pick_first_leaf" + Name = "pick_first_leaf" + disconnectionsMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{ + Name: "grpc.lb.pick_first.disconnections", + Description: "EXPERIMENTAL. Number of times the selected subchannel becomes disconnected.", + Unit: "disconnection", + Labels: []string{"grpc.target"}, + Default: false, + }) + connectionAttemptsSucceededMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{ + Name: "grpc.lb.pick_first.connection_attempts_succeeded", + Description: "EXPERIMENTAL. Number of successful connection attempts.", + Unit: "attempt", + Labels: []string{"grpc.target"}, + Default: false, + }) + connectionAttemptsFailedMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{ + Name: "grpc.lb.pick_first.connection_attempts_failed", + Description: "EXPERIMENTAL. Number of failed connection attempts.", + Unit: "attempt", + Labels: []string{"grpc.target"}, + Default: false, + }) ) const ( @@ -80,9 +102,12 @@ const ( type pickfirstBuilder struct{} -func (pickfirstBuilder) Build(cc balancer.ClientConn, _ balancer.BuildOptions) balancer.Balancer { +func (pickfirstBuilder) Build(cc balancer.ClientConn, bo balancer.BuildOptions) balancer.Balancer { b := &pickfirstBalancer{ - cc: cc, + cc: cc, + target: bo.Target.String(), + metricsRecorder: bo.MetricsRecorder, // ClientConn will always create a Metrics Recorder. + addressList: addressList{}, subConns: resolver.NewAddressMap(), state: connectivity.Connecting, @@ -147,8 +172,10 @@ func (b *pickfirstBalancer) newSCData(addr resolver.Address) (*scData, error) { type pickfirstBalancer struct { // The following fields are initialized at build time and read-only after // that and therefore do not need to be guarded by a mutex. - logger *internalgrpclog.PrefixLogger - cc balancer.ClientConn + logger *internalgrpclog.PrefixLogger + cc balancer.ClientConn + target string + metricsRecorder expstats.MetricsRecorder // guaranteed to be non nil // The mutex is used to ensure synchronization of updates triggered // from the idle picker and the already serialized resolver, @@ -532,10 +559,6 @@ func (b *pickfirstBalancer) updateSubConnState(sd *scData, newState balancer.Sub b.mu.Lock() defer b.mu.Unlock() oldState := sd.state - // Record a connection attempt when exiting CONNECTING. - if newState.ConnectivityState == connectivity.TransientFailure { - sd.connectionFailedInFirstPass = true - } sd.state = newState.ConnectivityState // Previously relevant SubConns can still callback with state updates. // To prevent pickers from returning these obsolete SubConns, this logic @@ -548,7 +571,14 @@ func (b *pickfirstBalancer) updateSubConnState(sd *scData, newState balancer.Sub return } + // Record a connection attempt when exiting CONNECTING. + if newState.ConnectivityState == connectivity.TransientFailure { + sd.connectionFailedInFirstPass = true + connectionAttemptsFailedMetric.Record(b.metricsRecorder, 1, b.target) + } + if newState.ConnectivityState == connectivity.Ready { + connectionAttemptsSucceededMetric.Record(b.metricsRecorder, 1, b.target) b.shutdownRemainingLocked(sd) if !b.addressList.seekTo(sd.addr) { // This should not fail as we should have only one SubConn after @@ -575,6 +605,15 @@ func (b *pickfirstBalancer) updateSubConnState(sd *scData, newState balancer.Sub // the first address when the picker is used. b.shutdownRemainingLocked(sd) b.state = connectivity.Idle + // READY SubConn interspliced in between CONNECTING and IDLE, need to + // account for that. + if oldState == connectivity.Connecting { + // A known issue (https://github.com/grpc/grpc-go/issues/7862) + // causes a race that prevents the READY state change notification. + // This works around it. + connectionAttemptsSucceededMetric.Record(b.metricsRecorder, 1, b.target) + } + disconnectionsMetric.Record(b.metricsRecorder, 1, b.target) b.addressList.reset() b.cc.UpdateState(balancer.State{ ConnectivityState: connectivity.Idle, diff --git a/balancer/pickfirst/pickfirstleaf/pickfirstleaf_ext_test.go b/balancer/pickfirst/pickfirstleaf/pickfirstleaf_ext_test.go index bf957f98b119..007157249689 100644 --- a/balancer/pickfirst/pickfirstleaf/pickfirstleaf_ext_test.go +++ b/balancer/pickfirst/pickfirstleaf/pickfirstleaf_ext_test.go @@ -39,6 +39,7 @@ import ( "google.golang.org/grpc/internal/stubserver" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/internal/testutils/pickfirst" + "google.golang.org/grpc/internal/testutils/stats" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" "google.golang.org/grpc/status" @@ -863,10 +864,12 @@ func (s) TestPickFirstLeaf_HappyEyeballs_TF_AfterEndOfList(t *testing.T) { triggerTimer, timeAfter := mockTimer() pfinternal.TimeAfterFunc = timeAfter + tmr := stats.NewTestMetricsRecorder() dialer := testutils.NewBlockingDialer() opts := []grpc.DialOption{ grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, pickfirstleaf.Name)), grpc.WithContextDialer(dialer.DialContext), + grpc.WithStatsHandler(tmr), } cc, rb, bm := setupPickFirstLeaf(t, 3, opts...) addrs := bm.resolverAddrs() @@ -906,6 +909,7 @@ func (s) TestPickFirstLeaf_HappyEyeballs_TF_AfterEndOfList(t *testing.T) { // First SubConn Fails. holds[0].Fail(fmt.Errorf("test error")) + tmr.WaitForInt64CountIncr(ctx, 1) // No TF should be reported until the first pass is complete. shortCtx, shortCancel := context.WithTimeout(ctx, defaultTestShortTimeout) @@ -916,11 +920,24 @@ func (s) TestPickFirstLeaf_HappyEyeballs_TF_AfterEndOfList(t *testing.T) { shortCtx, shortCancel = context.WithTimeout(ctx, defaultTestShortTimeout) defer shortCancel() holds[2].Fail(fmt.Errorf("test error")) + tmr.WaitForInt64CountIncr(ctx, 1) testutils.AwaitNotState(shortCtx, t, cc, connectivity.TransientFailure) // Last SubConn fails, this should result in a TF update. holds[1].Fail(fmt.Errorf("test error")) + tmr.WaitForInt64CountIncr(ctx, 1) testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure) + + // Only connection attempt fails in this test. + if got, _ := tmr.Metric("grpc.lb.pick_first.connection_attempts_succeeded"); got != 0 { + t.Errorf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.pick_first.connection_attempts_succeeded", got, 0) + } + if got, _ := tmr.Metric("grpc.lb.pick_first.connection_attempts_failed"); got != 1 { + t.Errorf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.pick_first.connection_attempts_failed", got, 1) + } + if got, _ := tmr.Metric("grpc.lb.pick_first.disconnections"); got != 0 { + t.Errorf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.pick_first.disconnections", got, 0) + } } // Test verifies that pickfirst attempts to connect to the second backend once @@ -936,10 +953,12 @@ func (s) TestPickFirstLeaf_HappyEyeballs_TriggerConnectionDelay(t *testing.T) { triggerTimer, timeAfter := mockTimer() pfinternal.TimeAfterFunc = timeAfter + tmr := stats.NewTestMetricsRecorder() dialer := testutils.NewBlockingDialer() opts := []grpc.DialOption{ grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, pickfirstleaf.Name)), grpc.WithContextDialer(dialer.DialContext), + grpc.WithStatsHandler(tmr), } cc, rb, bm := setupPickFirstLeaf(t, 2, opts...) addrs := bm.resolverAddrs() @@ -968,6 +987,17 @@ func (s) TestPickFirstLeaf_HappyEyeballs_TriggerConnectionDelay(t *testing.T) { // that the channel becomes READY. holds[1].Resume() testutils.AwaitState(ctx, t, cc, connectivity.Ready) + + // Only connection attempt successes in this test. + if got, _ := tmr.Metric("grpc.lb.pick_first.connection_attempts_succeeded"); got != 1 { + t.Errorf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.pick_first.connection_attempts_succeeded", got, 1) + } + if got, _ := tmr.Metric("grpc.lb.pick_first.connection_attempts_failed"); got != 0 { + t.Errorf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.pick_first.connection_attempts_failed", got, 0) + } + if got, _ := tmr.Metric("grpc.lb.pick_first.disconnections"); got != 0 { + t.Errorf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.pick_first.disconnections", got, 0) + } } // Test tests the pickfirst balancer by causing a SubConn to fail and then @@ -983,10 +1013,12 @@ func (s) TestPickFirstLeaf_HappyEyeballs_TF_ThenTimerFires(t *testing.T) { triggerTimer, timeAfter := mockTimer() pfinternal.TimeAfterFunc = timeAfter + tmr := stats.NewTestMetricsRecorder() dialer := testutils.NewBlockingDialer() opts := []grpc.DialOption{ grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, pickfirstleaf.Name)), grpc.WithContextDialer(dialer.DialContext), + grpc.WithStatsHandler(tmr), } cc, rb, bm := setupPickFirstLeaf(t, 3, opts...) addrs := bm.resolverAddrs() @@ -1014,6 +1046,9 @@ func (s) TestPickFirstLeaf_HappyEyeballs_TF_ThenTimerFires(t *testing.T) { if holds[1].Wait(ctx) != true { t.Fatalf("Timeout waiting for server %d with address %q to be contacted", 1, addrs[1]) } + if got, _ := tmr.Metric("grpc.lb.pick_first.connection_attempts_failed"); got != 1 { + t.Errorf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.pick_first.connection_attempts_failed", got, 1) + } if holds[2].IsStarted() != false { t.Fatalf("Server %d with address %q contacted unexpectedly", 2, addrs[2]) } @@ -1030,13 +1065,20 @@ func (s) TestPickFirstLeaf_HappyEyeballs_TF_ThenTimerFires(t *testing.T) { // that the channel becomes READY. holds[1].Resume() testutils.AwaitState(ctx, t, cc, connectivity.Ready) + + if got, _ := tmr.Metric("grpc.lb.pick_first.connection_attempts_succeeded"); got != 1 { + t.Errorf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.pick_first.connection_attempts_succeeded", got, 1) + } + if got, _ := tmr.Metric("grpc.lb.pick_first.disconnections"); got != 0 { + t.Errorf("Unexpected data for metric %v, got: %v, want: %v", "grpc.lb.pick_first.disconnections", got, 0) + } } func (s) TestPickFirstLeaf_InterleavingIPV4Preffered(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() cc := testutils.NewBalancerClientConn(t) - bal := balancer.Get(pickfirstleaf.Name).Build(cc, balancer.BuildOptions{}) + bal := balancer.Get(pickfirstleaf.Name).Build(cc, balancer.BuildOptions{MetricsRecorder: &stats.NoopMetricsRecorder{}}) defer bal.Close() ccState := balancer.ClientConnState{ ResolverState: resolver.State{ @@ -1082,7 +1124,7 @@ func (s) TestPickFirstLeaf_InterleavingIPv6Preffered(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() cc := testutils.NewBalancerClientConn(t) - bal := balancer.Get(pickfirstleaf.Name).Build(cc, balancer.BuildOptions{}) + bal := balancer.Get(pickfirstleaf.Name).Build(cc, balancer.BuildOptions{MetricsRecorder: &stats.NoopMetricsRecorder{}}) defer bal.Close() ccState := balancer.ClientConnState{ ResolverState: resolver.State{ @@ -1126,7 +1168,7 @@ func (s) TestPickFirstLeaf_InterleavingUnknownPreffered(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() cc := testutils.NewBalancerClientConn(t) - bal := balancer.Get(pickfirstleaf.Name).Build(cc, balancer.BuildOptions{}) + bal := balancer.Get(pickfirstleaf.Name).Build(cc, balancer.BuildOptions{MetricsRecorder: &stats.NoopMetricsRecorder{}}) defer bal.Close() ccState := balancer.ClientConnState{ ResolverState: resolver.State{ diff --git a/balancer/pickfirst/pickfirstleaf/pickfirstleaf_test.go b/balancer/pickfirst/pickfirstleaf/pickfirstleaf_test.go index 71984a238cd5..f269a71a7a97 100644 --- a/balancer/pickfirst/pickfirstleaf/pickfirstleaf_test.go +++ b/balancer/pickfirst/pickfirstleaf/pickfirstleaf_test.go @@ -29,6 +29,7 @@ import ( "google.golang.org/grpc/connectivity" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/testutils/stats" "google.golang.org/grpc/resolver" ) @@ -195,7 +196,7 @@ func (s) TestPickFirstLeaf_TFPickerUpdate(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() cc := testutils.NewBalancerClientConn(t) - bal := pickfirstBuilder{}.Build(cc, balancer.BuildOptions{}) + bal := pickfirstBuilder{}.Build(cc, balancer.BuildOptions{MetricsRecorder: &stats.NoopMetricsRecorder{}}) defer bal.Close() ccState := balancer.ClientConnState{ ResolverState: resolver.State{ diff --git a/gcp/observability/go.sum b/gcp/observability/go.sum index 30e984fb4343..472ac41d57ad 100644 --- a/gcp/observability/go.sum +++ b/gcp/observability/go.sum @@ -1107,6 +1107,7 @@ go.opentelemetry.io/otel/metric v1.31.0 h1:FSErL0ATQAmYHUIzSezZibnyVlft1ybhy4ozR go.opentelemetry.io/otel/metric v1.31.0/go.mod h1:C3dEloVbLuYoX41KpmAhOqNriGbA+qqH6PQ5E5mUfnY= go.opentelemetry.io/otel/sdk v1.31.0 h1:xLY3abVHYZ5HSfOg3l2E5LUj2Cwva5Y7yGxnSW9H5Gk= go.opentelemetry.io/otel/sdk v1.31.0/go.mod h1:TfRbMdhvxIIr/B2N2LQW2S5v9m3gOQ/08KsbbO5BPT0= +go.opentelemetry.io/otel/sdk/metric v1.31.0 h1:i9hxxLJF/9kkvfHppyLL55aW7iIJz4JjxTeYusH7zMc= go.opentelemetry.io/otel/sdk/metric v1.31.0/go.mod h1:CRInTMVvNhUKgSAMbKyTMxqOBC0zgyxzW55lZzX43Y8= go.opentelemetry.io/otel/trace v1.31.0 h1:ffjsj1aRouKewfr85U2aGagJ46+MvodynlQ1HYdmJys= go.opentelemetry.io/otel/trace v1.31.0/go.mod h1:TXZkRk7SM2ZQLtR6eoAWQFIHPvzQ06FJAsO1tJg480A= diff --git a/internal/balancergroup/balancergroup_test.go b/internal/balancergroup/balancergroup_test.go index c154c029d8f2..e49e8135a1b7 100644 --- a/internal/balancergroup/balancergroup_test.go +++ b/internal/balancergroup/balancergroup_test.go @@ -33,6 +33,7 @@ import ( "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/testutils/stats" "google.golang.org/grpc/resolver" ) @@ -603,6 +604,7 @@ func (s) TestBalancerGracefulSwitch(t *testing.T) { childPolicyName := t.Name() stub.Register(childPolicyName, stub.BalancerFuncs{ Init: func(bd *stub.BalancerData) { + bd.BuildOptions.MetricsRecorder = &stats.NoopMetricsRecorder{} bd.Data = balancer.Get(pickfirst.Name).Build(bd.ClientConn, bd.BuildOptions) }, Close: func(bd *stub.BalancerData) { diff --git a/internal/testutils/stats/test_metrics_recorder.go b/internal/testutils/stats/test_metrics_recorder.go index 72a20c1cbf44..e13013e38d53 100644 --- a/internal/testutils/stats/test_metrics_recorder.go +++ b/internal/testutils/stats/test_metrics_recorder.go @@ -63,6 +63,8 @@ func NewTestMetricsRecorder() *TestMetricsRecorder { // Metric returns the most recent data for a metric, and whether this recorder // has received data for a metric. func (r *TestMetricsRecorder) Metric(name string) (float64, bool) { + r.mu.Lock() + defer r.mu.Unlock() data, ok := r.data[estats.Metric(name)] return data, ok } @@ -102,6 +104,21 @@ func (r *TestMetricsRecorder) WaitForInt64Count(ctx context.Context, metricsData return nil } +// WaitForInt64CountIncr waits for an int64 count metric to be recorded and +// verifies that the recorded metrics data incr matches the expected incr. +// Returns an error if failed to wait or received wrong data. +func (r *TestMetricsRecorder) WaitForInt64CountIncr(ctx context.Context, incrWant int64) error { + got, err := r.intCountCh.Receive(ctx) + if err != nil { + return fmt.Errorf("timeout waiting for int64Count") + } + metricsDataGot := got.(MetricsData) + if diff := cmp.Diff(metricsDataGot.IntIncr, incrWant); diff != "" { + return fmt.Errorf("int64count metricsData received unexpected value (-got, +want): %v", diff) + } + return nil +} + // RecordInt64Count sends the metrics data to the intCountCh channel and updates // the internal data map with the recorded value. func (r *TestMetricsRecorder) RecordInt64Count(handle *estats.Int64CountHandle, incr int64, labels ...string) { diff --git a/interop/observability/go.sum b/interop/observability/go.sum index a749b30fb223..4cdbd27b0fc9 100644 --- a/interop/observability/go.sum +++ b/interop/observability/go.sum @@ -1109,6 +1109,7 @@ go.opentelemetry.io/otel/metric v1.31.0 h1:FSErL0ATQAmYHUIzSezZibnyVlft1ybhy4ozR go.opentelemetry.io/otel/metric v1.31.0/go.mod h1:C3dEloVbLuYoX41KpmAhOqNriGbA+qqH6PQ5E5mUfnY= go.opentelemetry.io/otel/sdk v1.31.0 h1:xLY3abVHYZ5HSfOg3l2E5LUj2Cwva5Y7yGxnSW9H5Gk= go.opentelemetry.io/otel/sdk v1.31.0/go.mod h1:TfRbMdhvxIIr/B2N2LQW2S5v9m3gOQ/08KsbbO5BPT0= +go.opentelemetry.io/otel/sdk/metric v1.31.0 h1:i9hxxLJF/9kkvfHppyLL55aW7iIJz4JjxTeYusH7zMc= go.opentelemetry.io/otel/sdk/metric v1.31.0/go.mod h1:CRInTMVvNhUKgSAMbKyTMxqOBC0zgyxzW55lZzX43Y8= go.opentelemetry.io/otel/trace v1.31.0 h1:ffjsj1aRouKewfr85U2aGagJ46+MvodynlQ1HYdmJys= go.opentelemetry.io/otel/trace v1.31.0/go.mod h1:TXZkRk7SM2ZQLtR6eoAWQFIHPvzQ06FJAsO1tJg480A= diff --git a/security/advancedtls/examples/go.sum b/security/advancedtls/examples/go.sum index 9102af782ca0..2192e85919d7 100644 --- a/security/advancedtls/examples/go.sum +++ b/security/advancedtls/examples/go.sum @@ -1,7 +1,23 @@ +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +go.opentelemetry.io/otel v1.31.0 h1:NsJcKPIW0D0H3NgzPDHmo0WW6SptzPdqg/L1zsIm2hY= +go.opentelemetry.io/otel v1.31.0/go.mod h1:O0C14Yl9FgkjqcCZAsE053C13OaddMYr/hz6clDkEJE= +go.opentelemetry.io/otel/metric v1.31.0 h1:FSErL0ATQAmYHUIzSezZibnyVlft1ybhy4ozRPcF2fE= +go.opentelemetry.io/otel/metric v1.31.0/go.mod h1:C3dEloVbLuYoX41KpmAhOqNriGbA+qqH6PQ5E5mUfnY= +go.opentelemetry.io/otel/sdk v1.31.0 h1:xLY3abVHYZ5HSfOg3l2E5LUj2Cwva5Y7yGxnSW9H5Gk= +go.opentelemetry.io/otel/sdk v1.31.0/go.mod h1:TfRbMdhvxIIr/B2N2LQW2S5v9m3gOQ/08KsbbO5BPT0= +go.opentelemetry.io/otel/sdk/metric v1.31.0 h1:i9hxxLJF/9kkvfHppyLL55aW7iIJz4JjxTeYusH7zMc= +go.opentelemetry.io/otel/sdk/metric v1.31.0/go.mod h1:CRInTMVvNhUKgSAMbKyTMxqOBC0zgyxzW55lZzX43Y8= +go.opentelemetry.io/otel/trace v1.31.0 h1:ffjsj1aRouKewfr85U2aGagJ46+MvodynlQ1HYdmJys= +go.opentelemetry.io/otel/trace v1.31.0/go.mod h1:TXZkRk7SM2ZQLtR6eoAWQFIHPvzQ06FJAsO1tJg480A= golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= diff --git a/security/advancedtls/go.sum b/security/advancedtls/go.sum index 9102af782ca0..2192e85919d7 100644 --- a/security/advancedtls/go.sum +++ b/security/advancedtls/go.sum @@ -1,7 +1,23 @@ +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +go.opentelemetry.io/otel v1.31.0 h1:NsJcKPIW0D0H3NgzPDHmo0WW6SptzPdqg/L1zsIm2hY= +go.opentelemetry.io/otel v1.31.0/go.mod h1:O0C14Yl9FgkjqcCZAsE053C13OaddMYr/hz6clDkEJE= +go.opentelemetry.io/otel/metric v1.31.0 h1:FSErL0ATQAmYHUIzSezZibnyVlft1ybhy4ozRPcF2fE= +go.opentelemetry.io/otel/metric v1.31.0/go.mod h1:C3dEloVbLuYoX41KpmAhOqNriGbA+qqH6PQ5E5mUfnY= +go.opentelemetry.io/otel/sdk v1.31.0 h1:xLY3abVHYZ5HSfOg3l2E5LUj2Cwva5Y7yGxnSW9H5Gk= +go.opentelemetry.io/otel/sdk v1.31.0/go.mod h1:TfRbMdhvxIIr/B2N2LQW2S5v9m3gOQ/08KsbbO5BPT0= +go.opentelemetry.io/otel/sdk/metric v1.31.0 h1:i9hxxLJF/9kkvfHppyLL55aW7iIJz4JjxTeYusH7zMc= +go.opentelemetry.io/otel/sdk/metric v1.31.0/go.mod h1:CRInTMVvNhUKgSAMbKyTMxqOBC0zgyxzW55lZzX43Y8= +go.opentelemetry.io/otel/trace v1.31.0 h1:ffjsj1aRouKewfr85U2aGagJ46+MvodynlQ1HYdmJys= +go.opentelemetry.io/otel/trace v1.31.0/go.mod h1:TXZkRk7SM2ZQLtR6eoAWQFIHPvzQ06FJAsO1tJg480A= golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= diff --git a/stats/opencensus/go.sum b/stats/opencensus/go.sum index 447a12f3eb15..2e88e8bf1877 100644 --- a/stats/opencensus/go.sum +++ b/stats/opencensus/go.sum @@ -821,7 +821,9 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2 github.com/go-latex/latex v0.0.0-20210118124228-b3d85cf34e07/go.mod h1:CO1AlKB2CSIqUrmQPqA0gdRIlnLEY0gK5JGjh37zN5U= github.com/go-latex/latex v0.0.0-20210823091927-c0d11ff05a81/go.mod h1:SX0U8uGpxhq9o2S/CELCSUxEWWAuoCUcVCQWv7G2OCk= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-pdf/fpdf v0.5.0/go.mod h1:HzcnA+A23uwogo0tp9yU+l3V+KXhiESpt1PMayhOh5M= github.com/go-pdf/fpdf v0.6.0/go.mod h1:HzcnA+A23uwogo0tp9yU+l3V+KXhiESpt1PMayhOh5M= @@ -914,6 +916,7 @@ github.com/google/s2a-go v0.1.3/go.mod h1:Ej+mSEMGRnqRzjc7VtF+jdBwYG5fuJfiZ8ELkj github.com/google/s2a-go v0.1.4/go.mod h1:Ej+mSEMGRnqRzjc7VtF+jdBwYG5fuJfiZ8ELkjEwM0A= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/enterprise-certificate-proxy v0.0.0-20220520183353-fd19c99a87aa/go.mod h1:17drOmN3MwGY7t0e+Ei9b45FFGA3fBs3x36SsCg1hq8= github.com/googleapis/enterprise-certificate-proxy v0.1.0/go.mod h1:17drOmN3MwGY7t0e+Ei9b45FFGA3fBs3x36SsCg1hq8= @@ -1037,10 +1040,15 @@ go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= go.opentelemetry.io/contrib/detectors/gcp v1.31.0/go.mod h1:tzQL6E1l+iV44YFTkcAeNQqzXUiekSYP9jjJjXwEd00= +go.opentelemetry.io/otel v1.31.0 h1:NsJcKPIW0D0H3NgzPDHmo0WW6SptzPdqg/L1zsIm2hY= go.opentelemetry.io/otel v1.31.0/go.mod h1:O0C14Yl9FgkjqcCZAsE053C13OaddMYr/hz6clDkEJE= +go.opentelemetry.io/otel/metric v1.31.0 h1:FSErL0ATQAmYHUIzSezZibnyVlft1ybhy4ozRPcF2fE= go.opentelemetry.io/otel/metric v1.31.0/go.mod h1:C3dEloVbLuYoX41KpmAhOqNriGbA+qqH6PQ5E5mUfnY= +go.opentelemetry.io/otel/sdk v1.31.0 h1:xLY3abVHYZ5HSfOg3l2E5LUj2Cwva5Y7yGxnSW9H5Gk= go.opentelemetry.io/otel/sdk v1.31.0/go.mod h1:TfRbMdhvxIIr/B2N2LQW2S5v9m3gOQ/08KsbbO5BPT0= +go.opentelemetry.io/otel/sdk/metric v1.31.0 h1:i9hxxLJF/9kkvfHppyLL55aW7iIJz4JjxTeYusH7zMc= go.opentelemetry.io/otel/sdk/metric v1.31.0/go.mod h1:CRInTMVvNhUKgSAMbKyTMxqOBC0zgyxzW55lZzX43Y8= +go.opentelemetry.io/otel/trace v1.31.0 h1:ffjsj1aRouKewfr85U2aGagJ46+MvodynlQ1HYdmJys= go.opentelemetry.io/otel/trace v1.31.0/go.mod h1:TXZkRk7SM2ZQLtR6eoAWQFIHPvzQ06FJAsO1tJg480A= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= go.opentelemetry.io/proto/otlp v0.15.0/go.mod h1:H7XAot3MsfNsj7EXtrA2q5xSNQ10UqI405h3+duxN4U= diff --git a/xds/internal/balancer/clustermanager/clustermanager_test.go b/xds/internal/balancer/clustermanager/clustermanager_test.go index b606cb9e5e34..6ef8738dfcf4 100644 --- a/xds/internal/balancer/clustermanager/clustermanager_test.go +++ b/xds/internal/balancer/clustermanager/clustermanager_test.go @@ -34,6 +34,7 @@ import ( "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/hierarchy" "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/internal/testutils/stats" "google.golang.org/grpc/resolver" "google.golang.org/grpc/status" ) @@ -643,6 +644,7 @@ func TestClusterGracefulSwitch(t *testing.T) { childPolicyName := t.Name() stub.Register(childPolicyName, stub.BalancerFuncs{ Init: func(bd *stub.BalancerData) { + bd.BuildOptions.MetricsRecorder = &stats.NoopMetricsRecorder{} bd.Data = balancer.Get(pickfirst.Name).Build(bd.ClientConn, bd.BuildOptions) }, Close: func(bd *stub.BalancerData) { From 4c07bca27377feb808912b844b3fa95ad10f946b Mon Sep 17 00:00:00 2001 From: Ismail Gjevori Date: Wed, 27 Nov 2024 00:08:44 +0100 Subject: [PATCH 9/9] stream: add jitter to retry backoff in accordance with gRFC A6 (#7869) --- service_config.go | 5 +++-- stream.go | 11 +++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/service_config.go b/service_config.go index 2671c5ef69f0..7e83027d1994 100644 --- a/service_config.go +++ b/service_config.go @@ -168,6 +168,7 @@ func init() { return parseServiceConfig(js, defaultMaxCallAttempts) } } + func parseServiceConfig(js string, maxAttempts int) *serviceconfig.ParseResult { if len(js) == 0 { return &serviceconfig.ParseResult{Err: fmt.Errorf("no JSON service config provided")} @@ -297,7 +298,7 @@ func convertRetryPolicy(jrp *jsonRetryPolicy, maxAttempts int) (p *internalservi return rp, nil } -func min(a, b *int) *int { +func minPointers(a, b *int) *int { if *a < *b { return a } @@ -309,7 +310,7 @@ func getMaxSize(mcMax, doptMax *int, defaultVal int) *int { return &defaultVal } if mcMax != nil && doptMax != nil { - return min(mcMax, doptMax) + return minPointers(mcMax, doptMax) } if mcMax != nil { return mcMax diff --git a/stream.go b/stream.go index 6d10d0ac8713..17e2267b3320 100644 --- a/stream.go +++ b/stream.go @@ -218,7 +218,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth var mc serviceconfig.MethodConfig var onCommit func() - var newStream = func(ctx context.Context, done func()) (iresolver.ClientStream, error) { + newStream := func(ctx context.Context, done func()) (iresolver.ClientStream, error) { return newClientStreamWithParams(ctx, desc, cc, method, mc, onCommit, done, opts...) } @@ -708,11 +708,10 @@ func (a *csAttempt) shouldRetry(err error) (bool, error) { cs.numRetriesSincePushback = 0 } else { fact := math.Pow(rp.BackoffMultiplier, float64(cs.numRetriesSincePushback)) - cur := float64(rp.InitialBackoff) * fact - if max := float64(rp.MaxBackoff); cur > max { - cur = max - } - dur = time.Duration(rand.Int64N(int64(cur))) + cur := min(float64(rp.InitialBackoff)*fact, float64(rp.MaxBackoff)) + // Apply jitter by multiplying with a random factor between 0.8 and 1.2 + cur *= 0.8 + 0.4*rand.Float64() + dur = time.Duration(int64(cur)) cs.numRetriesSincePushback++ }