From 7899162e4223a6de67a724ad3d9503d59f62af28 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Thu, 19 Dec 2024 18:18:49 +0800 Subject: [PATCH] Implement the query region gRPC client Signed-off-by: JmPotato --- client/client.go | 25 +- client/clients/router/client.go | 435 ++++++++++++++++++ client/clients/router/request.go | 115 +++++ client/clients/tso/client.go | 4 +- client/errs/errno.go | 1 + client/go.mod | 2 +- client/go.sum | 4 +- client/inner_client.go | 3 + client/pkg/utils/timerutil/util.go | 32 ++ client/servicediscovery/service_discovery.go | 53 ++- .../servicediscovery/tso_service_discovery.go | 8 +- tests/integrations/client/client_test.go | 1 - 12 files changed, 634 insertions(+), 49 deletions(-) create mode 100644 client/clients/router/request.go create mode 100644 client/pkg/utils/timerutil/util.go diff --git a/client/client.go b/client/client.go index e3d3f4e5b14..8b21b17169e 100644 --- a/client/client.go +++ b/client/client.go @@ -570,23 +570,6 @@ func (c *client) GetMinTS(ctx context.Context) (physical int64, logical int64, e return minTS.Physical, minTS.Logical, nil } -func handleRegionResponse(res *pdpb.GetRegionResponse) *router.Region { - if res.Region == nil { - return nil - } - - r := &router.Region{ - Meta: res.Region, - Leader: res.Leader, - PendingPeers: res.PendingPeers, - Buckets: res.Buckets, - } - for _, s := range res.DownPeers { - r.DownPeers = append(r.DownPeers, s.Peer) - } - return r -} - // GetRegionFromMember implements the RPCClient interface. func (c *client) GetRegionFromMember(ctx context.Context, key []byte, memberURLs []string, _ ...opt.GetRegionOption) (*router.Region, error) { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { @@ -623,7 +606,7 @@ func (c *client) GetRegionFromMember(ctx context.Context, key []byte, memberURLs errorMsg := fmt.Sprintf("[pd] can't get region info from member URLs: %+v", memberURLs) return nil, errors.WithStack(errors.New(errorMsg)) } - return handleRegionResponse(resp), nil + return router.ConvertToRegion(resp), nil } // GetRegion implements the RPCClient interface. @@ -663,7 +646,7 @@ func (c *client) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegio if err = c.respForErr(metrics.CmdFailedDurationGetRegion, start, err, resp.GetHeader()); err != nil { return nil, err } - return handleRegionResponse(resp), nil + return router.ConvertToRegion(resp), nil } // GetPrevRegion implements the RPCClient interface. @@ -703,7 +686,7 @@ func (c *client) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetR if err = c.respForErr(metrics.CmdFailedDurationGetPrevRegion, start, err, resp.GetHeader()); err != nil { return nil, err } - return handleRegionResponse(resp), nil + return router.ConvertToRegion(resp), nil } // GetRegionByID implements the RPCClient interface. @@ -744,7 +727,7 @@ func (c *client) GetRegionByID(ctx context.Context, regionID uint64, opts ...opt if err = c.respForErr(metrics.CmdFailedDurationGetRegionByID, start, err, resp.GetHeader()); err != nil { return nil, err } - return handleRegionResponse(resp), nil + return router.ConvertToRegion(resp), nil } // ScanRegions implements the RPCClient interface. diff --git a/client/clients/router/client.go b/client/clients/router/client.go index 48cebfa950e..e1012701460 100644 --- a/client/clients/router/client.go +++ b/client/clients/router/client.go @@ -18,12 +18,31 @@ import ( "context" "encoding/hex" "net/url" + "runtime/trace" + "sync" + "sync/atomic" + "time" + + "github.com/opentracing/opentracing-go" + "go.uber.org/zap" + "google.golang.org/grpc" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/log" + "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/opt" + "github.com/tikv/pd/client/pkg/batch" + cctx "github.com/tikv/pd/client/pkg/connectionctx" + "github.com/tikv/pd/client/pkg/retry" + "github.com/tikv/pd/client/pkg/utils/timerutil" + sd "github.com/tikv/pd/client/servicediscovery" ) +// defaultMaxRouterRequestBatchSize is the default max size of the router request batch. +const defaultMaxRouterRequestBatchSize = 10000 + // Region contains information of a region's meta and its peers. type Region struct { Meta *metapb.Region @@ -33,6 +52,33 @@ type Region struct { Buckets *metapb.Buckets } +type regionResponse interface { + GetRegion() *metapb.Region + GetLeader() *metapb.Peer + GetDownPeers() []*pdpb.PeerStats + GetPendingPeers() []*metapb.Peer + GetBuckets() *metapb.Buckets +} + +// ConvertToRegion converts the region response to the region. +func ConvertToRegion(res regionResponse) *Region { + region := res.GetRegion() + if region == nil { + return nil + } + + r := &Region{ + Meta: region, + Leader: res.GetLeader(), + PendingPeers: res.GetPendingPeers(), + Buckets: res.GetBuckets(), + } + for _, s := range res.GetDownPeers() { + r.DownPeers = append(r.DownPeers, s.Peer) + } + return r +} + // KeyRange defines a range of keys in bytes. type KeyRange struct { StartKey []byte @@ -92,3 +138,392 @@ type Client interface { // The returned regions are flattened, even there are key ranges located in the same region, only one region will be returned. BatchScanRegions(ctx context.Context, keyRanges []KeyRange, limit int, opts ...opt.GetRegionOption) ([]*Region, error) } + +// Cli is the implementation of the router client. +type Cli struct { + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + option *opt.Option + + svcDiscovery sd.ServiceDiscovery + // leaderURL is the URL of the router leader. + leaderURL atomic.Value + // conCtxMgr is used to store the context of the router stream connection(s). + conCtxMgr *cctx.Manager[pdpb.PD_QueryRegionClient] + // updateConnectionCh is used to trigger the connection update actively. + updateConnectionCh chan struct{} + // bo is the backoffer for the router client. + bo *retry.Backoffer + + reqPool *sync.Pool + requestCh chan *Request + batchController *batch.Controller[*Request] +} + +// NewClient returns a new router client. +func NewClient( + ctx context.Context, + svcDiscovery sd.ServiceDiscovery, + option *opt.Option, +) *Cli { + ctx, cancel := context.WithCancel(ctx) + c := &Cli{ + ctx: ctx, + cancel: cancel, + svcDiscovery: svcDiscovery, + option: option, + conCtxMgr: cctx.NewManager[pdpb.PD_QueryRegionClient](), + updateConnectionCh: make(chan struct{}, 1), + bo: retry.InitialBackoffer( + sd.UpdateMemberBackOffBaseTime, sd.UpdateMemberTimeout, sd.UpdateMemberBackOffBaseTime), + reqPool: &sync.Pool{ + New: func() any { + return &Request{ + done: make(chan error, 1), + } + }, + }, + requestCh: make(chan *Request, defaultMaxRouterRequestBatchSize*2), + batchController: batch.NewController(defaultMaxRouterRequestBatchSize, requestFinisher(nil), nil), + } + c.leaderURL.Store(svcDiscovery.GetServingURL()) + + eventSrc := svcDiscovery.(sd.EventSource) + eventSrc.SetLeaderURLUpdatedCallback(c.updateLeaderURL) + + c.wg.Add(2) + go c.connectionDaemon() + go c.dispatcher() + + return c +} + +func (c *Cli) newRequest(ctx context.Context) *Request { + req := c.reqPool.Get().(*Request) + req.requestCtx = ctx + req.clientCtx = c.ctx + req.pool = c.reqPool + + return req +} + +func requestFinisher(resp *pdpb.QueryRegionResponse) batch.FinisherFunc[*Request] { + var keyIdx, prevKeyIdx int + return func(_ int, req *Request, err error) { + requestCtx := req.requestCtx + defer trace.StartRegion(requestCtx, "pdclient.regionReqDone").End() + + if err != nil { + req.tryDone(err) + return + } + + var id uint64 + if req.key != nil { + id = resp.KeyIdMap[keyIdx] + keyIdx++ + } else if req.prevKey != nil { + id = resp.PrevKeyIdMap[prevKeyIdx] + prevKeyIdx++ + } else if req.id != 0 { + id = req.id + } + region, ok := resp.RegionsById[id] + if !ok { + err = errs.ErrClientRegionNotFound.FastGenByArgs(id) + } else { + req.region = ConvertToRegion(region) + } + req.tryDone(err) + } +} + +func (c *Cli) cancelCollectedRequests(err error) { + c.batchController.FinishCollectedRequests(requestFinisher(nil), err) +} + +func (c *Cli) doneCollectedRequests(resp *pdpb.QueryRegionResponse) { + c.batchController.FinishCollectedRequests(requestFinisher(resp), nil) +} + +// Close closes the router client. +func (c *Cli) Close() { + if c == nil { + return + } + log.Info("[router] closing router client") + + c.cancel() + c.wg.Wait() + + log.Info("[router] router client is closed") +} + +func (c *Cli) getLeaderURL() string { + url := c.leaderURL.Load() + if url == nil { + return "" + } + return url.(string) +} + +func (c *Cli) updateLeaderURL(url string) error { + oldURL := c.getLeaderURL() + if !c.leaderURL.CompareAndSwap(oldURL, url) { + return nil + } + c.scheduleUpdateConnection() + + log.Info("[router] switch the router leader serving url", + zap.String("old-url", oldURL), zap.String("new-url", url)) + return nil +} + +// getLeaderClientConn returns the leader gRPC client connection. +func (c *Cli) getLeaderClientConn() (*grpc.ClientConn, string) { + url := c.getLeaderURL() + if len(url) == 0 { + c.svcDiscovery.ScheduleCheckMemberChanged() + return nil, "" + } + cc, ok := c.svcDiscovery.GetClientConns().Load(url) + if !ok { + return nil, url + } + return cc.(*grpc.ClientConn), url +} + +// scheduleUpdateConnection is used to schedule an update to the connection(s). +func (c *Cli) scheduleUpdateConnection() { + select { + case c.updateConnectionCh <- struct{}{}: + default: + } +} + +// connectionDaemon is used to update the router leader/primary/backup connection(s) in background. +// It aims to provide a seamless connection updating for the router client to keep providing the +// router service without interruption. +func (c *Cli) connectionDaemon() { + defer c.wg.Done() + updaterCtx, updaterCancel := context.WithCancel(c.ctx) + defer updaterCancel() + updateTicker := time.NewTicker(sd.MemberUpdateInterval) + defer updateTicker.Stop() + + log.Info("[router] connection daemon is started") + for { + c.updateConnection(updaterCtx) + select { + case <-updaterCtx.Done(): + log.Info("[router] connection daemon is exiting") + return + case <-updateTicker.C: + case <-c.updateConnectionCh: + } + } +} + +// updateConnection is used to get the leader client connection and update the connection context if it does not exist before. +func (c *Cli) updateConnection(ctx context.Context) { + cc, url := c.getLeaderClientConn() + if cc == nil || len(url) == 0 { + log.Warn("[router] got an invalid leader client connection", zap.String("url", url)) + return + } + if c.conCtxMgr.Exist(url) { + log.Debug("[router] the router leader remains unchanged", zap.String("url", url)) + return + } + stream, err := pdpb.NewPDClient(cc).QueryRegion(ctx) + if err != nil { + log.Error("[router] failed to create the router stream connection", errs.ZapError(err)) + } + c.conCtxMgr.Store(ctx, url, stream) + // TODO: support the forwarding mechanism for the router client. + // TODO: support sending the router requests to the follower nodes. +} + +func (c *Cli) dispatcher() { + defer c.wg.Done() + + var ( + stream pdpb.PD_QueryRegionClient + streamURL string + streamCtx context.Context + timeoutTimer *time.Timer + resetTimeoutTimer = func() { + if timeoutTimer == nil { + timeoutTimer = time.NewTimer(c.option.Timeout) + } else { + timerutil.SafeResetTimer(timeoutTimer, c.option.Timeout) + } + } + ctx, cancel = context.WithCancel(c.ctx) + ) + + log.Info("[router] dispatcher is started") + defer func() { + log.Info("[router] dispatcher is exiting") + cancel() + if timeoutTimer != nil { + timeoutTimer.Stop() + } + log.Info("[router] dispatcher exited") + }() +batchLoop: + for { + select { + case <-ctx.Done(): + return + default: + } + + // Step 1: Fetch the pending router requests in batch. + err := c.batchController.FetchPendingRequests(ctx, c.requestCh, nil, 0) + if err != nil { + if err == context.Canceled { + log.Info("[router] stop fetching the pending router requests due to context canceled") + } else { + log.Error("[router] failed to fetch the pending router requests", errs.ZapError(err)) + } + return + } + + // Step 2: Choose a stream connection to send the router request. + resetTimeoutTimer() + connectionCtxChoosingLoop: + for { + // Check if the dispatcher is canceled or the timeout timer is triggered. + select { + case <-ctx.Done(): + return + case <-timeoutTimer.C: + log.Error("[router] router stream connection is not ready until timeout, abort the batch") + c.svcDiscovery.ScheduleCheckMemberChanged() + c.batchController.FinishCollectedRequests(requestFinisher(nil), err) + continue batchLoop + default: + } + // Choose a stream connection to send the router request later. + connectionCtx := c.conCtxMgr.GetConnectionCtx() + if connectionCtx == nil { + log.Info("[router] router stream connection is not ready") + c.scheduleUpdateConnection() + continue connectionCtxChoosingLoop + } + streamCtx, streamURL, stream = connectionCtx.Ctx, connectionCtx.StreamURL, connectionCtx.Stream + // Check if the stream connection is canceled. + select { + case <-streamCtx.Done(): + log.Info("[router] router stream connection is canceled", zap.String("stream-url", streamURL)) + c.conCtxMgr.Release(streamURL) + continue connectionCtxChoosingLoop + default: + } + // The stream connection is ready, break the loop. + break connectionCtxChoosingLoop + } + + // Step 3: Dispatch the router requests to the stream connection. + // TODO: timeout handling if the stream takes too long to process the requests. + err = c.processRequests(stream) + if err != nil { + if !c.handleProcessRequestError(ctx, streamURL, err) { + return + } + } + } +} + +func (c *Cli) processRequests(stream pdpb.PD_QueryRegionClient) error { + var ( + requests = c.batchController.GetCollectedRequests() + traceRegions = make([]*trace.Region, 0, len(requests)) + spans = make([]opentracing.Span, 0, len(requests)) + ) + for _, req := range requests { + traceRegions = append(traceRegions, trace.StartRegion(req.requestCtx, "pdclient.regionReqSend")) + if span := opentracing.SpanFromContext(req.requestCtx); span != nil && span.Tracer() != nil { + spans = append(spans, span.Tracer().StartSpan("pdclient.processRegionRequests", opentracing.ChildOf(span.Context()))) + } + } + defer func() { + for i := range spans { + spans[i].Finish() + } + for i := range traceRegions { + traceRegions[i].End() + } + }() + + queryReq := &pdpb.QueryRegionRequest{ + Header: &pdpb.RequestHeader{ + ClusterId: c.svcDiscovery.GetClusterID(), + }, + Keys: make([][]byte, 0, len(requests)), + PrevKeys: make([][]byte, 0, len(requests)), + Ids: make([]uint64, 0, len(requests)), + } + for _, req := range requests { + if !queryReq.NeedBuckets && req.needBuckets { + queryReq.NeedBuckets = true + } + if req.key != nil { + queryReq.Keys = append(queryReq.Keys, req.key) + } else if req.prevKey != nil { + queryReq.PrevKeys = append(queryReq.PrevKeys, req.prevKey) + } else if req.id != 0 { + queryReq.Ids = append(queryReq.Ids, req.id) + } else { + panic("invalid region query request received") + } + } + err := stream.Send(queryReq) + if err != nil { + return err + } + resp, err := stream.Recv() + if err != nil { + return err + } + c.doneCollectedRequests(resp) + return nil +} + +func (c *Cli) handleProcessRequestError( + ctx context.Context, + streamURL string, + err error, +) bool { + log.Error("[router] failed to process the router requests", + zap.String("stream-url", streamURL), + errs.ZapError(err)) + c.cancelCollectedRequests(err) + + select { + case <-ctx.Done(): + return false + default: + } + + // Delete the stream connection context. + c.conCtxMgr.Release(streamURL) + if errs.IsLeaderChange(err) { + // If the leader changes, we better call `CheckMemberChanged` blockingly to + // ensure the next round of router requests can be sent to the new leader. + if err := c.bo.Exec(ctx, c.svcDiscovery.CheckMemberChanged); err != nil { + select { + case <-ctx.Done(): + return false + default: + } + } + } else { + // For other errors, we can just schedule a member change check asynchronously. + c.svcDiscovery.ScheduleCheckMemberChanged() + } + + return true +} diff --git a/client/clients/router/request.go b/client/clients/router/request.go new file mode 100644 index 00000000000..2e1c2e97aa5 --- /dev/null +++ b/client/clients/router/request.go @@ -0,0 +1,115 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package router + +import ( + "context" + "runtime/trace" + "sync" + + "github.com/pingcap/errors" + + "github.com/tikv/pd/client/opt" +) + +// Request is a region info request. +type Request struct { + requestCtx context.Context + clientCtx context.Context + + key []byte + prevKey []byte + id uint64 + needBuckets bool + + done chan error + // region will be set after the request is done. + region *Region + + // Runtime fields. + pool *sync.Pool +} + +func (req *Request) tryDone(err error) { + select { + case req.done <- err: + default: + } +} + +func (req *Request) wait() (*Region, error) { + // TODO: introduce the metrics. + select { + case err := <-req.done: + defer req.pool.Put(req) + defer trace.StartRegion(req.requestCtx, "pdclient.regionReqDone").End() + if err != nil { + return nil, errors.WithStack(err) + } + return req.region, nil + case <-req.requestCtx.Done(): + return nil, errors.WithStack(req.requestCtx.Err()) + case <-req.clientCtx.Done(): + return nil, errors.WithStack(req.clientCtx.Err()) + } +} + +// GetRegion implements the Client interface. +func (c *Cli) GetRegion(ctx context.Context, key []byte, opts ...opt.GetRegionOption) (*Region, error) { + req := c.newRequest(ctx) + req.key = key + options := &opt.GetRegionOp{} + for _, opt := range opts { + opt(options) + } + req.needBuckets = options.NeedBuckets + + c.requestCh <- req + return req.wait() +} + +// GetRegionFromMember implements the Client interface. +func (c *Cli) GetRegionFromMember(ctx context.Context, key []byte, _ []string, opts ...opt.GetRegionOption) (*Region, error) { + // Before we support the follower stream connection, this method is equivalent to `GetRegion`. + return c.GetRegion(ctx, key, opts...) +} + +// GetPrevRegion implements the Client interface. +func (c *Cli) GetPrevRegion(ctx context.Context, key []byte, opts ...opt.GetRegionOption) (*Region, error) { + req := c.newRequest(ctx) + req.prevKey = key + options := &opt.GetRegionOp{} + for _, opt := range opts { + opt(options) + } + req.needBuckets = options.NeedBuckets + + c.requestCh <- req + return req.wait() +} + +// GetRegionByID implements the Client interface. +func (c *Cli) GetRegionByID(ctx context.Context, regionID uint64, opts ...opt.GetRegionOption) (*Region, error) { + req := c.newRequest(ctx) + req.id = regionID + options := &opt.GetRegionOp{} + for _, opt := range opts { + opt(options) + } + req.needBuckets = options.NeedBuckets + + c.requestCh <- req + return req.wait() +} diff --git a/client/clients/tso/client.go b/client/clients/tso/client.go index 7bc768ee21b..bd237ecfd15 100644 --- a/client/clients/tso/client.go +++ b/client/clients/tso/client.go @@ -116,8 +116,8 @@ func NewClient( }, } - eventSrc := svcDiscovery.(sd.TSOEventSource) - eventSrc.SetTSOLeaderURLUpdatedCallback(c.updateTSOLeaderURL) + eventSrc := svcDiscovery.(sd.EventSource) + eventSrc.SetLeaderURLUpdatedCallback(c.updateTSOLeaderURL) c.svcDiscovery.AddServiceURLsSwitchedCallback(c.scheduleUpdateTSOConnectionCtxs) return c diff --git a/client/errs/errno.go b/client/errs/errno.go index 99a426d0776..8f81d2d6777 100644 --- a/client/errs/errno.go +++ b/client/errs/errno.go @@ -70,6 +70,7 @@ var ( ErrClientFindGroupByKeyspaceID = errors.Normalize("can't find keyspace group by keyspace id", errors.RFCCodeText("PD:client:ErrClientFindGroupByKeyspaceID")) ErrClientWatchGCSafePointV2Stream = errors.Normalize("watch gc safe point v2 stream failed", errors.RFCCodeText("PD:client:ErrClientWatchGCSafePointV2Stream")) ErrCircuitBreakerOpen = errors.Normalize("circuit breaker is open", errors.RFCCodeText("PD:client:ErrCircuitBreakerOpen")) + ErrClientRegionNotFound = errors.Normalize("region %d not found", errors.RFCCodeText("PD:client:ErrClientRegionNotFound")) ) // grpcutil errors diff --git a/client/go.mod b/client/go.mod index 78aef084ff7..a84bf303be1 100644 --- a/client/go.mod +++ b/client/go.mod @@ -10,7 +10,7 @@ require ( github.com/opentracing/opentracing-go v1.2.0 github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 - github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037 + github.com/pingcap/kvproto v0.0.0-20250117122752-2b87602a94a1 github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 github.com/prometheus/client_golang v1.20.5 github.com/stretchr/testify v1.9.0 diff --git a/client/go.sum b/client/go.sum index 4cca5ba3ad5..2873e4f550c 100644 --- a/client/go.sum +++ b/client/go.sum @@ -49,8 +49,8 @@ github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c h1:xpW9bvK+HuuTm github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c/go.mod h1:X2r9ueLEUZgtx2cIogM0v4Zj5uvvzhuuiu7Pn8HzMPg= github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 h1:tdMsjOqUR7YXHoBitzdebTvOjs/swniBTOLy5XiMtuE= github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86/go.mod h1:exzhVYca3WRtd6gclGNErRWb1qEgff3LYta0LvRmON4= -github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037 h1:xYNSJjYNur4Dr5bV+9BXK9n5E0T1zlcAN25XX68+mOg= -github.com/pingcap/kvproto v0.0.0-20241120071417-b5b7843d9037/go.mod h1:rXxWk2UnwfUhLXha1jxRWPADw9eMZGWEWCg92Tgmb/8= +github.com/pingcap/kvproto v0.0.0-20250117122752-2b87602a94a1 h1:rTAyiswGyWSGHJVa4Mkhdi8YfGqfA4LrUVKsH9nrJ8E= +github.com/pingcap/kvproto v0.0.0-20250117122752-2b87602a94a1/go.mod h1:rXxWk2UnwfUhLXha1jxRWPADw9eMZGWEWCg92Tgmb/8= github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 h1:HR/ylkkLmGdSSDaD8IDP+SZrdhV1Kibl9KrHxJ9eciw= github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/client/inner_client.go b/client/inner_client.go index 8379b6a51a9..b883cfe8005 100644 --- a/client/inner_client.go +++ b/client/inner_client.go @@ -13,6 +13,7 @@ import ( "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/log" + "github.com/tikv/pd/client/clients/router" "github.com/tikv/pd/client/clients/tso" "github.com/tikv/pd/client/errs" "github.com/tikv/pd/client/metrics" @@ -31,6 +32,7 @@ type innerClient struct { serviceDiscovery sd.ServiceDiscovery tokenDispatcher *tokenDispatcher + routerClient *router.Cli // For service mode switching. serviceModeKeeper @@ -55,6 +57,7 @@ func (c *innerClient) init(updateKeyspaceIDCb sd.UpdateKeyspaceIDFunc) error { } return err } + c.routerClient = router.NewClient(c.ctx, c.serviceDiscovery, c.option) return nil } diff --git a/client/pkg/utils/timerutil/util.go b/client/pkg/utils/timerutil/util.go new file mode 100644 index 00000000000..7e24671a09e --- /dev/null +++ b/client/pkg/utils/timerutil/util.go @@ -0,0 +1,32 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package timerutil + +import "time" + +// SafeResetTimer is used to reset timer safely. +// Before Go 1.23, the only safe way to use Reset was to call Timer.Stop and explicitly drain the timer first. +// We need be careful here, see more details in the comments of Timer.Reset. +// https://pkg.go.dev/time@master#Timer.Reset +func SafeResetTimer(t *time.Timer, d time.Duration) { + // Stop the timer if it's not stopped. + if !t.Stop() { + select { + case <-t.C: // try to drain from the channel + default: + } + } + t.Reset(d) +} diff --git a/client/servicediscovery/service_discovery.go b/client/servicediscovery/service_discovery.go index f5ac665b7cd..45ec90f7bc4 100644 --- a/client/servicediscovery/service_discovery.go +++ b/client/servicediscovery/service_discovery.go @@ -394,17 +394,17 @@ func (c *serviceBalancer) get() (ret ServiceClient) { // UpdateKeyspaceIDFunc is the function type for updating the keyspace ID. type UpdateKeyspaceIDFunc func() error -type tsoLeaderURLUpdatedFunc func(string) error +type leaderURLUpdatedFunc func(string) error -// TSOEventSource subscribes to events related to changes in the TSO leader/primary from the service discovery. -type TSOEventSource interface { - // SetTSOLeaderURLUpdatedCallback adds a callback which will be called when the TSO leader/primary is updated. - SetTSOLeaderURLUpdatedCallback(callback tsoLeaderURLUpdatedFunc) +// EventSource subscribes to events related to changes in the leader/primary from the service discovery. +type EventSource interface { + // SetLeaderURLUpdatedCallback adds a callback which will be called when the leader/primary is updated. + SetLeaderURLUpdatedCallback(callback leaderURLUpdatedFunc) } var ( _ ServiceDiscovery = (*serviceDiscovery)(nil) - _ TSOEventSource = (*serviceDiscovery)(nil) + _ EventSource = (*serviceDiscovery)(nil) ) // serviceDiscovery is the service discovery client of PD/PD service which is quorum based @@ -433,8 +433,8 @@ type serviceDiscovery struct { // membersChangedCbs will be called after there is any membership change in the // leader and followers membersChangedCbs []func() - // tsoLeaderUpdatedCb will be called when the TSO leader is updated. - tsoLeaderUpdatedCb tsoLeaderURLUpdatedFunc + // leaderUpdatedCb will be called when the leader/primary is updated. + leaderUpdatedCb atomic.Value // Store as []leaderURLUpdatedFunc checkMembershipCh chan struct{} @@ -485,6 +485,25 @@ func NewServiceDiscovery( return pdsd } +func (c *serviceDiscovery) addLeaderUpdatedCb(cb leaderURLUpdatedFunc) { + if c.leaderUpdatedCb.Load() == nil { + c.leaderUpdatedCb.Store(make([]leaderURLUpdatedFunc, 0, 2)) + } + c.leaderUpdatedCb.Store(append(c.leaderUpdatedCb.Load().([]leaderURLUpdatedFunc), cb)) +} + +func (c *serviceDiscovery) callLeaderUpdatedCb(url string) (err error) { + for _, cb := range c.leaderUpdatedCb.Load().([]leaderURLUpdatedFunc) { + if cb == nil { + continue + } + if err = cb(url); err != nil { + return err + } + } + return nil +} + // Init initializes the service discovery. func (c *serviceDiscovery) Init() error { if c.isInitialized { @@ -803,15 +822,15 @@ func (c *serviceDiscovery) AddServiceURLsSwitchedCallback(callbacks ...func()) { c.membersChangedCbs = append(c.membersChangedCbs, callbacks...) } -// SetTSOLeaderURLUpdatedCallback adds a callback which will be called when the TSO leader is updated. -func (c *serviceDiscovery) SetTSOLeaderURLUpdatedCallback(callback tsoLeaderURLUpdatedFunc) { +// SetLeaderURLUpdatedCallback adds a callback which will be called when the PD leader is updated. +func (c *serviceDiscovery) SetLeaderURLUpdatedCallback(callback leaderURLUpdatedFunc) { url := c.getLeaderURL() if len(url) > 0 { if err := callback(url); err != nil { - log.Error("[tso] failed to call back when tso leader url update", zap.String("url", url), errs.ZapError(err)) + log.Error("[pd] failed to call back when pd leader url update", zap.String("url", url), errs.ZapError(err)) } } - c.tsoLeaderUpdatedCb = callback + c.addLeaderUpdatedCb(callback) } // getLeaderURL returns the leader URL. @@ -980,23 +999,21 @@ func (c *serviceDiscovery) switchLeader(url string) (bool, error) { return false, nil } - newConn, err := c.GetOrCreateGRPCConn(url) + newConn, _ := c.GetOrCreateGRPCConn(url) // If gRPC connect is created successfully or leader is new, still saves. if url != oldLeader.GetURL() || newConn != nil { leaderClient := newPDServiceClient(url, url, newConn, true) c.leader.Store(leaderClient) } // Run callbacks - if c.tsoLeaderUpdatedCb != nil { - if err := c.tsoLeaderUpdatedCb(url); err != nil { - return true, err - } + if err := c.callLeaderUpdatedCb(url); err != nil { + return true, err } for _, cb := range c.leaderSwitchedCbs { cb() } log.Info("[pd] switch leader", zap.String("new-leader", url), zap.String("old-leader", oldLeader.GetURL())) - return true, err + return true, nil } func (c *serviceDiscovery) updateFollowers(members []*pdpb.Member, leaderID uint64, leaderURL string) (changed bool) { diff --git a/client/servicediscovery/tso_service_discovery.go b/client/servicediscovery/tso_service_discovery.go index 7734fd23107..e9acf10c8f8 100644 --- a/client/servicediscovery/tso_service_discovery.go +++ b/client/servicediscovery/tso_service_discovery.go @@ -57,7 +57,7 @@ const ( var ( _ ServiceDiscovery = (*tsoServiceDiscovery)(nil) - _ TSOEventSource = (*tsoServiceDiscovery)(nil) + _ EventSource = (*tsoServiceDiscovery)(nil) ) // keyspaceGroupSvcDiscovery is used for discovering the serving endpoints of the keyspace @@ -144,7 +144,7 @@ type tsoServiceDiscovery struct { clientConns sync.Map // Store as map[string]*grpc.ClientConn // tsoLeaderUpdatedCb will be called when the TSO leader is updated. - tsoLeaderUpdatedCb tsoLeaderURLUpdatedFunc + tsoLeaderUpdatedCb leaderURLUpdatedFunc checkMembershipCh chan struct{} @@ -369,8 +369,8 @@ func (*tsoServiceDiscovery) AddServingURLSwitchedCallback(...func()) {} // in a primary/secondary configured cluster is changed. func (*tsoServiceDiscovery) AddServiceURLsSwitchedCallback(...func()) {} -// SetTSOLeaderURLUpdatedCallback adds a callback which will be called when the TSO leader is updated. -func (c *tsoServiceDiscovery) SetTSOLeaderURLUpdatedCallback(callback tsoLeaderURLUpdatedFunc) { +// SetLeaderURLUpdatedCallback adds a callback which will be called when the TSO leader is updated. +func (c *tsoServiceDiscovery) SetLeaderURLUpdatedCallback(callback leaderURLUpdatedFunc) { url := c.getPrimaryURL() if len(url) > 0 { if err := callback(url); err != nil { diff --git a/tests/integrations/client/client_test.go b/tests/integrations/client/client_test.go index 0e0bf25d74d..f3f3fdc15ff 100644 --- a/tests/integrations/client/client_test.go +++ b/tests/integrations/client/client_test.go @@ -1203,7 +1203,6 @@ func (suite *clientTestSuite) TestGetPrevRegion() { err := suite.regionHeartbeat.Send(req) re.NoError(err) } - time.Sleep(500 * time.Millisecond) for i := range 20 { testutil.Eventually(re, func() bool { r, err := suite.client.GetPrevRegion(context.Background(), []byte{byte(i)})