-
Notifications
You must be signed in to change notification settings - Fork 867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add ability to use ':' in named args #2178
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -107,6 +107,19 @@ func rawState(l *sqlLexer) stateFn { | |
return singleQuoteState | ||
case '"': | ||
return doubleQuoteState | ||
case ':': | ||
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) | ||
prevRune := rune(0) | ||
if l.pos > 1 { | ||
prevRune, _ = utf8.DecodeRuneInString(l.src[l.pos-2:]) | ||
} | ||
if nextRune != ':' && prevRune != ':' && (isLetter(nextRune) || nextRune == '_') { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems to me that we can omit the first check nextRune != ':' as there will be a more specific check next, but in addition I would like to say that this check is much easier than the next ones, and will cut off cast types a little faster.... But type casts are not done so often to leave this prevenient check. What do you think, should I remove |
||
if l.pos-l.start > 0 { | ||
l.parts = append(l.parts, l.src[l.start:l.pos-width]) | ||
} | ||
l.start = l.pos | ||
return namedArgState | ||
Comment on lines
+117
to
+121
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Copy from '@' case |
||
} | ||
case '@': | ||
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) | ||
if isLetter(nextRune) || nextRune == '_' { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -160,3 +160,155 @@ func TestStrictNamedArgsRewriteQuery(t *testing.T) { | |
} | ||
} | ||
} | ||
|
||
func TestNamedArgsRewriteQuery2(t *testing.T) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also added a tests that should confirm the clarity of my implementation |
||
t.Parallel() | ||
|
||
for i, tt := range []struct { | ||
sql string | ||
args []any | ||
namedArgs pgx.NamedArgs | ||
expectedSQL string | ||
expectedArgs []any | ||
}{ | ||
{ | ||
sql: "select * from users where id = :id", | ||
namedArgs: pgx.NamedArgs{"id": int32(42)}, | ||
expectedSQL: "select * from users where id = $1", | ||
expectedArgs: []any{int32(42)}, | ||
}, | ||
{ | ||
sql: "select * from t where foo < :abc and baz = :def and bar < :abc", | ||
namedArgs: pgx.NamedArgs{"abc": int32(42), "def": int32(1)}, | ||
expectedSQL: "select * from t where foo < $1 and baz = $2 and bar < $1", | ||
expectedArgs: []any{int32(42), int32(1)}, | ||
}, | ||
{ | ||
sql: "select :a::int, :b::text", | ||
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, | ||
expectedSQL: "select $1::int, $2::text", | ||
expectedArgs: []any{int32(42), "foo"}, | ||
}, | ||
{ | ||
sql: "select :Abc::int, :b_4::text, :_c::int", | ||
namedArgs: pgx.NamedArgs{"Abc": int32(42), "b_4": "foo", "_c": int32(1)}, | ||
expectedSQL: "select $1::int, $2::text, $3::int", | ||
expectedArgs: []any{int32(42), "foo", int32(1)}, | ||
}, | ||
{ | ||
sql: "at end :", | ||
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, | ||
expectedSQL: "at end :", | ||
expectedArgs: []any{}, | ||
}, | ||
{ | ||
sql: "ignores without valid character after : foo bar", | ||
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, | ||
expectedSQL: "ignores without valid character after : foo bar", | ||
expectedArgs: []any{}, | ||
}, | ||
{ | ||
sql: "name cannot start with number :1 foo bar", | ||
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, | ||
expectedSQL: "name cannot start with number :1 foo bar", | ||
expectedArgs: []any{}, | ||
}, | ||
{ | ||
sql: `select *, ':foo' as ":bar" from users where id = :id`, | ||
namedArgs: pgx.NamedArgs{"id": int32(42)}, | ||
expectedSQL: `select *, ':foo' as ":bar" from users where id = $1`, | ||
expectedArgs: []any{int32(42)}, | ||
}, | ||
{ | ||
sql: `select * -- :foo | ||
from users -- :single line comments | ||
where id = :id;`, | ||
namedArgs: pgx.NamedArgs{"id": int32(42)}, | ||
expectedSQL: `select * -- :foo | ||
from users -- :single line comments | ||
where id = $1;`, | ||
expectedArgs: []any{int32(42)}, | ||
}, | ||
{ | ||
sql: `select * /* :multi line | ||
:comment | ||
*/ | ||
/* /* with :nesting */ */ | ||
from users | ||
where id = :id;`, | ||
namedArgs: pgx.NamedArgs{"id": int32(42)}, | ||
expectedSQL: `select * /* :multi line | ||
:comment | ||
*/ | ||
/* /* with :nesting */ */ | ||
from users | ||
where id = $1;`, | ||
expectedArgs: []any{int32(42)}, | ||
}, | ||
{ | ||
sql: "extra provided argument", | ||
namedArgs: pgx.NamedArgs{"extra": int32(1)}, | ||
expectedSQL: "extra provided argument", | ||
expectedArgs: []any{}, | ||
}, | ||
{ | ||
sql: ":missing argument", | ||
namedArgs: pgx.NamedArgs{}, | ||
expectedSQL: "$1 argument", | ||
expectedArgs: []any{nil}, | ||
}, | ||
|
||
// test comments and quotes | ||
} { | ||
sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args) | ||
require.NoError(t, err) | ||
assert.Equalf(t, tt.expectedSQL, sql, "%d", i) | ||
assert.Equalf(t, tt.expectedArgs, args, "%d", i) | ||
} | ||
} | ||
|
||
func TestStrictNamedArgsRewriteQuery2(t *testing.T) { | ||
t.Parallel() | ||
|
||
for i, tt := range []struct { | ||
sql string | ||
namedArgs pgx.StrictNamedArgs | ||
expectedSQL string | ||
expectedArgs []any | ||
isExpectedError bool | ||
}{ | ||
{ | ||
sql: "no arguments", | ||
namedArgs: pgx.StrictNamedArgs{}, | ||
expectedSQL: "no arguments", | ||
expectedArgs: []any{}, | ||
isExpectedError: false, | ||
}, | ||
{ | ||
sql: ":all :matches", | ||
namedArgs: pgx.StrictNamedArgs{"all": int32(1), "matches": int32(2)}, | ||
expectedSQL: "$1 $2", | ||
expectedArgs: []any{int32(1), int32(2)}, | ||
isExpectedError: false, | ||
}, | ||
{ | ||
sql: "extra provided argument", | ||
namedArgs: pgx.StrictNamedArgs{"extra": int32(1)}, | ||
isExpectedError: true, | ||
}, | ||
{ | ||
sql: ":missing argument", | ||
namedArgs: pgx.StrictNamedArgs{}, | ||
isExpectedError: true, | ||
}, | ||
} { | ||
sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, nil) | ||
if tt.isExpectedError { | ||
assert.Errorf(t, err, "%d", i) | ||
} else { | ||
require.NoErrorf(t, err, "%d", i) | ||
assert.Equalf(t, tt.expectedSQL, sql, "%d", i) | ||
assert.Equalf(t, tt.expectedArgs, args, "%d", i) | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this check to avoid panic when : is at the beginning of a line.