Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions tidb/completion/completion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,60 @@ func hasDuplicates(candidates []Candidate) bool {
return false
}

func TestComplete_QualifiedColumnScopedToQualifier(t *testing.T) {
cat := catalog.New()
if _, err := cat.Exec("CREATE DATABASE `db`; USE `db`; CREATE TABLE `t1` (`a1` int, `a2` int); CREATE TABLE `t2` (`b1` int, `b2` int);", &catalog.ExecOptions{ContinueOnError: true}); err != nil {
t.Fatal(err)
}

// `a.` (alias for t1) must offer only t1's columns.
got := Complete("SELECT a. FROM t1 AS a JOIN t2 AS b", len("SELECT a."), cat)
if !containsCandidate(got, "a1", CandidateColumn) || !containsCandidate(got, "a2", CandidateColumn) {
t.Errorf("a. should offer t1 columns a1,a2; got %v", got)
}
if containsCandidate(got, "b1", CandidateColumn) || containsCandidate(got, "b2", CandidateColumn) {
t.Errorf("a. must NOT offer t2 columns b1,b2; got %v", got)
}

// `b.` (alias for t2) must offer only t2's columns.
got = Complete("SELECT b. FROM t1 AS a JOIN t2 AS b", len("SELECT b."), cat)
if !containsCandidate(got, "b1", CandidateColumn) || !containsCandidate(got, "b2", CandidateColumn) {
t.Errorf("b. should offer t2 columns b1,b2; got %v", got)
}
if containsCandidate(got, "a1", CandidateColumn) || containsCandidate(got, "a2", CandidateColumn) {
t.Errorf("b. must NOT offer t1 columns a1,a2; got %v", got)
}

// Qualifying by table name in a JOIN restricts to that table.
got = Complete("SELECT t1. FROM t1 JOIN t2", len("SELECT t1."), cat)
if !containsCandidate(got, "a1", CandidateColumn) {
t.Errorf("t1. should offer t1 columns; got %v", got)
}
if containsCandidate(got, "b1", CandidateColumn) {
t.Errorf("t1. must NOT offer t2 columns; got %v", got)
}

// An unqualified column reference still offers all in-scope columns.
got = Complete("SELECT FROM t1 JOIN t2", len("SELECT "), cat)
if !containsCandidate(got, "a1", CandidateColumn) || !containsCandidate(got, "b1", CandidateColumn) {
t.Errorf("unqualified column ref should offer all in-scope columns; got %v", got)
}

// A fully-qualified db.table. must scope by database too, when the same table
// name exists in more than one database in scope.
xdb := catalog.New()
if _, err := xdb.Exec("CREATE DATABASE `db1`; CREATE TABLE `db1`.`t` (`a1` int); CREATE DATABASE `db2`; CREATE TABLE `db2`.`t` (`b1` int); USE `db1`;", &catalog.ExecOptions{ContinueOnError: true}); err != nil {
t.Fatal(err)
}
got = Complete("SELECT db1.t. FROM db1.t JOIN db2.t", len("SELECT db1.t."), xdb)
if !containsCandidate(got, "a1", CandidateColumn) {
t.Errorf("db1.t. should offer db1.t column a1; got %v", got)
}
if containsCandidate(got, "b1", CandidateColumn) {
t.Errorf("db1.t. must NOT offer db2.t column b1; got %v", got)
}
}

func TestComplete_2_1_CompleteReturnsSlice(t *testing.T) {
// Scenario: Complete(sql, cursorOffset, catalog) returns []Candidate
cat := catalog.New()
Expand Down
91 changes: 91 additions & 0 deletions tidb/completion/resolve.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package completion

import (
"strings"

"github.com/bytebase/omni/tidb/catalog"
"github.com/bytebase/omni/tidb/parser"
)
Expand Down Expand Up @@ -186,9 +188,30 @@ func resolveColumnRefScoped(cat *catalog.Catalog, sql string, cursorOffset int)
return resolveColumnRef(cat)
}

// A qualified column reference (e.g. `a.col`, `t1.col`, or `db1.t1.col`) is
// restricted to the table whose alias or name matches the qualifier before the
// dot — and, when the qualifier names a database, whose database matches too
// (so `db1.t.` is not satisfied by `db2.t`). An unqualified reference uses
// every in-scope table.
qDB, qName := columnQualifier(sql, cursorOffset)

seen := make(map[string]bool)
var result []Candidate
for _, ref := range refs {
if qName != "" {
if !strings.EqualFold(ref.Alias, qName) && !strings.EqualFold(ref.Table, qName) {
continue
}
if qDB != "" {
refDB := ref.Database
if refDB == "" {
refDB = cat.CurrentDatabase()
}
if !strings.EqualFold(refDB, qDB) {
continue
}
}
}
// Resolve table in the appropriate database.
targetDB := db
if ref.Database != "" {
Expand Down Expand Up @@ -451,3 +474,71 @@ func currentDB(cat *catalog.Catalog) *catalog.Database {
}
return cat.GetDatabase(name)
}

// identBeforeDot scans left from pos expecting (optional whitespace) '.'
// (optional whitespace) <identifier>, returning the identifier text, the index
// where it starts, and whether the pattern matched. The identifier may be bare
// or backtick-quoted.
func identBeforeDot(sql string, pos int) (ident string, start int, ok bool) {
if pos > len(sql) {
pos = len(sql)
}
i := pos
for i > 0 && isSpaceByte(sql[i-1]) {
i--
}
if i == 0 || sql[i-1] != '.' {
return "", pos, false
}
i-- // consume the dot
for i > 0 && isSpaceByte(sql[i-1]) {
i--
}
if i == 0 {
return "", pos, false
}
if sql[i-1] == '`' {
end := i - 1
j := end - 1
for j >= 0 && sql[j] != '`' {
j--
}
if j < 0 {
return "", pos, false
}
return sql[j+1 : end], j, true
}
end := i
for i > 0 && isIdentByte(sql[i-1]) {
i--
}
if i == end {
return "", pos, false
}
return sql[i:end], i, true
}

// columnQualifier returns the optional database part and the table/alias part of
// the qualifier preceding the dot at the cursor: `db1.t.` -> ("db1","t"),
// `t.`/`a.` -> ("","t")/("","a"), and ("","") when the reference is unqualified.
// The caller passes the collect offset, which sits at the start of the partial
// column name.
func columnQualifier(sql string, offset int) (db, name string) {
n, start, ok := identBeforeDot(sql, offset)
if !ok {
return "", ""
}
if d, _, ok2 := identBeforeDot(sql, start); ok2 {
db = d
}
return db, n
}

func isSpaceByte(b byte) bool {
return b == ' ' || b == '\t' || b == '\n' || b == '\r'
}

func isIdentByte(b byte) bool {
return b == '_' || b == '$' ||
(b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9')
}
Loading