From 81d8dea0ebc465dadc38c7797669e5c4dcad0359 Mon Sep 17 00:00:00 2001 From: zyguan Date: Thu, 28 Mar 2024 14:00:22 +0800 Subject: [PATCH] tikv: ensure safe-ts won't be max uint64 (#1250) * tikv: ensure safe-ts won't be max uint64 Signed-off-by: zyguan * fix a typo Signed-off-by: zyguan * fix lint issue Signed-off-by: zyguan * address the comment and fix test Signed-off-by: zyguan --------- Signed-off-by: zyguan --- integration_tests/pd_api_test.go | 5 +- tikv/kv.go | 35 +++++-- tikv/kv_test.go | 160 ++++++++++++++++++++++++++----- 3 files changed, 166 insertions(+), 34 deletions(-) diff --git a/integration_tests/pd_api_test.go b/integration_tests/pd_api_test.go index 2da4c4bfb..add750f56 100644 --- a/integration_tests/pd_api_test.go +++ b/integration_tests/pd_api_test.go @@ -17,7 +17,6 @@ package tikv_test import ( "context" "fmt" - "math" "strings" "sync/atomic" "testing" @@ -183,8 +182,8 @@ func (s *apiTestSuite) TestInitClusterMinResolvedTSZero() { // Try to get the minimum resolved timestamp of the cluster from TiKV. require.NoError(failpoint.Enable("tikvclient/InjectPDMinResolvedTS", `return(0)`)) // Make sure the store's min resolved ts is not initialized. - s.waitForMinSafeTS(oracle.GlobalTxnScope, math.MaxUint64) - require.Equal(uint64(math.MaxUint64), s.store.GetMinSafeTS(oracle.GlobalTxnScope)) + s.waitForMinSafeTS(oracle.GlobalTxnScope, 0) + require.Equal(uint64(0), s.store.GetMinSafeTS(oracle.GlobalTxnScope)) require.NoError(failpoint.Disable("tikvclient/InjectPDMinResolvedTS")) // Try to get the minimum resolved timestamp of the cluster from PD. diff --git a/tikv/kv.go b/tikv/kv.go index 172f6f614..34b5fa24d 100644 --- a/tikv/kv.go +++ b/tikv/kv.go @@ -572,6 +572,15 @@ func (s *KVStore) GetMinSafeTS(txnScope string) uint64 { return 0 } +func (s *KVStore) setMinSafeTS(txnScope string, safeTS uint64) { + // ensure safeTS is not set to max uint64 + if safeTS == math.MaxUint64 { + logutil.BgLogger().Warn("skip setting min-safe-ts to max uint64", zap.String("txnScope", txnScope), zap.Stack("stack")) + return + } + s.minSafeTS.Store(txnScope, safeTS) +} + // Ctx returns ctx. func (s *KVStore) Ctx() context.Context { return s.ctx @@ -607,6 +616,11 @@ func (s *KVStore) getSafeTS(storeID uint64) (bool, uint64) { // setSafeTS sets safeTs for store storeID, export for testing func (s *KVStore) setSafeTS(storeID, safeTS uint64) { + // ensure safeTS is not set to max uint64 + if safeTS == math.MaxUint64 { + logutil.BgLogger().Warn("skip setting safe-ts to max uint64", zap.Uint64("storeID", storeID), zap.Stack("stack")) + return + } s.safeTSMap.Store(storeID, safeTS) } @@ -614,11 +628,12 @@ func (s *KVStore) updateMinSafeTS(txnScope string, storeIDs []uint64) { minSafeTS := uint64(math.MaxUint64) // when there is no store, return 0 in order to let minStartTS become startTS directly if len(storeIDs) < 1 { - s.minSafeTS.Store(txnScope, 0) + s.setMinSafeTS(txnScope, 0) } for _, store := range storeIDs { ok, safeTS := s.getSafeTS(store) if ok { + // safeTS is guaranteed to be less than math.MaxUint64 (by setSafeTS and its callers) if safeTS != 0 && safeTS < minSafeTS { minSafeTS = safeTS } @@ -626,7 +641,7 @@ func (s *KVStore) updateMinSafeTS(txnScope string, storeIDs []uint64) { minSafeTS = 0 } } - s.minSafeTS.Store(txnScope, minSafeTS) + s.setMinSafeTS(txnScope, minSafeTS) } func (s *KVStore) safeTSUpdater() { @@ -690,8 +705,8 @@ func (s *KVStore) updateSafeTS(ctx context.Context) { safeTS uint64 storeIDStr = strconv.FormatUint(storeID, 10) ) - // If getting the minimum resolved timestamp from PD failed or returned 0, try to get it from TiKV. - if storeMinResolvedTSs == nil || storeMinResolvedTSs[storeID] == 0 || err != nil { + // If getting the minimum resolved timestamp from PD failed or returned 0/MaxUint64, try to get it from TiKV. + if storeMinResolvedTSs == nil || !isValidSafeTS(storeMinResolvedTSs[storeID]) || err != nil { resp, err := tikvClient.SendRequest( ctx, storeAddr, tikvrpc.NewRequest( tikvrpc.CmdStoreSafeTS, &kvrpcpb.StoreSafeTSRequest{ @@ -785,17 +800,17 @@ func (s *KVStore) updateGlobalTxnScopeTSFromPD(ctx context.Context) bool { clusterMinSafeTS, _, err := s.getMinResolvedTSByStoresIDs(ctx, nil) if err != nil { logutil.BgLogger().Debug("get resolved TS from PD failed", zap.Error(err)) - } else if clusterMinSafeTS != 0 { + } else if isValidSafeTS(clusterMinSafeTS) { // Update ts and metrics. preClusterMinSafeTS := s.GetMinSafeTS(oracle.GlobalTxnScope) - // If preClusterMinSafeTS is maxUint64, it means that the min safe ts has not been initialized. + // preClusterMinSafeTS is guaranteed to be less than math.MaxUint64 (by this method and setMinSafeTS) // related to https://github.com/tikv/client-go/issues/991 - if preClusterMinSafeTS != math.MaxUint64 && preClusterMinSafeTS > clusterMinSafeTS { + if preClusterMinSafeTS > clusterMinSafeTS { skipSafeTSUpdateCounter.Inc() preSafeTSTime := oracle.GetTimeFromTS(preClusterMinSafeTS) clusterMinSafeTSGap.Set(time.Since(preSafeTSTime).Seconds()) } else { - s.minSafeTS.Store(oracle.GlobalTxnScope, clusterMinSafeTS) + s.setMinSafeTS(oracle.GlobalTxnScope, clusterMinSafeTS) successSafeTSUpdateCounter.Inc() safeTSTime := oracle.GetTimeFromTS(clusterMinSafeTS) clusterMinSafeTSGap.Set(time.Since(safeTSTime).Seconds()) @@ -807,6 +822,10 @@ func (s *KVStore) updateGlobalTxnScopeTSFromPD(ctx context.Context) bool { return false } +func isValidSafeTS(ts uint64) bool { + return ts != 0 && ts != math.MaxUint64 +} + // EnableResourceControl enables the resource control. func EnableResourceControl() { client.ResourceControlSwitch.Store(true) diff --git a/tikv/kv_test.go b/tikv/kv_test.go index 9f9af8500..edf6182ed 100644 --- a/tikv/kv_test.go +++ b/tikv/kv_test.go @@ -17,6 +17,7 @@ package tikv import ( "context" "fmt" + "math" "sync/atomic" "testing" "time" @@ -29,6 +30,7 @@ import ( "github.com/tikv/client-go/v2/testutils" "github.com/tikv/client-go/v2/tikvrpc" "github.com/tikv/client-go/v2/util" + pdhttp "github.com/tikv/pd/client/http" ) func TestKV(t *testing.T) { @@ -38,18 +40,27 @@ func TestKV(t *testing.T) { type testKVSuite struct { suite.Suite - store *KVStore - cluster *mocktikv.Cluster - tikvStoreID uint64 - tiflashStoreID uint64 - tiflashPeerStoreID uint64 + store *KVStore + cluster *mocktikv.Cluster + tikvStoreID uint64 + tiflashStoreID uint64 + + mockGetMinResolvedTSByStoresIDs atomic.Pointer[func(ctx context.Context, ids []uint64) (uint64, map[uint64]uint64, error)] } func (s *testKVSuite) SetupTest() { client, cluster, pdClient, err := testutils.NewMockTiKV("", nil) s.Require().Nil(err) testutils.BootstrapWithSingleStore(cluster) - store, err := NewTestTiKVStore(client, pdClient, nil, nil, 0) + s.setGetMinResolvedTSByStoresIDs(func(ctx context.Context, ids []uint64) (uint64, map[uint64]uint64, error) { + return 0, nil, nil + }) + store, err := NewTestTiKVStore(client, pdClient, nil, nil, 0, Option(func(store *KVStore) { + store.pdHttpClient = &mockPDHTTPClient{ + Client: pdhttp.NewClientWithServiceDiscovery("test", nil), + mockGetMinResolvedTSByStoresIDs: &s.mockGetMinResolvedTSByStoresIDs, + } + })) s.Require().Nil(err) s.store = store @@ -58,14 +69,18 @@ func (s *testKVSuite) SetupTest() { storeIDs, _, _, _ := mocktikv.BootstrapWithMultiStores(s.cluster, 2) s.tikvStoreID = storeIDs[0] s.tiflashStoreID = storeIDs[1] - tiflashPeerAddrID := cluster.AllocIDs(1) - s.tiflashPeerStoreID = tiflashPeerAddrID[0] - s.cluster.UpdateStorePeerAddr(s.tiflashStoreID, s.storeAddr(s.tiflashPeerStoreID), &metapb.StoreLabel{Key: "engine", Value: "tiflash"}) - s.store.regionCache.SetRegionCacheStore(s.tikvStoreID, s.storeAddr(s.tikvStoreID), s.storeAddr(s.tikvStoreID), tikvrpc.TiKV, 1, nil) var labels []*metapb.StoreLabel - labels = append(labels, &metapb.StoreLabel{Key: "engine", Value: "tiflash"}) - s.store.regionCache.SetRegionCacheStore(s.tiflashStoreID, s.storeAddr(s.tiflashStoreID), s.storeAddr(s.tiflashPeerStoreID), tikvrpc.TiFlash, 1, labels) + labels = append(cluster.GetStore(s.tikvStoreID).Labels, + &metapb.StoreLabel{Key: DCLabelKey, Value: "z1"}) + s.cluster.UpdateStorePeerAddr(s.tikvStoreID, s.storeAddr(s.tikvStoreID), labels...) + s.store.regionCache.SetRegionCacheStore(s.tikvStoreID, s.storeAddr(s.tikvStoreID), s.storeAddr(s.tikvStoreID), tikvrpc.TiKV, 1, labels) + + labels = append(cluster.GetStore(s.tiflashStoreID).Labels, + &metapb.StoreLabel{Key: DCLabelKey, Value: "z2"}, + &metapb.StoreLabel{Key: "engine", Value: "tiflash"}) + s.cluster.UpdateStorePeerAddr(s.tiflashStoreID, s.storeAddr(s.tiflashStoreID), labels...) + s.store.regionCache.SetRegionCacheStore(s.tiflashStoreID, s.storeAddr(s.tiflashStoreID), s.storeAddr(s.tiflashStoreID), tikvrpc.TiFlash, 1, labels) } @@ -77,6 +92,10 @@ func (s *testKVSuite) storeAddr(id uint64) string { return fmt.Sprintf("store%d", id) } +func (s *testKVSuite) setGetMinResolvedTSByStoresIDs(f func(ctx context.Context, ids []uint64) (uint64, map[uint64]uint64, error)) { + s.mockGetMinResolvedTSByStoresIDs.Store(&f) +} + type storeSafeTsMockClient struct { Client requestCount int32 @@ -89,7 +108,7 @@ func (c *storeSafeTsMockClient) SendRequest(ctx context.Context, addr string, re } atomic.AddInt32(&c.requestCount, 1) resp := &tikvrpc.Response{} - if addr == c.testSuite.storeAddr(c.testSuite.tiflashPeerStoreID) { + if addr == c.testSuite.storeAddr(c.testSuite.tiflashStoreID) { resp.Resp = &kvrpcpb.StoreSafeTSResponse{SafeTs: 80} } else { resp.Resp = &kvrpcpb.StoreSafeTSResponse{SafeTs: 100} @@ -105,22 +124,117 @@ func (c *storeSafeTsMockClient) CloseAddr(addr string) error { return c.Client.CloseAddr(addr) } -func (s *testKVSuite) TestMinSafeTs() { +type mockPDHTTPClient struct { + pdhttp.Client + mockGetMinResolvedTSByStoresIDs *atomic.Pointer[func(ctx context.Context, ids []uint64) (uint64, map[uint64]uint64, error)] +} + +func (c *mockPDHTTPClient) GetMinResolvedTSByStoresIDs(ctx context.Context, storeIDs []uint64) (uint64, map[uint64]uint64, error) { + if f := c.mockGetMinResolvedTSByStoresIDs.Load(); f != nil { + return (*f)(ctx, storeIDs) + } + return c.Client.GetMinResolvedTSByStoresIDs(ctx, storeIDs) +} + +func (s *testKVSuite) TestMinSafeTsFromStores() { mockClient := storeSafeTsMockClient{ Client: s.store.GetTiKVClient(), testSuite: s, } s.store.SetTiKVClient(&mockClient) - // wait for updateMinSafeTS - var retryCount int - for s.store.GetMinSafeTS(oracle.GlobalTxnScope) != 80 { - time.Sleep(2 * time.Second) - if retryCount > 5 { - break - } - retryCount++ - } + s.Eventually(func() bool { + ts := s.store.GetMinSafeTS(oracle.GlobalTxnScope) + s.Require().False(math.MaxUint64 == ts) + return ts == 80 + }, 15*time.Second, time.Second) s.Require().GreaterOrEqual(atomic.LoadInt32(&mockClient.requestCount), int32(2)) s.Require().Equal(uint64(80), s.store.GetMinSafeTS(oracle.GlobalTxnScope)) + ok, ts := s.store.getSafeTS(s.tikvStoreID) + s.Require().True(ok) + s.Require().Equal(uint64(100), ts) +} + +func (s *testKVSuite) TestMinSafeTsFromPD() { + mockClient := storeSafeTsMockClient{Client: s.store.GetTiKVClient(), testSuite: s} + s.store.SetTiKVClient(&mockClient) + s.setGetMinResolvedTSByStoresIDs(func(ctx context.Context, ids []uint64) (uint64, map[uint64]uint64, error) { + return 90, nil, nil + }) + s.Eventually(func() bool { + ts := s.store.GetMinSafeTS(oracle.GlobalTxnScope) + s.Require().False(math.MaxUint64 == ts) + return ts == 90 + }, 15*time.Second, time.Second) + s.Require().Equal(atomic.LoadInt32(&mockClient.requestCount), int32(0)) + s.Require().Equal(uint64(90), s.store.GetMinSafeTS(oracle.GlobalTxnScope)) +} + +func (s *testKVSuite) TestMinSafeTsFromPDByStores() { + mockClient := storeSafeTsMockClient{Client: s.store.GetTiKVClient(), testSuite: s} + s.store.SetTiKVClient(&mockClient) + s.setGetMinResolvedTSByStoresIDs(func(ctx context.Context, ids []uint64) (uint64, map[uint64]uint64, error) { + m := make(map[uint64]uint64) + for _, id := range ids { + m[id] = uint64(100) + id + } + return math.MaxUint64, m, nil + }) + s.Eventually(func() bool { + ts := s.store.GetMinSafeTS(oracle.GlobalTxnScope) + s.Require().False(math.MaxUint64 == ts) + return ts == uint64(100)+s.tikvStoreID + }, 15*time.Second, time.Second) + s.Require().Equal(atomic.LoadInt32(&mockClient.requestCount), int32(0)) + s.Require().Equal(uint64(100)+s.tikvStoreID, s.store.GetMinSafeTS(oracle.GlobalTxnScope)) +} + +func (s *testKVSuite) TestMinSafeTsFromMixed1() { + mockClient := storeSafeTsMockClient{Client: s.store.GetTiKVClient(), testSuite: s} + s.store.SetTiKVClient(&mockClient) + s.setGetMinResolvedTSByStoresIDs(func(ctx context.Context, ids []uint64) (uint64, map[uint64]uint64, error) { + m := make(map[uint64]uint64) + for _, id := range ids { + if id == s.tiflashStoreID { + m[id] = 0 + } else { + m[id] = uint64(10) + } + } + return math.MaxUint64, m, nil + }) + s.Eventually(func() bool { + ts := s.store.GetMinSafeTS("z1") + s.Require().False(math.MaxUint64 == ts) + return ts == uint64(10) + }, 15*time.Second, time.Second) + s.Require().GreaterOrEqual(atomic.LoadInt32(&mockClient.requestCount), int32(1)) + s.Require().Equal(uint64(10), s.store.GetMinSafeTS(oracle.GlobalTxnScope)) + s.Require().Equal(uint64(10), s.store.GetMinSafeTS("z1")) + s.Require().Equal(uint64(80), s.store.GetMinSafeTS("z2")) +} + +func (s *testKVSuite) TestMinSafeTsFromMixed2() { + mockClient := storeSafeTsMockClient{Client: s.store.GetTiKVClient(), testSuite: s} + s.store.SetTiKVClient(&mockClient) + s.setGetMinResolvedTSByStoresIDs(func(ctx context.Context, ids []uint64) (uint64, map[uint64]uint64, error) { + m := make(map[uint64]uint64) + for _, id := range ids { + if id == s.tiflashStoreID { + m[id] = uint64(10) + } else { + m[id] = math.MaxUint64 + } + } + return math.MaxUint64, m, nil + }) + s.Eventually(func() bool { + ts := s.store.GetMinSafeTS("z2") + s.Require().False(math.MaxUint64 == ts) + return ts == uint64(10) + }, 15*time.Second, time.Second) + s.Require().GreaterOrEqual(atomic.LoadInt32(&mockClient.requestCount), int32(1)) + s.Require().Equal(uint64(10), s.store.GetMinSafeTS(oracle.GlobalTxnScope)) + s.Require().Equal(uint64(100), s.store.GetMinSafeTS("z1")) + s.Require().Equal(uint64(10), s.store.GetMinSafeTS("z2")) }