diff --git a/pkg/ddl/column_change_test.go b/pkg/ddl/column_change_test.go index 36462cb063c9a..728cd0f32d5d8 100644 --- a/pkg/ddl/column_change_test.go +++ b/pkg/ddl/column_change_test.go @@ -35,7 +35,6 @@ import ( "github.com/pingcap/tidb/pkg/testkit" "github.com/pingcap/tidb/pkg/testkit/external" "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/mock" "github.com/stretchr/testify/require" ) @@ -51,7 +50,7 @@ func TestColumnAdd(t *testing.T) { d := dom.DDL() tc := &callback.TestDDLCallback{Do: dom} - ct := testNewContext(store) + ct := testNewContext(t, store) // set up hook var ( deleteOnlyTable table.Table @@ -127,7 +126,7 @@ func TestColumnAdd(t *testing.T) { return } first = false - sess := testNewContext(store) + sess := testNewContext(t, store) err := sessiontxn.NewTxn(context.Background(), sess) require.NoError(t, err) _, err = writeOnlyTable.AddRecord(sess, types.MakeDatums(10, 10)) @@ -431,10 +430,8 @@ func testCheckJobDone(t *testing.T, store kv.Storage, jobID int64, isAdd bool) { } } -func testNewContext(store kv.Storage) sessionctx.Context { - ctx := mock.NewContext() - ctx.Store = store - return ctx +func testNewContext(t *testing.T, store kv.Storage) sessionctx.Context { + return testkit.NewSession(t, store) } func TestIssue40135(t *testing.T) { diff --git a/pkg/ddl/column_test.go b/pkg/ddl/column_test.go index 0f83e803a2995..275c9a0035284 100644 --- a/pkg/ddl/column_test.go +++ b/pkg/ddl/column_test.go @@ -167,7 +167,7 @@ func TestColumnBasic(t *testing.T) { tk.MustExec(fmt.Sprintf("insert into t1 values(%d, %d, %d)", i, 10*i, 100*i)) } - ctx := testNewContext(store) + ctx := testNewContext(t, store) err := sessiontxn.NewTxn(context.Background(), ctx) require.NoError(t, err) @@ -611,7 +611,7 @@ func checkPublicColumn(t *testing.T, ctx sessionctx.Context, tableID int64, newC } func checkAddColumn(t *testing.T, state model.SchemaState, tableID int64, handle kv.Handle, newCol *table.Column, oldRow []types.Datum, columnValue interface{}, dom *domain.Domain, store kv.Storage, columnCnt int) { - ctx := testNewContext(store) + ctx := testNewContext(t, store) switch state { case model.StateNone: checkNoneColumn(t, ctx, tableID, handle, newCol, columnValue, dom) @@ -655,7 +655,7 @@ func TestAddColumn(t *testing.T) { tableID = int64(tableIDi) tbl := testGetTable(t, dom, tableID) - ctx := testNewContext(store) + ctx := testNewContext(t, store) err := sessiontxn.NewTxn(context.Background(), ctx) require.NoError(t, err) oldRow := types.MakeDatums(int64(1), int64(2), int64(3)) @@ -728,7 +728,7 @@ func TestAddColumns(t *testing.T) { tableID = int64(tableIDi) tbl := testGetTable(t, dom, tableID) - ctx := testNewContext(store) + ctx := testNewContext(t, store) err := sessiontxn.NewTxn(context.Background(), ctx) require.NoError(t, err) oldRow := types.MakeDatums(int64(1), int64(2), int64(3)) @@ -791,7 +791,7 @@ func TestDropColumnInColumnTest(t *testing.T) { tableID = int64(tableIDi) tbl := testGetTable(t, dom, tableID) - ctx := testNewContext(store) + ctx := testNewContext(t, store) colName := "c4" defaultColValue := int64(4) row := types.MakeDatums(int64(1), int64(2), int64(3)) @@ -852,7 +852,7 @@ func TestDropColumns(t *testing.T) { tableID = int64(tableIDi) tbl := testGetTable(t, dom, tableID) - ctx := testNewContext(store) + ctx := testNewContext(t, store) err := sessiontxn.NewTxn(context.Background(), ctx) require.NoError(t, err) diff --git a/pkg/ddl/db_integration_test.go b/pkg/ddl/db_integration_test.go index 1900b0b17208e..40133641406b2 100644 --- a/pkg/ddl/db_integration_test.go +++ b/pkg/ddl/db_integration_test.go @@ -18,6 +18,7 @@ import ( "bytes" "context" "fmt" + "github.com/pingcap/tidb/pkg/planner/core" "math" "strconv" "strings" @@ -41,7 +42,6 @@ import ( "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/planner/core" "github.com/pingcap/tidb/pkg/session" "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" @@ -1610,6 +1610,62 @@ func TestDefaultColumnWithRand(t *testing.T) { tk.MustGetErrCode("CREATE TABLE t3 (c int, c1 int default a_function_not_supported_yet());", errno.ErrDefValGeneratedNamedFunctionIsNotAllowed) } +// TestDefaultValueAsExpressions is used for tests that are inconvenient to place in the pkg/tests directory. +func TestDefaultValueAsExpressions(t *testing.T) { + store := testkit.CreateMockStoreWithSchemaLease(t, testLease) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t, t1, t2") + + // date_format + tk.MustExec("create table t6 (c int(10), c1 int default (date_format(now(),'%Y-%m-%d %H:%i:%s')))") + tk.MustExec("create table t7 (c int(10), c1 date default (date_format(now(),'%Y-%m')))") + // Error message like: Error 1292 (22007): Truncated incorrect DOUBLE value: '2024-03-05 16:37:25'. + tk.MustGetErrCode("insert into t6(c) values (1)", errno.ErrTruncatedWrongValue) + tk.MustGetErrCode("insert into t7(c) values (1)", errno.ErrTruncatedWrongValue) + + // user + tk.MustExec("create table t (c int(10), c1 varchar(256) default (upper(substring_index(user(),'@',1))));") + tk.Session().GetSessionVars().User = &auth.UserIdentity{Username: "root", Hostname: "localhost"} + tk.MustExec("insert into t(c) values (1),(2),(3)") + tk.Session().GetSessionVars().User = &auth.UserIdentity{Username: "xyz", Hostname: "localhost"} + tk.MustExec("insert into t(c) values (4),(5),(6)") + tk.MustExec("insert into t values (7, default)") + rows := tk.MustQuery("SELECT c1 from t order by c").Rows() + for i, row := range rows { + d, ok := row[0].(string) + require.True(t, ok) + if i < 3 { + require.Equal(t, "ROOT", d) + } else { + require.Equal(t, "XYZ", d) + } + } + + // replace + tk.MustExec("create table t1 (c int(10), c1 int default (REPLACE(UPPER(UUID()), '-', '')))") + // Different UUID values will result in different error code. + _, err := tk.Exec("insert into t1(c) values (1)") + originErr := errors.Cause(err) + tErr, ok := originErr.(*terror.Error) + require.Truef(t, ok, "expect type 'terror.Error', but obtain '%T': %v", originErr, originErr) + sqlErr := terror.ToSQLError(tErr) + if int(sqlErr.Code) != errno.ErrTruncatedWrongValue { + require.Equal(t, errno.ErrDataOutOfRange, int(sqlErr.Code)) + } + // test modify column + // The error message has UUID, so put this test here. + tk.MustExec("create table t2(c int(10), c1 varchar(256) default (REPLACE(UPPER(UUID()), '-', '')), index idx(c1));") + tk.MustExec("insert into t2(c) values (1),(2),(3);") + tk.MustGetErrCode("alter table t2 modify column c1 varchar(30) default 'xx';", errno.WarnDataTruncated) + // test add column for enum + nowStr := time.Now().Format("2006-01") + sql := fmt.Sprintf("alter table t2 add column c3 enum('%v','n')", nowStr) + " default (date_format(now(),'%Y-%m'))" + tk.MustExec(sql) + tk.MustExec("insert into t2(c) values (4);") + tk.MustQuery("select c3 from t2").Check(testkit.Rows(nowStr, nowStr, nowStr, nowStr)) +} + func TestChangingDBCharset(t *testing.T) { store := testkit.CreateMockStore(t, mockstore.WithDDLChecker()) @@ -1821,8 +1877,6 @@ func TestParserIssue284(t *testing.T) { func TestAddExpressionIndex(t *testing.T) { config.UpdateGlobal(func(conf *config.Config) { - // Test for table lock. - conf.EnableTableLock = true conf.Instance.SlowThreshold = 10000 conf.TiKVClient.AsyncCommit.SafeWindow = 0 conf.TiKVClient.AsyncCommit.AllowedClockDrift = 0 @@ -1901,60 +1955,6 @@ func TestAddExpressionIndex(t *testing.T) { }) } -func TestCreateExpressionIndexError(t *testing.T) { - config.UpdateGlobal(func(conf *config.Config) { - // Test for table lock. - conf.EnableTableLock = true - conf.Instance.SlowThreshold = 10000 - conf.TiKVClient.AsyncCommit.SafeWindow = 0 - conf.TiKVClient.AsyncCommit.AllowedClockDrift = 0 - conf.Experimental.AllowsExpressionIndex = true - }) - store := testkit.CreateMockStore(t) - tk := testkit.NewTestKit(t, store) - tk.MustExec("use test") - tk.MustExec("drop table if exists t;") - tk.MustExec("create table t (a int, b real);") - tk.MustGetErrCode("alter table t add primary key ((a+b)) nonclustered;", errno.ErrFunctionalIndexPrimaryKey) - - tk.MustGetErrCode("create table t(a int, index((cast(a as JSON))))", errno.ErrFunctionalIndexOnJSONOrGeometryFunction) - - // Test for error - tk.MustExec("drop table if exists t;") - tk.MustExec("create table t (a int, b real);") - tk.MustGetErrCode("alter table t add primary key ((a+b)) nonclustered;", errno.ErrFunctionalIndexPrimaryKey) - tk.MustGetErrCode("alter table t add index ((rand()));", errno.ErrFunctionalIndexFunctionIsNotAllowed) - tk.MustGetErrCode("alter table t add index ((now()+1));", errno.ErrFunctionalIndexFunctionIsNotAllowed) - - tk.MustExec("alter table t add column (_V$_idx_0 int);") - tk.MustGetErrCode("alter table t add index idx((a+1));", errno.ErrDupFieldName) - tk.MustExec("alter table t drop column _V$_idx_0;") - tk.MustExec("alter table t add index idx((a+1));") - tk.MustGetErrCode("alter table t add column (_V$_idx_0 int);", errno.ErrDupFieldName) - tk.MustExec("alter table t drop index idx;") - tk.MustExec("alter table t add column (_V$_idx_0 int);") - - tk.MustExec("alter table t add column (_V$_expression_index_0 int);") - tk.MustGetErrCode("alter table t add index ((a+1));", errno.ErrDupFieldName) - tk.MustExec("alter table t drop column _V$_expression_index_0;") - tk.MustExec("alter table t add index ((a+1));") - tk.MustGetErrCode("alter table t drop column _V$_expression_index_0;", errno.ErrCantDropFieldOrKey) - tk.MustGetErrCode("alter table t add column e int as (_V$_expression_index_0 + 1);", errno.ErrBadField) - - // NOTE (#18150): In creating expression index, row value is not allowed. - tk.MustExec("drop table if exists t;") - tk.MustGetErrCode("create table t (j json, key k (((j,j))))", errno.ErrFunctionalIndexRowValueIsNotAllowed) - tk.MustExec("create table t (j json, key k ((j+1),(j+1)))") - - tk.MustGetErrCode("create table t1 (col1 int, index ((concat(''))));", errno.ErrWrongKeyColumnFunctionalIndex) - tk.MustGetErrCode("CREATE TABLE t1 (col1 INT, PRIMARY KEY ((ABS(col1))) NONCLUSTERED);", errno.ErrFunctionalIndexPrimaryKey) - - // For issue 26349 - tk.MustExec("drop table if exists t;") - tk.MustExec("create table t(id char(10) primary key, short_name char(10), name char(10), key n((upper(`name`))));") - tk.MustExec("update t t1 set t1.short_name='a' where t1.id='1';") -} - func queryIndexOnTable(dbName, tableName string) string { return fmt.Sprintf("select distinct index_name, is_visible from information_schema.statistics where table_schema = '%s' and table_name = '%s' order by index_name", dbName, tableName) } @@ -2353,20 +2353,6 @@ func TestEnumAndSetDefaultValue(t *testing.T) { require.Equal(t, "a", tbl.Meta().Columns[1].DefaultValue) } -func TestStrictDoubleTypeCheck(t *testing.T) { - store := testkit.CreateMockStore(t) - tk := testkit.NewTestKit(t, store) - tk.MustExec("use test") - tk.MustExec("set @@tidb_enable_strict_double_type_check = 'ON'") - sql := "create table double_type_check(id int, c double(10));" - _, err := tk.Exec(sql) - require.Error(t, err) - require.Equal(t, "[parser:1149]You have an error in your SQL syntax; check the manual that corresponds to your MySQL server version for the right syntax to use", err.Error()) - tk.MustExec("set @@tidb_enable_strict_double_type_check = 'OFF'") - defer tk.MustExec("set @@tidb_enable_strict_double_type_check = 'ON'") - tk.MustExec(sql) -} - func TestDuplicateErrorMessage(t *testing.T) { defer collate.SetNewCollationEnabledForTest(true) store := testkit.CreateMockStore(t) @@ -2388,10 +2374,7 @@ func TestDuplicateErrorMessage(t *testing.T) { for _, newCollate := range []bool{false, true} { collate.SetNewCollationEnabledForTest(newCollate) for _, globalIndex := range []bool{false, true} { - restoreConfig := config.RestoreFunc() - config.UpdateGlobal(func(conf *config.Config) { - conf.EnableGlobalIndex = globalIndex - }) + tk.MustExec(fmt.Sprintf("set tidb_enable_global_index=%t", globalIndex)) for _, clusteredIndex := range []variable.ClusteredIndexDefMode{variable.ClusteredIndexDefModeOn, variable.ClusteredIndexDefModeOff, variable.ClusteredIndexDefModeIntOnly} { tk.Session().GetSessionVars().EnableClusteredIndex = clusteredIndex for _, t := range tests { @@ -2418,7 +2401,7 @@ func TestDuplicateErrorMessage(t *testing.T) { fmt.Sprintf("[kv:1062]Duplicate entry '1-%s' for key 't.t_idx'", strings.Join(fields, "-"))) } } - restoreConfig() + tk.MustExec("set tidb_enable_global_index=default") } } } @@ -2673,8 +2656,6 @@ func TestAvoidCreateViewOnLocalTemporaryTable(t *testing.T) { func TestDropTemporaryTable(t *testing.T) { config.UpdateGlobal(func(conf *config.Config) { - // Test for table lock. - conf.EnableTableLock = true conf.Instance.SlowThreshold = 10000 conf.TiKVClient.AsyncCommit.SafeWindow = 0 conf.TiKVClient.AsyncCommit.AllowedClockDrift = 0 @@ -2940,42 +2921,6 @@ func TestIssue29282(t *testing.T) { } } -// See https://github.com/pingcap/tidb/issues/35644 -func TestCreateTempTableInTxn(t *testing.T) { - store := testkit.CreateMockStore(t) - tk := testkit.NewTestKit(t, store) - tk.MustExec("use test") - tk.MustExec("begin") - // new created temporary table should be visible - tk.MustExec("create temporary table t1(id int primary key, v int)") - tk.MustQuery("select * from t1").Check(testkit.Rows()) - // new inserted data should be visible - tk.MustExec("insert into t1 values(123, 456)") - tk.MustQuery("select * from t1 where id=123").Check(testkit.Rows("123 456")) - // truncate table will clear data but table still visible - tk.MustExec("truncate table t1") - tk.MustQuery("select * from t1 where id=123").Check(testkit.Rows()) - tk.MustExec("commit") - - tk1 := testkit.NewTestKit(t, store) - tk1.MustExec("use test") - tk1.MustExec("create table tt(id int)") - tk1.MustExec("begin") - tk1.MustExec("create temporary table t1(id int)") - tk1.MustExec("insert into tt select * from t1") - tk1.MustExec("drop table tt") - - tk2 := testkit.NewTestKit(t, store) - tk2.MustExec("use test") - tk2.MustExec("create table t2(id int primary key, v int)") - tk2.MustExec("insert into t2 values(234, 567)") - tk2.MustExec("begin") - // create a new temporary table with the same name will override physical table - tk2.MustExec("create temporary table t2(id int primary key, v int)") - tk2.MustQuery("select * from t2 where id=234").Check(testkit.Rows()) - tk2.MustExec("commit") -} - // See https://github.com/pingcap/tidb/issues/29327 func TestEnumDefaultValue(t *testing.T) { store := testkit.CreateMockStore(t, mockstore.WithDDLChecker()) @@ -3071,6 +3016,12 @@ func TestDefaultCollationForUTF8MB4(t *testing.T) { "dby CREATE DATABASE `dby` /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci */")) } +func TestOptimizeTable(t *testing.T) { + store := testkit.CreateMockStore(t, mockstore.WithDDLChecker()) + tk := testkit.NewTestKit(t, store) + tk.MustGetErrMsg("optimize table t", "[ddl:8200]OPTIMIZE TABLE is not supported") +} + func TestIssue52680(t *testing.T) { store, dom := testkit.CreateMockStoreAndDomain(t) tk := testkit.NewTestKit(t, store) @@ -3135,3 +3086,31 @@ func TestIssue52680(t *testing.T) { tk.MustExec("insert into issue52680 values(default);") tk.MustQuery("select * from issue52680").Check(testkit.Rows("1", "2", "3")) } + +func TestCreateIndexWithChangeMaxIndexLength(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + originCfg := config.GetGlobalConfig() + defer func() { + config.StoreGlobalConfig(originCfg) + }() + + originHook := dom.DDL().GetHook() + defer dom.DDL().SetHook(originHook) + hook := &callback.TestDDLCallback{Do: dom} + hook.OnJobRunBeforeExported = func(job *model.Job) { + if job.Type != model.ActionAddIndex { + return + } + if job.SchemaState == model.StateNone { + newCfg := *originCfg + newCfg.MaxIndexLength = 1000 + config.StoreGlobalConfig(&newCfg) + } + } + dom.DDL().SetHook(hook) + + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test;") + tk.MustExec("create table t(id int, a json DEFAULT NULL, b varchar(2) DEFAULT NULL);") + tk.MustGetErrMsg("CREATE INDEX idx_test on t ((cast(a as char(2000) array)),b);", "[ddl:1071]Specified key was too long (2000 bytes); max key length is 1000 bytes") +} diff --git a/pkg/ddl/ddl_worker_test.go b/pkg/ddl/ddl_worker_test.go index 9d6abc20dece7..41d3bc8465d65 100644 --- a/pkg/ddl/ddl_worker_test.go +++ b/pkg/ddl/ddl_worker_test.go @@ -51,7 +51,7 @@ func TestInvalidDDLJob(t *testing.T) { BinlogInfo: &model.HistoryInfo{}, Args: []interface{}{}, } - ctx := testNewContext(store) + ctx := testNewContext(t, store) ctx.SetValue(sessionctx.QueryString, "skip") err := dom.DDL().DoDDLJob(ctx, job) require.Equal(t, err.Error(), "[ddl:8204]invalid ddl job type: none") @@ -59,7 +59,7 @@ func TestInvalidDDLJob(t *testing.T) { func TestAddBatchJobError(t *testing.T) { store, dom := testkit.CreateMockStoreAndDomainWithSchemaLease(t, testLease) - ctx := testNewContext(store) + ctx := testNewContext(t, store) require.Nil(t, failpoint.Enable("github.com/pingcap/tidb/pkg/ddl/mockAddBatchDDLJobsErr", `return(true)`)) // Test the job runner should not hang forever. diff --git a/pkg/ddl/index_change_test.go b/pkg/ddl/index_change_test.go index 9c65c56b918a1..f5e8151bba127 100644 --- a/pkg/ddl/index_change_test.go +++ b/pkg/ddl/index_change_test.go @@ -59,7 +59,7 @@ func TestIndexChange(t *testing.T) { return } jobID.Store(job.ID) - ctx1 := testNewContext(store) + ctx1 := testNewContext(t, store) prevState = job.SchemaState require.NoError(t, dom.Reload()) tbl, exist := dom.InfoSchema().TableByID(job.TableID) @@ -108,7 +108,7 @@ func TestIndexChange(t *testing.T) { require.NoError(t, dom.Reload()) tbl, exist := dom.InfoSchema().TableByID(job.TableID) require.True(t, exist) - ctx1 := testNewContext(store) + ctx1 := testNewContext(t, store) switch job.SchemaState { case model.StateWriteOnly: writeOnlyTable = tbl diff --git a/pkg/executor/set.go b/pkg/executor/set.go index 1fcce5863c6f6..0bc3d936fce58 100644 --- a/pkg/executor/set.go +++ b/pkg/executor/set.go @@ -213,7 +213,8 @@ func (e *SetExecutor) setSysVariable(ctx context.Context, name string, v *expres newSnapshotTS := getSnapshotTSByName() newSnapshotIsSet := newSnapshotTS > 0 && newSnapshotTS != oldSnapshotTS if newSnapshotIsSet { - err = sessionctx.ValidateSnapshotReadTS(ctx, e.Ctx().GetStore(), newSnapshotTS) + isStaleRead := name == variable.TiDBTxnReadTS + err = sessionctx.ValidateSnapshotReadTS(ctx, e.Ctx().GetStore(), newSnapshotTS, isStaleRead) if name != variable.TiDBTxnReadTS { // Also check gc safe point for snapshot read. // We don't check snapshot with gc safe point for read_ts diff --git a/pkg/executor/test/executor/executor_test.go b/pkg/executor/test/executor/executor_test.go index 251218e90207f..049f8c434ef59 100644 --- a/pkg/executor/test/executor/executor_test.go +++ b/pkg/executor/test/executor/executor_test.go @@ -907,8 +907,7 @@ func TestExecutorBit(t *testing.T) { func TestCheckIndex(t *testing.T) { store, dom := testkit.CreateMockStoreAndDomain(t) - ctx := mock.NewContext() - ctx.Store = store + ctx := testkit.NewSession(t, store) se, err := session.CreateSession4Test(store) require.NoError(t, err) defer se.Close() diff --git a/pkg/executor/test/writetest/BUILD.bazel b/pkg/executor/test/writetest/BUILD.bazel index a1ba7d1035e1d..14fc0c2860b74 100644 --- a/pkg/executor/test/writetest/BUILD.bazel +++ b/pkg/executor/test/writetest/BUILD.bazel @@ -26,7 +26,6 @@ go_test( "//pkg/testkit", "//pkg/types", "//pkg/util", - "//pkg/util/mock", "@com_github_pingcap_failpoint//:failpoint", "@com_github_stretchr_testify//require", "@com_github_tikv_client_go_v2//tikv", diff --git a/pkg/planner/core/planbuilder.go b/pkg/planner/core/planbuilder.go index 83bc087ad90f6..7fc58a36301c3 100644 --- a/pkg/planner/core/planbuilder.go +++ b/pkg/planner/core/planbuilder.go @@ -3696,7 +3696,7 @@ func (b *PlanBuilder) buildSimple(ctx context.Context, node ast.StmtNode) (Plan, if err != nil { return nil, err } - if err := sessionctx.ValidateSnapshotReadTS(ctx, b.ctx.GetStore(), startTS); err != nil { + if err := sessionctx.ValidateSnapshotReadTS(ctx, b.ctx.GetStore(), startTS, true); err != nil { return nil, err } p.StaleTxnStartTS = startTS @@ -3710,7 +3710,7 @@ func (b *PlanBuilder) buildSimple(ctx context.Context, node ast.StmtNode) (Plan, if err != nil { return nil, err } - if err := sessionctx.ValidateSnapshotReadTS(ctx, b.ctx.GetStore(), startTS); err != nil { + if err := sessionctx.ValidateSnapshotReadTS(ctx, b.ctx.GetStore(), startTS, true); err != nil { return nil, err } p.StaleTxnStartTS = startTS diff --git a/pkg/sessionctx/context.go b/pkg/sessionctx/context.go index 5ebb1052cd384..7b9aa6abf3fb4 100644 --- a/pkg/sessionctx/context.go +++ b/pkg/sessionctx/context.go @@ -213,9 +213,12 @@ const ( LastExecuteDDL basicCtxType = 3 ) -// ValidateSnapshotReadTS strictly validates that readTS does not exceed the PD timestamp -func ValidateSnapshotReadTS(ctx context.Context, store kv.Storage, readTS uint64) error { - return store.GetOracle().ValidateSnapshotReadTS(ctx, readTS, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) +// ValidateSnapshotReadTS strictly validates that readTS does not exceed the PD timestamp. +// For read requests to the storage, the check can be implicitly performed when sending the RPC request. So this +// function is only needed when it's not proper to delay the check to when RPC requests are being sent (e.g., `BEGIN` +// statements that don't make reading operation immediately). +func ValidateSnapshotReadTS(ctx context.Context, store kv.Storage, readTS uint64, isStaleRead bool) error { + return store.GetOracle().ValidateReadTS(ctx, readTS, isStaleRead, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) } // SysProcTracker is used to track background sys processes diff --git a/pkg/sessiontxn/staleread/processor.go b/pkg/sessiontxn/staleread/processor.go index 393c3e7c378bb..e1a3f547fd11b 100644 --- a/pkg/sessiontxn/staleread/processor.go +++ b/pkg/sessiontxn/staleread/processor.go @@ -285,7 +285,7 @@ func parseAndValidateAsOf(ctx context.Context, sctx sessionctx.Context, asOf *as return 0, err } - if err = sessionctx.ValidateSnapshotReadTS(ctx, sctx.GetStore(), ts); err != nil { + if err = sessionctx.ValidateSnapshotReadTS(ctx, sctx.GetStore(), ts, true); err != nil { return 0, err } diff --git a/pkg/sessiontxn/staleread/util.go b/pkg/sessiontxn/staleread/util.go index 09f3edc2dbe0b..01791c6437900 100644 --- a/pkg/sessiontxn/staleread/util.go +++ b/pkg/sessiontxn/staleread/util.go @@ -77,7 +77,7 @@ func CalculateTsWithReadStaleness(ctx context.Context, sctx sessionctx.Context, // If the final calculated exceeds the min safe ts, we are not sure whether the ts is safe to read (note that // reading with a ts larger than PD's max allocated ts + 1 is unsafe and may break linearizability). // So in this case, do an extra check on it. - err = sessionctx.ValidateSnapshotReadTS(ctx, sctx.GetStore(), readTS) + err = sessionctx.ValidateSnapshotReadTS(ctx, sctx.GetStore(), readTS, true) if err != nil { return 0, err } diff --git a/pkg/store/copr/BUILD.bazel b/pkg/store/copr/BUILD.bazel index 9bebc01364fbb..e8c17b8094c8d 100644 --- a/pkg/store/copr/BUILD.bazel +++ b/pkg/store/copr/BUILD.bazel @@ -55,6 +55,7 @@ go_library( "@com_github_tikv_client_go_v2//config", "@com_github_tikv_client_go_v2//error", "@com_github_tikv_client_go_v2//metrics", + "@com_github_tikv_client_go_v2//oracle", "@com_github_tikv_client_go_v2//tikv", "@com_github_tikv_client_go_v2//tikvrpc", "@com_github_tikv_client_go_v2//tikvrpc/interceptor", diff --git a/pkg/store/copr/batch_coprocessor.go b/pkg/store/copr/batch_coprocessor.go index c33e8f9f6e112..19c43a567e9c8 100644 --- a/pkg/store/copr/batch_coprocessor.go +++ b/pkg/store/copr/batch_coprocessor.go @@ -1292,7 +1292,7 @@ func (b *batchCopIterator) retryBatchCopTask(ctx context.Context, bo *backoff.Ba const TiFlashReadTimeoutUltraLong = 3600 * time.Second func (b *batchCopIterator) handleTaskOnce(ctx context.Context, bo *backoff.Backoffer, task *batchCopTask) ([]*batchCopTask, error) { - sender := NewRegionBatchRequestSender(b.store.GetRegionCache(), b.store.GetTiKVClient(), b.enableCollectExecutionInfo) + sender := NewRegionBatchRequestSender(b.store.GetRegionCache(), b.store.GetTiKVClient(), b.store.store.GetOracle(), b.enableCollectExecutionInfo) var regionInfos = make([]*coprocessor.RegionInfo, 0, len(task.regionInfos)) for _, ri := range task.regionInfos { regionInfos = append(regionInfos, ri.toCoprocessorRegionInfo()) diff --git a/pkg/store/copr/batch_request_sender.go b/pkg/store/copr/batch_request_sender.go index ccb138f7753c3..5c6d9a6cbe192 100644 --- a/pkg/store/copr/batch_request_sender.go +++ b/pkg/store/copr/batch_request_sender.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/tidb/pkg/config" tikverr "github.com/tikv/client-go/v2/error" + "github.com/tikv/client-go/v2/oracle" "github.com/tikv/client-go/v2/tikv" "github.com/tikv/client-go/v2/tikvrpc" "google.golang.org/grpc/codes" @@ -56,9 +57,9 @@ type RegionBatchRequestSender struct { } // NewRegionBatchRequestSender creates a RegionBatchRequestSender object. -func NewRegionBatchRequestSender(cache *RegionCache, client tikv.Client, enableCollectExecutionInfo bool) *RegionBatchRequestSender { +func NewRegionBatchRequestSender(cache *RegionCache, client tikv.Client, oracle oracle.Oracle, enableCollectExecutionInfo bool) *RegionBatchRequestSender { return &RegionBatchRequestSender{ - RegionRequestSender: tikv.NewRegionRequestSender(cache.RegionCache, client), + RegionRequestSender: tikv.NewRegionRequestSender(cache.RegionCache, client, oracle), enableCollectExecutionInfo: enableCollectExecutionInfo, } } diff --git a/pkg/store/copr/mpp.go b/pkg/store/copr/mpp.go index cd0695a3e0d9d..cc5d361c6d73e 100644 --- a/pkg/store/copr/mpp.go +++ b/pkg/store/copr/mpp.go @@ -138,7 +138,7 @@ func (c *MPPClient) DispatchMPPTask(param kv.DispatchMPPTaskParam) (resp *mpp.Di // Or else it's the task without region, which always happens in high layer task without table. // In that case if originalTask != nil { - sender := NewRegionBatchRequestSender(c.store.GetRegionCache(), c.store.GetTiKVClient(), param.EnableCollectExecutionInfo) + sender := NewRegionBatchRequestSender(c.store.GetRegionCache(), c.store.GetTiKVClient(), c.store.store.GetOracle(), param.EnableCollectExecutionInfo) rpcResp, retry, _, err = sender.SendReqToAddr(bo, originalTask.ctx, originalTask.regionInfos, wrappedReq, tikv.ReadTimeoutMedium) // No matter what the rpc error is, we won't retry the mpp dispatch tasks. // TODO: If we want to retry, we must redo the plan fragment cutting and task scheduling. diff --git a/pkg/util/mock/BUILD.bazel b/pkg/util/mock/BUILD.bazel index 75ac693889df1..88c26bfd03067 100644 --- a/pkg/util/mock/BUILD.bazel +++ b/pkg/util/mock/BUILD.bazel @@ -22,6 +22,7 @@ go_library( "//pkg/sessionctx/variable", "//pkg/util", "//pkg/util/disk", + "//pkg/util/logutil", "//pkg/util/memory", "//pkg/util/sli", "//pkg/util/sqlexec", diff --git a/pkg/util/mock/context.go b/pkg/util/mock/context.go index 7c6b52c96bce5..18e02c07d40e4 100644 --- a/pkg/util/mock/context.go +++ b/pkg/util/mock/context.go @@ -32,6 +32,7 @@ import ( "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/util" "github.com/pingcap/tidb/pkg/util/disk" + "github.com/pingcap/tidb/pkg/util/logutil" "github.com/pingcap/tidb/pkg/util/memory" "github.com/pingcap/tidb/pkg/util/sli" "github.com/pingcap/tidb/pkg/util/sqlexec" @@ -67,7 +68,7 @@ type wrapTxn struct { } func (txn *wrapTxn) validOrPending() bool { - return txn.tsFuture != nil || txn.Transaction.Valid() + return txn.tsFuture != nil || (txn.Transaction != nil && txn.Transaction.Valid()) } func (txn *wrapTxn) pending() bool { @@ -173,7 +174,15 @@ func (c *Context) GetSessionVars() *variable.SessionVars { } // Txn implements sessionctx.Context Txn interface. -func (c *Context) Txn(bool) (kv.Transaction, error) { +func (c *Context) Txn(active bool) (kv.Transaction, error) { + if active { + if !c.txn.validOrPending() { + err := c.newTxn(context.Background()) + if err != nil { + return nil, err + } + } + } return &c.txn, nil } @@ -253,10 +262,12 @@ func (c *Context) GetSessionPlanCache() sessionctx.PlanCache { return c.pcache } -// NewTxn implements the sessionctx.Context interface. -func (c *Context) NewTxn(context.Context) error { +// newTxn Creates new transaction on the session context. +func (c *Context) newTxn(ctx context.Context) error { if c.Store == nil { - return errors.New("store is not set") + logutil.Logger(ctx).Warn("mock.Context: No store is specified when trying to create new transaction. A fake transaction will be created. Note that this is unrecommended usage.") + c.fakeTxn() + return nil } if c.txn.Valid() { err := c.txn.Commit(c.ctx) @@ -273,14 +284,41 @@ func (c *Context) NewTxn(context.Context) error { return nil } -// NewStaleTxnWithStartTS implements the sessionctx.Context interface. -func (c *Context) NewStaleTxnWithStartTS(ctx context.Context, _ uint64) error { - return c.NewTxn(ctx) +// fakeTxn is used to let some tests pass in the context without an available kv.Storage. Once usages to access +// transactions without a kv.Storage are removed, this type should also be removed. +// New code should never use this. +type fakeTxn struct { + // The inner should always be nil. + kv.Transaction + startTS uint64 +} + +func (t *fakeTxn) StartTS() uint64 { + return t.startTS +} + +func (*fakeTxn) SetDiskFullOpt(_ kvrpcpb.DiskFullOpt) {} + +func (*fakeTxn) SetOption(_ int, _ any) {} + +func (*fakeTxn) Get(ctx context.Context, _ kv.Key) ([]byte, error) { + // Check your implementation if you meet this error. It's dangerous if some calculation relies on the data but the + // read result is faked. + logutil.Logger(ctx).Warn("mock.Context: No store is specified but trying to access data from a transaction.") + return nil, nil +} + +func (*fakeTxn) Valid() bool { return true } + +func (c *Context) fakeTxn() { + c.txn.Transaction = &fakeTxn{ + startTS: 1, + } } // RefreshTxnCtx implements the sessionctx.Context interface. func (c *Context) RefreshTxnCtx(ctx context.Context) error { - return errors.Trace(c.NewTxn(ctx)) + return errors.Trace(c.newTxn(ctx)) } // RollbackTxn indicates an expected call of RollbackTxn.