Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions tidb/ast/outfuncs.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ func writeNode(sb *strings.Builder, node Node) {
writeUpdateStmt(sb, n)
case *DeleteStmt:
writeDeleteStmt(sb, n)
case *BatchStmt:
writeBatchStmt(sb, n)
case *CreateTableStmt:
writeCreateTableStmt(sb, n)
case *AlterTableStmt:
Expand Down Expand Up @@ -787,6 +789,28 @@ func writeDeleteStmt(sb *strings.Builder, n *DeleteStmt) {
sb.WriteString("}")
}

func writeBatchStmt(sb *strings.Builder, n *BatchStmt) {
sb.WriteString("{BATCH")
fmt.Fprintf(sb, " :loc %d", n.Loc.Start)
if n.ShardColumn != nil {
sb.WriteString(" :shard_column ")
writeNode(sb, n.ShardColumn)
}
fmt.Fprintf(sb, " :limit %d", n.Limit)
switch n.DryRun {
case BatchDryRunQuery:
sb.WriteString(" :dry_run query")
case BatchDryRunSplitDML:
sb.WriteString(" :dry_run split_dml")
case BatchDryRunNone:
}
if n.DML != nil {
sb.WriteString(" :dml ")
writeNode(sb, n.DML)
}
sb.WriteString("}")
}

func writeCreateTableStmt(sb *strings.Builder, n *CreateTableStmt) {
sb.WriteString("{CREATE_TABLE")
fmt.Fprintf(sb, " :loc %d", n.Loc.Start)
Expand Down
29 changes: 29 additions & 0 deletions tidb/ast/parsenodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,35 @@ type DeleteStmt struct {
func (s *DeleteStmt) nodeTag() {}
func (s *DeleteStmt) stmtNode() {}

// BatchDryRunMode enumerates TiDB BATCH non-transactional DML dry-run modes.
// Numeric values mirror pingcap ast (dml.go: NoDryRun=0, DryRunQuery=1,
// DryRunSplitDml=2). DRY RUN → SplitDML; DRY RUN QUERY → Query.
type BatchDryRunMode int

const (
BatchDryRunNone BatchDryRunMode = iota // 0: execute the split DMLs
BatchDryRunQuery // 1: DRY RUN QUERY — show the SELECT that splits
BatchDryRunSplitDML // 2: DRY RUN — show the split DML jobs
)

// BatchStmt represents a TiDB non-transactional DML statement:
//
// BATCH [ON <col>] LIMIT <n> [DRY RUN [QUERY]] {DELETE | UPDATE | INSERT | REPLACE}
//
// Ref: pingcap parser.y NonTransactionalDMLStmt (production "BATCH"
// OptionalShardColumn "LIMIT" NUM DryRunOptions ShardableStmt). REPLACE is
// surfaced as *InsertStmt with IsReplace=true (omni unifies INSERT/REPLACE).
type BatchStmt struct {
Loc Loc
ShardColumn *ColumnRef // ON <col>; nil → TiDB auto-selects the handle column
Limit uint64 // LIMIT <n>; grammar is NUM (integer literal), not an expression
DryRun BatchDryRunMode
DML StmtNode // *DeleteStmt | *UpdateStmt | *InsertStmt
}

func (s *BatchStmt) nodeTag() {}
func (s *BatchStmt) stmtNode() {}

// CreateTableStmt represents a CREATE TABLE statement.
type CreateTableStmt struct {
Loc Loc
Expand Down
5 changes: 5 additions & 0 deletions tidb/ast/walk_generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tidb/completion/completion.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Package completion provides parser-native C3-style SQL completion for MySQL.
// Package completion provides parser-native C3-style SQL completion for TiDB.
package completion

import (
Expand Down
101 changes: 101 additions & 0 deletions tidb/parser/batch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package parser

import (
"strconv"

nodes "github.com/bytebase/omni/tidb/ast"
)

// parseBatchStmt parses a TiDB non-transactional DML statement:
//
// BATCH [ON <col>] LIMIT <n> [DRY RUN [QUERY]] {DELETE | UPDATE | INSERT | REPLACE}
//
// Grammar ref: pingcap parser.y NonTransactionalDMLStmt —
// "BATCH" OptionalShardColumn "LIMIT" NUM DryRunOptions ShardableStmt.
// LIMIT takes a bare integer (NUM), not an expression.
func (p *Parser) parseBatchStmt() (*nodes.BatchStmt, error) {
start := p.pos()
p.advance() // consume BATCH

stmt := &nodes.BatchStmt{Loc: nodes.Loc{Start: start}}

// Completion: after BATCH, offer ON / LIMIT.
p.checkCursor()
if p.collectMode() {
p.addTokenCandidate(kwON)
p.addTokenCandidate(kwLIMIT)
return nil, &ParseError{Message: "collecting"}
}

// OptionalShardColumn: [ON ColumnName]. Upstream is a plain (optionally
// qualified) column name — parseColumnRef also accepts wildcard forms
// (t.* / db.t.*), which TiDB rejects here, so disallow them.
if _, ok := p.match(kwON); ok {
col, err := p.parseColumnRef()
if err != nil {
return nil, err
}
if col.Star {
return nil, &ParseError{Message: "BATCH shard column must be a column name, not a wildcard", Position: col.Loc.Start}
}
stmt.ShardColumn = col
}

// LIMIT NUM (mandatory). Grammar is NUM, so reject expressions, placeholders,
// and signed/parenthesized forms by requiring an integer-literal token.
if _, err := p.expect(kwLIMIT); err != nil {
return nil, err
}
if p.cur.Type != tokICONST {
return nil, p.syntaxErrorAtCur()
}
numTok := p.advance()
limit, err := strconv.ParseUint(numTok.Str, 10, 64)
if err != nil {
return nil, &ParseError{Message: "invalid BATCH LIMIT value", Position: numTok.Loc}
}
stmt.Limit = limit

// DryRunOptions: [] | DRY RUN | DRY RUN QUERY
if _, ok := p.match(kwDRY); ok {
if _, err := p.expect(kwRUN); err != nil {
return nil, err
}
if _, ok := p.match(kwQUERY); ok {
stmt.DryRun = nodes.BatchDryRunQuery
} else {
stmt.DryRun = nodes.BatchDryRunSplitDML
}
}

// Completion: before the DML, offer DRY plus the shardable statement starters.
p.checkCursor()
if p.collectMode() {
for _, tok := range []int{kwDRY, kwDELETE, kwUPDATE, kwINSERT, kwREPLACE} {
p.addTokenCandidate(tok)
}
return nil, &ParseError{Message: "collecting"}
}

// ShardableStmt: DELETE | UPDATE | INSERT | REPLACE.
var dml nodes.StmtNode
switch p.cur.Type {
case kwDELETE:
dml, err = p.parseDeleteStmt()
case kwUPDATE:
dml, err = p.parseUpdateStmt()
case kwINSERT:
dml, err = p.parseInsertStmt()
case kwREPLACE:
dml, err = p.parseReplaceStmt()
default:
return nil, p.syntaxErrorAtCur()
}
if err != nil {
return nil, err
}
stmt.DML = dml

stmt.Loc.End = p.pos()
return stmt, nil
}
Loading
Loading