diff --git a/tidb/ast/outfuncs.go b/tidb/ast/outfuncs.go index 9f909421..cce6e5f8 100644 --- a/tidb/ast/outfuncs.go +++ b/tidb/ast/outfuncs.go @@ -39,6 +39,8 @@ func writeNode(sb *strings.Builder, node Node) { writeUpdateStmt(sb, n) case *DeleteStmt: writeDeleteStmt(sb, n) + case *BatchStmt: + writeBatchStmt(sb, n) case *CreateTableStmt: writeCreateTableStmt(sb, n) case *AlterTableStmt: @@ -787,6 +789,28 @@ func writeDeleteStmt(sb *strings.Builder, n *DeleteStmt) { sb.WriteString("}") } +func writeBatchStmt(sb *strings.Builder, n *BatchStmt) { + sb.WriteString("{BATCH") + fmt.Fprintf(sb, " :loc %d", n.Loc.Start) + if n.ShardColumn != nil { + sb.WriteString(" :shard_column ") + writeNode(sb, n.ShardColumn) + } + fmt.Fprintf(sb, " :limit %d", n.Limit) + switch n.DryRun { + case BatchDryRunQuery: + sb.WriteString(" :dry_run query") + case BatchDryRunSplitDML: + sb.WriteString(" :dry_run split_dml") + case BatchDryRunNone: + } + if n.DML != nil { + sb.WriteString(" :dml ") + writeNode(sb, n.DML) + } + sb.WriteString("}") +} + func writeCreateTableStmt(sb *strings.Builder, n *CreateTableStmt) { sb.WriteString("{CREATE_TABLE") fmt.Fprintf(sb, " :loc %d", n.Loc.Start) diff --git a/tidb/ast/parsenodes.go b/tidb/ast/parsenodes.go index 6ad99f9a..80b82c03 100644 --- a/tidb/ast/parsenodes.go +++ b/tidb/ast/parsenodes.go @@ -130,6 +130,35 @@ type DeleteStmt struct { func (s *DeleteStmt) nodeTag() {} func (s *DeleteStmt) stmtNode() {} +// BatchDryRunMode enumerates TiDB BATCH non-transactional DML dry-run modes. +// Numeric values mirror pingcap ast (dml.go: NoDryRun=0, DryRunQuery=1, +// DryRunSplitDml=2). DRY RUN → SplitDML; DRY RUN QUERY → Query. +type BatchDryRunMode int + +const ( + BatchDryRunNone BatchDryRunMode = iota // 0: execute the split DMLs + BatchDryRunQuery // 1: DRY RUN QUERY — show the SELECT that splits + BatchDryRunSplitDML // 2: DRY RUN — show the split DML jobs +) + +// BatchStmt represents a TiDB non-transactional DML statement: +// +// BATCH [ON ] LIMIT [DRY RUN [QUERY]] {DELETE | UPDATE | INSERT | REPLACE} +// +// Ref: pingcap parser.y NonTransactionalDMLStmt (production "BATCH" +// OptionalShardColumn "LIMIT" NUM DryRunOptions ShardableStmt). REPLACE is +// surfaced as *InsertStmt with IsReplace=true (omni unifies INSERT/REPLACE). +type BatchStmt struct { + Loc Loc + ShardColumn *ColumnRef // ON ; nil → TiDB auto-selects the handle column + Limit uint64 // LIMIT ; grammar is NUM (integer literal), not an expression + DryRun BatchDryRunMode + DML StmtNode // *DeleteStmt | *UpdateStmt | *InsertStmt +} + +func (s *BatchStmt) nodeTag() {} +func (s *BatchStmt) stmtNode() {} + // CreateTableStmt represents a CREATE TABLE statement. type CreateTableStmt struct { Loc Loc diff --git a/tidb/ast/walk_generated.go b/tidb/ast/walk_generated.go index 6b74c379..c67f69bb 100644 --- a/tidb/ast/walk_generated.go +++ b/tidb/ast/walk_generated.go @@ -110,6 +110,11 @@ func walkChildren(v Visitor, node Node) { Walk(v, n.Column) } Walk(v, n.Value) + case *BatchStmt: + if n.ShardColumn != nil { + Walk(v, n.ShardColumn) + } + Walk(v, n.DML) case *BeginEndBlock: for _, item := range n.Stmts { Walk(v, item) diff --git a/tidb/completion/completion.go b/tidb/completion/completion.go index 7b15c9ac..4c8408d0 100644 --- a/tidb/completion/completion.go +++ b/tidb/completion/completion.go @@ -1,4 +1,4 @@ -// Package completion provides parser-native C3-style SQL completion for MySQL. +// Package completion provides parser-native C3-style SQL completion for TiDB. package completion import ( diff --git a/tidb/parser/batch.go b/tidb/parser/batch.go new file mode 100644 index 00000000..ef1c89ad --- /dev/null +++ b/tidb/parser/batch.go @@ -0,0 +1,101 @@ +package parser + +import ( + "strconv" + + nodes "github.com/bytebase/omni/tidb/ast" +) + +// parseBatchStmt parses a TiDB non-transactional DML statement: +// +// BATCH [ON ] LIMIT [DRY RUN [QUERY]] {DELETE | UPDATE | INSERT | REPLACE} +// +// Grammar ref: pingcap parser.y NonTransactionalDMLStmt — +// "BATCH" OptionalShardColumn "LIMIT" NUM DryRunOptions ShardableStmt. +// LIMIT takes a bare integer (NUM), not an expression. +func (p *Parser) parseBatchStmt() (*nodes.BatchStmt, error) { + start := p.pos() + p.advance() // consume BATCH + + stmt := &nodes.BatchStmt{Loc: nodes.Loc{Start: start}} + + // Completion: after BATCH, offer ON / LIMIT. + p.checkCursor() + if p.collectMode() { + p.addTokenCandidate(kwON) + p.addTokenCandidate(kwLIMIT) + return nil, &ParseError{Message: "collecting"} + } + + // OptionalShardColumn: [ON ColumnName]. Upstream is a plain (optionally + // qualified) column name — parseColumnRef also accepts wildcard forms + // (t.* / db.t.*), which TiDB rejects here, so disallow them. + if _, ok := p.match(kwON); ok { + col, err := p.parseColumnRef() + if err != nil { + return nil, err + } + if col.Star { + return nil, &ParseError{Message: "BATCH shard column must be a column name, not a wildcard", Position: col.Loc.Start} + } + stmt.ShardColumn = col + } + + // LIMIT NUM (mandatory). Grammar is NUM, so reject expressions, placeholders, + // and signed/parenthesized forms by requiring an integer-literal token. + if _, err := p.expect(kwLIMIT); err != nil { + return nil, err + } + if p.cur.Type != tokICONST { + return nil, p.syntaxErrorAtCur() + } + numTok := p.advance() + limit, err := strconv.ParseUint(numTok.Str, 10, 64) + if err != nil { + return nil, &ParseError{Message: "invalid BATCH LIMIT value", Position: numTok.Loc} + } + stmt.Limit = limit + + // DryRunOptions: [] | DRY RUN | DRY RUN QUERY + if _, ok := p.match(kwDRY); ok { + if _, err := p.expect(kwRUN); err != nil { + return nil, err + } + if _, ok := p.match(kwQUERY); ok { + stmt.DryRun = nodes.BatchDryRunQuery + } else { + stmt.DryRun = nodes.BatchDryRunSplitDML + } + } + + // Completion: before the DML, offer DRY plus the shardable statement starters. + p.checkCursor() + if p.collectMode() { + for _, tok := range []int{kwDRY, kwDELETE, kwUPDATE, kwINSERT, kwREPLACE} { + p.addTokenCandidate(tok) + } + return nil, &ParseError{Message: "collecting"} + } + + // ShardableStmt: DELETE | UPDATE | INSERT | REPLACE. + var dml nodes.StmtNode + switch p.cur.Type { + case kwDELETE: + dml, err = p.parseDeleteStmt() + case kwUPDATE: + dml, err = p.parseUpdateStmt() + case kwINSERT: + dml, err = p.parseInsertStmt() + case kwREPLACE: + dml, err = p.parseReplaceStmt() + default: + return nil, p.syntaxErrorAtCur() + } + if err != nil { + return nil, err + } + stmt.DML = dml + + stmt.Loc.End = p.pos() + return stmt, nil +} diff --git a/tidb/parser/batch_test.go b/tidb/parser/batch_test.go new file mode 100644 index 00000000..a0c624c9 --- /dev/null +++ b/tidb/parser/batch_test.go @@ -0,0 +1,245 @@ +package parser + +import ( + "errors" + "strings" + "testing" + + gomysql "github.com/go-sql-driver/mysql" + + "github.com/bytebase/omni/tidb/ast" +) + +// TestBatchKeywordsStillIdentifiers verifies that introducing the BATCH/DRY/RUN +// keyword tokens does not regress their use as identifiers. Upstream classifies +// all three as non-reserved (BATCH/DRY/RUN are TiDBKeyword), so they must remain +// usable as table, column, and qualifier names. +func TestBatchKeywordsStillIdentifiers(t *testing.T) { + cases := []string{ + "SELECT batch FROM t", + "SELECT dry, run FROM t", + "SELECT * FROM t WHERE run = 1", + "SELECT * FROM t WHERE dry > 0 AND batch < 10", + "CREATE TABLE batch (id INT)", + "CREATE TABLE run (dry INT, batch VARCHAR(10))", + "INSERT INTO run (dry) VALUES (1)", + "SELECT batch.id FROM batch", + "UPDATE batch SET run = 1 WHERE dry = 0", + } + for _, sql := range cases { + t.Run(sql, func(t *testing.T) { + ParseAndCheck(t, sql) + }) + } +} + +// batchCase is a BATCH grammar acceptance case, shared by the pure-parser test +// and the TiDB container lockstep so both assert against the same corpus. +type batchCase struct { + name string + sql string + wantAccept bool +} + +// batchCases enumerates BATCH grammar positives and sibling-arm negatives. +// Negatives are derived from the upstream production +// "BATCH" OptionalShardColumn "LIMIT" NUM DryRunOptions ShardableStmt: +// every decorator a sibling arm accepts that BATCH's arm rejects gets a case. +var batchCases = []batchCase{ + // Positives. + {"shard col + delete", "BATCH ON id LIMIT 5000 DELETE FROM t WHERE 1=1", true}, + {"no shard col + delete", "BATCH LIMIT 5000 DELETE FROM t WHERE 1=1", true}, + {"dry run (split dml)", "BATCH ON id LIMIT 1000 DRY RUN DELETE FROM t WHERE 1=1", true}, + {"dry run query", "BATCH ON id LIMIT 1000 DRY RUN QUERY DELETE FROM t WHERE 1=1", true}, + {"update", "BATCH ON id LIMIT 100 UPDATE t SET x=1 WHERE 1=1", true}, + {"insert select", "BATCH ON id LIMIT 100 INSERT INTO t2 SELECT * FROM t WHERE 1=1", true}, + {"insert ignore", "BATCH ON id LIMIT 100 INSERT IGNORE INTO t2 SELECT * FROM t WHERE 1=1", true}, + {"replace select", "BATCH ON id LIMIT 100 REPLACE INTO t2 SELECT * FROM t WHERE 1=1", true}, + {"qualified shard col", "BATCH ON t.id LIMIT 100 DELETE FROM t WHERE 1=1", true}, + + // Negatives — sibling-arm decorators BATCH rejects. + {"missing LIMIT", "BATCH DELETE FROM t WHERE 1=1", false}, + {"ON without column", "BATCH ON LIMIT 100 DELETE FROM t WHERE 1=1", false}, + {"LIMIT without NUM", "BATCH ON id LIMIT DELETE FROM t WHERE 1=1", false}, + {"LIMIT expression not NUM", "BATCH ON id LIMIT (1+1) DELETE FROM t WHERE 1=1", false}, + {"LIMIT negative", "BATCH ON id LIMIT -100 DELETE FROM t WHERE 1=1", false}, + {"SELECT not shardable", "BATCH ON id LIMIT 100 SELECT * FROM t", false}, + {"DDL not shardable", "BATCH ON id LIMIT 100 CREATE TABLE x (id INT)", false}, + {"DRY without RUN", "BATCH ON id LIMIT 100 DRY DELETE FROM t WHERE 1=1", false}, + {"DRY RUN junk modifier", "BATCH ON id LIMIT 100 DRY RUN FOO DELETE FROM t WHERE 1=1", false}, + {"no DML statement", "BATCH ON id LIMIT 100", false}, + {"bare BATCH", "BATCH", false}, + + // Shard column must be a (qualified) column name, not a wildcard: + // pingcap OptionalShardColumn is "ON ColumnName", which has no star form. + {"bare star shard col", "BATCH ON * LIMIT 100 DELETE FROM t WHERE 1=1", false}, + {"wildcard shard col table.*", "BATCH ON t.* LIMIT 100 DELETE FROM t WHERE 1=1", false}, + {"wildcard shard col db.table.*", "BATCH ON db.t.* LIMIT 100 DELETE FROM t WHERE 1=1", false}, +} + +// TestBatchParse locks BATCH grammar acceptance in the pure parser (no Docker). +func TestBatchParse(t *testing.T) { + for _, c := range batchCases { + t.Run(c.name, func(t *testing.T) { + _, err := Parse(c.sql) + accepts := err == nil + if accepts != c.wantAccept { + t.Errorf("Parse(%q): accepts=%v, want %v (err=%v)", c.sql, accepts, c.wantAccept, err) + } + }) + } +} + +// TestBatchAST locks the semantic mapping of the parsed BatchStmt: dry-run mode, +// shard column shape, limit value, and DML node type (incl. REPLACE unification). +func TestBatchAST(t *testing.T) { + mustBatch := func(t *testing.T, sql string) *ast.BatchStmt { + t.Helper() + l := ParseAndCheck(t, sql) + b, ok := l.Items[0].(*ast.BatchStmt) + if !ok { + t.Fatalf("Parse(%q): got %T, want *ast.BatchStmt", sql, l.Items[0]) + } + return b + } + + t.Run("dry run query → Query", func(t *testing.T) { + b := mustBatch(t, "BATCH ON id LIMIT 1000 DRY RUN QUERY DELETE FROM t WHERE 1=1") + if b.DryRun != ast.BatchDryRunQuery { + t.Errorf("DryRun=%d, want BatchDryRunQuery(%d)", b.DryRun, ast.BatchDryRunQuery) + } + if _, ok := b.DML.(*ast.DeleteStmt); !ok { + t.Errorf("DML=%T, want *ast.DeleteStmt", b.DML) + } + }) + + t.Run("dry run → SplitDML", func(t *testing.T) { + b := mustBatch(t, "BATCH ON id LIMIT 1000 DRY RUN DELETE FROM t WHERE 1=1") + if b.DryRun != ast.BatchDryRunSplitDML { + t.Errorf("DryRun=%d, want BatchDryRunSplitDML(%d)", b.DryRun, ast.BatchDryRunSplitDML) + } + }) + + t.Run("no dry run → None", func(t *testing.T) { + b := mustBatch(t, "BATCH ON id LIMIT 5000 DELETE FROM t WHERE 1=1") + if b.DryRun != ast.BatchDryRunNone { + t.Errorf("DryRun=%d, want BatchDryRunNone(%d)", b.DryRun, ast.BatchDryRunNone) + } + if b.Limit != 5000 { + t.Errorf("Limit=%d, want 5000", b.Limit) + } + if b.ShardColumn == nil || b.ShardColumn.Column != "id" { + t.Errorf("ShardColumn=%+v, want {Column:id}", b.ShardColumn) + } + }) + + t.Run("no shard column → nil", func(t *testing.T) { + b := mustBatch(t, "BATCH LIMIT 5000 DELETE FROM t WHERE 1=1") + if b.ShardColumn != nil { + t.Errorf("ShardColumn=%+v, want nil", b.ShardColumn) + } + }) + + t.Run("qualified shard column", func(t *testing.T) { + b := mustBatch(t, "BATCH ON t.id LIMIT 100 DELETE FROM t WHERE 1=1") + if b.ShardColumn == nil || b.ShardColumn.Table != "t" || b.ShardColumn.Column != "id" { + t.Errorf("ShardColumn=%+v, want {Table:t Column:id}", b.ShardColumn) + } + }) + + t.Run("replace → InsertStmt IsReplace", func(t *testing.T) { + b := mustBatch(t, "BATCH ON id LIMIT 100 REPLACE INTO t2 SELECT * FROM t") + ins, ok := b.DML.(*ast.InsertStmt) + if !ok { + t.Fatalf("DML=%T, want *ast.InsertStmt", b.DML) + } + if !ins.IsReplace { + t.Errorf("InsertStmt.IsReplace=false, want true for REPLACE") + } + }) +} + +// TestBatchDryRunEnumValues locks the numeric enum values to mirror pingcap's +// ast (NoDryRun=0, DryRunQuery=1, DryRunSplitDml=2). The omni parser and any +// downstream consumer rely on these exact values. +func TestBatchDryRunEnumValues(t *testing.T) { + if ast.BatchDryRunNone != 0 { + t.Errorf("BatchDryRunNone=%d, want 0", ast.BatchDryRunNone) + } + if ast.BatchDryRunQuery != 1 { + t.Errorf("BatchDryRunQuery=%d, want 1", ast.BatchDryRunQuery) + } + if ast.BatchDryRunSplitDML != 2 { + t.Errorf("BatchDryRunSplitDML=%d, want 2", ast.BatchDryRunSplitDML) + } +} + +// TestBatchSerialize verifies the AST serialization (outfuncs writeBatchStmt) +// for both DRY RUN variants and the default mode. omni has no statement-level +// SQL deparse (the deparse package handles only expressions and SELECT), so +// parse→NodeToString→inspect is the round-trip analog for BatchStmt. +func TestBatchSerialize(t *testing.T) { + cases := []struct { + sql string + wantContains string + }{ + {"BATCH ON id LIMIT 1000 DRY RUN DELETE FROM t WHERE 1=1", ":dry_run split_dml"}, + {"BATCH ON id LIMIT 1000 DRY RUN QUERY UPDATE t SET x=1 WHERE 1=1", ":dry_run query"}, + {"BATCH ON id LIMIT 5000 DELETE FROM t WHERE 1=1", ":limit 5000"}, + {"BATCH ON id LIMIT 100 REPLACE INTO t2 SELECT * FROM t", ":replace true"}, + } + for _, c := range cases { + got := ast.NodeToString(ParseAndCheck(t, c.sql).Items[0]) + if !strings.Contains(got, c.wantContains) { + t.Errorf("NodeToString(%q) = %q, want substring %q", c.sql, got, c.wantContains) + } + } + + // Default (no DRY RUN) mode must omit the :dry_run field entirely. + noDryRun := ast.NodeToString(ParseAndCheck(t, "BATCH ON id LIMIT 5000 DELETE FROM t WHERE 1=1").Items[0]) + if strings.Contains(noDryRun, ":dry_run") { + t.Errorf("default mode should omit :dry_run, got %q", noDryRun) + } +} + +// tidbRejectedSyntax reports whether a TiDB execution error is a parse error +// (ER_PARSE_ERROR, 1064). Any other error (e.g. 1146 table-not-found) means +// TiDB parsed the statement, so the syntax is accepted. +func tidbRejectedSyntax(err error) bool { + if err == nil { + return false + } + var myErr *gomysql.MySQLError + if errors.As(err, &myErr) { + return myErr.Number == 1064 + } + return false +} + +// TestBatchTiDBOracle lockstep-verifies every batchCase against real TiDB +// v8.5.5: our parser's accept/reject must match TiDB's syntax acceptance. +// Skips under -short (CI) and when the container is unavailable. +func TestBatchTiDBOracle(t *testing.T) { + tc := startTiDB(t) + + tc.db.ExecContext(tc.ctx, "CREATE DATABASE IF NOT EXISTS omni_batch_test") + tc.db.ExecContext(tc.ctx, "USE omni_batch_test") + defer tc.db.ExecContext(tc.ctx, "DROP DATABASE IF EXISTS omni_batch_test") + + for _, c := range batchCases { + t.Run(c.name, func(t *testing.T) { + _, omniErr := Parse(c.sql) + omniAccepts := omniErr == nil + + tidbAccepts := !tidbRejectedSyntax(tc.canExecute(c.sql)) + + if omniAccepts != tidbAccepts { + t.Errorf("MISMATCH %q: omni accepts=%v (err=%v), TiDB accepts=%v", + c.sql, omniAccepts, omniErr, tidbAccepts) + } + if omniAccepts != c.wantAccept { + t.Errorf("omni %q: accepts=%v, want %v (err=%v)", c.sql, omniAccepts, c.wantAccept, omniErr) + } + }) + } +} diff --git a/tidb/parser/lexer.go b/tidb/parser/lexer.go index 445bbedd..879f2596 100644 --- a/tidb/parser/lexer.go +++ b/tidb/parser/lexer.go @@ -911,6 +911,13 @@ const ( // replication.go check the token directly instead of eqFold on // the identifier string. kwMASTER_LOG_FILE + + // TiDB non-transactional DML (BATCH ... LIMIT n DELETE/UPDATE/INSERT). + // BATCH/DRY/RUN are TiDBKeyword upstream — all unreserved + // (parser.y:7355, 7397-7398). QUERY (kwQUERY) already exists. + kwBATCH + kwDRY + kwRUN ) // keywords maps lowercase keyword strings to their token types. @@ -1751,6 +1758,11 @@ var keywords = map[string]int{ // Legacy MySQL replication alias (superseded by SOURCE_LOG_FILE). "master_log_file": kwMASTER_LOG_FILE, + + // TiDB non-transactional DML (BATCH ... LIMIT n DELETE/UPDATE/INSERT). + "batch": kwBATCH, + "dry": kwDRY, + "run": kwRUN, } // Token represents a lexical token. diff --git a/tidb/parser/name.go b/tidb/parser/name.go index c56054e2..536b7dad 100644 --- a/tidb/parser/name.go +++ b/tidb/parser/name.go @@ -507,6 +507,13 @@ var keywordCategories = map[int]keywordCategory{ // MASTER_LOG_FILE — legacy MySQL replication alias, unreserved. kwMASTER_LOG_FILE: kwCatUnambiguous, + + // TiDB non-transactional DML keywords (BATCH = parser.y:7355; + // DRY/RUN = parser.y:7397-7398) — all unreserved upstream, usable as + // identifiers/labels/roles/lvalues. + kwBATCH: kwCatUnambiguous, + kwDRY: kwCatUnambiguous, + kwRUN: kwCatUnambiguous, } // isReserved returns true if the token type is a reserved keyword that cannot diff --git a/tidb/parser/stmt.go b/tidb/parser/stmt.go index 26155ca6..c06063a9 100644 --- a/tidb/parser/stmt.go +++ b/tidb/parser/stmt.go @@ -12,7 +12,7 @@ func (p *Parser) parseStmt() (nodes.Node, error) { if p.collectMode() { // Add all top-level statement-starting keywords as candidates. stmtTokens := []int{ - kwSELECT, kwINSERT, kwUPDATE, kwDELETE, + kwSELECT, kwINSERT, kwUPDATE, kwDELETE, kwBATCH, kwCREATE, kwALTER, kwDROP, kwTRUNCATE, kwRENAME, kwWITH, kwTABLE, kwVALUES, kwREPLACE, kwSET, kwSHOW, kwUSE, @@ -56,6 +56,9 @@ func (p *Parser) parseStmt() (nodes.Node, error) { case kwDELETE: return p.parseDeleteStmt() + case kwBATCH: + return p.parseBatchStmt() + case kwCREATE: return p.parseCreateDispatch()