Skip to content

Commit

Permalink
tikv: ensure safe-ts won't be max uint64 (#1250)
Browse files Browse the repository at this point in the history
* tikv: ensure safe-ts won't be max uint64

Signed-off-by: zyguan <[email protected]>

* fix a typo

Signed-off-by: zyguan <[email protected]>

* fix lint issue

Signed-off-by: zyguan <[email protected]>

* address the comment and fix test

Signed-off-by: zyguan <[email protected]>

---------

Signed-off-by: zyguan <[email protected]>
  • Loading branch information
zyguan authored Mar 28, 2024
1 parent 603dc7b commit 81d8dea
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 34 deletions.
5 changes: 2 additions & 3 deletions integration_tests/pd_api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package tikv_test
import (
"context"
"fmt"
"math"
"strings"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -183,8 +182,8 @@ func (s *apiTestSuite) TestInitClusterMinResolvedTSZero() {
// Try to get the minimum resolved timestamp of the cluster from TiKV.
require.NoError(failpoint.Enable("tikvclient/InjectPDMinResolvedTS", `return(0)`))
// Make sure the store's min resolved ts is not initialized.
s.waitForMinSafeTS(oracle.GlobalTxnScope, math.MaxUint64)
require.Equal(uint64(math.MaxUint64), s.store.GetMinSafeTS(oracle.GlobalTxnScope))
s.waitForMinSafeTS(oracle.GlobalTxnScope, 0)
require.Equal(uint64(0), s.store.GetMinSafeTS(oracle.GlobalTxnScope))
require.NoError(failpoint.Disable("tikvclient/InjectPDMinResolvedTS"))

// Try to get the minimum resolved timestamp of the cluster from PD.
Expand Down
35 changes: 27 additions & 8 deletions tikv/kv.go
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,15 @@ func (s *KVStore) GetMinSafeTS(txnScope string) uint64 {
return 0
}

func (s *KVStore) setMinSafeTS(txnScope string, safeTS uint64) {
// ensure safeTS is not set to max uint64
if safeTS == math.MaxUint64 {
logutil.BgLogger().Warn("skip setting min-safe-ts to max uint64", zap.String("txnScope", txnScope), zap.Stack("stack"))
return
}
s.minSafeTS.Store(txnScope, safeTS)
}

// Ctx returns ctx.
func (s *KVStore) Ctx() context.Context {
return s.ctx
Expand Down Expand Up @@ -607,26 +616,32 @@ func (s *KVStore) getSafeTS(storeID uint64) (bool, uint64) {

// setSafeTS sets safeTs for store storeID, export for testing
func (s *KVStore) setSafeTS(storeID, safeTS uint64) {
// ensure safeTS is not set to max uint64
if safeTS == math.MaxUint64 {
logutil.BgLogger().Warn("skip setting safe-ts to max uint64", zap.Uint64("storeID", storeID), zap.Stack("stack"))
return
}
s.safeTSMap.Store(storeID, safeTS)
}

func (s *KVStore) updateMinSafeTS(txnScope string, storeIDs []uint64) {
minSafeTS := uint64(math.MaxUint64)
// when there is no store, return 0 in order to let minStartTS become startTS directly
if len(storeIDs) < 1 {
s.minSafeTS.Store(txnScope, 0)
s.setMinSafeTS(txnScope, 0)
}
for _, store := range storeIDs {
ok, safeTS := s.getSafeTS(store)
if ok {
// safeTS is guaranteed to be less than math.MaxUint64 (by setSafeTS and its callers)
if safeTS != 0 && safeTS < minSafeTS {
minSafeTS = safeTS
}
} else {
minSafeTS = 0
}
}
s.minSafeTS.Store(txnScope, minSafeTS)
s.setMinSafeTS(txnScope, minSafeTS)
}

func (s *KVStore) safeTSUpdater() {
Expand Down Expand Up @@ -690,8 +705,8 @@ func (s *KVStore) updateSafeTS(ctx context.Context) {
safeTS uint64
storeIDStr = strconv.FormatUint(storeID, 10)
)
// If getting the minimum resolved timestamp from PD failed or returned 0, try to get it from TiKV.
if storeMinResolvedTSs == nil || storeMinResolvedTSs[storeID] == 0 || err != nil {
// If getting the minimum resolved timestamp from PD failed or returned 0/MaxUint64, try to get it from TiKV.
if storeMinResolvedTSs == nil || !isValidSafeTS(storeMinResolvedTSs[storeID]) || err != nil {
resp, err := tikvClient.SendRequest(
ctx, storeAddr, tikvrpc.NewRequest(
tikvrpc.CmdStoreSafeTS, &kvrpcpb.StoreSafeTSRequest{
Expand Down Expand Up @@ -785,17 +800,17 @@ func (s *KVStore) updateGlobalTxnScopeTSFromPD(ctx context.Context) bool {
clusterMinSafeTS, _, err := s.getMinResolvedTSByStoresIDs(ctx, nil)
if err != nil {
logutil.BgLogger().Debug("get resolved TS from PD failed", zap.Error(err))
} else if clusterMinSafeTS != 0 {
} else if isValidSafeTS(clusterMinSafeTS) {
// Update ts and metrics.
preClusterMinSafeTS := s.GetMinSafeTS(oracle.GlobalTxnScope)
// If preClusterMinSafeTS is maxUint64, it means that the min safe ts has not been initialized.
// preClusterMinSafeTS is guaranteed to be less than math.MaxUint64 (by this method and setMinSafeTS)
// related to https://github.com/tikv/client-go/issues/991
if preClusterMinSafeTS != math.MaxUint64 && preClusterMinSafeTS > clusterMinSafeTS {
if preClusterMinSafeTS > clusterMinSafeTS {
skipSafeTSUpdateCounter.Inc()
preSafeTSTime := oracle.GetTimeFromTS(preClusterMinSafeTS)
clusterMinSafeTSGap.Set(time.Since(preSafeTSTime).Seconds())
} else {
s.minSafeTS.Store(oracle.GlobalTxnScope, clusterMinSafeTS)
s.setMinSafeTS(oracle.GlobalTxnScope, clusterMinSafeTS)
successSafeTSUpdateCounter.Inc()
safeTSTime := oracle.GetTimeFromTS(clusterMinSafeTS)
clusterMinSafeTSGap.Set(time.Since(safeTSTime).Seconds())
Expand All @@ -807,6 +822,10 @@ func (s *KVStore) updateGlobalTxnScopeTSFromPD(ctx context.Context) bool {
return false
}

func isValidSafeTS(ts uint64) bool {
return ts != 0 && ts != math.MaxUint64
}

// EnableResourceControl enables the resource control.
func EnableResourceControl() {
client.ResourceControlSwitch.Store(true)
Expand Down
160 changes: 137 additions & 23 deletions tikv/kv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package tikv
import (
"context"
"fmt"
"math"
"sync/atomic"
"testing"
"time"
Expand All @@ -29,6 +30,7 @@ import (
"github.com/tikv/client-go/v2/testutils"
"github.com/tikv/client-go/v2/tikvrpc"
"github.com/tikv/client-go/v2/util"
pdhttp "github.com/tikv/pd/client/http"
)

func TestKV(t *testing.T) {
Expand All @@ -38,18 +40,27 @@ func TestKV(t *testing.T) {

type testKVSuite struct {
suite.Suite
store *KVStore
cluster *mocktikv.Cluster
tikvStoreID uint64
tiflashStoreID uint64
tiflashPeerStoreID uint64
store *KVStore
cluster *mocktikv.Cluster
tikvStoreID uint64
tiflashStoreID uint64

mockGetMinResolvedTSByStoresIDs atomic.Pointer[func(ctx context.Context, ids []uint64) (uint64, map[uint64]uint64, error)]
}

func (s *testKVSuite) SetupTest() {
client, cluster, pdClient, err := testutils.NewMockTiKV("", nil)
s.Require().Nil(err)
testutils.BootstrapWithSingleStore(cluster)
store, err := NewTestTiKVStore(client, pdClient, nil, nil, 0)
s.setGetMinResolvedTSByStoresIDs(func(ctx context.Context, ids []uint64) (uint64, map[uint64]uint64, error) {
return 0, nil, nil
})
store, err := NewTestTiKVStore(client, pdClient, nil, nil, 0, Option(func(store *KVStore) {
store.pdHttpClient = &mockPDHTTPClient{
Client: pdhttp.NewClientWithServiceDiscovery("test", nil),
mockGetMinResolvedTSByStoresIDs: &s.mockGetMinResolvedTSByStoresIDs,
}
}))
s.Require().Nil(err)

s.store = store
Expand All @@ -58,14 +69,18 @@ func (s *testKVSuite) SetupTest() {
storeIDs, _, _, _ := mocktikv.BootstrapWithMultiStores(s.cluster, 2)
s.tikvStoreID = storeIDs[0]
s.tiflashStoreID = storeIDs[1]
tiflashPeerAddrID := cluster.AllocIDs(1)
s.tiflashPeerStoreID = tiflashPeerAddrID[0]

s.cluster.UpdateStorePeerAddr(s.tiflashStoreID, s.storeAddr(s.tiflashPeerStoreID), &metapb.StoreLabel{Key: "engine", Value: "tiflash"})
s.store.regionCache.SetRegionCacheStore(s.tikvStoreID, s.storeAddr(s.tikvStoreID), s.storeAddr(s.tikvStoreID), tikvrpc.TiKV, 1, nil)
var labels []*metapb.StoreLabel
labels = append(labels, &metapb.StoreLabel{Key: "engine", Value: "tiflash"})
s.store.regionCache.SetRegionCacheStore(s.tiflashStoreID, s.storeAddr(s.tiflashStoreID), s.storeAddr(s.tiflashPeerStoreID), tikvrpc.TiFlash, 1, labels)
labels = append(cluster.GetStore(s.tikvStoreID).Labels,
&metapb.StoreLabel{Key: DCLabelKey, Value: "z1"})
s.cluster.UpdateStorePeerAddr(s.tikvStoreID, s.storeAddr(s.tikvStoreID), labels...)
s.store.regionCache.SetRegionCacheStore(s.tikvStoreID, s.storeAddr(s.tikvStoreID), s.storeAddr(s.tikvStoreID), tikvrpc.TiKV, 1, labels)

labels = append(cluster.GetStore(s.tiflashStoreID).Labels,
&metapb.StoreLabel{Key: DCLabelKey, Value: "z2"},
&metapb.StoreLabel{Key: "engine", Value: "tiflash"})
s.cluster.UpdateStorePeerAddr(s.tiflashStoreID, s.storeAddr(s.tiflashStoreID), labels...)
s.store.regionCache.SetRegionCacheStore(s.tiflashStoreID, s.storeAddr(s.tiflashStoreID), s.storeAddr(s.tiflashStoreID), tikvrpc.TiFlash, 1, labels)

}

Expand All @@ -77,6 +92,10 @@ func (s *testKVSuite) storeAddr(id uint64) string {
return fmt.Sprintf("store%d", id)
}

func (s *testKVSuite) setGetMinResolvedTSByStoresIDs(f func(ctx context.Context, ids []uint64) (uint64, map[uint64]uint64, error)) {
s.mockGetMinResolvedTSByStoresIDs.Store(&f)
}

type storeSafeTsMockClient struct {
Client
requestCount int32
Expand All @@ -89,7 +108,7 @@ func (c *storeSafeTsMockClient) SendRequest(ctx context.Context, addr string, re
}
atomic.AddInt32(&c.requestCount, 1)
resp := &tikvrpc.Response{}
if addr == c.testSuite.storeAddr(c.testSuite.tiflashPeerStoreID) {
if addr == c.testSuite.storeAddr(c.testSuite.tiflashStoreID) {
resp.Resp = &kvrpcpb.StoreSafeTSResponse{SafeTs: 80}
} else {
resp.Resp = &kvrpcpb.StoreSafeTSResponse{SafeTs: 100}
Expand All @@ -105,22 +124,117 @@ func (c *storeSafeTsMockClient) CloseAddr(addr string) error {
return c.Client.CloseAddr(addr)
}

func (s *testKVSuite) TestMinSafeTs() {
type mockPDHTTPClient struct {
pdhttp.Client
mockGetMinResolvedTSByStoresIDs *atomic.Pointer[func(ctx context.Context, ids []uint64) (uint64, map[uint64]uint64, error)]
}

func (c *mockPDHTTPClient) GetMinResolvedTSByStoresIDs(ctx context.Context, storeIDs []uint64) (uint64, map[uint64]uint64, error) {
if f := c.mockGetMinResolvedTSByStoresIDs.Load(); f != nil {
return (*f)(ctx, storeIDs)
}
return c.Client.GetMinResolvedTSByStoresIDs(ctx, storeIDs)
}

func (s *testKVSuite) TestMinSafeTsFromStores() {
mockClient := storeSafeTsMockClient{
Client: s.store.GetTiKVClient(),
testSuite: s,
}
s.store.SetTiKVClient(&mockClient)

// wait for updateMinSafeTS
var retryCount int
for s.store.GetMinSafeTS(oracle.GlobalTxnScope) != 80 {
time.Sleep(2 * time.Second)
if retryCount > 5 {
break
}
retryCount++
}
s.Eventually(func() bool {
ts := s.store.GetMinSafeTS(oracle.GlobalTxnScope)
s.Require().False(math.MaxUint64 == ts)
return ts == 80
}, 15*time.Second, time.Second)
s.Require().GreaterOrEqual(atomic.LoadInt32(&mockClient.requestCount), int32(2))
s.Require().Equal(uint64(80), s.store.GetMinSafeTS(oracle.GlobalTxnScope))
ok, ts := s.store.getSafeTS(s.tikvStoreID)
s.Require().True(ok)
s.Require().Equal(uint64(100), ts)
}

func (s *testKVSuite) TestMinSafeTsFromPD() {
mockClient := storeSafeTsMockClient{Client: s.store.GetTiKVClient(), testSuite: s}
s.store.SetTiKVClient(&mockClient)
s.setGetMinResolvedTSByStoresIDs(func(ctx context.Context, ids []uint64) (uint64, map[uint64]uint64, error) {
return 90, nil, nil
})
s.Eventually(func() bool {
ts := s.store.GetMinSafeTS(oracle.GlobalTxnScope)
s.Require().False(math.MaxUint64 == ts)
return ts == 90
}, 15*time.Second, time.Second)
s.Require().Equal(atomic.LoadInt32(&mockClient.requestCount), int32(0))
s.Require().Equal(uint64(90), s.store.GetMinSafeTS(oracle.GlobalTxnScope))
}

func (s *testKVSuite) TestMinSafeTsFromPDByStores() {
mockClient := storeSafeTsMockClient{Client: s.store.GetTiKVClient(), testSuite: s}
s.store.SetTiKVClient(&mockClient)
s.setGetMinResolvedTSByStoresIDs(func(ctx context.Context, ids []uint64) (uint64, map[uint64]uint64, error) {
m := make(map[uint64]uint64)
for _, id := range ids {
m[id] = uint64(100) + id
}
return math.MaxUint64, m, nil
})
s.Eventually(func() bool {
ts := s.store.GetMinSafeTS(oracle.GlobalTxnScope)
s.Require().False(math.MaxUint64 == ts)
return ts == uint64(100)+s.tikvStoreID
}, 15*time.Second, time.Second)
s.Require().Equal(atomic.LoadInt32(&mockClient.requestCount), int32(0))
s.Require().Equal(uint64(100)+s.tikvStoreID, s.store.GetMinSafeTS(oracle.GlobalTxnScope))
}

func (s *testKVSuite) TestMinSafeTsFromMixed1() {
mockClient := storeSafeTsMockClient{Client: s.store.GetTiKVClient(), testSuite: s}
s.store.SetTiKVClient(&mockClient)
s.setGetMinResolvedTSByStoresIDs(func(ctx context.Context, ids []uint64) (uint64, map[uint64]uint64, error) {
m := make(map[uint64]uint64)
for _, id := range ids {
if id == s.tiflashStoreID {
m[id] = 0
} else {
m[id] = uint64(10)
}
}
return math.MaxUint64, m, nil
})
s.Eventually(func() bool {
ts := s.store.GetMinSafeTS("z1")
s.Require().False(math.MaxUint64 == ts)
return ts == uint64(10)
}, 15*time.Second, time.Second)
s.Require().GreaterOrEqual(atomic.LoadInt32(&mockClient.requestCount), int32(1))
s.Require().Equal(uint64(10), s.store.GetMinSafeTS(oracle.GlobalTxnScope))
s.Require().Equal(uint64(10), s.store.GetMinSafeTS("z1"))
s.Require().Equal(uint64(80), s.store.GetMinSafeTS("z2"))
}

func (s *testKVSuite) TestMinSafeTsFromMixed2() {
mockClient := storeSafeTsMockClient{Client: s.store.GetTiKVClient(), testSuite: s}
s.store.SetTiKVClient(&mockClient)
s.setGetMinResolvedTSByStoresIDs(func(ctx context.Context, ids []uint64) (uint64, map[uint64]uint64, error) {
m := make(map[uint64]uint64)
for _, id := range ids {
if id == s.tiflashStoreID {
m[id] = uint64(10)
} else {
m[id] = math.MaxUint64
}
}
return math.MaxUint64, m, nil
})
s.Eventually(func() bool {
ts := s.store.GetMinSafeTS("z2")
s.Require().False(math.MaxUint64 == ts)
return ts == uint64(10)
}, 15*time.Second, time.Second)
s.Require().GreaterOrEqual(atomic.LoadInt32(&mockClient.requestCount), int32(1))
s.Require().Equal(uint64(10), s.store.GetMinSafeTS(oracle.GlobalTxnScope))
s.Require().Equal(uint64(100), s.store.GetMinSafeTS("z1"))
s.Require().Equal(uint64(10), s.store.GetMinSafeTS("z2"))
}

0 comments on commit 81d8dea

Please sign in to comment.