diff --git a/pkg/lightning/backend/kv/BUILD.bazel b/pkg/lightning/backend/kv/BUILD.bazel index 34d4dc24ef32a..6b66a7289c486 100644 --- a/pkg/lightning/backend/kv/BUILD.bazel +++ b/pkg/lightning/backend/kv/BUILD.bazel @@ -35,6 +35,7 @@ go_library( "//pkg/table/tblctx", "//pkg/tablecodec", "//pkg/types", + "//pkg/util", "//pkg/util/chunk", "//pkg/util/codec", "//pkg/util/context", diff --git a/pkg/lightning/backend/kv/context.go b/pkg/lightning/backend/kv/context.go index 14c1963723c21..b90675e1330f9 100644 --- a/pkg/lightning/backend/kv/context.go +++ b/pkg/lightning/backend/kv/context.go @@ -31,6 +31,7 @@ import ( "github.com/pingcap/tidb/pkg/table" "github.com/pingcap/tidb/pkg/table/tblctx" "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" contextutil "github.com/pingcap/tidb/pkg/util/context" "github.com/pingcap/tidb/pkg/util/intest" "github.com/pingcap/tidb/pkg/util/timeutil" @@ -48,11 +49,7 @@ type litExprContext struct { // NewExpressionContext creates a new `*ExprContext` for lightning import. func newLitExprContext(sqlMode mysql.SQLMode, sysVars map[string]string, timestamp int64) (*litExprContext, error) { - flags := types.DefaultStmtFlags. - WithTruncateAsWarning(!sqlMode.HasStrictMode()). - WithIgnoreInvalidDateErr(sqlMode.HasAllowInvalidDatesMode()). - WithIgnoreZeroInDate(!sqlMode.HasStrictMode() || sqlMode.HasAllowInvalidDatesMode() || - !sqlMode.HasNoZeroInDateMode() || !sqlMode.HasNoZeroDateMode()) + flags := util.GetTypeFlagsForImportInto(types.DefaultStmtFlags, sqlMode) errLevels := stmtctx.DefaultStmtErrLevels errLevels[errctx.ErrGroupTruncate] = errctx.ResolveErrLevel(flags.IgnoreTruncateErr(), flags.TruncateAsWarning()) diff --git a/pkg/lightning/backend/kv/context_test.go b/pkg/lightning/backend/kv/context_test.go index bb04a4f3d5bbe..2cf6d64f85e03 100644 --- a/pkg/lightning/backend/kv/context_test.go +++ b/pkg/lightning/backend/kv/context_test.go @@ -36,6 +36,7 @@ import ( ) func TestLitExprContext(t *testing.T) { + baseFlags := types.DefaultStmtFlags &^ types.FlagAllowNegativeToUnsigned cases := []struct { sqlMode mysql.SQLMode sysVars map[string]string @@ -47,7 +48,7 @@ func TestLitExprContext(t *testing.T) { { sqlMode: mysql.ModeNone, timestamp: 1234567, - checkFlags: types.DefaultStmtFlags | types.FlagTruncateAsWarning | types.FlagIgnoreZeroInDateErr, + checkFlags: baseFlags | types.FlagTruncateAsWarning | types.FlagIgnoreZeroInDateErr, checkErrLevel: func() errctx.LevelMap { m := stmtctx.DefaultStmtErrLevels m[errctx.ErrGroupTruncate] = errctx.LevelWarn @@ -68,7 +69,7 @@ func TestLitExprContext(t *testing.T) { { sqlMode: mysql.ModeStrictTransTables | mysql.ModeNoZeroDate | mysql.ModeNoZeroInDate | mysql.ModeErrorForDivisionByZero, - checkFlags: types.DefaultStmtFlags, + checkFlags: baseFlags, checkErrLevel: func() errctx.LevelMap { m := stmtctx.DefaultStmtErrLevels m[errctx.ErrGroupTruncate] = errctx.LevelError @@ -80,7 +81,7 @@ func TestLitExprContext(t *testing.T) { }, { sqlMode: mysql.ModeNoZeroDate | mysql.ModeNoZeroInDate | mysql.ModeErrorForDivisionByZero, - checkFlags: types.DefaultStmtFlags | types.FlagTruncateAsWarning | types.FlagIgnoreZeroInDateErr, + checkFlags: baseFlags | types.FlagTruncateAsWarning | types.FlagIgnoreZeroInDateErr, checkErrLevel: func() errctx.LevelMap { m := stmtctx.DefaultStmtErrLevels m[errctx.ErrGroupTruncate] = errctx.LevelWarn @@ -92,7 +93,7 @@ func TestLitExprContext(t *testing.T) { }, { sqlMode: mysql.ModeStrictTransTables | mysql.ModeNoZeroInDate, - checkFlags: types.DefaultStmtFlags | types.FlagIgnoreZeroInDateErr, + checkFlags: baseFlags | types.FlagIgnoreZeroInDateErr, checkErrLevel: func() errctx.LevelMap { m := stmtctx.DefaultStmtErrLevels m[errctx.ErrGroupTruncate] = errctx.LevelError @@ -104,7 +105,7 @@ func TestLitExprContext(t *testing.T) { }, { sqlMode: mysql.ModeStrictTransTables | mysql.ModeNoZeroDate, - checkFlags: types.DefaultStmtFlags | types.FlagIgnoreZeroInDateErr, + checkFlags: baseFlags | types.FlagIgnoreZeroInDateErr, checkErrLevel: func() errctx.LevelMap { m := stmtctx.DefaultStmtErrLevels m[errctx.ErrGroupTruncate] = errctx.LevelError @@ -116,7 +117,7 @@ func TestLitExprContext(t *testing.T) { }, { sqlMode: mysql.ModeStrictTransTables | mysql.ModeAllowInvalidDates, - checkFlags: types.DefaultStmtFlags | types.FlagIgnoreZeroInDateErr | types.FlagIgnoreInvalidDateErr, + checkFlags: baseFlags | types.FlagIgnoreZeroInDateErr | types.FlagIgnoreInvalidDateErr, checkErrLevel: func() errctx.LevelMap { m := stmtctx.DefaultStmtErrLevels m[errctx.ErrGroupTruncate] = errctx.LevelError diff --git a/pkg/util/misc.go b/pkg/util/misc.go index 59358876c4df2..3bb68aae20db4 100644 --- a/pkg/util/misc.go +++ b/pkg/util/misc.go @@ -698,12 +698,14 @@ func createTLSCertificates(certpath string, keypath string, rsaKeySize int) erro // GetTypeFlagsForInsert gets the type flags for insert statement. func GetTypeFlagsForInsert(baseFlags types.Flags, sqlMode mysql.SQLMode, ignoreErr bool) types.Flags { strictSQLMode := sqlMode.HasStrictMode() + // see comments in ResetContextOfStmt for WithAllowNegativeToUnsigned part. return baseFlags. WithTruncateAsWarning(!strictSQLMode || ignoreErr). WithIgnoreInvalidDateErr(sqlMode.HasAllowInvalidDatesMode()). WithIgnoreZeroInDate(!sqlMode.HasNoZeroInDateMode() || !sqlMode.HasNoZeroDateMode() || !strictSQLMode || ignoreErr || - sqlMode.HasAllowInvalidDatesMode()) + sqlMode.HasAllowInvalidDatesMode()). + WithAllowNegativeToUnsigned(false) } // GetTypeFlagsForImportInto gets the type flags for import into statement which diff --git a/tests/realtikvtest/importintotest2/from_select_test.go b/tests/realtikvtest/importintotest2/from_select_test.go index 60687a1a7cff6..43185281b077f 100644 --- a/tests/realtikvtest/importintotest2/from_select_test.go +++ b/tests/realtikvtest/importintotest2/from_select_test.go @@ -155,3 +155,12 @@ func (s *mockGCSSuite) TestImportFromSelectStaleRead() { s.tk.MustExec("import into dst from " + staleReadSQL) s.tk.MustQuery("select * from dst").Check(testkit.Rows("1 a", "2 b")) } + +func (s *mockGCSSuite) TestCastNegativeToUnsigned() { + s.prepareAndUseDB("from_select") + s.tk.MustExec("create table dt(id int unsigned)") + s.ErrorContains(s.tk.ExecToErr("import into dt from select -1"), "constant -1 overflows int") + s.tk.MustExec("set sql_mode=''") + s.tk.MustExec("import into dt from select -1") + s.tk.MustQuery("select * from dt").Check(testkit.Rows("0")) +}