diff --git a/integration_tests/pd_api_test.go b/integration_tests/pd_api_test.go index a990c9cad7..349cb3847a 100644 --- a/integration_tests/pd_api_test.go +++ b/integration_tests/pd_api_test.go @@ -206,7 +206,7 @@ func (s *apiTestSuite) TestInitClusterMinResolvedTSZero() { retryCount++ } // Make sure the store's min resolved ts is not initialized. - require.Equal(uint64(math.MaxUint64), s.store.GetMinSafeTS(oracle.GlobalTxnScope)) + require.Equal(uint64(0), s.store.GetMinSafeTS(oracle.GlobalTxnScope)) require.NoError(failpoint.Disable("tikvclient/InjectMinResolvedTS")) // Try to get the minimum resolved timestamp of the cluster from PD. diff --git a/internal/logutil/log.go b/internal/logutil/log.go index 0eb789ddb1..234f3408a5 100644 --- a/internal/logutil/log.go +++ b/internal/logutil/log.go @@ -36,6 +36,7 @@ package logutil import ( "context" + "testing" "github.com/pingcap/log" "go.uber.org/zap" @@ -60,3 +61,11 @@ type ctxLogKeyType struct{} // CtxLogKey is the key to retrieve logger from context. // It can be assigned to another value. var CtxLogKey interface{} = ctxLogKeyType{} + +// AssertWarn panics when in testing mode, and logs a warning msg otherwise. +func AssertWarn(logger *zap.Logger, msg string, fields ...zap.Field) { + if testing.Testing() { + logger.Panic(msg, fields...) + } + logger.Warn(msg, fields...) +} diff --git a/tikv/kv.go b/tikv/kv.go index a60cbb4c78..089011fb5d 100644 --- a/tikv/kv.go +++ b/tikv/kv.go @@ -138,6 +138,21 @@ type KVStore struct { wg sync.WaitGroup close atomicutil.Bool gP Pool + + testingKnobs struct { + mockGetMinResolvedTSByStoresIDs atomic.Pointer[func(ctx context.Context, ids []string) (uint64, map[uint64]uint64, error)] + } +} + +func (s *KVStore) setGetMinResolvedTSByStoresIDs(f func(ctx context.Context, ids []string) (uint64, map[uint64]uint64, error)) { + s.testingKnobs.mockGetMinResolvedTSByStoresIDs.Store(&f) +} + +func (s *KVStore) getMinResolvedTSByStoresIDs(ctx context.Context, ids []string) (uint64, map[uint64]uint64, error) { + if f := s.testingKnobs.mockGetMinResolvedTSByStoresIDs.Load(); f != nil { + return (*f)(ctx, ids) + } + return s.pdHttpClient.GetMinResolvedTSByStoresIDs(ctx, ids) } // Go run the function in a separate goroutine. @@ -497,6 +512,15 @@ func (s *KVStore) GetMinSafeTS(txnScope string) uint64 { return 0 } +func (s *KVStore) setMinSafeTS(txnScope string, safeTS uint64) { + // ensure safeTS is not set to max uint64 + if safeTS == math.MaxUint64 { + logutil.AssertWarn(logutil.BgLogger(), "skip setting min-safe-ts to max uint64", zap.String("txnScope", txnScope), zap.Stack("stack")) + return + } + s.minSafeTS.Store(txnScope, safeTS) +} + // Ctx returns ctx. func (s *KVStore) Ctx() context.Context { return s.ctx @@ -532,14 +556,22 @@ func (s *KVStore) getSafeTS(storeID uint64) (bool, uint64) { // setSafeTS sets safeTs for store storeID, export for testing func (s *KVStore) setSafeTS(storeID, safeTS uint64) { + // ensure safeTS is not set to max uint64 + if safeTS == math.MaxUint64 { + logutil.AssertWarn(logutil.BgLogger(), "skip setting safe-ts to max uint64", zap.Uint64("storeID", storeID), zap.Stack("stack")) + return + } s.safeTSMap.Store(storeID, safeTS) } func (s *KVStore) updateMinSafeTS(txnScope string, storeIDs []uint64) { minSafeTS := uint64(math.MaxUint64) // when there is no store, return 0 in order to let minStartTS become startTS directly + // actually storeIDs won't be empty since updateMinSafeTS is only called by updateSafeTS and updateSafeTS builds + // txnScopeMap with non-empty values. here we check it to make the logic more robust. if len(storeIDs) < 1 { - s.minSafeTS.Store(txnScope, 0) + s.setMinSafeTS(txnScope, 0) + return } for _, store := range storeIDs { ok, safeTS := s.getSafeTS(store) @@ -551,7 +583,11 @@ func (s *KVStore) updateMinSafeTS(txnScope string, storeIDs []uint64) { minSafeTS = 0 } } - s.minSafeTS.Store(txnScope, minSafeTS) + // if minSafeTS is still math.MaxUint64, that means all store safe ts are 0, then we set minSafeTS to 0. + if minSafeTS == math.MaxUint64 { + minSafeTS = 0 + } + s.setMinSafeTS(txnScope, minSafeTS) } func (s *KVStore) safeTSUpdater() { @@ -591,11 +627,11 @@ func (s *KVStore) updateSafeTS(ctx context.Context) { storeMinResolvedTSs map[uint64]uint64 ) storeIDs := make([]string, len(stores)) - if s.pdHttpClient != nil { + if s.pdHttpClient != nil || s.testingKnobs.mockGetMinResolvedTSByStoresIDs.Load() != nil { for i, store := range stores { storeIDs[i] = strconv.FormatUint(store.StoreID(), 10) } - _, storeMinResolvedTSs, err = s.pdHttpClient.GetMinResolvedTSByStoresIDs(ctx, storeIDs) + _, storeMinResolvedTSs, err = s.getMinResolvedTSByStoresIDs(ctx, storeIDs) if err != nil { // If getting the minimum resolved timestamp from PD failed, log the error and need to get it from TiKV. logutil.BgLogger().Debug("get resolved TS from PD failed", zap.Error(err), zap.Any("stores", storeIDs)) @@ -612,8 +648,8 @@ func (s *KVStore) updateSafeTS(ctx context.Context) { defer wg.Done() var safeTS uint64 - // If getting the minimum resolved timestamp from PD failed or returned 0, try to get it from TiKV. - if storeMinResolvedTSs == nil || storeMinResolvedTSs[storeID] == 0 || err != nil { + // If getting the minimum resolved timestamp from PD failed or returned 0/MaxUint64, try to get it from TiKV. + if storeMinResolvedTSs == nil || !isValidSafeTS(storeMinResolvedTSs[storeID]) || err != nil { resp, err := tikvClient.SendRequest( ctx, storeAddr, tikvrpc.NewRequest( tikvrpc.CmdStoreSafeTS, &kvrpcpb.StoreSafeTSRequest{ @@ -675,21 +711,21 @@ var ( func (s *KVStore) updateGlobalTxnScopeTSFromPD(ctx context.Context) bool { isGlobal := config.GetTxnScopeFromConfig() == oracle.GlobalTxnScope // Try to get the minimum resolved timestamp of the cluster from PD. - if s.pdHttpClient != nil && isGlobal { - clusterMinSafeTS, _, err := s.pdHttpClient.GetMinResolvedTSByStoresIDs(ctx, nil) + if (s.pdHttpClient != nil || s.testingKnobs.mockGetMinResolvedTSByStoresIDs.Load() != nil) && isGlobal { + clusterMinSafeTS, _, err := s.getMinResolvedTSByStoresIDs(ctx, nil) if err != nil { logutil.BgLogger().Debug("get resolved TS from PD failed", zap.Error(err)) - } else if clusterMinSafeTS != 0 { + } else if isValidSafeTS(clusterMinSafeTS) { // Update ts and metrics. preClusterMinSafeTS := s.GetMinSafeTS(oracle.GlobalTxnScope) - // If preClusterMinSafeTS is maxUint64, it means that the min safe ts has not been initialized. + // preClusterMinSafeTS is guaranteed to be less than math.MaxUint64 (by this method and setMinSafeTS) // related to https://github.com/tikv/client-go/issues/991 - if preClusterMinSafeTS != math.MaxUint64 && preClusterMinSafeTS > clusterMinSafeTS { + if preClusterMinSafeTS > clusterMinSafeTS { skipSafeTSUpdateCounter.Inc() preSafeTSTime := oracle.GetTimeFromTS(preClusterMinSafeTS) clusterMinSafeTSGap.Set(time.Since(preSafeTSTime).Seconds()) } else { - s.minSafeTS.Store(oracle.GlobalTxnScope, clusterMinSafeTS) + s.setMinSafeTS(oracle.GlobalTxnScope, clusterMinSafeTS) successSafeTSUpdateCounter.Inc() safeTSTime := oracle.GetTimeFromTS(clusterMinSafeTS) clusterMinSafeTSGap.Set(time.Since(safeTSTime).Seconds()) @@ -701,6 +737,10 @@ func (s *KVStore) updateGlobalTxnScopeTSFromPD(ctx context.Context) bool { return false } +func isValidSafeTS(ts uint64) bool { + return ts != 0 && ts != math.MaxUint64 +} + // EnableResourceControl enables the resource control. func EnableResourceControl() { client.ResourceControlSwitch.Store(true) diff --git a/tikv/kv_test.go b/tikv/kv_test.go index 9f9af85006..54603dbd80 100644 --- a/tikv/kv_test.go +++ b/tikv/kv_test.go @@ -17,6 +17,8 @@ package tikv import ( "context" "fmt" + "math" + "strconv" "sync/atomic" "testing" "time" @@ -38,11 +40,10 @@ func TestKV(t *testing.T) { type testKVSuite struct { suite.Suite - store *KVStore - cluster *mocktikv.Cluster - tikvStoreID uint64 - tiflashStoreID uint64 - tiflashPeerStoreID uint64 + store *KVStore + cluster *mocktikv.Cluster + tikvStoreID uint64 + tiflashStoreID uint64 } func (s *testKVSuite) SetupTest() { @@ -51,6 +52,9 @@ func (s *testKVSuite) SetupTest() { testutils.BootstrapWithSingleStore(cluster) store, err := NewTestTiKVStore(client, pdClient, nil, nil, 0) s.Require().Nil(err) + store.setGetMinResolvedTSByStoresIDs(func(ctx context.Context, ids []string) (uint64, map[uint64]uint64, error) { + return 0, nil, nil + }) s.store = store s.cluster = cluster @@ -58,14 +62,18 @@ func (s *testKVSuite) SetupTest() { storeIDs, _, _, _ := mocktikv.BootstrapWithMultiStores(s.cluster, 2) s.tikvStoreID = storeIDs[0] s.tiflashStoreID = storeIDs[1] - tiflashPeerAddrID := cluster.AllocIDs(1) - s.tiflashPeerStoreID = tiflashPeerAddrID[0] - s.cluster.UpdateStorePeerAddr(s.tiflashStoreID, s.storeAddr(s.tiflashPeerStoreID), &metapb.StoreLabel{Key: "engine", Value: "tiflash"}) - s.store.regionCache.SetRegionCacheStore(s.tikvStoreID, s.storeAddr(s.tikvStoreID), s.storeAddr(s.tikvStoreID), tikvrpc.TiKV, 1, nil) var labels []*metapb.StoreLabel - labels = append(labels, &metapb.StoreLabel{Key: "engine", Value: "tiflash"}) - s.store.regionCache.SetRegionCacheStore(s.tiflashStoreID, s.storeAddr(s.tiflashStoreID), s.storeAddr(s.tiflashPeerStoreID), tikvrpc.TiFlash, 1, labels) + labels = append(cluster.GetStore(s.tikvStoreID).Labels, + &metapb.StoreLabel{Key: DCLabelKey, Value: "z1"}) + s.cluster.UpdateStorePeerAddr(s.tikvStoreID, s.storeAddr(s.tikvStoreID), labels...) + s.store.regionCache.SetRegionCacheStore(s.tikvStoreID, s.storeAddr(s.tikvStoreID), s.storeAddr(s.tikvStoreID), tikvrpc.TiKV, 1, labels) + + labels = append(cluster.GetStore(s.tiflashStoreID).Labels, + &metapb.StoreLabel{Key: DCLabelKey, Value: "z2"}, + &metapb.StoreLabel{Key: "engine", Value: "tiflash"}) + s.cluster.UpdateStorePeerAddr(s.tiflashStoreID, s.storeAddr(s.tiflashStoreID), labels...) + s.store.regionCache.SetRegionCacheStore(s.tiflashStoreID, s.storeAddr(s.tiflashStoreID), s.storeAddr(s.tiflashStoreID), tikvrpc.TiFlash, 1, labels) } @@ -81,6 +89,18 @@ type storeSafeTsMockClient struct { Client requestCount int32 testSuite *testKVSuite + + tikvSafeTs uint64 + tiflashSafeTs uint64 +} + +func newStoreSafeTsMockClient(s *testKVSuite) *storeSafeTsMockClient { + return &storeSafeTsMockClient{ + Client: s.store.GetTiKVClient(), + testSuite: s, + tikvSafeTs: 100, + tiflashSafeTs: 80, + } } func (c *storeSafeTsMockClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (*tikvrpc.Response, error) { @@ -89,10 +109,10 @@ func (c *storeSafeTsMockClient) SendRequest(ctx context.Context, addr string, re } atomic.AddInt32(&c.requestCount, 1) resp := &tikvrpc.Response{} - if addr == c.testSuite.storeAddr(c.testSuite.tiflashPeerStoreID) { - resp.Resp = &kvrpcpb.StoreSafeTSResponse{SafeTs: 80} + if addr == c.testSuite.storeAddr(c.testSuite.tiflashStoreID) { + resp.Resp = &kvrpcpb.StoreSafeTSResponse{SafeTs: c.tiflashSafeTs} } else { - resp.Resp = &kvrpcpb.StoreSafeTSResponse{SafeTs: 100} + resp.Resp = &kvrpcpb.StoreSafeTSResponse{SafeTs: c.tikvSafeTs} } return resp, nil } @@ -105,22 +125,132 @@ func (c *storeSafeTsMockClient) CloseAddr(addr string) error { return c.Client.CloseAddr(addr) } -func (s *testKVSuite) TestMinSafeTs() { - mockClient := storeSafeTsMockClient{ - Client: s.store.GetTiKVClient(), - testSuite: s, - } - s.store.SetTiKVClient(&mockClient) - - // wait for updateMinSafeTS - var retryCount int - for s.store.GetMinSafeTS(oracle.GlobalTxnScope) != 80 { - time.Sleep(2 * time.Second) - if retryCount > 5 { - break - } - retryCount++ - } +func (s *testKVSuite) TestMinSafeTsFromStores() { + mockClient := newStoreSafeTsMockClient(s) + s.store.SetTiKVClient(mockClient) + + s.Eventually(func() bool { + ts := s.store.GetMinSafeTS(oracle.GlobalTxnScope) + s.Require().False(math.MaxUint64 == ts) + return ts == mockClient.tiflashSafeTs + }, 15*time.Second, time.Second) s.Require().GreaterOrEqual(atomic.LoadInt32(&mockClient.requestCount), int32(2)) - s.Require().Equal(uint64(80), s.store.GetMinSafeTS(oracle.GlobalTxnScope)) + s.Require().Equal(mockClient.tiflashSafeTs, s.store.GetMinSafeTS(oracle.GlobalTxnScope)) + ok, ts := s.store.getSafeTS(s.tikvStoreID) + s.Require().True(ok) + s.Require().Equal(mockClient.tikvSafeTs, ts) +} + +func (s *testKVSuite) TestMinSafeTsFromStoresWithAllZeros() { + // ref https://github.com/tikv/client-go/issues/1276 + mockClient := newStoreSafeTsMockClient(s) + mockClient.tikvSafeTs = 0 + mockClient.tiflashSafeTs = 0 + s.store.SetTiKVClient(mockClient) + + s.Eventually(func() bool { + return atomic.LoadInt32(&mockClient.requestCount) >= 4 + }, 15*time.Second, time.Second) + + s.Require().Equal(uint64(0), s.store.GetMinSafeTS(oracle.GlobalTxnScope)) +} + +func (s *testKVSuite) TestMinSafeTsFromStoresWithSomeZeros() { + // ref https://github.com/tikv/tikv/issues/13675 & https://github.com/tikv/client-go/pull/615 + mockClient := newStoreSafeTsMockClient(s) + mockClient.tiflashSafeTs = 0 + s.store.SetTiKVClient(mockClient) + + s.Eventually(func() bool { + return atomic.LoadInt32(&mockClient.requestCount) >= 4 + }, 15*time.Second, time.Second) + + s.Require().Equal(mockClient.tikvSafeTs, s.store.GetMinSafeTS(oracle.GlobalTxnScope)) +} + +func (s *testKVSuite) TestMinSafeTsFromPD() { + mockClient := newStoreSafeTsMockClient(s) + s.store.SetTiKVClient(mockClient) + s.store.setGetMinResolvedTSByStoresIDs(func(ctx context.Context, ids []string) (uint64, map[uint64]uint64, error) { + return 90, nil, nil + }) + s.Eventually(func() bool { + ts := s.store.GetMinSafeTS(oracle.GlobalTxnScope) + s.Require().False(math.MaxUint64 == ts) + return ts == 90 + }, 15*time.Second, time.Second) + s.Require().Equal(atomic.LoadInt32(&mockClient.requestCount), int32(0)) + s.Require().Equal(uint64(90), s.store.GetMinSafeTS(oracle.GlobalTxnScope)) +} + +func (s *testKVSuite) TestMinSafeTsFromPDByStores() { + mockClient := newStoreSafeTsMockClient(s) + s.store.SetTiKVClient(mockClient) + s.store.setGetMinResolvedTSByStoresIDs(func(ctx context.Context, ids []string) (uint64, map[uint64]uint64, error) { + m := make(map[uint64]uint64) + for _, id := range ids { + k, _ := strconv.ParseUint(id, 10, 64) + m[k] = uint64(100) + k + } + return math.MaxUint64, m, nil + }) + s.Eventually(func() bool { + ts := s.store.GetMinSafeTS(oracle.GlobalTxnScope) + s.Require().False(math.MaxUint64 == ts) + return ts == uint64(100)+s.tikvStoreID + }, 15*time.Second, time.Second) + s.Require().Equal(atomic.LoadInt32(&mockClient.requestCount), int32(0)) + s.Require().Equal(uint64(100)+s.tikvStoreID, s.store.GetMinSafeTS(oracle.GlobalTxnScope)) +} + +func (s *testKVSuite) TestMinSafeTsFromMixed1() { + mockClient := newStoreSafeTsMockClient(s) + s.store.SetTiKVClient(mockClient) + s.store.setGetMinResolvedTSByStoresIDs(func(ctx context.Context, ids []string) (uint64, map[uint64]uint64, error) { + m := make(map[uint64]uint64) + for _, id := range ids { + k, _ := strconv.ParseUint(id, 10, 64) + if k == s.tiflashStoreID { + m[k] = 0 + } else { + m[k] = uint64(10) + } + } + return math.MaxUint64, m, nil + }) + s.Eventually(func() bool { + ts := s.store.GetMinSafeTS("z1") + s.Require().False(math.MaxUint64 == ts) + return ts == uint64(10) && s.store.GetMinSafeTS(oracle.GlobalTxnScope) == uint64(10) + }, 15*time.Second, time.Second) + s.Require().GreaterOrEqual(atomic.LoadInt32(&mockClient.requestCount), int32(1)) + s.Require().Equal(uint64(10), s.store.GetMinSafeTS(oracle.GlobalTxnScope)) + s.Require().Equal(uint64(10), s.store.GetMinSafeTS("z1")) + s.Require().Equal(mockClient.tiflashSafeTs, s.store.GetMinSafeTS("z2")) +} + +func (s *testKVSuite) TestMinSafeTsFromMixed2() { + mockClient := newStoreSafeTsMockClient(s) + s.store.SetTiKVClient(mockClient) + s.store.setGetMinResolvedTSByStoresIDs(func(ctx context.Context, ids []string) (uint64, map[uint64]uint64, error) { + m := make(map[uint64]uint64) + for _, id := range ids { + k, _ := strconv.ParseUint(id, 10, 64) + if k == s.tiflashStoreID { + m[k] = uint64(10) + } else { + m[k] = math.MaxUint64 + } + } + return math.MaxUint64, m, nil + }) + s.Eventually(func() bool { + ts := s.store.GetMinSafeTS("z2") + s.Require().False(math.MaxUint64 == ts) + return ts == uint64(10) && s.store.GetMinSafeTS(oracle.GlobalTxnScope) == uint64(10) + }, 15*time.Second, time.Second) + s.Require().GreaterOrEqual(atomic.LoadInt32(&mockClient.requestCount), int32(1)) + s.Require().Equal(uint64(10), s.store.GetMinSafeTS(oracle.GlobalTxnScope)) + s.Require().Equal(mockClient.tikvSafeTs, s.store.GetMinSafeTS("z1")) + s.Require().Equal(uint64(10), s.store.GetMinSafeTS("z2")) }