Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to validate read ts for all RPC requests (#1513) #1546

Merged
merged 4 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions internal/locate/region_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ import (
"github.com/tikv/client-go/v2/internal/mockstore/mocktikv"
"github.com/tikv/client-go/v2/internal/retry"
"github.com/tikv/client-go/v2/kv"
"github.com/tikv/client-go/v2/oracle"
pd "github.com/tikv/pd/client"
)

Expand Down Expand Up @@ -1004,7 +1005,7 @@ func (s *testRegionCacheSuite) TestRegionEpochOnTiFlash() {
s.Equal(ctxTiFlash.Peer.Id, s.peer1)
ctxTiFlash.Peer.Role = metapb.PeerRole_Learner
r := ctxTiFlash.Meta
reqSend := NewRegionRequestSender(s.cache, nil)
reqSend := NewRegionRequestSender(s.cache, nil, oracle.NoopReadTSValidator{})
regionErr := &errorpb.Error{EpochNotMatch: &errorpb.EpochNotMatch{CurrentRegions: []*metapb.Region{r}}}
reqSend.onRegionError(s.bo, ctxTiFlash, nil, regionErr)

Expand Down Expand Up @@ -1640,7 +1641,7 @@ func (s *testRegionCacheSuite) TestShouldNotRetryFlashback() {
ctx, err := s.cache.GetTiKVRPCContext(retry.NewBackofferWithVars(context.Background(), 100, nil), loc.Region, kv.ReplicaReadLeader, 0)
s.NotNil(ctx)
s.NoError(err)
reqSend := NewRegionRequestSender(s.cache, nil)
reqSend := NewRegionRequestSender(s.cache, nil, oracle.NoopReadTSValidator{})
shouldRetry, err := reqSend.onRegionError(s.bo, ctx, nil, &errorpb.Error{FlashbackInProgress: &errorpb.FlashbackInProgress{}})
s.Error(err)
s.False(shouldRetry)
Expand Down
54 changes: 50 additions & 4 deletions internal/locate/region_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import (
"sync/atomic"
"time"

"github.com/tikv/client-go/v2/oracle"
"go.uber.org/zap"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -105,6 +106,7 @@ type RegionRequestSender struct {
regionCache *RegionCache
apiVersion kvrpcpb.APIVersion
client client.Client
readTSValidator oracle.ReadTSValidator
storeAddr string
rpcError error
replicaSelector *replicaSelector
Expand Down Expand Up @@ -193,11 +195,12 @@ func RecordRegionRequestRuntimeStats(stats map[tikvrpc.CmdType]*RPCRuntimeStats,
}

// NewRegionRequestSender creates a new sender.
func NewRegionRequestSender(regionCache *RegionCache, client client.Client) *RegionRequestSender {
func NewRegionRequestSender(regionCache *RegionCache, client client.Client, readTSValidator oracle.ReadTSValidator) *RegionRequestSender {
return &RegionRequestSender{
regionCache: regionCache,
apiVersion: regionCache.codec.GetAPIVersion(),
client: client,
regionCache: regionCache,
apiVersion: regionCache.codec.GetAPIVersion(),
client: client,
readTSValidator: readTSValidator,
}
}

Expand Down Expand Up @@ -1261,6 +1264,11 @@ func (s *RegionRequestSender) SendReqCtx(
}
}

if err = s.validateReadTS(bo.GetCtx(), req); err != nil {
logutil.Logger(bo.GetCtx()).Error("validate read ts failed for request", zap.Stringer("reqType", req.Type), zap.Stringer("req", req.Req.(fmt.Stringer)), zap.Stringer("context", &req.Context), zap.Stack("stack"), zap.Error(err))
return nil, nil, 0, err
}

// If the MaxExecutionDurationMs is not set yet, we set it to be the RPC timeout duration
// so TiKV can give up the requests whose response TiDB cannot receive due to timeout.
if req.Context.MaxExecutionDurationMs == 0 {
Expand Down Expand Up @@ -2179,6 +2187,44 @@ func (s *RegionRequestSender) onRegionError(
return false, nil
}

func (s *RegionRequestSender) validateReadTS(ctx context.Context, req *tikvrpc.Request) error {
if req.StoreTp == tikvrpc.TiDB {
// Skip the checking if the store type is TiDB.
return nil
}

var readTS uint64
switch req.Type {
case tikvrpc.CmdGet, tikvrpc.CmdScan, tikvrpc.CmdBatchGet, tikvrpc.CmdCop, tikvrpc.CmdCopStream, tikvrpc.CmdBatchCop, tikvrpc.CmdScanLock:
readTS = req.GetStartTS()

// TODO: Check transactional write requests that has implicit read.
// case tikvrpc.CmdPessimisticLock:
// readTS = req.PessimisticLock().GetForUpdateTs()
// case tikvrpc.CmdPrewrite:
// inner := req.Prewrite()
// readTS = inner.GetForUpdateTs()
// if readTS == 0 {
// readTS = inner.GetStartVersion()
// }
// case tikvrpc.CmdCheckTxnStatus:
// inner := req.CheckTxnStatus()
// // TiKV uses the greater one of these three fields to update the max_ts.
// readTS = inner.GetLockTs()
// if inner.GetCurrentTs() != math.MaxUint64 && inner.GetCurrentTs() > readTS {
// readTS = inner.GetCurrentTs()
// }
// if inner.GetCallerStartTs() != math.MaxUint64 && inner.GetCallerStartTs() > readTS {
// readTS = inner.GetCallerStartTs()
// }
// case tikvrpc.CmdCheckSecondaryLocks, tikvrpc.CmdCleanup, tikvrpc.CmdBatchRollback:
// readTS = req.GetStartTS()
default:
return nil
}
return s.readTSValidator.ValidateReadTS(ctx, readTS, req.StaleRead, &oracle.Option{TxnScope: req.TxnScope})
}

type staleReadMetricsCollector struct {
}

Expand Down
8 changes: 6 additions & 2 deletions internal/locate/region_request3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import (
"time"
"unsafe"

"github.com/pingcap/failpoint"
"github.com/pingcap/kvproto/pkg/errorpb"
"github.com/pingcap/kvproto/pkg/kvrpcpb"
"github.com/pingcap/kvproto/pkg/metapb"
Expand Down Expand Up @@ -82,7 +83,9 @@ func (s *testRegionRequestToThreeStoresSuite) SetupTest() {
s.cache = NewRegionCache(pdCli)
s.bo = retry.NewNoopBackoff(context.Background())
client := mocktikv.NewRPCClient(s.cluster, s.mvccStore, nil)
s.regionRequestSender = NewRegionRequestSender(s.cache, client)
s.regionRequestSender = NewRegionRequestSender(s.cache, client, oracle.NoopReadTSValidator{})

s.NoError(failpoint.Enable("tikvclient/doNotRecoverStoreHealthCheckPanic", "return"))
}

func (s *testRegionRequestToThreeStoresSuite) TearDownTest() {
Expand Down Expand Up @@ -147,7 +150,8 @@ func (s *testRegionRequestToThreeStoresSuite) loadAndGetLeaderStore() (*Store, s
}

func (s *testRegionRequestToThreeStoresSuite) TestForwarding() {
s.regionRequestSender.regionCache.enableForwarding = true
sender := NewRegionRequestSender(s.cache, s.regionRequestSender.client, oracle.NoopReadTSValidator{})
sender.regionCache.enableForwarding = true

// First get the leader's addr from region cache
leaderStore, leaderAddr := s.loadAndGetLeaderStore()
Expand Down
3 changes: 2 additions & 1 deletion internal/locate/region_request_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/tikv/client-go/v2/internal/retry"
"github.com/tikv/client-go/v2/kv"
"github.com/tikv/client-go/v2/metrics"
"github.com/tikv/client-go/v2/oracle"
"github.com/tikv/client-go/v2/tikvrpc"
)

Expand Down Expand Up @@ -76,7 +77,7 @@ func (s *testRegionCacheStaleReadSuite) SetupTest() {
s.cache = NewRegionCache(pdCli)
s.bo = retry.NewNoopBackoff(context.Background())
client := mocktikv.NewRPCClient(s.cluster, s.mvccStore, nil)
s.regionRequestSender = NewRegionRequestSender(s.cache, client)
s.regionRequestSender = NewRegionRequestSender(s.cache, client, oracle.NoopReadTSValidator{})
s.setClient()
s.injection = testRegionCacheFSMSuiteInjection{
unavailableStoreIDs: make(map[uint64]struct{}),
Expand Down
15 changes: 9 additions & 6 deletions internal/locate/region_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ import (
"github.com/tikv/client-go/v2/internal/client/mock_server"
"github.com/tikv/client-go/v2/internal/mockstore/mocktikv"
"github.com/tikv/client-go/v2/internal/retry"
"github.com/tikv/client-go/v2/oracle"
"github.com/tikv/client-go/v2/tikvrpc"
pd "github.com/tikv/pd/client"
pderr "github.com/tikv/pd/client/errs"
"google.golang.org/grpc"
)
Expand All @@ -75,6 +77,7 @@ type testRegionRequestToSingleStoreSuite struct {
store uint64
peer uint64
region uint64
pdCli pd.Client
cache *RegionCache
bo *retry.Backoffer
regionRequestSender *RegionRequestSender
Expand All @@ -85,11 +88,11 @@ func (s *testRegionRequestToSingleStoreSuite) SetupTest() {
s.mvccStore = mocktikv.MustNewMVCCStore()
s.cluster = mocktikv.NewCluster(s.mvccStore)
s.store, s.peer, s.region = mocktikv.BootstrapWithSingleStore(s.cluster)
pdCli := &CodecPDClient{mocktikv.NewPDClient(s.cluster), apicodec.NewCodecV1(apicodec.ModeTxn)}
s.cache = NewRegionCache(pdCli)
s.pdCli = &CodecPDClient{mocktikv.NewPDClient(s.cluster), apicodec.NewCodecV1(apicodec.ModeTxn)}
s.cache = NewRegionCache(s.pdCli)
s.bo = retry.NewNoopBackoff(context.Background())
client := mocktikv.NewRPCClient(s.cluster, s.mvccStore, nil)
s.regionRequestSender = NewRegionRequestSender(s.cache, client)
s.regionRequestSender = NewRegionRequestSender(s.cache, client, oracle.NoopReadTSValidator{})
}

func (s *testRegionRequestToSingleStoreSuite) TearDownTest() {
Expand Down Expand Up @@ -567,7 +570,7 @@ func (s *testRegionRequestToSingleStoreSuite) TestNoReloadRegionForGrpcWhenCtxCa
}()

cli := client.NewRPCClient()
sender := NewRegionRequestSender(s.cache, cli)
sender := NewRegionRequestSender(s.cache, cli, oracle.NoopReadTSValidator{})
req := tikvrpc.NewRequest(tikvrpc.CmdRawPut, &kvrpcpb.RawPutRequest{
Key: []byte("key"),
Value: []byte("value"),
Expand All @@ -586,7 +589,7 @@ func (s *testRegionRequestToSingleStoreSuite) TestNoReloadRegionForGrpcWhenCtxCa
Client: client.NewRPCClient(),
redirectAddr: addr,
}
sender = NewRegionRequestSender(s.cache, client1)
sender = NewRegionRequestSender(s.cache, client1, oracle.NoopReadTSValidator{})
sender.SendReq(s.bo, req, region.Region, 3*time.Second)

// cleanup
Expand Down Expand Up @@ -772,7 +775,7 @@ func (s *testRegionRequestToSingleStoreSuite) TestBatchClientSendLoopPanic() {
cancel()
}()
req := tikvrpc.NewRequest(tikvrpc.CmdCop, &coprocessor.Request{Data: []byte("a"), StartTs: 1})
regionRequestSender := NewRegionRequestSender(s.cache, fnClient)
regionRequestSender := NewRegionRequestSender(s.cache, fnClient, oracle.NoopReadTSValidator{})
regionRequestSender.regionCache.testingKnobs.mockRequestLiveness.Store((*livenessFunc)(&tf))
regionRequestSender.SendReq(bo, req, region.Region, client.ReadTimeoutShort)
}
Expand Down
38 changes: 34 additions & 4 deletions oracle/oracle.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ package oracle

import (
"context"
"fmt"
"time"
)

Expand Down Expand Up @@ -64,12 +65,17 @@ type Oracle interface {
GetExternalTimestamp(ctx context.Context) (uint64, error)
SetExternalTimestamp(ctx context.Context, ts uint64) error

// ValidateSnapshotReadTS verifies whether it can be guaranteed that the given readTS doesn't exceed the maximum ts
// that has been allocated by the oracle, so that it's safe to use this ts to perform snapshot read, stale read,
// etc.
ReadTSValidator
}

// ReadTSValidator is the interface for providing the ability for verifying whether a timestamp is safe to be used
// for readings, as part of the `Oracle` interface.
type ReadTSValidator interface {
// ValidateReadTS verifies whether it can be guaranteed that the given readTS doesn't exceed the maximum ts
// that has been allocated by the oracle, so that it's safe to use this ts to perform read operations.
// Note that this method only checks the ts from the oracle's perspective. It doesn't check whether the snapshot
// has been GCed.
ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *Option) error
ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *Option) error
}

// Future is a future which promises to return a timestamp.
Expand Down Expand Up @@ -121,3 +127,27 @@ func GoTimeToTS(t time.Time) uint64 {
func GoTimeToLowerLimitStartTS(now time.Time, maxTxnTimeUse int64) uint64 {
return GoTimeToTS(now.Add(-time.Duration(maxTxnTimeUse) * time.Millisecond))
}

// NoopReadTSValidator is a dummy implementation of ReadTSValidator that always let the validation pass.
// Only use this when using RPCs that are not related to ts (e.g. rawkv), or in tests where `Oracle` is not available
// and the validation is not necessary.
type NoopReadTSValidator struct{}

func (NoopReadTSValidator) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *Option) error {
return nil
}

type ErrFutureTSRead struct {
ReadTS uint64
CurrentTS uint64
}

func (e ErrFutureTSRead) Error() string {
return fmt.Sprintf("cannot set read timestamp to a future time, readTS: %d, currentTS: %d", e.ReadTS, e.CurrentTS)
}

type ErrLatestStaleRead struct{}

func (ErrLatestStaleRead) Error() string {
return "cannot set read ts to max uint64 for stale read"
}
2 changes: 1 addition & 1 deletion oracle/oracles/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,6 @@ func SetEmptyPDOracleLastTs(oc oracle.Oracle, ts uint64) {
case *pdOracle:
lastTSInterface, _ := o.lastTSMap.LoadOrStore(oracle.GlobalTxnScope, &atomic.Pointer[lastTSO]{})
lastTSPointer := lastTSInterface.(*atomic.Pointer[lastTSO])
lastTSPointer.Store(&lastTSO{tso: ts, arrival: ts})
lastTSPointer.Store(&lastTSO{tso: ts, arrival: oracle.GetTimeFromTS(ts)})
}
}
15 changes: 13 additions & 2 deletions oracle/oracles/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ package oracles

import (
"context"
"math"
"sync"
"time"

Expand Down Expand Up @@ -136,13 +137,23 @@ func (l *localOracle) GetExternalTimestamp(ctx context.Context) (uint64, error)
return l.getExternalTimestamp(ctx)
}

func (l *localOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *oracle.Option) error {
func (l *localOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) error {
if readTS == math.MaxUint64 {
if isStaleRead {
return oracle.ErrLatestStaleRead{}
}
return nil
}

currentTS, err := l.GetTimestamp(ctx, opt)
if err != nil {
return errors.Errorf("fail to validate read timestamp: %v", err)
}
if currentTS < readTS {
return errors.Errorf("cannot set read timestamp to a future time")
return oracle.ErrFutureTSRead{
ReadTS: readTS,
CurrentTS: currentTS,
}
}
return nil
}
19 changes: 17 additions & 2 deletions oracle/oracles/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ package oracles

import (
"context"
"math"
"sync"
"time"

Expand Down Expand Up @@ -122,13 +123,27 @@ func (o *MockOracle) GetLowResolutionTimestampAsync(ctx context.Context, opt *or
return o.GetTimestampAsync(ctx, opt)
}

func (o *MockOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *oracle.Option) error {
func (o *MockOracle) SetLowResolutionTimestampUpdateInterval(time.Duration) error {
return nil
}

func (o *MockOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) error {
if readTS == math.MaxUint64 {
if isStaleRead {
return oracle.ErrLatestStaleRead{}
}
return nil
}

currentTS, err := o.GetTimestamp(ctx, opt)
if err != nil {
return errors.Errorf("fail to validate read timestamp: %v", err)
}
if currentTS < readTS {
return errors.Errorf("cannot set read timestamp to a future time")
return oracle.ErrFutureTSRead{
ReadTS: readTS,
CurrentTS: currentTS,
}
}
return nil
}
Expand Down
Loading
Loading