Skip to content

Commit

Permalink
Add support for custom JSON marshal and unmarshal.
Browse files Browse the repository at this point in the history
The Codec interface is now implemented by *pgtype.JSONCodec
and *pgtype.JSONBCodec instead of pgtype.JSONCodec and
pgtype.JSONBCodec, respectively. This is technically a breaking
change, but it is extremely unlikely that anyone is depending on this,
and if there is downstream breakage it is trivial to fix.

Fixes #2005.
  • Loading branch information
mitar authored and jackc committed May 18, 2024
1 parent e1b90cf commit 7328897
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 32 deletions.
47 changes: 30 additions & 17 deletions pgtype/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,20 @@ import (
"reflect"
)

type JSONCodec struct{}
type JSONCodec struct {
Marshal func(v any) ([]byte, error)
Unmarshal func(data []byte, v any) error
}

func (JSONCodec) FormatSupported(format int16) bool {
func (*JSONCodec) FormatSupported(format int16) bool {
return format == TextFormatCode || format == BinaryFormatCode
}

func (JSONCodec) PreferredFormat() int16 {
func (*JSONCodec) PreferredFormat() int16 {
return TextFormatCode
}

func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
func (c *JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
switch value.(type) {
case string:
return encodePlanJSONCodecEitherFormatString{}
Expand All @@ -44,7 +47,9 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod
//
// https://github.com/jackc/pgx/issues/1681
case json.Marshaler:
return encodePlanJSONCodecEitherFormatMarshal{}
return &encodePlanJSONCodecEitherFormatMarshal{
marshal: c.Marshal,
}
}

// Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the
Expand All @@ -61,7 +66,9 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod
}
}

return encodePlanJSONCodecEitherFormatMarshal{}
return &encodePlanJSONCodecEitherFormatMarshal{
marshal: c.Marshal,
}
}

type encodePlanJSONCodecEitherFormatString struct{}
Expand Down Expand Up @@ -96,10 +103,12 @@ func (encodePlanJSONCodecEitherFormatJSONRawMessage) Encode(value any, buf []byt
return buf, nil
}

type encodePlanJSONCodecEitherFormatMarshal struct{}
type encodePlanJSONCodecEitherFormatMarshal struct {
marshal func(v any) ([]byte, error)
}

func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) {
jsonBytes, err := json.Marshal(value)
func (e *encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) {
jsonBytes, err := e.marshal(value)
if err != nil {
return nil, err
}
Expand All @@ -108,7 +117,7 @@ func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (new
return buf, nil
}

func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch target.(type) {
case *string:
return scanPlanAnyToString{}
Expand Down Expand Up @@ -141,7 +150,9 @@ func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan
return &scanPlanSQLScanner{formatCode: format}
}

return scanPlanJSONToJSONUnmarshal{}
return &scanPlanJSONToJSONUnmarshal{
unmarshal: c.Unmarshal,
}
}

type scanPlanAnyToString struct{}
Expand Down Expand Up @@ -173,9 +184,11 @@ func (scanPlanJSONToBytesScanner) Scan(src []byte, dst any) error {
return scanner.ScanBytes(src)
}

type scanPlanJSONToJSONUnmarshal struct{}
type scanPlanJSONToJSONUnmarshal struct {
unmarshal func(data []byte, v any) error
}

func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
if src == nil {
dstValue := reflect.ValueOf(dst)
if dstValue.Kind() == reflect.Ptr {
Expand All @@ -193,10 +206,10 @@ func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
elem := reflect.ValueOf(dst).Elem()
elem.Set(reflect.Zero(elem.Type()))

return json.Unmarshal(src, dst)
return s.unmarshal(src, dst)
}

func (c JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
func (c *JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
if src == nil {
return nil, nil
}
Expand All @@ -206,12 +219,12 @@ func (c JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src
return dstBuf, nil
}

func (c JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
func (c *JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil {
return nil, nil
}

var dst any
err := json.Unmarshal(src, &dst)
err := c.Unmarshal(src, &dst)
return dst, err
}
27 changes: 27 additions & 0 deletions pgtype/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ import (
"database/sql/driver"
"encoding/json"
"errors"
"reflect"
"testing"

pgx "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxtest"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -224,3 +226,28 @@ func TestJSONCodecEncodeJSONMarshalerThatCanBeWrapped(t *testing.T) {
require.Equal(t, `{"custom":"thing"}`, jsonStr)
})
}

func TestJSONCodecCustomMarshal(t *testing.T) {
skipCockroachDB(t, "CockroachDB treats json as jsonb. This causes it to format differently than PostgreSQL.")

connTestRunner := defaultConnTestRunner
connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
conn.TypeMap().RegisterType(&pgtype.Type{
Name: "json", OID: pgtype.JSONOID, Codec: &pgtype.JSONCodec{
Marshal: func(v any) ([]byte, error) {
return []byte(`{"custom":"value"}`), nil
},
Unmarshal: func(data []byte, v any) error {
return json.Unmarshal([]byte(`{"custom":"value"}`), v)
},
}})
}

pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
// There is no space between "custom" and "value" in json type.
{map[string]any{"something": "else"}, new(string), isExpectedEq(`{"custom":"value"}`)},
{[]byte(`{"something":"else"}`), new(map[string]any), func(v any) bool {
return reflect.DeepEqual(v, map[string]any{"custom": "value"})
}},
})
}
28 changes: 15 additions & 13 deletions pgtype/jsonb.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,31 @@ package pgtype

import (
"database/sql/driver"
"encoding/json"
"fmt"
)

type JSONBCodec struct{}
type JSONBCodec struct {
Marshal func(v any) ([]byte, error)
Unmarshal func(data []byte, v any) error
}

func (JSONBCodec) FormatSupported(format int16) bool {
func (*JSONBCodec) FormatSupported(format int16) bool {
return format == TextFormatCode || format == BinaryFormatCode
}

func (JSONBCodec) PreferredFormat() int16 {
func (*JSONBCodec) PreferredFormat() int16 {
return TextFormatCode
}

func (JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
func (c *JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
switch format {
case BinaryFormatCode:
plan := JSONCodec{}.PlanEncode(m, oid, TextFormatCode, value)
plan := (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanEncode(m, oid, TextFormatCode, value)
if plan != nil {
return &encodePlanJSONBCodecBinaryWrapper{textPlan: plan}
}
case TextFormatCode:
return JSONCodec{}.PlanEncode(m, oid, format, value)
return (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanEncode(m, oid, format, value)
}

return nil
Expand All @@ -39,15 +41,15 @@ func (plan *encodePlanJSONBCodecBinaryWrapper) Encode(value any, buf []byte) (ne
return plan.textPlan.Encode(value, buf)
}

func (JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
func (c *JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format {
case BinaryFormatCode:
plan := JSONCodec{}.PlanScan(m, oid, TextFormatCode, target)
plan := (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanScan(m, oid, TextFormatCode, target)
if plan != nil {
return &scanPlanJSONBCodecBinaryUnwrapper{textPlan: plan}
}
case TextFormatCode:
return JSONCodec{}.PlanScan(m, oid, format, target)
return (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanScan(m, oid, format, target)
}

return nil
Expand All @@ -73,7 +75,7 @@ func (plan *scanPlanJSONBCodecBinaryUnwrapper) Scan(src []byte, dst any) error {
return plan.textPlan.Scan(src[1:], dst)
}

func (c JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
func (c *JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
if src == nil {
return nil, nil
}
Expand All @@ -100,7 +102,7 @@ func (c JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src
}
}

func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
func (c *JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil {
return nil, nil
}
Expand All @@ -122,6 +124,6 @@ func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (a
}

var dst any
err := json.Unmarshal(src, &dst)
err := c.Unmarshal(src, &dst)
return dst, err
}
26 changes: 26 additions & 0 deletions pgtype/jsonb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ package pgtype_test

import (
"context"
"encoding/json"
"reflect"
"testing"

pgx "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxtest"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -80,3 +83,26 @@ func TestJSONBCodecEncodeJSONMarshalerThatCanBeWrapped(t *testing.T) {
require.Equal(t, `{"custom": "thing"}`, jsonStr) // Note that unlike json, jsonb reformats the JSON string.
})
}

func TestJSONBCodecCustomMarshal(t *testing.T) {
connTestRunner := defaultConnTestRunner
connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
conn.TypeMap().RegisterType(&pgtype.Type{
Name: "jsonb", OID: pgtype.JSONBOID, Codec: &pgtype.JSONBCodec{
Marshal: func(v any) ([]byte, error) {
return []byte(`{"custom":"value"}`), nil
},
Unmarshal: func(data []byte, v any) error {
return json.Unmarshal([]byte(`{"custom":"value"}`), v)
},
}})
}

pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "jsonb", []pgxtest.ValueRoundTripTest{
// There is space between "custom" and "value" in jsonb type.
{map[string]any{"something": "else"}, new(string), isExpectedEq(`{"custom": "value"}`)},
{[]byte(`{"something":"else"}`), new(map[string]any), func(v any) bool {
return reflect.DeepEqual(v, map[string]any{"custom": "value"})
}},
})
}
4 changes: 2 additions & 2 deletions pgtype/pgtype_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ func initDefaultMap() {
defaultMap.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}})
defaultMap.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}})
defaultMap.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}})
defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: JSONCodec{}})
defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}})
defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: &JSONCodec{Marshal: json.Marshal, Unmarshal: json.Unmarshal}})
defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: &JSONBCodec{Marshal: json.Marshal, Unmarshal: json.Unmarshal}})
defaultMap.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}})
defaultMap.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}})
defaultMap.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}})
Expand Down

0 comments on commit 7328897

Please sign in to comment.