Skip to content

Commit

Permalink
Fixes Postgres money and xml arrays in sync (#2965)
Browse files Browse the repository at this point in the history
  • Loading branch information
alishakawaguchi authored Nov 20, 2024
1 parent 71c21a4 commit f686279
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 145 deletions.
245 changes: 140 additions & 105 deletions internal/postgres/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ import (
"encoding/json"
"fmt"
"reflect"
"strconv"
"strings"

"github.com/gofrs/uuid"
"github.com/jackc/pgx/v5/pgtype"
)

Expand All @@ -20,7 +20,71 @@ type PgxArray[T any] struct {
// properly handles scanning postgres arrays
func (a *PgxArray[T]) Scan(src any) error {
m := pgtype.NewMap()
pgt, ok := m.TypeForName(strings.ToLower(a.colDataType))
// Register money types
m.RegisterType(&pgtype.Type{
Name: "money",
OID: 790,
Codec: pgtype.TextCodec{},
})
m.RegisterType(&pgtype.Type{
Name: "_money",
OID: 791,
Codec: &pgtype.ArrayCodec{
ElementType: &pgtype.Type{
Name: "money",
OID: 790,
Codec: pgtype.TextCodec{},
},
},
})

// Register UUID types
m.RegisterType(&pgtype.Type{
Name: "uuid",
OID: 2950, // UUID type OID
Codec: pgtype.TextCodec{},
})

m.RegisterType(&pgtype.Type{
Name: "_uuid",
OID: 2951,
Codec: &pgtype.ArrayCodec{
ElementType: &pgtype.Type{
Name: "uuid",
OID: 2950,
Codec: pgtype.TextCodec{},
},
},
})

// Register XML type
m.RegisterType(&pgtype.Type{
Name: "xml",
OID: 142,
Codec: pgtype.TextCodec{},
})

m.RegisterType(&pgtype.Type{
Name: "_xml",
OID: 143,
Codec: &pgtype.ArrayCodec{
ElementType: &pgtype.Type{
Name: "xml",
OID: 142,
Codec: pgtype.TextCodec{},
},
},
})

// Try to get the type by OID first if colDataType is numeric
var pgt *pgtype.Type
var ok bool

if oid, err := strconv.Atoi(a.colDataType); err == nil {
pgt, ok = m.TypeForOID(uint32(oid)) //nolint:gosec
} else {
pgt, ok = m.TypeForName(strings.ToLower(a.colDataType))
}
if !ok {
return fmt.Errorf("cannot convert to sql.Scanner: cannot find registered type for %s", a.colDataType)
}
Expand All @@ -34,13 +98,49 @@ func (a *PgxArray[T]) Scan(src any) error {
case []byte:
bufSrc = src
default:
bufSrc = []byte(fmt.Sprint(bufSrc))
bufSrc = []byte(fmt.Sprint(src))
}
}

return m.Scan(pgt.OID, pgtype.TextFormatCode, bufSrc, v)
}

type NullableJSON struct {
json.RawMessage
Valid bool
}

// Nullable JSON scanner
func (n *NullableJSON) Scan(value any) error {
if value == nil {
n.RawMessage, n.Valid = nil, false
return nil
}

n.Valid = true
switch v := value.(type) {
case []byte:
n.RawMessage = json.RawMessage(v)
return nil
case string:
n.RawMessage = json.RawMessage(v)
return nil
default:
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", value, n.RawMessage)
}
}

func (n *NullableJSON) Unmarshal() (any, error) {
if !n.Valid {
return nil, nil
}
var js any
if err := json.Unmarshal(n.RawMessage, &js); err != nil {
return nil, err
}
return js, nil
}

func SqlRowToPgTypesMap(rows *sql.Rows) (map[string]any, error) {
columnNames, err := rows.Columns()
if err != nil {
Expand All @@ -52,51 +152,51 @@ func SqlRowToPgTypesMap(rows *sql.Rows) (map[string]any, error) {
return nil, err
}

columnDbTypes := []string{}
for _, c := range cTypes {
columnDbTypes = append(columnDbTypes, c.DatabaseTypeName())
}

values := make([]any, len(columnNames))
valuesWrapped := make([]any, 0, len(columnNames))
scanTargets := make([]any, 0, len(columnNames))
for i := range values {
ctype := cTypes[i]
if IsPgArrayType(ctype.DatabaseTypeName()) {
// use custom array type scanner
values[i] = &PgxArray[any]{
colDataType: ctype.DatabaseTypeName(),
}
valuesWrapped = append(valuesWrapped, values[i])
} else {
valuesWrapped = append(valuesWrapped, &values[i])
dbTypeName := cTypes[i].DatabaseTypeName()
switch {
case isXmlDataType(dbTypeName):
values[i] = &sql.NullString{}
scanTargets = append(scanTargets, values[i])
case IsJsonPgDataType(dbTypeName):
values[i] = &NullableJSON{}
scanTargets = append(scanTargets, values[i])
case isPgxPgArrayType(dbTypeName):
values[i] = &PgxArray[any]{colDataType: dbTypeName}
scanTargets = append(scanTargets, values[i])
default:
scanTargets = append(scanTargets, &values[i])
}
}
if err := rows.Scan(valuesWrapped...); err != nil {
if err := rows.Scan(scanTargets...); err != nil {
return nil, err
}

jObj := parsePgRowValues(values, columnNames, columnDbTypes)
jObj := parsePgRowValues(values, columnNames)
return jObj, nil
}

func parsePgRowValues(values []any, columnNames, columnDbTypes []string) map[string]any {
func parsePgRowValues(values []any, columnNames []string) map[string]any {
jObj := map[string]any{}
for i, v := range values {
col := columnNames[i]
ctype := columnDbTypes[i]
switch t := v.(type) {
case []byte:
if IsJsonPgDataType(ctype) {
var js any
if err := json.Unmarshal(t, &js); err == nil {
jObj[col] = js
continue
}
} else if isBinaryDataType(ctype) {
jObj[col] = t
continue
case nil:
jObj[col] = t
case *sql.NullString:
var val any = nil
if t.Valid {
val = t.String
}
jObj[col] = val
case *NullableJSON:
js, err := t.Unmarshal()
if err != nil {
js = t
}
jObj[col] = string(t)
jObj[col] = js
case *PgxArray[any]:
jObj[col] = pgArrayToGoSlice(t)
default:
Expand All @@ -106,100 +206,35 @@ func parsePgRowValues(values []any, columnNames, columnDbTypes []string) map[str
return jObj
}

func isBinaryDataType(colDataType string) bool {
return strings.EqualFold(colDataType, "bytea")
func isXmlDataType(colDataType string) bool {
return strings.EqualFold(colDataType, "xml")
}

func IsJsonPgDataType(dataType string) bool {
return strings.EqualFold(dataType, "json") || strings.EqualFold(dataType, "jsonb")
}

func isJsonArrayPgDataType(dataType string) bool {
return strings.EqualFold(dataType, "_json") || strings.EqualFold(dataType, "_jsonb")
}

func isPgUuidArray(colDataType string) bool {
return strings.EqualFold(colDataType, "_uuid")
}

func isPgXmlArray(colDataType string) bool {
return strings.EqualFold(colDataType, "_xml")
}

func IsPgArrayType(dbTypeName string) bool {
return strings.HasPrefix(dbTypeName, "_")
func isPgxPgArrayType(dbTypeName string) bool {
return strings.HasPrefix(dbTypeName, "_") || dbTypeName == "791"
}

func IsPgArrayColumnDataType(colDataType string) bool {
return strings.Contains(colDataType, "[]")
return strings.HasSuffix(colDataType, "[]")
}

func pgArrayToGoSlice(array *PgxArray[any]) any {
if array.Elements == nil {
return nil
}
goSlice := convertArrayToGoType(array)

dim := array.Dimensions()
if len(dim) > 1 {
dims := []int{}
for _, d := range dim {
dims = append(dims, int(d.Length))
}
return CreateMultiDimSlice(dims, goSlice)
}
return goSlice
}

func convertArrayToGoType(array *PgxArray[any]) []any {
if !isJsonArrayPgDataType(array.colDataType) {
if isPgUuidArray(array.colDataType) {
return convertBytesToUuidSlice(array.Elements)
}
if isPgXmlArray(array.colDataType) {
return convertBytesToStringSlice(array.Elements)
}
return array.Elements
}

var newArray []any
for _, e := range array.Elements {
jsonBits, ok := e.([]byte)
if !ok {
newArray = append(newArray, e)
continue
}

var js any
err := json.Unmarshal(jsonBits, &js)
if err != nil {
newArray = append(newArray, e)
} else {
newArray = append(newArray, js)
}
}

return newArray
}

func convertBytesToStringSlice(bytes []any) []any {
stringSlice := []any{}
for _, el := range bytes {
if bits, ok := el.([]byte); ok {
stringSlice = append(stringSlice, string(bits))
}
}
return stringSlice
}

func convertBytesToUuidSlice(uuids []any) []any {
uuidSlice := []any{}
for _, el := range uuids {
if id, ok := el.([16]uint8); ok {
uuidSlice = append(uuidSlice, uuid.UUID(id).String())
}
return CreateMultiDimSlice(dims, array.Elements)
}
return uuidSlice
return array.Elements
}

// converts flat slice to multi-dimensional slice
Expand Down
Loading

0 comments on commit f686279

Please sign in to comment.