From 5b804aa6c1393207e582b60baa847f243dac0775 Mon Sep 17 00:00:00 2001 From: zyguan Date: Wed, 20 Mar 2024 17:38:54 +0800 Subject: [PATCH] tikvrpc: avoid data race on `XxxRequest.Context` Signed-off-by: zyguan --- tikvrpc/cmds_generated.go | 384 ++++++++++++++++++++++++++++++++++++++ tikvrpc/gen.sh | 85 +++++++++ tikvrpc/tikvrpc.go | 111 +++-------- tikvrpc/tikvrpc_test.go | 90 +++++++++ 4 files changed, 584 insertions(+), 86 deletions(-) create mode 100644 tikvrpc/cmds_generated.go create mode 100644 tikvrpc/gen.sh diff --git a/tikvrpc/cmds_generated.go b/tikvrpc/cmds_generated.go new file mode 100644 index 0000000000..147f9cc9c5 --- /dev/null +++ b/tikvrpc/cmds_generated.go @@ -0,0 +1,384 @@ +// Code generated gen.sh. DO NOT EDIT. + +package tikvrpc + +import ( + "github.com/pingcap/kvproto/pkg/kvrpcpb" +) + +func patchCmdCtx(req *Request, cmd CmdType, ctx *kvrpcpb.Context) bool { + switch cmd { + case CmdGet: + if req.rev == 0 { + req.Get().Context = ctx + } else { + cmd := *req.Get() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdScan: + if req.rev == 0 { + req.Scan().Context = ctx + } else { + cmd := *req.Scan() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdPrewrite: + if req.rev == 0 { + req.Prewrite().Context = ctx + } else { + cmd := *req.Prewrite() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdPessimisticLock: + if req.rev == 0 { + req.PessimisticLock().Context = ctx + } else { + cmd := *req.PessimisticLock() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdPessimisticRollback: + if req.rev == 0 { + req.PessimisticRollback().Context = ctx + } else { + cmd := *req.PessimisticRollback() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdCommit: + if req.rev == 0 { + req.Commit().Context = ctx + } else { + cmd := *req.Commit() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdCleanup: + if req.rev == 0 { + req.Cleanup().Context = ctx + } else { + cmd := *req.Cleanup() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdBatchGet: + if req.rev == 0 { + req.BatchGet().Context = ctx + } else { + cmd := *req.BatchGet() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdBatchRollback: + if req.rev == 0 { + req.BatchRollback().Context = ctx + } else { + cmd := *req.BatchRollback() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdScanLock: + if req.rev == 0 { + req.ScanLock().Context = ctx + } else { + cmd := *req.ScanLock() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdResolveLock: + if req.rev == 0 { + req.ResolveLock().Context = ctx + } else { + cmd := *req.ResolveLock() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdGC: + if req.rev == 0 { + req.GC().Context = ctx + } else { + cmd := *req.GC() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdDeleteRange: + if req.rev == 0 { + req.DeleteRange().Context = ctx + } else { + cmd := *req.DeleteRange() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRawGet: + if req.rev == 0 { + req.RawGet().Context = ctx + } else { + cmd := *req.RawGet() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRawBatchGet: + if req.rev == 0 { + req.RawBatchGet().Context = ctx + } else { + cmd := *req.RawBatchGet() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRawPut: + if req.rev == 0 { + req.RawPut().Context = ctx + } else { + cmd := *req.RawPut() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRawBatchPut: + if req.rev == 0 { + req.RawBatchPut().Context = ctx + } else { + cmd := *req.RawBatchPut() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRawDelete: + if req.rev == 0 { + req.RawDelete().Context = ctx + } else { + cmd := *req.RawDelete() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRawBatchDelete: + if req.rev == 0 { + req.RawBatchDelete().Context = ctx + } else { + cmd := *req.RawBatchDelete() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRawDeleteRange: + if req.rev == 0 { + req.RawDeleteRange().Context = ctx + } else { + cmd := *req.RawDeleteRange() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRawScan: + if req.rev == 0 { + req.RawScan().Context = ctx + } else { + cmd := *req.RawScan() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRawGetKeyTTL: + if req.rev == 0 { + req.RawGetKeyTTL().Context = ctx + } else { + cmd := *req.RawGetKeyTTL() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRawCompareAndSwap: + if req.rev == 0 { + req.RawCompareAndSwap().Context = ctx + } else { + cmd := *req.RawCompareAndSwap() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRawChecksum: + if req.rev == 0 { + req.RawChecksum().Context = ctx + } else { + cmd := *req.RawChecksum() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdUnsafeDestroyRange: + if req.rev == 0 { + req.UnsafeDestroyRange().Context = ctx + } else { + cmd := *req.UnsafeDestroyRange() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRegisterLockObserver: + if req.rev == 0 { + req.RegisterLockObserver().Context = ctx + } else { + cmd := *req.RegisterLockObserver() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdCheckLockObserver: + if req.rev == 0 { + req.CheckLockObserver().Context = ctx + } else { + cmd := *req.CheckLockObserver() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRemoveLockObserver: + if req.rev == 0 { + req.RemoveLockObserver().Context = ctx + } else { + cmd := *req.RemoveLockObserver() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdPhysicalScanLock: + if req.rev == 0 { + req.PhysicalScanLock().Context = ctx + } else { + cmd := *req.PhysicalScanLock() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdCop: + if req.rev == 0 { + req.Cop().Context = ctx + } else { + cmd := *req.Cop() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdBatchCop: + if req.rev == 0 { + req.BatchCop().Context = ctx + } else { + cmd := *req.BatchCop() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdMvccGetByKey: + if req.rev == 0 { + req.MvccGetByKey().Context = ctx + } else { + cmd := *req.MvccGetByKey() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdMvccGetByStartTs: + if req.rev == 0 { + req.MvccGetByStartTs().Context = ctx + } else { + cmd := *req.MvccGetByStartTs() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdSplitRegion: + if req.rev == 0 { + req.SplitRegion().Context = ctx + } else { + cmd := *req.SplitRegion() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdTxnHeartBeat: + if req.rev == 0 { + req.TxnHeartBeat().Context = ctx + } else { + cmd := *req.TxnHeartBeat() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdCheckTxnStatus: + if req.rev == 0 { + req.CheckTxnStatus().Context = ctx + } else { + cmd := *req.CheckTxnStatus() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdCheckSecondaryLocks: + if req.rev == 0 { + req.CheckSecondaryLocks().Context = ctx + } else { + cmd := *req.CheckSecondaryLocks() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdFlashbackToVersion: + if req.rev == 0 { + req.FlashbackToVersion().Context = ctx + } else { + cmd := *req.FlashbackToVersion() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdPrepareFlashbackToVersion: + if req.rev == 0 { + req.PrepareFlashbackToVersion().Context = ctx + } else { + cmd := *req.PrepareFlashbackToVersion() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdFlush: + if req.rev == 0 { + req.Flush().Context = ctx + } else { + cmd := *req.Flush() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdBufferBatchGet: + if req.rev == 0 { + req.BufferBatchGet().Context = ctx + } else { + cmd := *req.BufferBatchGet() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + default: + return false + } + return true +} diff --git a/tikvrpc/gen.sh b/tikvrpc/gen.sh new file mode 100644 index 0000000000..ce5da51d5d --- /dev/null +++ b/tikvrpc/gen.sh @@ -0,0 +1,85 @@ +#!/bin/bash + +output="cmds_generated.go" + +cat < $output +// Code generated gen.sh. DO NOT EDIT. + +package tikvrpc + +import ( + "github.com/pingcap/kvproto/pkg/kvrpcpb" +) +EOF + +cmds=( + Get + Scan + Prewrite + PessimisticLock + PessimisticRollback + Commit + Cleanup + BatchGet + BatchRollback + ScanLock + ResolveLock + GC + DeleteRange + RawGet + RawBatchGet + RawPut + RawBatchPut + RawDelete + RawBatchDelete + RawDeleteRange + RawScan + RawGetKeyTTL + RawCompareAndSwap + RawChecksum + UnsafeDestroyRange + RegisterLockObserver + CheckLockObserver + RemoveLockObserver + PhysicalScanLock + Cop + BatchCop + MvccGetByKey + MvccGetByStartTs + SplitRegion + TxnHeartBeat + CheckTxnStatus + CheckSecondaryLocks + FlashbackToVersion + PrepareFlashbackToVersion + Flush + BufferBatchGet +) + +cat <> $output + +func patchCmdCtx(req *Request, cmd CmdType, ctx *kvrpcpb.Context) bool { + switch cmd { +EOF + +for cmd in "${cmds[@]}"; do +cat <> $output + case Cmd${cmd}: + if req.rev == 0 { + req.${cmd}().Context = ctx + } else { + cmd := *req.${cmd}() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ +EOF +done + +cat <> $output + default: + return false + } + return true +} +EOF diff --git a/tikvrpc/tikvrpc.go b/tikvrpc/tikvrpc.go index ddc97d153e..173e1b077e 100644 --- a/tikvrpc/tikvrpc.go +++ b/tikvrpc/tikvrpc.go @@ -85,7 +85,7 @@ const ( CmdRawBatchDelete CmdRawDeleteRange CmdRawScan - CmdGetKeyTTL + CmdRawGetKeyTTL CmdRawCompareAndSwap CmdRawChecksum @@ -118,6 +118,11 @@ const ( CmdEmpty CmdType = 3072 + iota ) +// CmdType aliases. +const ( + CmdGetKeyTTL = CmdRawGetKeyTTL +) + func (t CmdType) String() string { switch t { case CmdGet: @@ -164,6 +169,10 @@ func (t CmdType) String() string { return "RawScan" case CmdRawChecksum: return "RawChecksum" + case CmdRawGetKeyTTL: + return "RawGetKeyTTL" + case CmdRawCompareAndSwap: + return "RawCompareAndSwap" case CmdUnsafeDestroyRange: return "UnsafeDestroyRange" case CmdRegisterLockObserver: @@ -244,6 +253,9 @@ type Request struct { ReadType string // InputRequestSource is the input source of the request, if it's not empty, the final RequestSource sent to store will be attached with the retry info. InputRequestSource string + + // rev represents the revision of the request, it's increased when `Req.Context` gets patched. + rev uint32 } // NewRequest returns new kv rpc request. @@ -731,104 +743,31 @@ type MPPStreamResponse struct { Lease } +//go:generate bash gen.sh + // AttachContext sets the request context to the request, // return false if encounter unknown request type. // Parameter `rpcCtx` use `kvrpcpb.Context` instead of `*kvrpcpb.Context` to avoid concurrent modification by shallow copy. func AttachContext(req *Request, rpcCtx kvrpcpb.Context) bool { ctx := &rpcCtx + cmd := req.Type + // CmdCopStream and CmdCop share the same request type. + if cmd == CmdCopStream { + cmd = CmdCop + } + if patchCmdCtx(req, cmd, ctx) { + return true + } switch req.Type { - case CmdGet: - req.Get().Context = ctx - case CmdScan: - req.Scan().Context = ctx - case CmdPrewrite: - req.Prewrite().Context = ctx - case CmdPessimisticLock: - req.PessimisticLock().Context = ctx - case CmdPessimisticRollback: - req.PessimisticRollback().Context = ctx - case CmdCommit: - req.Commit().Context = ctx - case CmdCleanup: - req.Cleanup().Context = ctx - case CmdBatchGet: - req.BatchGet().Context = ctx - case CmdBatchRollback: - req.BatchRollback().Context = ctx - case CmdScanLock: - req.ScanLock().Context = ctx - case CmdResolveLock: - req.ResolveLock().Context = ctx - case CmdGC: - req.GC().Context = ctx - case CmdDeleteRange: - req.DeleteRange().Context = ctx - case CmdRawGet: - req.RawGet().Context = ctx - case CmdRawBatchGet: - req.RawBatchGet().Context = ctx - case CmdRawPut: - req.RawPut().Context = ctx - case CmdRawBatchPut: - req.RawBatchPut().Context = ctx - case CmdRawDelete: - req.RawDelete().Context = ctx - case CmdRawBatchDelete: - req.RawBatchDelete().Context = ctx - case CmdRawDeleteRange: - req.RawDeleteRange().Context = ctx - case CmdRawScan: - req.RawScan().Context = ctx - case CmdGetKeyTTL: - req.RawGetKeyTTL().Context = ctx - case CmdRawCompareAndSwap: - req.RawCompareAndSwap().Context = ctx - case CmdRawChecksum: - req.RawChecksum().Context = ctx - case CmdUnsafeDestroyRange: - req.UnsafeDestroyRange().Context = ctx - case CmdRegisterLockObserver: - req.RegisterLockObserver().Context = ctx - case CmdCheckLockObserver: - req.CheckLockObserver().Context = ctx - case CmdRemoveLockObserver: - req.RemoveLockObserver().Context = ctx - case CmdPhysicalScanLock: - req.PhysicalScanLock().Context = ctx - case CmdCop: - req.Cop().Context = ctx - case CmdCopStream: - req.Cop().Context = ctx - case CmdBatchCop: - req.BatchCop().Context = ctx // Dispatching MPP tasks don't need a region context, because it's a request for store but not region. case CmdMPPTask: case CmdMPPConn: case CmdMPPCancel: case CmdMPPAlive: - case CmdMvccGetByKey: - req.MvccGetByKey().Context = ctx - case CmdMvccGetByStartTs: - req.MvccGetByStartTs().Context = ctx - case CmdSplitRegion: - req.SplitRegion().Context = ctx + // Empty command doesn't need a region context. case CmdEmpty: - req.SplitRegion().Context = ctx - case CmdTxnHeartBeat: - req.TxnHeartBeat().Context = ctx - case CmdCheckTxnStatus: - req.CheckTxnStatus().Context = ctx - case CmdCheckSecondaryLocks: - req.CheckSecondaryLocks().Context = ctx - case CmdFlashbackToVersion: - req.FlashbackToVersion().Context = ctx - case CmdPrepareFlashbackToVersion: - req.PrepareFlashbackToVersion().Context = ctx - case CmdFlush: - req.Flush().Context = ctx - case CmdBufferBatchGet: - req.BufferBatchGet().Context = ctx + default: return false } diff --git a/tikvrpc/tikvrpc_test.go b/tikvrpc/tikvrpc_test.go index e3d5e25fb3..5a301e09c9 100644 --- a/tikvrpc/tikvrpc_test.go +++ b/tikvrpc/tikvrpc_test.go @@ -35,8 +35,14 @@ package tikvrpc import ( + "fmt" + "math/rand" + "sync" "testing" + "time" + "github.com/pingcap/kvproto/pkg/coprocessor" + "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/kvproto/pkg/tikvpb" "github.com/stretchr/testify/assert" ) @@ -47,3 +53,87 @@ func TestBatchResponse(t *testing.T) { assert.Nil(t, batchResp) assert.NotNil(t, err) } + +// https://github.com/pingcap/tidb/issues/51921 +func TestTiDB51921(t *testing.T) { + for _, r := range []*Request{ + NewRequest(CmdGet, &kvrpcpb.GetRequest{}), + NewRequest(CmdScan, &kvrpcpb.ScanRequest{}), + NewRequest(CmdPrewrite, &kvrpcpb.PrewriteRequest{}), + NewRequest(CmdPessimisticLock, &kvrpcpb.PessimisticLockRequest{}), + NewRequest(CmdPessimisticRollback, &kvrpcpb.PessimisticRollbackRequest{}), + NewRequest(CmdCommit, &kvrpcpb.CommitRequest{}), + NewRequest(CmdCleanup, &kvrpcpb.CleanupRequest{}), + NewRequest(CmdBatchGet, &kvrpcpb.BatchGetRequest{}), + NewRequest(CmdBatchRollback, &kvrpcpb.BatchRollbackRequest{}), + NewRequest(CmdScanLock, &kvrpcpb.ScanLockRequest{}), + NewRequest(CmdResolveLock, &kvrpcpb.ResolveLockRequest{}), + NewRequest(CmdGC, &kvrpcpb.GCRequest{}), + NewRequest(CmdDeleteRange, &kvrpcpb.DeleteRangeRequest{}), + NewRequest(CmdRawGet, &kvrpcpb.RawGetRequest{}), + NewRequest(CmdRawBatchGet, &kvrpcpb.RawBatchGetRequest{}), + NewRequest(CmdRawPut, &kvrpcpb.RawPutRequest{}), + NewRequest(CmdRawBatchPut, &kvrpcpb.RawBatchPutRequest{}), + NewRequest(CmdRawDelete, &kvrpcpb.RawDeleteRequest{}), + NewRequest(CmdRawBatchDelete, &kvrpcpb.RawBatchDeleteRequest{}), + NewRequest(CmdRawDeleteRange, &kvrpcpb.RawDeleteRangeRequest{}), + NewRequest(CmdRawScan, &kvrpcpb.RawScanRequest{}), + NewRequest(CmdRawGetKeyTTL, &kvrpcpb.RawGetKeyTTLRequest{}), + NewRequest(CmdRawCompareAndSwap, &kvrpcpb.RawCASRequest{}), + NewRequest(CmdRawChecksum, &kvrpcpb.RawChecksumRequest{}), + NewRequest(CmdUnsafeDestroyRange, &kvrpcpb.UnsafeDestroyRangeRequest{}), + NewRequest(CmdRegisterLockObserver, &kvrpcpb.RegisterLockObserverRequest{}), + NewRequest(CmdCheckLockObserver, &kvrpcpb.CheckLockObserverRequest{}), + NewRequest(CmdRemoveLockObserver, &kvrpcpb.RemoveLockObserverRequest{}), + NewRequest(CmdPhysicalScanLock, &kvrpcpb.PhysicalScanLockRequest{}), + NewRequest(CmdCop, &coprocessor.Request{}), + NewRequest(CmdCopStream, &coprocessor.Request{}), + NewRequest(CmdBatchCop, &coprocessor.BatchRequest{}), + NewRequest(CmdMvccGetByKey, &kvrpcpb.MvccGetByKeyRequest{}), + NewRequest(CmdMvccGetByStartTs, &kvrpcpb.MvccGetByStartTsRequest{}), + NewRequest(CmdSplitRegion, &kvrpcpb.SplitRegionRequest{}), + NewRequest(CmdTxnHeartBeat, &kvrpcpb.TxnHeartBeatRequest{}), + NewRequest(CmdCheckTxnStatus, &kvrpcpb.CheckTxnStatusRequest{}), + NewRequest(CmdCheckSecondaryLocks, &kvrpcpb.CheckSecondaryLocksRequest{}), + NewRequest(CmdFlashbackToVersion, &kvrpcpb.FlashbackToVersionRequest{}), + NewRequest(CmdPrepareFlashbackToVersion, &kvrpcpb.PrepareFlashbackToVersionRequest{}), + NewRequest(CmdFlush, &kvrpcpb.FlushRequest{}), + NewRequest(CmdBufferBatchGet, &kvrpcpb.BufferBatchGetRequest{}), + } { + req := r + t.Run(fmt.Sprintf("%s#%d", req.Type.String(), req.Type), func(t *testing.T) { + if req.ToBatchCommandsRequest() == nil { + t.Skipf("%s doesn't support batch commands", req.Type.String()) + } + done := make(chan struct{}) + cmds := make(chan *tikvpb.BatchCommandsRequest_Request, 8) + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + for { + select { + case <-done: + close(cmds) + return + default: + // mock relocate and retry + AttachContext(req, kvrpcpb.Context{RegionId: rand.Uint64()}) + cmds <- req.ToBatchCommandsRequest() + } + } + }() + go func() { + defer wg.Done() + for cmd := range cmds { + // mock send and marshal in batch-send-loop + cmd.Marshal() + } + }() + + time.Sleep(time.Second / 4) + close(done) + wg.Wait() + }) + } +}