Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(pgdialect): postgres syntax errors for pointers and slices #877 #1111

Merged
merged 3 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions dialect/pgdialect/array.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,31 @@ func (d *Dialect) arrayElemAppender(typ reflect.Type) schema.AppenderFunc {
if typ.Implements(driverValuerType) {
return arrayAppendDriverValue
}
if typ == timeType {
return appendTimeElemValue
}

switch typ.Kind() {
case reflect.String:
return appendStringElemValue
case reflect.Slice:
if typ.Elem().Kind() == reflect.Uint8 {
return appendBytesElemValue
}
case reflect.Ptr:
return schema.PtrAppender(d.arrayElemAppender(typ.Elem()))
}
return schema.Appender(d, typ)
}

func appendTimeElemValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
ts := v.Convert(timeType).Interface().(time.Time)

b = append(b, '"')
b = appendTime(b, ts)
return append(b, '"')
}

func appendStringElemValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
return appendStringElem(b, v.String())
}
Expand Down
50 changes: 38 additions & 12 deletions dialect/pgdialect/array_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,23 @@ type arrayParser struct {

elem []byte
err error

isJson bool
}

func newArrayParser(b []byte) *arrayParser {
p := new(arrayParser)

if len(b) < 2 || b[0] != '{' || b[len(b)-1] != '}' {
if b[0] == 'n' {
p.p.Reset(nil)
return p
}

if len(b) < 2 || (b[0] != '{' && b[0] != '[') || (b[len(b)-1] != '}' && b[len(b)-1] != ']') {
p.err = fmt.Errorf("pgdialect: can't parse array: %q", b)
return p
}
p.isJson = b[0] == '['

p.p.Reset(b[1 : len(b)-1])
return p
Expand Down Expand Up @@ -51,7 +59,7 @@ func (p *arrayParser) readNext() error {
}

switch ch {
case '}':
case '}', ']':
return io.EOF
case '"':
b, err := p.p.ReadSubstring(ch)
Expand All @@ -78,16 +86,34 @@ func (p *arrayParser) readNext() error {
p.elem = rng
return nil
default:
lit := p.p.ReadLiteral(ch)
if bytes.Equal(lit, []byte("NULL")) {
lit = nil
}

if p.p.Peek() == ',' {
p.p.Advance()
if ch == '{' && p.isJson {
json, err := p.p.ReadJSON()
if err != nil {
return err
}

for {
if p.p.Peek() == ',' || p.p.Peek() == ' ' {
p.p.Advance()
} else {
break
}
}

p.elem = json
return nil
} else {
lit := p.p.ReadLiteral(ch)
if bytes.Equal(lit, []byte("NULL")) {
lit = nil
}

if p.p.Peek() == ',' {
p.p.Advance()
}

p.elem = lit
return nil
}

p.elem = lit
return nil
}
}
4 changes: 4 additions & 0 deletions dialect/pgdialect/array_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ func TestArrayParser(t *testing.T) {
{`{"1","2"}`, []string{"1", "2"}},
{`{"{1}","{2}"}`, []string{"{1}", "{2}"}},
{`{[1,2),[3,4)}`, []string{"[1,2)", "[3,4)"}},

{`[]`, []string{}},
{`[{"'\"[]"}]`, []string{`{"'\"[]"}`}},
{`[{"id": 1}, {"id":2, "name":"bob"}]`, []string{"{\"id\": 1}", "{\"id\":2, \"name\":\"bob\"}"}},
}

for i, test := range tests {
Expand Down
54 changes: 54 additions & 0 deletions dialect/pgdialect/array_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package pgdialect

import (
"testing"

"github.com/uptrace/bun/schema"
)

func ptr[T any](v T) *T {
return &v
}

func TestArrayAppend(t *testing.T) {
tcases := []struct {
input interface{}
out string
}{
{
input: []byte{1, 2},
out: `'{1,2}'`,
},
{
input: []*byte{ptr(byte(1)), ptr(byte(2))},
out: `'{1,2}'`,
},
{
input: []int{1, 2},
out: `'{1,2}'`,
},
{
input: []*int{ptr(1), ptr(2)},
out: `'{1,2}'`,
},
{
input: []string{"foo", "bar"},
out: `'{"foo","bar"}'`,
},
{
input: []*string{ptr("foo"), ptr("bar")},
out: `'{"foo","bar"}'`,
},
}

for _, tcase := range tcases {
out, err := Array(tcase.input).AppendQuery(schema.NewFormatter(New()), []byte{})
if err != nil {
t.Fatal(err)
}

if string(out) != tcase.out {
t.Errorf("expected output to be %s, was %s", tcase.out, string(out))
}
}
}
36 changes: 36 additions & 0 deletions dialect/pgdialect/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,39 @@ func (p *pgparser) ReadRange(ch byte) ([]byte, error) {

return p.buf, nil
}

func (p *pgparser) ReadJSON() ([]byte, error) {
p.Unread()

c, err := p.ReadByte()
if err != nil {
return nil, err
}

p.buf = p.buf[:0]

depth := 0
for {
switch c {
case '{':
depth++
case '}':
depth--
}

p.buf = append(p.buf, c)

if depth == 0 {
break
}

next, err := p.ReadByte()
if err != nil {
return nil, err
}

c = next
}

return p.buf, nil
}
4 changes: 4 additions & 0 deletions dialect/pgdialect/sqltype.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ func fieldSQLType(field *schema.Field) string {
}

func sqlType(typ reflect.Type) string {
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}

switch typ {
case nullStringType: // typ.Kind() == reflect.Struct, test for exact match
return sqltype.VarChar
Expand Down
101 changes: 101 additions & 0 deletions internal/dbtest/pg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect"
"github.com/uptrace/bun/driver/pgdriver"
"github.com/uptrace/bun/schema"
)

func TestPostgresArray(t *testing.T) {
Expand All @@ -25,16 +26,20 @@ func TestPostgresArray(t *testing.T) {
Array1 []string `bun:",array"`
Array2 *[]string `bun:",array"`
Array3 *[]string `bun:",array"`
Array4 []*string `bun:",array"`
}

db := pg(t)
t.Cleanup(func() { db.Close() })
mustResetModel(t, ctx, db, (*Model)(nil))

str1 := "hello"
str2 := "world"
model1 := &Model{
ID: 123,
Array1: []string{"one", "two", "three"},
Array2: &[]string{"hello", "world"},
Array4: []*string{&str1, &str2},
}
_, err := db.NewInsert().Model(model1).Exec(ctx)
require.NoError(t, err)
Expand All @@ -56,6 +61,12 @@ func TestPostgresArray(t *testing.T) {
Scan(ctx, pgdialect.Array(&strs))
require.NoError(t, err)
require.Nil(t, strs)

err = db.NewSelect().Model((*Model)(nil)).
Column("array4").
Scan(ctx, pgdialect.Array(&strs))
require.NoError(t, err)
require.Equal(t, []string{"hello", "world"}, strs)
}

func TestPostgresArrayQuote(t *testing.T) {
Expand Down Expand Up @@ -456,6 +467,7 @@ func TestPostgresTimeArray(t *testing.T) {
Array1 []time.Time `bun:",array"`
Array2 *[]time.Time `bun:",array"`
Array3 *[]time.Time `bun:",array"`
Array4 []*time.Time `bun:",array"`
}

db := pg(t)
Expand All @@ -471,6 +483,7 @@ func TestPostgresTimeArray(t *testing.T) {
ID: 123,
Array1: []time.Time{time1, time2, time3},
Array2: &[]time.Time{time1, time2, time3},
Array4: []*time.Time{&time1, &time2, &time3},
}
_, err := db.NewInsert().Model(model1).Exec(ctx)
require.NoError(t, err)
Expand Down Expand Up @@ -498,6 +511,12 @@ func TestPostgresTimeArray(t *testing.T) {
Scan(ctx, pgdialect.Array(&times))
require.NoError(t, err)
require.Nil(t, times)

err = db.NewSelect().Model((*Model)(nil)).
Column("array4").
Scan(ctx, pgdialect.Array(&times))
require.NoError(t, err)
require.Equal(t, 3, len(model1.Array4))
}

func TestPostgresOnConflictDoUpdate(t *testing.T) {
Expand Down Expand Up @@ -877,3 +896,85 @@ func TestPostgresMultiRange(t *testing.T) {
err = db.NewSelect().Model(out).Scan(ctx)
require.NoError(t, err)
}

type UserID struct {
ID string
}

func (u UserID) AppendQuery(fmter schema.Formatter, b []byte) ([]byte, error) {
v := []byte(`"` + u.ID + `"`)
return append(b, v...), nil
}

var _ schema.QueryAppender = (*UserID)(nil)

func (r *UserID) Scan(anySrc any) (err error) {
src, ok := anySrc.([]byte)
if !ok {
return fmt.Errorf("pgdialect: Range can't scan %T", anySrc)
}

r.ID = string(src)
return nil
}

var _ sql.Scanner = (*UserID)(nil)

func TestPostgresJSONB(t *testing.T) {
type Item struct {
Name string `json:"name"`
}
type Model struct {
ID int64 `bun:",pk,autoincrement"`
Item Item `bun:",type:jsonb"`
ItemPtr *Item `bun:",type:jsonb"`
Items []Item `bun:",type:jsonb"`
ItemsP []*Item `bun:",type:jsonb"`
ItemsNull []*Item `bun:",type:jsonb"`
TextItemA []UserID `bun:"type:text[]"`
}

db := pg(t)
t.Cleanup(func() { db.Close() })
mustResetModel(t, ctx, db, (*Model)(nil))

item1 := Item{Name: "one"}
item2 := Item{Name: "two"}
uid1 := UserID{ID: "1"}
uid2 := UserID{ID: "2"}
model1 := &Model{
ID: 123,
Item: item1,
ItemPtr: &item2,
Items: []Item{item1, item2},
ItemsP: []*Item{&item1, &item2},
ItemsNull: nil,
TextItemA: []UserID{uid1, uid2},
}
_, err := db.NewInsert().Model(model1).Exec(ctx)
require.NoError(t, err)

model2 := new(Model)
err = db.NewSelect().Model(model2).Scan(ctx)
require.NoError(t, err)
require.Equal(t, model1, model2)

var items []Item
err = db.NewSelect().Model((*Model)(nil)).
Column("items").
Scan(ctx, pgdialect.Array(&items))
require.NoError(t, err)
require.Equal(t, []Item{item1, item2}, items)

err = db.NewSelect().Model((*Model)(nil)).
Column("itemsp").
Scan(ctx, pgdialect.Array(&items))
require.NoError(t, err)
require.Equal(t, []Item{item1, item2}, items)

err = db.NewSelect().Model((*Model)(nil)).
Column("items_null").
Scan(ctx, pgdialect.Array(&items))
require.NoError(t, err)
require.Equal(t, []Item{}, items)
}
Loading