Skip to content

Commit

Permalink
importinto: use same type context flag setting as insert (#58606)
Browse files Browse the repository at this point in the history
close #58443
  • Loading branch information
D3Hunter authored Dec 31, 2024
1 parent 42d4fae commit 284a3ee
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 8 deletions.
13 changes: 6 additions & 7 deletions pkg/executor/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {

errLevels := sc.ErrLevels()
errLevels[errctx.ErrGroupDividedByZero] = errctx.LevelWarn
inImportInto := false
switch stmt := s.(type) {
// `ResetUpdateStmtCtx` and `ResetDeleteStmtCtx` may modify the flags, so we'll need to store them.
case *ast.UpdateStmt:
Expand Down Expand Up @@ -1077,12 +1078,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
!strictSQLMode || stmt.IgnoreErr,
)
sc.Priority = stmt.Priority
sc.SetTypeFlags(sc.TypeFlags().
WithTruncateAsWarning(!strictSQLMode || stmt.IgnoreErr).
WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode()).
WithIgnoreZeroInDate(!vars.SQLMode.HasNoZeroInDateMode() ||
!vars.SQLMode.HasNoZeroDateMode() || !strictSQLMode || stmt.IgnoreErr ||
vars.SQLMode.HasAllowInvalidDatesMode()))
sc.SetTypeFlags(util.GetTypeFlagsForInsert(sc.TypeFlags(), vars.SQLMode, stmt.IgnoreErr))
case *ast.CreateTableStmt, *ast.AlterTableStmt:
sc.InCreateOrAlterStmt = true
sc.SetTypeFlags(sc.TypeFlags().
Expand All @@ -1096,6 +1092,9 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
sc.InLoadDataStmt = true
// return warning instead of error when load data meet no partition for value
errLevels[errctx.ErrGroupNoMatchedPartition] = errctx.LevelWarn
case *ast.ImportIntoStmt:
inImportInto = true
sc.SetTypeFlags(util.GetTypeFlagsForImportInto(sc.TypeFlags(), vars.SQLMode))
case *ast.SelectStmt:
sc.InSelectStmt = true

Expand Down Expand Up @@ -1153,7 +1152,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) {
// WithAllowNegativeToUnsigned with false value indicates values less than 0 should be clipped to 0 for unsigned integer types.
// This is the case for `insert`, `update`, `alter table`, `create table` and `load data infile` statements, when not in strict SQL mode.
// see https://dev.mysql.com/doc/refman/5.7/en/out-of-range-and-overflow.html
WithAllowNegativeToUnsigned(!sc.InInsertStmt && !sc.InLoadDataStmt && !sc.InUpdateStmt && !sc.InCreateOrAlterStmt),
WithAllowNegativeToUnsigned(!sc.InInsertStmt && !sc.InLoadDataStmt && !inImportInto && !sc.InUpdateStmt && !sc.InCreateOrAlterStmt),
)

vars.PlanCacheParams.Reset()
Expand Down
37 changes: 37 additions & 0 deletions pkg/executor/select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
package executor_test

import (
"fmt"
"testing"

"github.com/pingcap/tidb/pkg/domain"
"github.com/pingcap/tidb/pkg/executor"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/util/mock"
"github.com/stretchr/testify/require"
)

func BenchmarkResetContextOfStmt(b *testing.B) {
Expand All @@ -31,3 +34,37 @@ func BenchmarkResetContextOfStmt(b *testing.B) {
executor.ResetContextOfStmt(ctx, stmt)
}
}

func TestImportIntoShouldHaveSameFlagsAsInsert(t *testing.T) {
insertStmt := &ast.InsertStmt{}
importStmt := &ast.ImportIntoStmt{}
insertCtx := mock.NewContext()
importCtx := mock.NewContext()
insertCtx.BindDomain(&domain.Domain{})
importCtx.BindDomain(&domain.Domain{})
for _, modeStr := range []string{
"",
"IGNORE_SPACE",
"STRICT_TRANS_TABLES",
"STRICT_ALL_TABLES",
"ALLOW_INVALID_DATES",
"NO_ZERO_IN_DATE",
"NO_ZERO_DATE",
"NO_ZERO_IN_DATE,STRICT_ALL_TABLES",
"NO_ZERO_DATE,STRICT_ALL_TABLES",
"NO_ZERO_IN_DATE,NO_ZERO_DATE,STRICT_ALL_TABLES",
} {
t.Run(fmt.Sprintf("mode %s", modeStr), func(t *testing.T) {
mode, err := mysql.GetSQLMode(modeStr)
require.NoError(t, err)
insertCtx.GetSessionVars().SQLMode = mode
require.NoError(t, executor.ResetContextOfStmt(insertCtx, insertStmt))
importCtx.GetSessionVars().SQLMode = mode
require.NoError(t, executor.ResetContextOfStmt(importCtx, importStmt))

insertTypeCtx := insertCtx.GetSessionVars().StmtCtx.TypeCtx()
importTypeCtx := importCtx.GetSessionVars().StmtCtx.TypeCtx()
require.EqualValues(t, insertTypeCtx.Flags(), importTypeCtx.Flags())
})
}
}
4 changes: 3 additions & 1 deletion pkg/types/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ const (
// FlagTruncateAsWarning indicates to append the truncate error to warnings instead of returning it to user.
FlagTruncateAsWarning
// FlagAllowNegativeToUnsigned indicates to allow the casting from negative to unsigned int.
// When this flag is not set by default, casting a negative value to unsigned results an overflow error.
// When this flag is not set by default, casting a negative value to unsigned
// results an overflow error, but if SQL mode is not strict, it's converted
// to 0 with a warning.
// Otherwise, a negative value will be cast to the corresponding unsigned value without any error.
// For example, when casting -1 to an unsigned bigint with `FlagAllowNegativeToUnsigned` set,
// we will get `18446744073709551615` which is the biggest unsigned value.
Expand Down
1 change: 1 addition & 0 deletions pkg/util/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ go_library(
"//pkg/session/cursor",
"//pkg/session/txninfo",
"//pkg/sessionctx/stmtctx",
"//pkg/types",
"//pkg/util/collate",
"//pkg/util/disk",
"//pkg/util/execdetails",
Expand Down
18 changes: 18 additions & 0 deletions pkg/util/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import (
pmodel "github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/parser/terror"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/collate"
"github.com/pingcap/tidb/pkg/util/logutil"
tlsutil "github.com/pingcap/tidb/pkg/util/tls"
Expand Down Expand Up @@ -693,3 +694,20 @@ func createTLSCertificates(certpath string, keypath string, rsaKeySize int) erro
// use RSA and unspecified signature algorithm
return CreateCertificates(certpath, keypath, rsaKeySize, x509.RSA, x509.UnknownSignatureAlgorithm)
}

// GetTypeFlagsForInsert gets the type flags for insert statement.
func GetTypeFlagsForInsert(baseFlags types.Flags, sqlMode mysql.SQLMode, ignoreErr bool) types.Flags {
strictSQLMode := sqlMode.HasStrictMode()
return baseFlags.
WithTruncateAsWarning(!strictSQLMode || ignoreErr).
WithIgnoreInvalidDateErr(sqlMode.HasAllowInvalidDatesMode()).
WithIgnoreZeroInDate(!sqlMode.HasNoZeroInDateMode() ||
!sqlMode.HasNoZeroDateMode() || !strictSQLMode || ignoreErr ||
sqlMode.HasAllowInvalidDatesMode())
}

// GetTypeFlagsForImportInto gets the type flags for import into statement which
// has the same flags as normal `INSERT INTO xxx`.
func GetTypeFlagsForImportInto(baseFlags types.Flags, sqlMode mysql.SQLMode) types.Flags {
return GetTypeFlagsForInsert(baseFlags, sqlMode, false)
}

0 comments on commit 284a3ee

Please sign in to comment.