Skip to content

Commit

Permalink
xds: Use the connected address for locality (grpc#7357)
Browse files Browse the repository at this point in the history
  • Loading branch information
townba committed Jul 1, 2024
1 parent c9caa9e commit 08ebd15
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 33 deletions.
15 changes: 15 additions & 0 deletions balancer/balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,20 @@ func unregisterForTesting(name string) {
delete(m, name)
}

// getConnectedAddress returns the connected address for a SubConnState.
func getConnectedAddress(scs SubConnState) (resolver.Address, bool) {
return scs.connectedAddress, scs.ConnectivityState == connectivity.Ready
}

// setConnectedAddress sets the connected address for a SubConnState.
func setConnectedAddress(scs *SubConnState, addr resolver.Address) {
scs.connectedAddress = addr
}

func init() {
internal.BalancerUnregister = unregisterForTesting
internal.GetConnectedAddress = getConnectedAddress
internal.SetConnectedAddress = setConnectedAddress
}

// Get returns the resolver builder registered with the given name.
Expand Down Expand Up @@ -410,6 +422,9 @@ type SubConnState struct {
// ConnectionError is set if the ConnectivityState is TransientFailure,
// describing the reason the SubConn failed. Otherwise, it is nil.
ConnectionError error
// connectedAddr contains the connected address when ConnectivityState is Ready. Otherwise, it is
// indeterminate.
connectedAddress resolver.Address
}

// ClientConnState describes the state of a ClientConn relevant to the
Expand Down
11 changes: 9 additions & 2 deletions balancer_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
grpcinternal "google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/balancer/gracefulswitch"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcsync"
Expand Down Expand Up @@ -252,15 +253,21 @@ type acBalancerWrapper struct {

// updateState is invoked by grpc to push a subConn state update to the
// underlying balancer.
func (acbw *acBalancerWrapper) updateState(s connectivity.State, err error) {
func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolver.Address, err error) {
acbw.ccb.serializer.Schedule(func(ctx context.Context) {
if ctx.Err() != nil || acbw.ccb.balancer == nil {
return
}
// Even though it is optional for balancers, gracefulswitch ensures
// opts.StateListener is set, so this cannot ever be nil.
// TODO: delete this comment when UpdateSubConnState is removed.
acbw.stateListener(balancer.SubConnState{ConnectivityState: s, ConnectionError: err})
scs := balancer.SubConnState{ConnectivityState: s, ConnectionError: err}
if s == connectivity.Ready {
if SetConnectedAddress, ok := grpcinternal.SetConnectedAddress.(func(state *balancer.SubConnState, addr resolver.Address)); ok {
SetConnectedAddress(&scs, curAddr)
}
}
acbw.stateListener(scs)
})
}

Expand Down
59 changes: 39 additions & 20 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ import (

const (
// minimum time to give a connection to complete
minConnectTimeout = 20 * time.Second
minConnectTimeout = 20 * time.Second
withBalancerAttributes = true
withoutBalancerAttributes = false
)

var (
Expand Down Expand Up @@ -812,16 +814,26 @@ func (cc *ClientConn) applyFailingLBLocked(sc *serviceconfig.ParseResult) {
cc.csMgr.updateState(connectivity.TransientFailure)
}

// Makes a copy of the input addresses slice and clears out the balancer
// attributes field. Addresses are passed during subconn creation and address
// update operations. In both cases, we will clear the balancer attributes by
// calling this function, and therefore we will be able to use the Equal method
// provided by the resolver.Address type for comparison.
func copyAddressesWithoutBalancerAttributes(in []resolver.Address) []resolver.Address {
// addressWithoutBalancerAttributes returns a copy of the input address with
// the BalancerAttributes field cleared.
func addressWithoutBalancerAttributes(a resolver.Address) resolver.Address {
a.BalancerAttributes = nil
return a
}

// Makes a copy of the input addresses slice and optionally clears out the
// balancer attributes field. Addresses are passed during subconn creation and
// address update operations. In both cases, we may clear the balancer
// attributes by calling this function, which would therefore allow us to use
// the Equal method provided by the resolver.Address type for comparison.
func copyAddresses(in []resolver.Address, includeBalancerAttributes bool) []resolver.Address {
out := make([]resolver.Address, len(in))
for i := range in {
out[i] = in[i]
out[i].BalancerAttributes = nil
if includeBalancerAttributes {
out[i] = in[i]
} else {
out[i] = addressWithoutBalancerAttributes(in[i])
}
}
return out
}
Expand All @@ -837,7 +849,7 @@ func (cc *ClientConn) newAddrConnLocked(addrs []resolver.Address, opts balancer.
ac := &addrConn{
state: connectivity.Idle,
cc: cc,
addrs: copyAddressesWithoutBalancerAttributes(addrs),
addrs: copyAddresses(addrs, withBalancerAttributes),
scopts: opts,
dopts: cc.dopts,
channelz: channelz.RegisterSubChannel(cc.channelz, ""),
Expand Down Expand Up @@ -924,12 +936,18 @@ func (ac *addrConn) connect() error {
return nil
}

func equalAddresses(a, b []resolver.Address) bool {
func equalAddressIgnoreBalancerAttributes(a, b resolver.Address) bool {
return a.Addr == b.Addr && a.ServerName == b.ServerName &&
a.Attributes.Equal(b.Attributes) &&
a.Metadata == b.Metadata
}

func equalAddressesIgnoreBalancerAttributes(a, b []resolver.Address) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if !v.Equal(b[i]) {
if !equalAddressIgnoreBalancerAttributes(v, b[i]) {
return false
}
}
Expand All @@ -939,15 +957,15 @@ func equalAddresses(a, b []resolver.Address) bool {
// updateAddrs updates ac.addrs with the new addresses list and handles active
// connections or connection attempts.
func (ac *addrConn) updateAddrs(addrs []resolver.Address) {
addrs = copyAddressesWithoutBalancerAttributes(addrs)
addrs = copyAddresses(addrs, withBalancerAttributes)
limit := len(addrs)
if limit > 5 {
limit = 5
}
channelz.Infof(logger, ac.channelz, "addrConn: updateAddrs addrs (%d of %d): %v", limit, len(addrs), addrs[:limit])
channelz.Infof(logger, ac.channelz, "addrConn: updateAddrs addrs (%d of %d): %v", limit, len(addrs), copyAddresses(addrs[:limit], withoutBalancerAttributes))

ac.mu.Lock()
if equalAddresses(ac.addrs, addrs) {
if equalAddressesIgnoreBalancerAttributes(ac.addrs, addrs) {
ac.mu.Unlock()
return
}
Expand All @@ -966,7 +984,7 @@ func (ac *addrConn) updateAddrs(addrs []resolver.Address) {
// Try to find the connected address.
for _, a := range addrs {
a.ServerName = ac.cc.getServerName(a)
if a.Equal(ac.curAddr) {
if equalAddressIgnoreBalancerAttributes(a, ac.curAddr) {
// We are connected to a valid address, so do nothing but
// update the addresses.
ac.mu.Unlock()
Expand Down Expand Up @@ -1214,7 +1232,7 @@ func (ac *addrConn) updateConnectivityState(s connectivity.State, lastErr error)
} else {
channelz.Infof(logger, ac.channelz, "Subchannel Connectivity change to %v, last error: %s", s, lastErr)
}
ac.acbw.updateState(s, lastErr)
ac.acbw.updateState(s, ac.curAddr, lastErr)
}

// adjustParams updates parameters used to create transports upon
Expand Down Expand Up @@ -1347,6 +1365,7 @@ func (ac *addrConn) tryAllAddrs(ctx context.Context, addrs []resolver.Address, c
// new transport.
func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address, copts transport.ConnectOptions, connectDeadline time.Time) error {
addr.ServerName = ac.cc.getServerName(addr)
addrWithoutBalancerAttributes := addressWithoutBalancerAttributes(addr)
hctx, hcancel := context.WithCancel(ctx)

onClose := func(r transport.GoAwayReason) {
Expand Down Expand Up @@ -1381,14 +1400,14 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address,
defer cancel()
copts.ChannelzParent = ac.channelz

newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, addr, copts, onClose)
newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, addrWithoutBalancerAttributes, copts, onClose)
if err != nil {
if logger.V(2) {
logger.Infof("Creating new client transport to %q: %v", addr, err)
logger.Infof("Creating new client transport to %q: %v", addrWithoutBalancerAttributes, err)
}
// newTr is either nil, or closed.
hcancel()
channelz.Warningf(logger, ac.channelz, "grpc: addrConn.createTransport failed to connect to %s. Err: %v", addr, err)
channelz.Warningf(logger, ac.channelz, "grpc: addrConn.createTransport failed to connect to %s. Err: %v", addrWithoutBalancerAttributes, err)
return err
}

Expand Down
7 changes: 7 additions & 0 deletions internal/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,13 @@ var (
// ShuffleAddressListForTesting pseudo-randomizes the order of addresses. n
// is the number of elements. swap swaps the elements with indexes i and j.
ShuffleAddressListForTesting any // func(n int, swap func(i, j int))

// GetConnectedAddress returns the connected address for a SubConnState and
// whether the address is valid based on the state.
GetConnectedAddress any // func (scs SubConnState) (resolver.Address, bool)

// SetConnectedAddress sets the connected address for a SubConnState.
SetConnectedAddress any // func(scs *SubConnState, addr resolver.Address)
)

// HealthChecker defines the signature of the client-side LB channel health
Expand Down
20 changes: 17 additions & 3 deletions xds/internal/balancer/clusterimpl/clusterimpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (

"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
grpcinternal "google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/balancer/gracefulswitch"
"google.golang.org/grpc/internal/buffer"
"google.golang.org/grpc/internal/grpclog"
Expand Down Expand Up @@ -366,14 +367,27 @@ func (b *clusterImplBalancer) NewSubConn(addrs []resolver.Address, opts balancer
lID = xdsinternal.GetLocalityID(newAddrs[i])
}
var sc balancer.SubConn
ret := &scWrapper{}
oldListener := opts.StateListener
opts.StateListener = func(state balancer.SubConnState) { b.updateSubConnState(sc, state, oldListener) }
opts.StateListener = func(state balancer.SubConnState) {
b.updateSubConnState(sc, state, oldListener)
// Read connected address and call updateLocalityID() based on the connected address's locality.
// https://github.com/grpc/grpc-go/issues/7339
if GetConnectedAddress, ok := grpcinternal.GetConnectedAddress.(func(state balancer.SubConnState) (resolver.Address, bool)); ok {
if addr, ok := GetConnectedAddress(state); ok {
// TODO: Why is lID empty when running the test? The locality info is being lost somehow.
lID := xdsinternal.GetLocalityID(addr)
if !lID.Equal(xdsinternal.LocalityID{}) {
ret.updateLocalityID(lID)
}
}
}
}
sc, err := b.ClientConn.NewSubConn(newAddrs, opts)
if err != nil {
return nil, err
}
// Wrap this SubConn in a wrapper, and add it to the map.
ret := &scWrapper{SubConn: sc}
ret.SubConn = sc
ret.updateLocalityID(lID)
return ret, nil
}
Expand Down
11 changes: 3 additions & 8 deletions xds/internal/balancer/clusterimpl/tests/balancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,14 +310,9 @@ func (s) TestLoadReportingPickFirstMultiLocality(t *testing.T) {
}
mgmtServer.LRSServer.LRSResponseChan <- &resp

// Wait for load to be reported for locality of server 2.
// We (incorrectly) wait for load report for region-2 because presently
// pickfirst always reports load for the locality of the last address in the
// subconn. This will be fixed by ensuring there is only one address per
// subconn.
// TODO(#7339): Change region to region-1 once fixed.
if err := waitForSuccessfulLoadReport(ctx, mgmtServer.LRSServer, "region-2"); err != nil {
t.Fatalf("region-2 did not receive load due to error: %v", err)
// Wait for load to be reported for locality of server 1.
if err := waitForSuccessfulLoadReport(ctx, mgmtServer.LRSServer, "region-1"); err != nil {
t.Fatalf("Server 1 did not receive load due to error: %v", err)
}

// Stop server 1 and send one more rpc. Now the request should go to server 2.
Expand Down

0 comments on commit 08ebd15

Please sign in to comment.