diff --git a/tidb/completion/completion_test.go b/tidb/completion/completion_test.go index 342b2bce..5fa9296a 100644 --- a/tidb/completion/completion_test.go +++ b/tidb/completion/completion_test.go @@ -43,6 +43,60 @@ func hasDuplicates(candidates []Candidate) bool { return false } +func TestComplete_QualifiedColumnScopedToQualifier(t *testing.T) { + cat := catalog.New() + if _, err := cat.Exec("CREATE DATABASE `db`; USE `db`; CREATE TABLE `t1` (`a1` int, `a2` int); CREATE TABLE `t2` (`b1` int, `b2` int);", &catalog.ExecOptions{ContinueOnError: true}); err != nil { + t.Fatal(err) + } + + // `a.` (alias for t1) must offer only t1's columns. + got := Complete("SELECT a. FROM t1 AS a JOIN t2 AS b", len("SELECT a."), cat) + if !containsCandidate(got, "a1", CandidateColumn) || !containsCandidate(got, "a2", CandidateColumn) { + t.Errorf("a. should offer t1 columns a1,a2; got %v", got) + } + if containsCandidate(got, "b1", CandidateColumn) || containsCandidate(got, "b2", CandidateColumn) { + t.Errorf("a. must NOT offer t2 columns b1,b2; got %v", got) + } + + // `b.` (alias for t2) must offer only t2's columns. + got = Complete("SELECT b. FROM t1 AS a JOIN t2 AS b", len("SELECT b."), cat) + if !containsCandidate(got, "b1", CandidateColumn) || !containsCandidate(got, "b2", CandidateColumn) { + t.Errorf("b. should offer t2 columns b1,b2; got %v", got) + } + if containsCandidate(got, "a1", CandidateColumn) || containsCandidate(got, "a2", CandidateColumn) { + t.Errorf("b. must NOT offer t1 columns a1,a2; got %v", got) + } + + // Qualifying by table name in a JOIN restricts to that table. + got = Complete("SELECT t1. FROM t1 JOIN t2", len("SELECT t1."), cat) + if !containsCandidate(got, "a1", CandidateColumn) { + t.Errorf("t1. should offer t1 columns; got %v", got) + } + if containsCandidate(got, "b1", CandidateColumn) { + t.Errorf("t1. must NOT offer t2 columns; got %v", got) + } + + // An unqualified column reference still offers all in-scope columns. + got = Complete("SELECT FROM t1 JOIN t2", len("SELECT "), cat) + if !containsCandidate(got, "a1", CandidateColumn) || !containsCandidate(got, "b1", CandidateColumn) { + t.Errorf("unqualified column ref should offer all in-scope columns; got %v", got) + } + + // A fully-qualified db.table. must scope by database too, when the same table + // name exists in more than one database in scope. + xdb := catalog.New() + if _, err := xdb.Exec("CREATE DATABASE `db1`; CREATE TABLE `db1`.`t` (`a1` int); CREATE DATABASE `db2`; CREATE TABLE `db2`.`t` (`b1` int); USE `db1`;", &catalog.ExecOptions{ContinueOnError: true}); err != nil { + t.Fatal(err) + } + got = Complete("SELECT db1.t. FROM db1.t JOIN db2.t", len("SELECT db1.t."), xdb) + if !containsCandidate(got, "a1", CandidateColumn) { + t.Errorf("db1.t. should offer db1.t column a1; got %v", got) + } + if containsCandidate(got, "b1", CandidateColumn) { + t.Errorf("db1.t. must NOT offer db2.t column b1; got %v", got) + } +} + func TestComplete_2_1_CompleteReturnsSlice(t *testing.T) { // Scenario: Complete(sql, cursorOffset, catalog) returns []Candidate cat := catalog.New() diff --git a/tidb/completion/resolve.go b/tidb/completion/resolve.go index 17181ddf..b063395c 100644 --- a/tidb/completion/resolve.go +++ b/tidb/completion/resolve.go @@ -1,6 +1,8 @@ package completion import ( + "strings" + "github.com/bytebase/omni/tidb/catalog" "github.com/bytebase/omni/tidb/parser" ) @@ -186,9 +188,30 @@ func resolveColumnRefScoped(cat *catalog.Catalog, sql string, cursorOffset int) return resolveColumnRef(cat) } + // A qualified column reference (e.g. `a.col`, `t1.col`, or `db1.t1.col`) is + // restricted to the table whose alias or name matches the qualifier before the + // dot — and, when the qualifier names a database, whose database matches too + // (so `db1.t.` is not satisfied by `db2.t`). An unqualified reference uses + // every in-scope table. + qDB, qName := columnQualifier(sql, cursorOffset) + seen := make(map[string]bool) var result []Candidate for _, ref := range refs { + if qName != "" { + if !strings.EqualFold(ref.Alias, qName) && !strings.EqualFold(ref.Table, qName) { + continue + } + if qDB != "" { + refDB := ref.Database + if refDB == "" { + refDB = cat.CurrentDatabase() + } + if !strings.EqualFold(refDB, qDB) { + continue + } + } + } // Resolve table in the appropriate database. targetDB := db if ref.Database != "" { @@ -451,3 +474,71 @@ func currentDB(cat *catalog.Catalog) *catalog.Database { } return cat.GetDatabase(name) } + +// identBeforeDot scans left from pos expecting (optional whitespace) '.' +// (optional whitespace) , returning the identifier text, the index +// where it starts, and whether the pattern matched. The identifier may be bare +// or backtick-quoted. +func identBeforeDot(sql string, pos int) (ident string, start int, ok bool) { + if pos > len(sql) { + pos = len(sql) + } + i := pos + for i > 0 && isSpaceByte(sql[i-1]) { + i-- + } + if i == 0 || sql[i-1] != '.' { + return "", pos, false + } + i-- // consume the dot + for i > 0 && isSpaceByte(sql[i-1]) { + i-- + } + if i == 0 { + return "", pos, false + } + if sql[i-1] == '`' { + end := i - 1 + j := end - 1 + for j >= 0 && sql[j] != '`' { + j-- + } + if j < 0 { + return "", pos, false + } + return sql[j+1 : end], j, true + } + end := i + for i > 0 && isIdentByte(sql[i-1]) { + i-- + } + if i == end { + return "", pos, false + } + return sql[i:end], i, true +} + +// columnQualifier returns the optional database part and the table/alias part of +// the qualifier preceding the dot at the cursor: `db1.t.` -> ("db1","t"), +// `t.`/`a.` -> ("","t")/("","a"), and ("","") when the reference is unqualified. +// The caller passes the collect offset, which sits at the start of the partial +// column name. +func columnQualifier(sql string, offset int) (db, name string) { + n, start, ok := identBeforeDot(sql, offset) + if !ok { + return "", "" + } + if d, _, ok2 := identBeforeDot(sql, start); ok2 { + db = d + } + return db, n +} + +func isSpaceByte(b byte) bool { + return b == ' ' || b == '\t' || b == '\n' || b == '\r' +} + +func isIdentByte(b byte) bool { + return b == '_' || b == '$' || + (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9') +}