From 5747f37d9cd5ba79cfe6209f79b0c1b074f11fe6 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 15 Aug 2024 18:20:07 -0500 Subject: [PATCH] Fix: Scan and encode types with underlying types of arrays Rather than special case the reported issue with UUID and [16]byte, this commit allows the system to find the underlying type of any type that is an array. fixes https://github.com/jackc/pgx/issues/2107 --- pgtype/pgtype.go | 26 +++++++++++++++++++++----- pgtype/uuid_test.go | 12 ++++++++++++ 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 30f6bdef5..bdd9f05ca 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -573,17 +573,24 @@ func TryFindUnderlyingTypeScanPlan(dst any) (plan WrappedScanPlanNextSetter, nex elemValue = dstValue.Elem() } nextDstType := elemKindToPointerTypes[elemValue.Kind()] - if nextDstType == nil && elemValue.Kind() == reflect.Slice { - if elemValue.Type().Elem().Kind() == reflect.Uint8 { - var v *[]byte - nextDstType = reflect.TypeOf(v) + if nextDstType == nil { + if elemValue.Kind() == reflect.Slice { + if elemValue.Type().Elem().Kind() == reflect.Uint8 { + var v *[]byte + nextDstType = reflect.TypeOf(v) + } + } + + // Get underlying type of any array. + // https://github.com/jackc/pgx/issues/2107 + if elemValue.Kind() == reflect.Array { + nextDstType = reflect.PointerTo(reflect.ArrayOf(elemValue.Len(), elemValue.Type().Elem())) } } if nextDstType != nil && dstValue.Type() != nextDstType && dstValue.CanConvert(nextDstType) { return &underlyingTypeScanPlan{dstType: dstValue.Type(), nextDstType: nextDstType}, dstValue.Convert(nextDstType).Interface(), true } - } return nil, nil, false @@ -1423,6 +1430,15 @@ func TryWrapFindUnderlyingTypeEncodePlan(value any) (plan WrappedEncodePlanNextS return &underlyingTypeEncodePlan{nextValueType: byteSliceType}, refValue.Convert(byteSliceType).Interface(), true } + // Get underlying type of any array. + // https://github.com/jackc/pgx/issues/2107 + if refValue.Kind() == reflect.Array { + underlyingArrayType := reflect.ArrayOf(refValue.Len(), refValue.Type().Elem()) + if refValue.Type() != underlyingArrayType { + return &underlyingTypeEncodePlan{nextValueType: underlyingArrayType}, refValue.Convert(underlyingArrayType).Interface(), true + } + } + return nil, nil, false } diff --git a/pgtype/uuid_test.go b/pgtype/uuid_test.go index 2dc258b1b..1c701e747 100644 --- a/pgtype/uuid_test.go +++ b/pgtype/uuid_test.go @@ -10,6 +10,8 @@ import ( "github.com/stretchr/testify/require" ) +type renamedUUIDByteArray [16]byte + func TestUUIDCodec(t *testing.T) { pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "uuid", []pgxtest.ValueRoundTripTest{ { @@ -43,6 +45,16 @@ func TestUUIDCodec(t *testing.T) { new(pgtype.UUID), isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), }, + { + renamedUUIDByteArray{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + new(pgtype.UUID), + isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), + }, + { + []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + new(renamedUUIDByteArray), + isExpectedEq(renamedUUIDByteArray{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}), + }, { []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, new(pgtype.UUID),