From 284a3ee23c7ae200e2d9579129361997d6465acb Mon Sep 17 00:00:00 2001 From: D3Hunter Date: Tue, 31 Dec 2024 15:02:33 +0800 Subject: [PATCH] importinto: use same type context flag setting as insert (#58606) close pingcap/tidb#58443 --- pkg/executor/select.go | 13 ++++++------- pkg/executor/select_test.go | 37 +++++++++++++++++++++++++++++++++++++ pkg/types/context.go | 4 +++- pkg/util/BUILD.bazel | 1 + pkg/util/misc.go | 18 ++++++++++++++++++ 5 files changed, 65 insertions(+), 8 deletions(-) diff --git a/pkg/executor/select.go b/pkg/executor/select.go index 2bd5dfd5beb8c..843183abad8ea 100644 --- a/pkg/executor/select.go +++ b/pkg/executor/select.go @@ -1049,6 +1049,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { errLevels := sc.ErrLevels() errLevels[errctx.ErrGroupDividedByZero] = errctx.LevelWarn + inImportInto := false switch stmt := s.(type) { // `ResetUpdateStmtCtx` and `ResetDeleteStmtCtx` may modify the flags, so we'll need to store them. case *ast.UpdateStmt: @@ -1077,12 +1078,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { !strictSQLMode || stmt.IgnoreErr, ) sc.Priority = stmt.Priority - sc.SetTypeFlags(sc.TypeFlags(). - WithTruncateAsWarning(!strictSQLMode || stmt.IgnoreErr). - WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode()). - WithIgnoreZeroInDate(!vars.SQLMode.HasNoZeroInDateMode() || - !vars.SQLMode.HasNoZeroDateMode() || !strictSQLMode || stmt.IgnoreErr || - vars.SQLMode.HasAllowInvalidDatesMode())) + sc.SetTypeFlags(util.GetTypeFlagsForInsert(sc.TypeFlags(), vars.SQLMode, stmt.IgnoreErr)) case *ast.CreateTableStmt, *ast.AlterTableStmt: sc.InCreateOrAlterStmt = true sc.SetTypeFlags(sc.TypeFlags(). @@ -1096,6 +1092,9 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.InLoadDataStmt = true // return warning instead of error when load data meet no partition for value errLevels[errctx.ErrGroupNoMatchedPartition] = errctx.LevelWarn + case *ast.ImportIntoStmt: + inImportInto = true + sc.SetTypeFlags(util.GetTypeFlagsForImportInto(sc.TypeFlags(), vars.SQLMode)) case *ast.SelectStmt: sc.InSelectStmt = true @@ -1153,7 +1152,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { // WithAllowNegativeToUnsigned with false value indicates values less than 0 should be clipped to 0 for unsigned integer types. // This is the case for `insert`, `update`, `alter table`, `create table` and `load data infile` statements, when not in strict SQL mode. // see https://dev.mysql.com/doc/refman/5.7/en/out-of-range-and-overflow.html - WithAllowNegativeToUnsigned(!sc.InInsertStmt && !sc.InLoadDataStmt && !sc.InUpdateStmt && !sc.InCreateOrAlterStmt), + WithAllowNegativeToUnsigned(!sc.InInsertStmt && !sc.InLoadDataStmt && !inImportInto && !sc.InUpdateStmt && !sc.InCreateOrAlterStmt), ) vars.PlanCacheParams.Reset() diff --git a/pkg/executor/select_test.go b/pkg/executor/select_test.go index c6dd909bebfe6..d4a3ca9a2d8c0 100644 --- a/pkg/executor/select_test.go +++ b/pkg/executor/select_test.go @@ -15,12 +15,15 @@ package executor_test import ( + "fmt" "testing" "github.com/pingcap/tidb/pkg/domain" "github.com/pingcap/tidb/pkg/executor" "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/util/mock" + "github.com/stretchr/testify/require" ) func BenchmarkResetContextOfStmt(b *testing.B) { @@ -31,3 +34,37 @@ func BenchmarkResetContextOfStmt(b *testing.B) { executor.ResetContextOfStmt(ctx, stmt) } } + +func TestImportIntoShouldHaveSameFlagsAsInsert(t *testing.T) { + insertStmt := &ast.InsertStmt{} + importStmt := &ast.ImportIntoStmt{} + insertCtx := mock.NewContext() + importCtx := mock.NewContext() + insertCtx.BindDomain(&domain.Domain{}) + importCtx.BindDomain(&domain.Domain{}) + for _, modeStr := range []string{ + "", + "IGNORE_SPACE", + "STRICT_TRANS_TABLES", + "STRICT_ALL_TABLES", + "ALLOW_INVALID_DATES", + "NO_ZERO_IN_DATE", + "NO_ZERO_DATE", + "NO_ZERO_IN_DATE,STRICT_ALL_TABLES", + "NO_ZERO_DATE,STRICT_ALL_TABLES", + "NO_ZERO_IN_DATE,NO_ZERO_DATE,STRICT_ALL_TABLES", + } { + t.Run(fmt.Sprintf("mode %s", modeStr), func(t *testing.T) { + mode, err := mysql.GetSQLMode(modeStr) + require.NoError(t, err) + insertCtx.GetSessionVars().SQLMode = mode + require.NoError(t, executor.ResetContextOfStmt(insertCtx, insertStmt)) + importCtx.GetSessionVars().SQLMode = mode + require.NoError(t, executor.ResetContextOfStmt(importCtx, importStmt)) + + insertTypeCtx := insertCtx.GetSessionVars().StmtCtx.TypeCtx() + importTypeCtx := importCtx.GetSessionVars().StmtCtx.TypeCtx() + require.EqualValues(t, insertTypeCtx.Flags(), importTypeCtx.Flags()) + }) + } +} diff --git a/pkg/types/context.go b/pkg/types/context.go index 07ffcf9266b12..c94bf212dc23a 100644 --- a/pkg/types/context.go +++ b/pkg/types/context.go @@ -34,7 +34,9 @@ const ( // FlagTruncateAsWarning indicates to append the truncate error to warnings instead of returning it to user. FlagTruncateAsWarning // FlagAllowNegativeToUnsigned indicates to allow the casting from negative to unsigned int. - // When this flag is not set by default, casting a negative value to unsigned results an overflow error. + // When this flag is not set by default, casting a negative value to unsigned + // results an overflow error, but if SQL mode is not strict, it's converted + // to 0 with a warning. // Otherwise, a negative value will be cast to the corresponding unsigned value without any error. // For example, when casting -1 to an unsigned bigint with `FlagAllowNegativeToUnsigned` set, // we will get `18446744073709551615` which is the biggest unsigned value. diff --git a/pkg/util/BUILD.bazel b/pkg/util/BUILD.bazel index a4f856dc962f0..23cd407573828 100644 --- a/pkg/util/BUILD.bazel +++ b/pkg/util/BUILD.bazel @@ -41,6 +41,7 @@ go_library( "//pkg/session/cursor", "//pkg/session/txninfo", "//pkg/sessionctx/stmtctx", + "//pkg/types", "//pkg/util/collate", "//pkg/util/disk", "//pkg/util/execdetails", diff --git a/pkg/util/misc.go b/pkg/util/misc.go index a907600010a87..59358876c4df2 100644 --- a/pkg/util/misc.go +++ b/pkg/util/misc.go @@ -46,6 +46,7 @@ import ( pmodel "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/collate" "github.com/pingcap/tidb/pkg/util/logutil" tlsutil "github.com/pingcap/tidb/pkg/util/tls" @@ -693,3 +694,20 @@ func createTLSCertificates(certpath string, keypath string, rsaKeySize int) erro // use RSA and unspecified signature algorithm return CreateCertificates(certpath, keypath, rsaKeySize, x509.RSA, x509.UnknownSignatureAlgorithm) } + +// GetTypeFlagsForInsert gets the type flags for insert statement. +func GetTypeFlagsForInsert(baseFlags types.Flags, sqlMode mysql.SQLMode, ignoreErr bool) types.Flags { + strictSQLMode := sqlMode.HasStrictMode() + return baseFlags. + WithTruncateAsWarning(!strictSQLMode || ignoreErr). + WithIgnoreInvalidDateErr(sqlMode.HasAllowInvalidDatesMode()). + WithIgnoreZeroInDate(!sqlMode.HasNoZeroInDateMode() || + !sqlMode.HasNoZeroDateMode() || !strictSQLMode || ignoreErr || + sqlMode.HasAllowInvalidDatesMode()) +} + +// GetTypeFlagsForImportInto gets the type flags for import into statement which +// has the same flags as normal `INSERT INTO xxx`. +func GetTypeFlagsForImportInto(baseFlags types.Flags, sqlMode mysql.SQLMode) types.Flags { + return GetTypeFlagsForInsert(baseFlags, sqlMode, false) +}