From 7e4058c977e4c59ef5bafc76b7ffe8107c567815 Mon Sep 17 00:00:00 2001 From: okJiang <819421878@qq.com> Date: Mon, 25 Nov 2024 17:14:05 +0800 Subject: [PATCH] unify rate limit middleware Signed-off-by: okJiang <819421878@qq.com> --- server/gc_service.go | 64 +++- server/grpc_service.go | 769 ++++++++++++++++++++--------------------- server/middleware.go | 31 +- 3 files changed, 447 insertions(+), 417 deletions(-) diff --git a/server/gc_service.go b/server/gc_service.go index 227af2c5ec8..c88a0395db6 100644 --- a/server/gc_service.go +++ b/server/gc_service.go @@ -35,10 +35,20 @@ import ( // GetGCSafePointV2 return gc safe point for the given keyspace. func (s *GrpcServer) GetGCSafePointV2(ctx context.Context, request *pdpb.GetGCSafePointV2Request) (*pdpb.GetGCSafePointV2Response, error) { - if rsp, err := s.unaryMiddleware(ctx, request, "GetGCSafePointV2"); err != nil { + if midResp, err := s.unaryMiddleware(ctx, request, "GetGCSafePointV2"); err != nil { return nil, err - } else if rsp != nil { - return rsp.(*pdpb.GetGCSafePointV2Response), err + } else if midResp != nil { + if midResp.header != nil { + return &pdpb.GetGCSafePointV2Response{ + Header: midResp.header, + }, nil + } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.GetGCSafePointV2Response), nil + } } safePoint, err := s.safePointV2Manager.LoadGCSafePoint(request.GetKeyspaceId()) @@ -57,10 +67,20 @@ func (s *GrpcServer) GetGCSafePointV2(ctx context.Context, request *pdpb.GetGCSa // UpdateGCSafePointV2 update gc safe point for the given keyspace. func (s *GrpcServer) UpdateGCSafePointV2(ctx context.Context, request *pdpb.UpdateGCSafePointV2Request) (*pdpb.UpdateGCSafePointV2Response, error) { - if rsp, err := s.unaryMiddleware(ctx, request, "UpdateGCSafePointV2"); err != nil { + if midResp, err := s.unaryMiddleware(ctx, request, "UpdateGCSafePointV2"); err != nil { return nil, err - } else if rsp != nil { - return rsp.(*pdpb.UpdateGCSafePointV2Response), err + } else if midResp != nil { + if midResp.header != nil { + return &pdpb.UpdateGCSafePointV2Response{ + Header: midResp.header, + }, nil + } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.UpdateGCSafePointV2Response), nil + } } newSafePoint := request.GetSafePoint() @@ -91,10 +111,20 @@ func (s *GrpcServer) UpdateGCSafePointV2(ctx context.Context, request *pdpb.Upda // UpdateServiceSafePointV2 update service safe point for the given keyspace. func (s *GrpcServer) UpdateServiceSafePointV2(ctx context.Context, request *pdpb.UpdateServiceSafePointV2Request) (*pdpb.UpdateServiceSafePointV2Response, error) { - if rsp, err := s.unaryMiddleware(ctx, request, "UpdateServiceSafePointV2"); err != nil { + if midResp, err := s.unaryMiddleware(ctx, request, "UpdateServiceSafePointV2"); err != nil { return nil, err - } else if rsp != nil { - return rsp.(*pdpb.UpdateServiceSafePointV2Response), err + } else if midResp != nil { + if midResp.header != nil { + return &pdpb.UpdateServiceSafePointV2Response{ + Header: midResp.header, + }, nil + } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.UpdateServiceSafePointV2Response), nil + } } nowTSO, err := s.getGlobalTSO(ctx) @@ -185,10 +215,20 @@ func (s *GrpcServer) WatchGCSafePointV2(request *pdpb.WatchGCSafePointV2Request, // GetAllGCSafePointV2 return all gc safe point v2. func (s *GrpcServer) GetAllGCSafePointV2(ctx context.Context, request *pdpb.GetAllGCSafePointV2Request) (*pdpb.GetAllGCSafePointV2Response, error) { - if rsp, err := s.unaryMiddleware(ctx, request, "GetAllGCSafePointV2"); err != nil { + if midResp, err := s.unaryMiddleware(ctx, request, "GetAllGCSafePointV2"); err != nil { return nil, err - } else if rsp != nil { - return rsp.(*pdpb.GetAllGCSafePointV2Response), err + } else if midResp != nil { + if midResp.header != nil { + return &pdpb.GetAllGCSafePointV2Response{ + Header: midResp.header, + }, nil + } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.GetAllGCSafePointV2Response), nil + } } startkey := keypath.GCSafePointV2Prefix() diff --git a/server/grpc_service.go b/server/grpc_service.go index 34778ac6fe7..d73fa8334c9 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -263,22 +263,20 @@ func (s *GrpcServer) GetClusterInfo(context.Context, *pdpb.GetClusterInfoRequest func (s *GrpcServer) GetMinTS( ctx context.Context, request *pdpb.GetMinTSRequest, ) (*pdpb.GetMinTSResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "GetMinTS"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.GetMinTSResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - - if rsp, err := s.unaryMiddleware(ctx, request, "GetMinTS"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.GetMinTSResponse), nil + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.GetMinTSResponse), nil + } } var ( @@ -600,21 +598,20 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { // Bootstrap implements gRPC PDServer. func (s *GrpcServer) Bootstrap(ctx context.Context, request *pdpb.BootstrapRequest) (*pdpb.BootstrapResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "Bootstrap"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.BootstrapResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - if rsp, err := s.unaryMiddleware(ctx, request, "Bootstrap"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.BootstrapResponse), nil + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.BootstrapResponse), nil + } } rc := s.GetRaftCluster() @@ -641,22 +638,20 @@ func (s *GrpcServer) Bootstrap(ctx context.Context, request *pdpb.BootstrapReque // IsBootstrapped implements gRPC PDServer. func (s *GrpcServer) IsBootstrapped(ctx context.Context, request *pdpb.IsBootstrappedRequest) (*pdpb.IsBootstrappedResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "IsBootstrapped"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.IsBootstrappedResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - - if rsp, err := s.unaryMiddleware(ctx, request, "IsBootstrapped"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.IsBootstrappedResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.IsBootstrappedResponse), nil + } } rc := s.GetRaftCluster() @@ -668,22 +663,20 @@ func (s *GrpcServer) IsBootstrapped(ctx context.Context, request *pdpb.IsBootstr // AllocID implements gRPC PDServer. func (s *GrpcServer) AllocID(ctx context.Context, request *pdpb.AllocIDRequest) (*pdpb.AllocIDResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "AllocID"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.AllocIDResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - - if rsp, err := s.unaryMiddleware(ctx, request, "AllocID"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.AllocIDResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.AllocIDResponse), nil + } } // We can use an allocator for all types ID allocation. @@ -701,18 +694,23 @@ func (s *GrpcServer) AllocID(ctx context.Context, request *pdpb.AllocIDRequest) } // IsSnapshotRecovering implements gRPC PDServer. -func (s *GrpcServer) IsSnapshotRecovering(ctx context.Context, _ *pdpb.IsSnapshotRecoveringRequest) (*pdpb.IsSnapshotRecoveringResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { +func (s *GrpcServer) IsSnapshotRecovering(ctx context.Context, req *pdpb.IsSnapshotRecoveringRequest) (*pdpb.IsSnapshotRecoveringResponse, error) { + if midResp, err := s.unaryMiddleware(ctx, req, "IsSnapshotRecovering"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.IsSnapshotRecoveringResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.IsSnapshotRecoveringResponse), nil + } } + // recovering mark is stored in etcd directly, there's no need to forward. marked, err := s.Server.IsSnapshotRecovering(ctx) if err != nil { @@ -728,22 +726,22 @@ func (s *GrpcServer) IsSnapshotRecovering(ctx context.Context, _ *pdpb.IsSnapsho // GetStore implements gRPC PDServer. func (s *GrpcServer) GetStore(ctx context.Context, request *pdpb.GetStoreRequest) (*pdpb.GetStoreResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "GetStore"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.GetStoreResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.GetStoreResponse), nil + } } - if rsp, err := s.unaryMiddleware(ctx, request, "GetStore"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.GetStoreResponse), err - } + rc := s.GetRaftCluster() if rc == nil { return &pdpb.GetStoreResponse{Header: notBootstrappedHeader()}, nil @@ -781,22 +779,20 @@ func checkStore(rc *cluster.RaftCluster, storeID uint64) *pdpb.Error { // PutStore implements gRPC PDServer. func (s *GrpcServer) PutStore(ctx context.Context, request *pdpb.PutStoreRequest) (*pdpb.PutStoreResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "PutStore"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.PutStoreResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - - if rsp, err := s.unaryMiddleware(ctx, request, "PutStore"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.PutStoreResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.PutStoreResponse), nil + } } rc := s.GetRaftCluster() @@ -836,21 +832,20 @@ func (s *GrpcServer) PutStore(ctx context.Context, request *pdpb.PutStoreRequest // GetAllStores implements gRPC PDServer. func (s *GrpcServer) GetAllStores(ctx context.Context, request *pdpb.GetAllStoresRequest) (*pdpb.GetAllStoresResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "GetAllStores"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.GetAllStoresResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - if rsp, err := s.unaryMiddleware(ctx, request, "GetAllStores"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.GetAllStoresResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.GetAllStoresResponse), nil + } } rc := s.GetRaftCluster() @@ -878,21 +873,20 @@ func (s *GrpcServer) GetAllStores(ctx context.Context, request *pdpb.GetAllStore // StoreHeartbeat implements gRPC PDServer. func (s *GrpcServer) StoreHeartbeat(ctx context.Context, request *pdpb.StoreHeartbeatRequest) (*pdpb.StoreHeartbeatResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "StoreHearbeat"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.StoreHeartbeatResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), + Header: midResp.header, }, nil } - } - if rsp, err := s.unaryMiddleware(ctx, request, "StoreHearbeat"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.StoreHeartbeatResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.StoreHeartbeatResponse), nil + } } if request.GetStats() == nil { @@ -1377,26 +1371,25 @@ func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error // GetRegion implements gRPC PDServer. func (s *GrpcServer) GetRegion(ctx context.Context, request *pdpb.GetRegionRequest) (*pdpb.GetRegionResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "GetRegion"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.GetRegionResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.GetRegionResponse), nil + } } - if rsp, err := s.unaryMiddleware(ctx, request, "GetRegion"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.GetRegionResponse), nil - } failpoint.Inject("delayProcess", nil) var ( - rc = s.GetRaftCluster() + rc *cluster.RaftCluster followerHandle = !s.member.IsLeader() region *core.RegionInfo ) @@ -1404,6 +1397,7 @@ func (s *GrpcServer) GetRegion(ctx context.Context, request *pdpb.GetRegionReque return &pdpb.GetRegionResponse{Header: notBootstrappedHeader()}, nil } if followerHandle { + rc = s.cluster if !rc.GetRegionSyncer().IsRunning() { return &pdpb.GetRegionResponse{Header: regionNotFound()}, nil } @@ -1413,6 +1407,10 @@ func (s *GrpcServer) GetRegion(ctx context.Context, request *pdpb.GetRegionReque return &pdpb.GetRegionResponse{Header: regionNotFound()}, nil } } else { + rc = s.GetRaftCluster() + if rc == nil { + return &pdpb.GetRegionResponse{Header: notBootstrappedHeader()}, nil + } region = rc.GetRegionByKey(request.GetRegionKey()) if region == nil { log.Warn("leader get region nil", zap.String("key", string(request.GetRegionKey()))) @@ -1437,27 +1435,27 @@ func (s *GrpcServer) GetRegion(ctx context.Context, request *pdpb.GetRegionReque // GetPrevRegion implements gRPC PDServer func (s *GrpcServer) GetPrevRegion(ctx context.Context, request *pdpb.GetRegionRequest) (*pdpb.GetRegionResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "GetPrevRegion"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.GetRegionResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.GetRegionResponse), nil + } } - followerHandle := new(bool) - if rsp, err := s.unaryMiddleware(ctx, request, "GetPrevRegion"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.GetRegionResponse), err - } - - var rc *cluster.RaftCluster - if *followerHandle { + var ( + rc *cluster.RaftCluster + followerHandle = !s.member.IsLeader() + ) + if followerHandle { // no need to check running status rc = s.cluster if !rc.GetRegionSyncer().IsRunning() { @@ -1472,14 +1470,14 @@ func (s *GrpcServer) GetPrevRegion(ctx context.Context, request *pdpb.GetRegionR region := rc.GetPrevRegionByKey(request.GetRegionKey()) if region == nil { - if *followerHandle { + if followerHandle { return &pdpb.GetRegionResponse{Header: regionNotFound()}, nil } return &pdpb.GetRegionResponse{Header: wrapHeader()}, nil } var buckets *metapb.Buckets // FIXME: If the bucket is disabled dynamically, the bucket information is returned unexpectedly - if !*followerHandle && rc.GetStoreConfig().IsEnableRegionBucket() && request.GetNeedBuckets() { + if !followerHandle && rc.GetStoreConfig().IsEnableRegionBucket() && request.GetNeedBuckets() { buckets = region.GetBuckets() } return &pdpb.GetRegionResponse{ @@ -1494,27 +1492,27 @@ func (s *GrpcServer) GetPrevRegion(ctx context.Context, request *pdpb.GetRegionR // GetRegionByID implements gRPC PDServer. func (s *GrpcServer) GetRegionByID(ctx context.Context, request *pdpb.GetRegionByIDRequest) (*pdpb.GetRegionResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "GetRegionByID"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.GetRegionResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.GetRegionResponse), nil + } } - followerHandle := new(bool) - if rsp, err := s.unaryMiddleware(ctx, request, "GetRegionByID"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.GetRegionResponse), err - } - - var rc *cluster.RaftCluster - if *followerHandle { + var ( + rc *cluster.RaftCluster + followerHandle = !s.member.IsLeader() + ) + if followerHandle { rc = s.cluster if !rc.GetRegionSyncer().IsRunning() { return &pdpb.GetRegionResponse{Header: regionNotFound()}, nil @@ -1527,18 +1525,18 @@ func (s *GrpcServer) GetRegionByID(ctx context.Context, request *pdpb.GetRegionB } region := rc.GetRegion(request.GetRegionId()) failpoint.Inject("followerHandleError", func() { - if *followerHandle { + if followerHandle { region = nil } }) if region == nil { - if *followerHandle { + if followerHandle { return &pdpb.GetRegionResponse{Header: regionNotFound()}, nil } return &pdpb.GetRegionResponse{Header: wrapHeader()}, nil } var buckets *metapb.Buckets - if !*followerHandle && rc.GetStoreConfig().IsEnableRegionBucket() && request.GetNeedBuckets() { + if !followerHandle && rc.GetStoreConfig().IsEnableRegionBucket() && request.GetNeedBuckets() { buckets = region.GetBuckets() } return &pdpb.GetRegionResponse{ @@ -1554,27 +1552,27 @@ func (s *GrpcServer) GetRegionByID(ctx context.Context, request *pdpb.GetRegionB // Deprecated: use BatchScanRegions instead. // ScanRegions implements gRPC PDServer. func (s *GrpcServer) ScanRegions(ctx context.Context, request *pdpb.ScanRegionsRequest) (*pdpb.ScanRegionsResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "ScanRegions"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.ScanRegionsResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.ScanRegionsResponse), nil + } } - followerHandle := new(bool) - if rsp, err := s.unaryMiddleware(ctx, request, "ScanRegions"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.ScanRegionsResponse), nil - } - - var rc *cluster.RaftCluster - if *followerHandle { + var ( + rc *cluster.RaftCluster + followerHandle = !s.member.IsLeader() + ) + if followerHandle { rc = s.cluster if !rc.GetRegionSyncer().IsRunning() { return &pdpb.ScanRegionsResponse{Header: regionNotFound()}, nil @@ -1586,7 +1584,7 @@ func (s *GrpcServer) ScanRegions(ctx context.Context, request *pdpb.ScanRegionsR } } regions := rc.ScanRegions(request.GetStartKey(), request.GetEndKey(), int(request.GetLimit())) - if *followerHandle && len(regions) == 0 { + if followerHandle && len(regions) == 0 { return &pdpb.ScanRegionsResponse{Header: regionNotFound()}, nil } resp := &pdpb.ScanRegionsResponse{Header: wrapHeader()} @@ -1610,27 +1608,27 @@ func (s *GrpcServer) ScanRegions(ctx context.Context, request *pdpb.ScanRegionsR // BatchScanRegions implements gRPC PDServer. func (s *GrpcServer) BatchScanRegions(ctx context.Context, request *pdpb.BatchScanRegionsRequest) (*pdpb.BatchScanRegionsResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "BatchScanRegions"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.BatchScanRegionsResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.BatchScanRegionsResponse), nil + } } - followerHandle := new(bool) - if rsp, err := s.unaryMiddleware(ctx, request, "BatchScanRegions"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.BatchScanRegionsResponse), nil - } - - var rc *cluster.RaftCluster - if *followerHandle { + var ( + rc *cluster.RaftCluster + followerHandle = !s.member.IsLeader() + ) + if followerHandle { rc = s.cluster if !rc.GetRegionSyncer().IsRunning() { return &pdpb.BatchScanRegionsResponse{Header: regionNotFound()}, nil @@ -1641,7 +1639,7 @@ func (s *GrpcServer) BatchScanRegions(ctx context.Context, request *pdpb.BatchSc return &pdpb.BatchScanRegionsResponse{Header: notBootstrappedHeader()}, nil } } - needBucket := request.GetNeedBuckets() && !*followerHandle && rc.GetStoreConfig().IsEnableRegionBucket() + needBucket := request.GetNeedBuckets() && !followerHandle && rc.GetStoreConfig().IsEnableRegionBucket() limit := request.GetLimit() // cast to core.KeyRanges and check the validation. keyRanges := core.NewKeyRangesWithSize(len(request.GetRanges())) @@ -1691,7 +1689,7 @@ func (s *GrpcServer) BatchScanRegions(ctx context.Context, request *pdpb.BatchSc Buckets: buckets, }) } - if *followerHandle && len(regions) == 0 { + if followerHandle && len(regions) == 0 { return &pdpb.BatchScanRegionsResponse{Header: regionNotFound()}, nil } resp := &pdpb.BatchScanRegionsResponse{Header: wrapHeader(), Regions: regions} @@ -1700,21 +1698,20 @@ func (s *GrpcServer) BatchScanRegions(ctx context.Context, request *pdpb.BatchSc // AskSplit implements gRPC PDServer. func (s *GrpcServer) AskSplit(ctx context.Context, request *pdpb.AskSplitRequest) (*pdpb.AskSplitResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "AskSplit"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.AskSplitResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - if rsp, err := s.unaryMiddleware(ctx, request, "AskSplit"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.AskSplitResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.AskSplitResponse), nil + } } rc := s.GetRaftCluster() @@ -1743,16 +1740,20 @@ func (s *GrpcServer) AskSplit(ctx context.Context, request *pdpb.AskSplitRequest // AskBatchSplit implements gRPC PDServer. func (s *GrpcServer) AskBatchSplit(ctx context.Context, request *pdpb.AskBatchSplitRequest) (*pdpb.AskBatchSplitResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "AskBatchSplit"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.AskBatchSplitResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.AskBatchSplitResponse), nil + } } rc := s.GetRaftCluster() @@ -1787,12 +1788,6 @@ func (s *GrpcServer) AskBatchSplit(ctx context.Context, request *pdpb.AskBatchSp } } - if rsp, err := s.unaryMiddleware(ctx, request, "AskBatchSplit"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.AskBatchSplitResponse), err - } - if !versioninfo.IsFeatureSupported(rc.GetOpts().GetClusterVersion(), versioninfo.BatchSplit) { return &pdpb.AskBatchSplitResponse{Header: s.incompatibleVersion("batch_split")}, nil } @@ -1817,22 +1812,20 @@ func (s *GrpcServer) AskBatchSplit(ctx context.Context, request *pdpb.AskBatchSp // ReportSplit implements gRPC PDServer. func (s *GrpcServer) ReportSplit(ctx context.Context, request *pdpb.ReportSplitRequest) (*pdpb.ReportSplitResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "ReportSplit"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.ReportSplitResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - - if rsp, err := s.unaryMiddleware(ctx, request, "ReportSplit"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.ReportSplitResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.ReportSplitResponse), nil + } } rc := s.GetRaftCluster() @@ -1853,22 +1846,20 @@ func (s *GrpcServer) ReportSplit(ctx context.Context, request *pdpb.ReportSplitR // ReportBatchSplit implements gRPC PDServer. func (s *GrpcServer) ReportBatchSplit(ctx context.Context, request *pdpb.ReportBatchSplitRequest) (*pdpb.ReportBatchSplitResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "ReportBatchSplit"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.ReportBatchSplitResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - - if rsp, err := s.unaryMiddleware(ctx, request, "ReportBatchSplit"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.ReportBatchSplitResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.ReportBatchSplitResponse), nil + } } rc := s.GetRaftCluster() @@ -1890,22 +1881,20 @@ func (s *GrpcServer) ReportBatchSplit(ctx context.Context, request *pdpb.ReportB // GetClusterConfig implements gRPC PDServer. func (s *GrpcServer) GetClusterConfig(ctx context.Context, request *pdpb.GetClusterConfigRequest) (*pdpb.GetClusterConfigResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "GetClusterConfig"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.GetClusterConfigResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - - if rsp, err := s.unaryMiddleware(ctx, request, "GetClusterConfig"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.GetClusterConfigResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.GetClusterConfigResponse), nil + } } rc := s.GetRaftCluster() @@ -1920,22 +1909,20 @@ func (s *GrpcServer) GetClusterConfig(ctx context.Context, request *pdpb.GetClus // PutClusterConfig implements gRPC PDServer. func (s *GrpcServer) PutClusterConfig(ctx context.Context, request *pdpb.PutClusterConfigRequest) (*pdpb.PutClusterConfigResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "PutClusterConfig"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.PutClusterConfigResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - - if rsp, err := s.unaryMiddleware(ctx, request, "PutClusterConfig"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.PutClusterConfigResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.PutClusterConfigResponse), nil + } } rc := s.GetRaftCluster() @@ -1959,16 +1946,20 @@ func (s *GrpcServer) PutClusterConfig(ctx context.Context, request *pdpb.PutClus // ScatterRegion implements gRPC PDServer. func (s *GrpcServer) ScatterRegion(ctx context.Context, request *pdpb.ScatterRegionRequest) (*pdpb.ScatterRegionResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "ScatterRegion"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.ScatterRegionResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.ScatterRegionResponse), nil + } } rc := s.GetRaftCluster() @@ -2019,12 +2010,6 @@ func (s *GrpcServer) ScatterRegion(ctx context.Context, request *pdpb.ScatterReg } } - if rsp, err := s.unaryMiddleware(ctx, request, "ScatterRegion"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.ScatterRegionResponse), err - } - if len(request.GetRegionsId()) > 0 { percentage, err := scatterRegions(rc, request.GetRegionsId(), request.GetGroup(), int(request.GetRetryLimit()), request.GetSkipStoreLimit()) if err != nil { @@ -2070,24 +2055,21 @@ func (s *GrpcServer) ScatterRegion(ctx context.Context, request *pdpb.ScatterReg // GetGCSafePoint implements gRPC PDServer. func (s *GrpcServer) GetGCSafePoint(ctx context.Context, request *pdpb.GetGCSafePointRequest) (*pdpb.GetGCSafePointResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "GetGCSafePoint"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.GetGCSafePointResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.GetGCSafePointResponse), nil + } } - - if rsp, err := s.unaryMiddleware(ctx, request, "GetGCSafePoint"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.GetGCSafePointResponse), err - } - rc := s.GetRaftCluster() if rc == nil { return &pdpb.GetGCSafePointResponse{Header: notBootstrappedHeader()}, nil @@ -2127,24 +2109,21 @@ func (s *GrpcServer) SyncRegions(stream pdpb.PD_SyncRegionsServer) error { // UpdateGCSafePoint implements gRPC PDServer. func (s *GrpcServer) UpdateGCSafePoint(ctx context.Context, request *pdpb.UpdateGCSafePointRequest) (*pdpb.UpdateGCSafePointResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "UpdateGCSafePoint"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.UpdateGCSafePointResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.UpdateGCSafePointResponse), nil + } } - - if rsp, err := s.unaryMiddleware(ctx, request, "UpdateGCSafePoint"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.UpdateGCSafePointResponse), err - } - rc := s.GetRaftCluster() if rc == nil { return &pdpb.UpdateGCSafePointResponse{Header: notBootstrappedHeader()}, nil @@ -2174,22 +2153,20 @@ func (s *GrpcServer) UpdateGCSafePoint(ctx context.Context, request *pdpb.Update // UpdateServiceGCSafePoint update the safepoint for specific service func (s *GrpcServer) UpdateServiceGCSafePoint(ctx context.Context, request *pdpb.UpdateServiceGCSafePointRequest) (*pdpb.UpdateServiceGCSafePointResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "UpdateServiceGCSafePoint"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.UpdateServiceGCSafePointResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - - if rsp, err := s.unaryMiddleware(ctx, request, "UpdateServiceGCSafePoint"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.UpdateServiceGCSafePointResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.UpdateServiceGCSafePointResponse), nil + } } rc := s.GetRaftCluster() @@ -2228,16 +2205,20 @@ func (s *GrpcServer) UpdateServiceGCSafePoint(ctx context.Context, request *pdpb // GetOperator gets information about the operator belonging to the specify region. func (s *GrpcServer) GetOperator(ctx context.Context, request *pdpb.GetOperatorRequest) (*pdpb.GetOperatorResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "GetOperator"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.GetOperatorResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.GetOperatorResponse), nil + } } rc := s.GetRaftCluster() @@ -2272,12 +2253,6 @@ func (s *GrpcServer) GetOperator(ctx context.Context, request *pdpb.GetOperatorR } } - if rsp, err := s.unaryMiddleware(ctx, request, "GetOperator"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.GetOperatorResponse), err - } - opController := rc.GetOperatorController() requestID := request.GetRegionId() r := opController.GetOperatorStatus(requestID) @@ -2546,16 +2521,20 @@ func (s *GrpcServer) SyncMaxTS(_ context.Context, request *pdpb.SyncMaxTSRequest // SplitRegions split regions by the given split keys func (s *GrpcServer) SplitRegions(ctx context.Context, request *pdpb.SplitRegionsRequest) (*pdpb.SplitRegionsResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "SplitRegions"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.SplitRegionsResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.SplitRegionsResponse), nil + } } rc := s.GetRaftCluster() @@ -2591,12 +2570,6 @@ func (s *GrpcServer) SplitRegions(ctx context.Context, request *pdpb.SplitRegion } } - if rsp, err := s.unaryMiddleware(ctx, request, "SplitRegions"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.SplitRegionsResponse), err - } - finishedPercentage, newRegionIDs := rc.GetRegionSplitter().SplitRegions(ctx, request.GetSplitKeys(), int(request.GetRetryLimit())) return &pdpb.SplitRegionsResponse{ Header: wrapHeader(), @@ -2609,22 +2582,20 @@ func (s *GrpcServer) SplitRegions(ctx context.Context, request *pdpb.SplitRegion // Only regions which split successfully will be scattered. // scatterFinishedPercentage indicates the percentage of successfully split regions that are scattered. func (s *GrpcServer) SplitAndScatterRegions(ctx context.Context, request *pdpb.SplitAndScatterRegionsRequest) (*pdpb.SplitAndScatterRegionsResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "SplitAndScatterRegions"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.SplitAndScatterRegionsResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - - if rsp, err := s.unaryMiddleware(ctx, request, "SplitAndScatterRegions"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.SplitAndScatterRegionsResponse), err + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.SplitAndScatterRegionsResponse), nil + } } rc := s.GetRaftCluster() if rc == nil { @@ -2935,22 +2906,20 @@ func (s *GrpcServer) handleDamagedStore(stats *pdpb.StoreStats) { // ReportMinResolvedTS implements gRPC PDServer. func (s *GrpcServer) ReportMinResolvedTS(ctx context.Context, request *pdpb.ReportMinResolvedTsRequest) (*pdpb.ReportMinResolvedTsResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "ReportMinResolvedTS"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.ReportMinResolvedTsResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - - if rsp, err := s.unaryMiddleware(ctx, request, "ReportMinResolvedTS"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.ReportMinResolvedTsResponse), nil + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.ReportMinResolvedTsResponse), nil + } } rc := s.GetRaftCluster() @@ -2973,22 +2942,20 @@ func (s *GrpcServer) ReportMinResolvedTS(ctx context.Context, request *pdpb.Repo // SetExternalTimestamp implements gRPC PDServer. func (s *GrpcServer) SetExternalTimestamp(ctx context.Context, request *pdpb.SetExternalTimestampRequest) (*pdpb.SetExternalTimestampResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "SetExternalTimestamp"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.SetExternalTimestampResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - - if rsp, err := s.unaryMiddleware(ctx, request, "SetExternalTimestamp"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.SetExternalTimestampResponse), nil + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.SetExternalTimestampResponse), nil + } } nowTSO, err := s.getGlobalTSO(ctx) @@ -3009,22 +2976,20 @@ func (s *GrpcServer) SetExternalTimestamp(ctx context.Context, request *pdpb.Set // GetExternalTimestamp implements gRPC PDServer. func (s *GrpcServer) GetExternalTimestamp(ctx context.Context, request *pdpb.GetExternalTimestampRequest) (*pdpb.GetExternalTimestampResponse, error) { - if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { - fName := currentFunction() - limiter := s.GetGRPCRateLimiter() - if done, err := limiter.Allow(fName); err == nil { - defer done() - } else { + if midResp, err := s.unaryMiddleware(ctx, request, "GetExternalTimestamp"); err != nil { + return nil, err + } else if midResp != nil { + if midResp.header != nil { return &pdpb.GetExternalTimestampResponse{ - Header: wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), + Header: midResp.header, }, nil } - } - - if rsp, err := s.unaryMiddleware(ctx, request, "GetExternalTimestamp"); err != nil { - return nil, err - } else if rsp != nil { - return rsp.(*pdpb.GetExternalTimestampResponse), nil + if midResp.deferFunc != nil { + defer midResp.deferFunc() + } + if midResp.resp != nil { + return midResp.resp.(*pdpb.GetExternalTimestampResponse), nil + } } timestamp := s.GetExternalTS() diff --git a/server/middleware.go b/server/middleware.go index 9ca6bf12583..667c9da6bea 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -144,15 +144,40 @@ var allowFollowerMethods = map[string]struct{}{ "BatchScanRegions": {}, } -func (s *GrpcServer) unaryMiddleware(ctx context.Context, req request, methodName string) (rsp any, err error) { +var notRateLimitMethods = map[string]struct{}{ + "GetGCSafePointV2": {}, + "UpdateGCSafePointV2": {}, + "UpdateServiceSafePointV2": {}, + "GetAllGCSafePointV2": {}, +} + +type middlewareResponse struct { + resp any + header *pdpb.ResponseHeader + deferFunc func() +} + +func (s *GrpcServer) unaryMiddleware(ctx context.Context, req request, methodName string) (rsp *middlewareResponse, err error) { + midResp := &middlewareResponse{} + _, ok := notRateLimitMethods[methodName] + if !ok && s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { + limiter := s.GetGRPCRateLimiter() + if done, err := limiter.Allow(methodName); err != nil { + midResp.header = wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()) + return midResp, nil + } else { + midResp.deferFunc = done + } + } resp, err := s.unaryFollowerMiddleware(ctx, req, forwardFns[methodName]) if resp != nil || err != nil { - return resp, err + midResp.resp = resp + return midResp, err } if err := s.validateRoleInRequest(ctx, req.GetHeader(), methodName); err != nil { return nil, err } - return nil, nil + return midResp, nil } // unaryFollowerMiddleware forward the request to the leader if the request is