diff --git a/client/client.go b/client/client.go index 2fc9bd3ef0d..7aa28cbc0cd 100644 --- a/client/client.go +++ b/client/client.go @@ -25,8 +25,6 @@ import ( "github.com/opentracing/opentracing-go" "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "github.com/pingcap/errors" "github.com/pingcap/failpoint" @@ -42,7 +40,6 @@ import ( "github.com/tikv/pd/client/metrics" "github.com/tikv/pd/client/opt" "github.com/tikv/pd/client/pkg/caller" - cb "github.com/tikv/pd/client/pkg/circuitbreaker" "github.com/tikv/pd/client/pkg/utils/tlsutil" sd "github.com/tikv/pd/client/servicediscovery" ) @@ -460,12 +457,6 @@ func (c *client) UpdateOption(option opt.DynamicOption, value any) error { return errors.New("[pd] invalid value type for TSOClientRPCConcurrency option, it should be int") } c.inner.option.SetTSOClientRPCConcurrency(value) - case opt.RegionMetadataCircuitBreakerSettings: - applySettingsChange, ok := value.(func(config *cb.Settings)) - if !ok { - return errors.New("[pd] invalid value type for RegionMetadataCircuitBreakerSettings option, it should be pd.Settings") - } - c.inner.regionMetaCircuitBreaker.ChangeSettings(applySettingsChange) default: return errors.New("[pd] unsupported client option") } @@ -660,13 +651,7 @@ func (c *client) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegio if serviceClient == nil { return nil, errs.ErrClientGetProtoClient } - resp, err := c.inner.regionMetaCircuitBreaker.Execute(func() (*pdpb.GetRegionResponse, cb.Overloading, error) { - region, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegion(cctx, req) - failpoint.Inject("triggerCircuitBreaker", func() { - err = status.Error(codes.ResourceExhausted, "resource exhausted") - }) - return region, isOverloaded(err), err - }) + resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegion(cctx, req) if serviceClient.NeedRetry(resp.GetHeader().GetError(), err) { protoClient, cctx := c.getClientAndContext(ctx) if protoClient == nil { @@ -706,10 +691,7 @@ func (c *client) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetR if serviceClient == nil { return nil, errs.ErrClientGetProtoClient } - resp, err := c.inner.regionMetaCircuitBreaker.Execute(func() (*pdpb.GetRegionResponse, cb.Overloading, error) { - resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetPrevRegion(cctx, req) - return resp, isOverloaded(err), err - }) + resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetPrevRegion(cctx, req) if serviceClient.NeedRetry(resp.GetHeader().GetError(), err) { protoClient, cctx := c.getClientAndContext(ctx) if protoClient == nil { @@ -749,10 +731,8 @@ func (c *client) GetRegionByID(ctx context.Context, regionID uint64, opts ...opt if serviceClient == nil { return nil, errs.ErrClientGetProtoClient } - resp, err := c.inner.regionMetaCircuitBreaker.Execute(func() (*pdpb.GetRegionResponse, cb.Overloading, error) { - resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegionByID(cctx, req) - return resp, isOverloaded(err), err - }) + + resp, err := pdpb.NewPDClient(serviceClient.GetClientConn()).GetRegionByID(cctx, req) if serviceClient.NeedRetry(resp.GetHeader().GetError(), err) { protoClient, cctx := c.getClientAndContext(ctx) if protoClient == nil { diff --git a/client/inner_client.go b/client/inner_client.go index 91f999dd3b5..404cbcf0b80 100644 --- a/client/inner_client.go +++ b/client/inner_client.go @@ -8,8 +8,6 @@ import ( "go.uber.org/zap" "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/pdpb" @@ -19,7 +17,6 @@ import ( "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/metrics" "github.com/tikv/pd/client/opt" - cb "github.com/tikv/pd/client/pkg/circuitbreaker" sd "github.com/tikv/pd/client/servicediscovery" ) @@ -29,11 +26,10 @@ const ( ) type innerClient struct { - keyspaceID uint32 - svrUrls []string - pdSvcDiscovery sd.ServiceDiscovery - tokenDispatcher *tokenDispatcher - regionMetaCircuitBreaker *cb.CircuitBreaker[*pdpb.GetRegionResponse] + keyspaceID uint32 + svrUrls []string + pdSvcDiscovery sd.ServiceDiscovery + tokenDispatcher *tokenDispatcher // For service mode switching. serviceModeKeeper @@ -59,7 +55,6 @@ func (c *innerClient) init(updateKeyspaceIDCb sd.UpdateKeyspaceIDFunc) error { } return err } - c.regionMetaCircuitBreaker = cb.NewCircuitBreaker[*pdpb.GetRegionResponse]("region_meta", c.option.RegionMetaCircuitBreakerSettings) return nil } @@ -252,12 +247,3 @@ func (c *innerClient) dispatchTSORequestWithRetry(ctx context.Context) tso.TSFut } return req } - -func isOverloaded(err error) cb.Overloading { - switch status.Code(errors.Cause(err)) { - case codes.DeadlineExceeded, codes.Unavailable, codes.ResourceExhausted: - return cb.Yes - default: - return cb.No - } -} diff --git a/client/opt/option.go b/client/opt/option.go index af95a225fab..2aa9be8ae7f 100644 --- a/client/opt/option.go +++ b/client/opt/option.go @@ -23,7 +23,6 @@ import ( "github.com/pingcap/errors" - cb "github.com/tikv/pd/client/pkg/circuitbreaker" "github.com/tikv/pd/client/pkg/retry" ) @@ -50,8 +49,6 @@ const ( EnableFollowerHandle // TSOClientRPCConcurrency controls the amount of ongoing TSO RPC requests at the same time in a single TSO client. TSOClientRPCConcurrency - // RegionMetadataCircuitBreakerSettings controls settings for circuit breaker for region metadata requests. - RegionMetadataCircuitBreakerSettings dynamicOptionCount ) @@ -72,18 +69,16 @@ type Option struct { // Dynamic options. dynamicOptions [dynamicOptionCount]atomic.Value - EnableTSOFollowerProxyCh chan struct{} - RegionMetaCircuitBreakerSettings cb.Settings + EnableTSOFollowerProxyCh chan struct{} } // NewOption creates a new PD client option with the default values set. func NewOption() *Option { co := &Option{ - Timeout: defaultPDTimeout, - MaxRetryTimes: maxInitClusterRetries, - EnableTSOFollowerProxyCh: make(chan struct{}, 1), - InitMetrics: true, - RegionMetaCircuitBreakerSettings: cb.AlwaysClosedSettings, + Timeout: defaultPDTimeout, + MaxRetryTimes: maxInitClusterRetries, + EnableTSOFollowerProxyCh: make(chan struct{}, 1), + InitMetrics: true, } co.dynamicOptions[MaxTSOBatchWaitInterval].Store(defaultMaxTSOBatchWaitInterval) @@ -154,11 +149,6 @@ func (o *Option) GetTSOClientRPCConcurrency() int { return o.dynamicOptions[TSOClientRPCConcurrency].Load().(int) } -// GetRegionMetadataCircuitBreakerSettings gets circuit breaker settings for PD region metadata calls. -func (o *Option) GetRegionMetadataCircuitBreakerSettings() cb.Settings { - return o.dynamicOptions[RegionMetadataCircuitBreakerSettings].Load().(cb.Settings) -} - // ClientOption configures client. type ClientOption func(*Option) @@ -213,13 +203,6 @@ func WithInitMetricsOption(initMetrics bool) ClientOption { } } -// WithRegionMetaCircuitBreaker configures the client with circuit breaker for region meta calls -func WithRegionMetaCircuitBreaker(config cb.Settings) ClientOption { - return func(op *Option) { - op.RegionMetaCircuitBreakerSettings = config - } -} - // WithBackoffer configures the client with backoffer. func WithBackoffer(bo *retry.Backoffer) ClientOption { return func(op *Option) { diff --git a/client/pkg/circuitbreaker/circuit_breaker.go b/client/pkg/circuitbreaker/circuit_breaker.go index 2c65f4f1965..0acee5d5c8d 100644 --- a/client/pkg/circuitbreaker/circuit_breaker.go +++ b/client/pkg/circuitbreaker/circuit_breaker.go @@ -14,6 +14,7 @@ package circuitbreaker import ( + "context" "fmt" "strings" "sync" @@ -62,12 +63,12 @@ var AlwaysClosedSettings = Settings{ } // CircuitBreaker is a state machine to prevent sending requests that are likely to fail. -type CircuitBreaker[T any] struct { +type CircuitBreaker struct { config *Settings name string mutex sync.Mutex - state *State[T] + state *State successCounter prometheus.Counter errorCounter prometheus.Counter @@ -102,8 +103,8 @@ func (s StateType) String() string { var replacer = strings.NewReplacer(" ", "_", "-", "_") // NewCircuitBreaker returns a new CircuitBreaker configured with the given Settings. -func NewCircuitBreaker[T any](name string, st Settings) *CircuitBreaker[T] { - cb := new(CircuitBreaker[T]) +func NewCircuitBreaker(name string, st Settings) *CircuitBreaker { + cb := new(CircuitBreaker) cb.name = name cb.config = &st cb.state = cb.newState(time.Now(), StateClosed) @@ -118,7 +119,7 @@ func NewCircuitBreaker[T any](name string, st Settings) *CircuitBreaker[T] { // ChangeSettings changes the CircuitBreaker settings. // The changes will be reflected only in the next evaluation window. -func (cb *CircuitBreaker[T]) ChangeSettings(apply func(config *Settings)) { +func (cb *CircuitBreaker) ChangeSettings(apply func(config *Settings)) { cb.mutex.Lock() defer cb.mutex.Unlock() @@ -129,12 +130,11 @@ func (cb *CircuitBreaker[T]) ChangeSettings(apply func(config *Settings)) { // Execute calls the given function if the CircuitBreaker is closed and returns the result of execution. // Execute returns an error instantly if the CircuitBreaker is open. // https://github.com/tikv/rfcs/blob/master/text/0115-circuit-breaker.md -func (cb *CircuitBreaker[T]) Execute(call func() (T, Overloading, error)) (T, error) { +func (cb *CircuitBreaker) Execute(call func() (Overloading, error)) error { state, err := cb.onRequest() if err != nil { cb.fastFailCounter.Inc() - var defaultValue T - return defaultValue, err + return err } defer func() { @@ -146,13 +146,13 @@ func (cb *CircuitBreaker[T]) Execute(call func() (T, Overloading, error)) (T, er } }() - result, overloaded, err := call() + overloaded, err := call() cb.emitMetric(overloaded, err) cb.onResult(state, overloaded) - return result, err + return err } -func (cb *CircuitBreaker[T]) onRequest() (*State[T], error) { +func (cb *CircuitBreaker) onRequest() (*State, error) { cb.mutex.Lock() defer cb.mutex.Unlock() @@ -161,7 +161,7 @@ func (cb *CircuitBreaker[T]) onRequest() (*State[T], error) { return state, err } -func (cb *CircuitBreaker[T]) onResult(state *State[T], overloaded Overloading) { +func (cb *CircuitBreaker) onResult(state *State, overloaded Overloading) { cb.mutex.Lock() defer cb.mutex.Unlock() @@ -170,7 +170,7 @@ func (cb *CircuitBreaker[T]) onResult(state *State[T], overloaded Overloading) { state.onResult(overloaded) } -func (cb *CircuitBreaker[T]) emitMetric(overloaded Overloading, err error) { +func (cb *CircuitBreaker) emitMetric(overloaded Overloading, err error) { switch overloaded { case No: cb.successCounter.Inc() @@ -185,9 +185,9 @@ func (cb *CircuitBreaker[T]) emitMetric(overloaded Overloading, err error) { } // State represents the state of CircuitBreaker. -type State[T any] struct { +type State struct { stateType StateType - cb *CircuitBreaker[T] + cb *CircuitBreaker end time.Time pendingCount uint32 @@ -196,7 +196,7 @@ type State[T any] struct { } // newState creates a new State with the given configuration and reset all success/failure counters. -func (cb *CircuitBreaker[T]) newState(now time.Time, stateType StateType) *State[T] { +func (cb *CircuitBreaker) newState(now time.Time, stateType StateType) *State { var end time.Time var pendingCount uint32 switch stateType { @@ -211,7 +211,7 @@ func (cb *CircuitBreaker[T]) newState(now time.Time, stateType StateType) *State default: panic("unknown state") } - return &State[T]{ + return &State{ cb: cb, stateType: stateType, pendingCount: pendingCount, @@ -227,7 +227,7 @@ func (cb *CircuitBreaker[T]) newState(now time.Time, stateType StateType) *State // Open state fails all request, it has a fixed duration of `Settings.CoolDownInterval` and always moves to HalfOpen state at the end of the interval. // HalfOpen state does not have a fixed duration and lasts till `Settings.HalfOpenSuccessCount` are evaluated. // If any of `Settings.HalfOpenSuccessCount` fails then it moves back to Open state, otherwise it moves to Closed state. -func (s *State[T]) onRequest(cb *CircuitBreaker[T]) (*State[T], error) { +func (s *State) onRequest(cb *CircuitBreaker) (*State, error) { var now = time.Now() switch s.stateType { case StateClosed: @@ -299,7 +299,7 @@ func (s *State[T]) onRequest(cb *CircuitBreaker[T]) (*State[T], error) { } } -func (s *State[T]) onResult(overloaded Overloading) { +func (s *State) onResult(overloaded Overloading) { switch overloaded { case No: s.successCount++ @@ -309,3 +309,25 @@ func (s *State[T]) onResult(overloaded Overloading) { panic("unknown state") } } + +// Define context key type +type cbCtxKey struct{} + +// Key used to store circuit breaker +var CircuitBreakerKey = cbCtxKey{} + +// FromContext retrieves the circuit breaker from the context +func FromContext(ctx context.Context) *CircuitBreaker { + if ctx == nil { + return nil + } + if cb, ok := ctx.Value(CircuitBreakerKey).(*CircuitBreaker); ok { + return cb + } + return nil +} + +// WithCircuitBreaker stores the circuit breaker into a new context +func WithCircuitBreaker(ctx context.Context, cb *CircuitBreaker) context.Context { + return context.WithValue(ctx, CircuitBreakerKey, cb) +} diff --git a/client/pkg/circuitbreaker/circuit_breaker_test.go b/client/pkg/circuitbreaker/circuit_breaker_test.go index 07a3c06f86e..e62e55c1ab8 100644 --- a/client/pkg/circuitbreaker/circuit_breaker_test.go +++ b/client/pkg/circuitbreaker/circuit_breaker_test.go @@ -24,7 +24,7 @@ import ( ) // advance emulate the state machine clock moves forward by the given duration -func (cb *CircuitBreaker[T]) advance(duration time.Duration) { +func (cb *CircuitBreaker) advance(duration time.Duration) { cb.state.end = cb.state.end.Add(-duration - 1) } @@ -40,26 +40,24 @@ var minCountToOpen = int(settings.MinQPSForOpen * uint32(settings.ErrorRateWindo func TestCircuitBreakerExecuteWrapperReturnValues(t *testing.T) { re := require.New(t) - cb := NewCircuitBreaker[int]("test_cb", settings) + cb := NewCircuitBreaker("test_cb", settings) originalError := errors.New("circuit breaker is open") - result, err := cb.Execute(func() (int, Overloading, error) { - return 42, No, originalError + err := cb.Execute(func() (Overloading, error) { + return No, originalError }) re.Equal(err, originalError) - re.Equal(42, result) // same by interpret the result as overloading error - result, err = cb.Execute(func() (int, Overloading, error) { - return 42, Yes, originalError + err = cb.Execute(func() (Overloading, error) { + return Yes, originalError }) re.Equal(err, originalError) - re.Equal(42, result) } func TestCircuitBreakerOpenState(t *testing.T) { re := require.New(t) - cb := NewCircuitBreaker[int]("test_cb", settings) + cb := NewCircuitBreaker("test_cb", settings) driveQPS(cb, minCountToOpen, Yes, re) re.Equal(StateClosed, cb.state.stateType) assertSucceeds(cb, re) // no error till ErrorRateWindow is finished @@ -70,7 +68,7 @@ func TestCircuitBreakerOpenState(t *testing.T) { func TestCircuitBreakerCloseStateNotEnoughQPS(t *testing.T) { re := require.New(t) - cb := NewCircuitBreaker[int]("test_cb", settings) + cb := NewCircuitBreaker("test_cb", settings) re.Equal(StateClosed, cb.state.stateType) driveQPS(cb, minCountToOpen/2, Yes, re) cb.advance(settings.ErrorRateWindow) @@ -80,7 +78,7 @@ func TestCircuitBreakerCloseStateNotEnoughQPS(t *testing.T) { func TestCircuitBreakerCloseStateNotEnoughErrorRate(t *testing.T) { re := require.New(t) - cb := NewCircuitBreaker[int]("test_cb", settings) + cb := NewCircuitBreaker("test_cb", settings) re.Equal(StateClosed, cb.state.stateType) driveQPS(cb, minCountToOpen/4, Yes, re) driveQPS(cb, minCountToOpen, No, re) @@ -91,7 +89,7 @@ func TestCircuitBreakerCloseStateNotEnoughErrorRate(t *testing.T) { func TestCircuitBreakerHalfOpenToClosed(t *testing.T) { re := require.New(t) - cb := NewCircuitBreaker[int]("test_cb", settings) + cb := NewCircuitBreaker("test_cb", settings) re.Equal(StateClosed, cb.state.stateType) driveQPS(cb, minCountToOpen, Yes, re) cb.advance(settings.ErrorRateWindow) @@ -109,7 +107,7 @@ func TestCircuitBreakerHalfOpenToClosed(t *testing.T) { func TestCircuitBreakerHalfOpenToOpen(t *testing.T) { re := require.New(t) - cb := NewCircuitBreaker[int]("test_cb", settings) + cb := NewCircuitBreaker("test_cb", settings) re.Equal(StateClosed, cb.state.stateType) driveQPS(cb, minCountToOpen, Yes, re) cb.advance(settings.ErrorRateWindow) @@ -118,8 +116,8 @@ func TestCircuitBreakerHalfOpenToOpen(t *testing.T) { cb.advance(settings.CoolDownInterval) assertSucceeds(cb, re) re.Equal(StateHalfOpen, cb.state.stateType) - _, err := cb.Execute(func() (int, Overloading, error) { - return 42, Yes, nil // this trip circuit breaker again + err := cb.Execute(func() (Overloading, error) { + return Yes, nil // this trip circuit breaker again }) re.NoError(err) re.Equal(StateHalfOpen, cb.state.stateType) @@ -149,10 +147,10 @@ func TestCircuitBreakerHalfOpenFailOverPendingCount(t *testing.T) { defer func() { end <- true }() - _, err := cb.Execute(func() (int, Overloading, error) { + err := cb.Execute(func() (Overloading, error) { start <- true <-wait - return 42, No, nil + return No, nil }) re.NoError(err) }() @@ -178,7 +176,7 @@ func TestCircuitBreakerHalfOpenFailOverPendingCount(t *testing.T) { func TestCircuitBreakerCountOnlyRequestsInSameWindow(t *testing.T) { re := require.New(t) - cb := NewCircuitBreaker[int]("test_cb", settings) + cb := NewCircuitBreaker("test_cb", settings) re.Equal(StateClosed, cb.state.stateType) start := make(chan bool) @@ -188,10 +186,10 @@ func TestCircuitBreakerCountOnlyRequestsInSameWindow(t *testing.T) { defer func() { end <- true }() - _, err := cb.Execute(func() (int, Overloading, error) { + err := cb.Execute(func() (Overloading, error) { start <- true <-wait - return 42, No, nil + return No, nil }) re.NoError(err) }() @@ -214,7 +212,7 @@ func TestCircuitBreakerCountOnlyRequestsInSameWindow(t *testing.T) { func TestCircuitBreakerChangeSettings(t *testing.T) { re := require.New(t) - cb := NewCircuitBreaker[int]("test_cb", AlwaysClosedSettings) + cb := NewCircuitBreaker("test_cb", AlwaysClosedSettings) driveQPS(cb, int(AlwaysClosedSettings.MinQPSForOpen*uint32(AlwaysClosedSettings.ErrorRateWindow.Seconds())), Yes, re) cb.advance(AlwaysClosedSettings.ErrorRateWindow) assertSucceeds(cb, re) @@ -231,8 +229,8 @@ func TestCircuitBreakerChangeSettings(t *testing.T) { re.Equal(StateOpen, cb.state.stateType) } -func newCircuitBreakerMovedToHalfOpenState(re *require.Assertions) *CircuitBreaker[int] { - cb := NewCircuitBreaker[int]("test_cb", settings) +func newCircuitBreakerMovedToHalfOpenState(re *require.Assertions) *CircuitBreaker { + cb := NewCircuitBreaker("test_cb", settings) re.Equal(StateClosed, cb.state.stateType) driveQPS(cb, minCountToOpen, Yes, re) cb.advance(settings.ErrorRateWindow) @@ -242,29 +240,28 @@ func newCircuitBreakerMovedToHalfOpenState(re *require.Assertions) *CircuitBreak return cb } -func driveQPS(cb *CircuitBreaker[int], count int, overload Overloading, re *require.Assertions) { +func driveQPS(cb *CircuitBreaker, count int, overload Overloading, re *require.Assertions) { for range count { - _, err := cb.Execute(func() (int, Overloading, error) { - return 42, overload, nil + err := cb.Execute(func() (Overloading, error) { + return overload, nil }) re.NoError(err) } } -func assertFastFail(cb *CircuitBreaker[int], re *require.Assertions) { +func assertFastFail(cb *CircuitBreaker, re *require.Assertions) { var executed = false - _, err := cb.Execute(func() (int, Overloading, error) { + err := cb.Execute(func() (Overloading, error) { executed = true - return 42, No, nil + return No, nil }) re.Equal(err, errs.ErrCircuitBreakerOpen) re.False(executed) } -func assertSucceeds(cb *CircuitBreaker[int], re *require.Assertions) { - result, err := cb.Execute(func() (int, Overloading, error) { - return 42, No, nil +func assertSucceeds(cb *CircuitBreaker, re *require.Assertions) { + err := cb.Execute(func() (Overloading, error) { + return No, nil }) re.NoError(err) - re.Equal(42, result) } diff --git a/client/pkg/utils/grpcutil/grpcutil.go b/client/pkg/utils/grpcutil/grpcutil.go index b73d117fe84..235e1088747 100644 --- a/client/pkg/utils/grpcutil/grpcutil.go +++ b/client/pkg/utils/grpcutil/grpcutil.go @@ -24,15 +24,18 @@ import ( "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/backoff" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/log" "github.com/tikv/pd/client/errs" + "github.com/tikv/pd/client/pkg/circuitbreaker" "github.com/tikv/pd/client/pkg/retry" ) @@ -71,6 +74,36 @@ func UnaryBackofferInterceptor() grpc.UnaryClientInterceptor { } } +// UnaryCircuitBreakerInterceptor is a gRPC interceptor that adds a circuit breaker to the call. +func UnaryCircuitBreakerInterceptor() grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + cb := circuitbreaker.FromContext(ctx) + if cb == nil { + return invoker(ctx, method, req, reply, cc, opts...) + } + err := cb.Execute(func() (circuitbreaker.Overloading, error) { + err := invoker(ctx, method, req, reply, cc, opts...) + failpoint.Inject("triggerCircuitBreaker", func() { + err = status.Error(codes.ResourceExhausted, "resource exhausted") + }) + return isOverloaded(err), err + }) + if err != nil { + return err + } + return nil + } +} + +func isOverloaded(err error) circuitbreaker.Overloading { + switch status.Code(errors.Cause(err)) { + case codes.DeadlineExceeded, codes.Unavailable, codes.ResourceExhausted: + return circuitbreaker.Yes + default: + return circuitbreaker.No + } +} + // GetClientConn returns a gRPC client connection. // creates a client connection to the given target. By default, it's // a non-blocking dial (the function won't wait for connections to be @@ -96,7 +129,10 @@ func GetClientConn(ctx context.Context, addr string, tlsCfg *tls.Config, do ...g } // Add backoffer interceptor - retryOpt := grpc.WithUnaryInterceptor(UnaryBackofferInterceptor()) + retryOpt := grpc.WithChainUnaryInterceptor(UnaryBackofferInterceptor()) + + // Add circuit breaker interceptor + cbOpt := grpc.WithChainUnaryInterceptor(UnaryCircuitBreakerInterceptor()) // Add retry related connection parameters backoffOpts := grpc.WithConnectParams(grpc.ConnectParams{ @@ -108,7 +144,7 @@ func GetClientConn(ctx context.Context, addr string, tlsCfg *tls.Config, do ...g }, }) - do = append(do, opt, retryOpt, backoffOpts) + do = append(do, opt, retryOpt, cbOpt, backoffOpts) cc, err := grpc.DialContext(ctx, u.Host, do...) if err != nil { return nil, errs.ErrGRPCDial.Wrap(err).GenWithStackByCause() diff --git a/tests/integrations/client/client_test.go b/tests/integrations/client/client_test.go index 2018860130e..397e1079af3 100644 --- a/tests/integrations/client/client_test.go +++ b/tests/integrations/client/client_test.go @@ -2096,28 +2096,30 @@ func TestCircuitBreaker(t *testing.T) { } endpoints := runServer(re, cluster) - cli := setupCli(ctx, re, endpoints, opt.WithRegionMetaCircuitBreaker(circuitBreakerSettings)) + cli := setupCli(ctx, re, endpoints) defer cli.Close() + circuitBreaker := cb.NewCircuitBreaker("region_meta", circuitBreakerSettings) + ctx = cb.WithCircuitBreaker(ctx, circuitBreaker) for range 10 { - region, err := cli.GetRegion(context.TODO(), []byte("a")) + region, err := cli.GetRegion(ctx, []byte("a")) re.NoError(err) re.NotNil(region) } - re.NoError(failpoint.Enable("github.com/tikv/pd/client/triggerCircuitBreaker", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/pkg/utils/grpcutil/triggerCircuitBreaker", "return(true)")) for range 100 { - _, err := cli.GetRegion(context.TODO(), []byte("a")) + _, err := cli.GetRegion(ctx, []byte("a")) re.Error(err) } - _, err = cli.GetRegion(context.TODO(), []byte("a")) + _, err = cli.GetRegion(ctx, []byte("a")) re.Error(err) re.Contains(err.Error(), "circuit breaker is open") - re.NoError(failpoint.Disable("github.com/tikv/pd/client/triggerCircuitBreaker")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/pkg/utils/grpcutil/triggerCircuitBreaker")) - _, err = cli.GetRegion(context.TODO(), []byte("a")) + _, err = cli.GetRegion(ctx, []byte("a")) re.Error(err) re.Contains(err.Error(), "circuit breaker is open") @@ -2125,7 +2127,7 @@ func TestCircuitBreaker(t *testing.T) { time.Sleep(time.Second) for range 10 { - region, err := cli.GetRegion(context.TODO(), []byte("a")) + region, err := cli.GetRegion(ctx, []byte("a")) re.NoError(err) re.NotNil(region) } @@ -2149,34 +2151,35 @@ func TestCircuitBreakerOpenAndChangeSettings(t *testing.T) { } endpoints := runServer(re, cluster) - cli := setupCli(ctx, re, endpoints, opt.WithRegionMetaCircuitBreaker(circuitBreakerSettings)) + cli := setupCli(ctx, re, endpoints) defer cli.Close() + circuitBreaker := cb.NewCircuitBreaker("region_meta", circuitBreakerSettings) + ctx = cb.WithCircuitBreaker(ctx, circuitBreaker) for range 10 { - region, err := cli.GetRegion(context.TODO(), []byte("a")) + region, err := cli.GetRegion(ctx, []byte("a")) re.NoError(err) re.NotNil(region) } - re.NoError(failpoint.Enable("github.com/tikv/pd/client/triggerCircuitBreaker", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/pkg/utils/grpcutil/triggerCircuitBreaker", "return(true)")) for range 100 { - _, err := cli.GetRegion(context.TODO(), []byte("a")) + _, err := cli.GetRegion(ctx, []byte("a")) re.Error(err) } - _, err = cli.GetRegion(context.TODO(), []byte("a")) + _, err = cli.GetRegion(ctx, []byte("a")) re.Error(err) re.Contains(err.Error(), "circuit breaker is open") - cli.UpdateOption(opt.RegionMetadataCircuitBreakerSettings, func(config *cb.Settings) { + circuitBreaker.ChangeSettings(func(config *cb.Settings) { *config = cb.AlwaysClosedSettings }) - - _, err = cli.GetRegion(context.TODO(), []byte("a")) + _, err = cli.GetRegion(ctx, []byte("a")) re.Error(err) re.Contains(err.Error(), "ResourceExhausted") - re.NoError(failpoint.Disable("github.com/tikv/pd/client/triggerCircuitBreaker")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/pkg/utils/grpcutil/triggerCircuitBreaker")) } func TestCircuitBreakerHalfOpenAndChangeSettings(t *testing.T) { @@ -2197,23 +2200,26 @@ func TestCircuitBreakerHalfOpenAndChangeSettings(t *testing.T) { } endpoints := runServer(re, cluster) - cli := setupCli(ctx, re, endpoints, opt.WithRegionMetaCircuitBreaker(circuitBreakerSettings)) + + cli := setupCli(ctx, re, endpoints) defer cli.Close() + circuitBreaker := cb.NewCircuitBreaker("region_meta", circuitBreakerSettings) + ctx = cb.WithCircuitBreaker(ctx, circuitBreaker) for range 10 { - region, err := cli.GetRegion(context.TODO(), []byte("a")) + region, err := cli.GetRegion(ctx, []byte("a")) re.NoError(err) re.NotNil(region) } - re.NoError(failpoint.Enable("github.com/tikv/pd/client/triggerCircuitBreaker", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/pkg/utils/grpcutil/triggerCircuitBreaker", "return(true)")) for range 100 { - _, err := cli.GetRegion(context.TODO(), []byte("a")) + _, err := cli.GetRegion(ctx, []byte("a")) re.Error(err) } - _, err = cli.GetRegion(context.TODO(), []byte("a")) + _, err = cli.GetRegion(ctx, []byte("a")) re.Error(err) re.Contains(err.Error(), "circuit breaker is open") @@ -2221,9 +2227,9 @@ func TestCircuitBreakerHalfOpenAndChangeSettings(t *testing.T) { defer os.RemoveAll(fname) // wait for cooldown time.Sleep(time.Second) - re.NoError(failpoint.Disable("github.com/tikv/pd/client/triggerCircuitBreaker")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/pkg/utils/grpcutil/triggerCircuitBreaker")) // trigger circuit breaker state to be half open - _, err = cli.GetRegion(context.TODO(), []byte("a")) + _, err = cli.GetRegion(ctx, []byte("a")) re.NoError(err) testutil.Eventually(re, func() bool { b, _ := os.ReadFile(fname) @@ -2233,17 +2239,16 @@ func TestCircuitBreakerHalfOpenAndChangeSettings(t *testing.T) { }) // The state is half open - re.NoError(failpoint.Enable("github.com/tikv/pd/client/triggerCircuitBreaker", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/client/pkg/utils/grpcutil/triggerCircuitBreaker", "return(true)")) // change settings to always closed - cli.UpdateOption(opt.RegionMetadataCircuitBreakerSettings, func(config *cb.Settings) { + circuitBreaker.ChangeSettings(func(config *cb.Settings) { *config = cb.AlwaysClosedSettings }) - // It won't be changed to open state. for range 100 { - _, err := cli.GetRegion(context.TODO(), []byte("a")) + _, err := cli.GetRegion(ctx, []byte("a")) re.Error(err) re.NotContains(err.Error(), "circuit breaker is open") } - re.NoError(failpoint.Disable("github.com/tikv/pd/client/triggerCircuitBreaker")) + re.NoError(failpoint.Disable("github.com/tikv/pd/client/pkg/utils/grpcutil/triggerCircuitBreaker")) }