Skip to content

Commit a22564d

Browse files
committed
Add ScanLocation to pgtype.TimestamptzCodec
If ScanLocation is set, it will be used to convert the time to the given location when scanning from the database. The Codec interface is now implemented by *pgtype.TimestamptzCodec instead of pgtype.TimestamptzCodec. This is technically a breaking change, but it is extremely unlikely that anyone is depending on this, and if there is downstream breakage it is trivial to fix. #1195 #1945
1 parent 1b6227a commit a22564d

File tree

3 files changed

+60
-15
lines changed

3 files changed

+60
-15
lines changed

pgtype/pgtype_default.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func initDefaultMap() {
8282
defaultMap.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}})
8383
defaultMap.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}})
8484
defaultMap.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}})
85-
defaultMap.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}})
85+
defaultMap.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: &TimestamptzCodec{}})
8686
defaultMap.RegisterType(&Type{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}})
8787
defaultMap.RegisterType(&Type{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}})
8888
defaultMap.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}})

pgtype/timestamptz.go

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func (tstz *Timestamptz) Scan(src any) error {
5454

5555
switch src := src.(type) {
5656
case string:
57-
return scanPlanTextTimestamptzToTimestamptzScanner{}.Scan([]byte(src), tstz)
57+
return (&scanPlanTextTimestamptzToTimestamptzScanner{}).Scan([]byte(src), tstz)
5858
case time.Time:
5959
*tstz = Timestamptz{Time: src, Valid: true}
6060
return nil
@@ -124,17 +124,21 @@ func (tstz *Timestamptz) UnmarshalJSON(b []byte) error {
124124
return nil
125125
}
126126

127-
type TimestamptzCodec struct{}
127+
type TimestamptzCodec struct {
128+
// ScanLocation is the location to return scanned timestamptz values in. This does not change the instant in time that
129+
// the timestamptz represents.
130+
ScanLocation *time.Location
131+
}
128132

129-
func (TimestamptzCodec) FormatSupported(format int16) bool {
133+
func (*TimestamptzCodec) FormatSupported(format int16) bool {
130134
return format == TextFormatCode || format == BinaryFormatCode
131135
}
132136

133-
func (TimestamptzCodec) PreferredFormat() int16 {
137+
func (*TimestamptzCodec) PreferredFormat() int16 {
134138
return BinaryFormatCode
135139
}
136140

137-
func (TimestamptzCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
141+
func (*TimestamptzCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
138142
if _, ok := value.(TimestamptzValuer); !ok {
139143
return nil
140144
}
@@ -220,27 +224,27 @@ func (encodePlanTimestamptzCodecText) Encode(value any, buf []byte) (newBuf []by
220224
return buf, nil
221225
}
222226

223-
func (TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
227+
func (c *TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
224228

225229
switch format {
226230
case BinaryFormatCode:
227231
switch target.(type) {
228232
case TimestamptzScanner:
229-
return scanPlanBinaryTimestamptzToTimestamptzScanner{}
233+
return &scanPlanBinaryTimestamptzToTimestamptzScanner{location: c.ScanLocation}
230234
}
231235
case TextFormatCode:
232236
switch target.(type) {
233237
case TimestamptzScanner:
234-
return scanPlanTextTimestamptzToTimestamptzScanner{}
238+
return &scanPlanTextTimestamptzToTimestamptzScanner{location: c.ScanLocation}
235239
}
236240
}
237241

238242
return nil
239243
}
240244

241-
type scanPlanBinaryTimestamptzToTimestamptzScanner struct{}
245+
type scanPlanBinaryTimestamptzToTimestamptzScanner struct{ location *time.Location }
242246

243-
func (scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error {
247+
func (plan *scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error {
244248
scanner := (dst).(TimestamptzScanner)
245249

246250
if src == nil {
@@ -264,15 +268,18 @@ func (scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) e
264268
microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000,
265269
(microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000),
266270
)
271+
if plan.location != nil {
272+
tim = tim.In(plan.location)
273+
}
267274
tstz = Timestamptz{Time: tim, Valid: true}
268275
}
269276

270277
return scanner.ScanTimestamptz(tstz)
271278
}
272279

273-
type scanPlanTextTimestamptzToTimestamptzScanner struct{}
280+
type scanPlanTextTimestamptzToTimestamptzScanner struct{ location *time.Location }
274281

275-
func (scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error {
282+
func (plan *scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error {
276283
scanner := (dst).(TimestamptzScanner)
277284

278285
if src == nil {
@@ -312,13 +319,17 @@ func (scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) err
312319
tim = time.Date(year, tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), tim.Location())
313320
}
314321

322+
if plan.location != nil {
323+
tim = tim.In(plan.location)
324+
}
325+
315326
tstz = Timestamptz{Time: tim, Valid: true}
316327
}
317328

318329
return scanner.ScanTimestamptz(tstz)
319330
}
320331

321-
func (c TimestamptzCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
332+
func (c *TimestamptzCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
322333
if src == nil {
323334
return nil, nil
324335
}
@@ -336,7 +347,7 @@ func (c TimestamptzCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int1
336347
return tstz.Time, nil
337348
}
338349

339-
func (c TimestamptzCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
350+
func (c *TimestamptzCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
340351
if src == nil {
341352
return nil, nil
342353
}

pgtype/timestamptz_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,40 @@ func TestTimestamptzCodec(t *testing.T) {
3838
})
3939
}
4040

41+
func TestTimestamptzCodecWithLocationUTC(t *testing.T) {
42+
skipCockroachDB(t, "Server does not support infinite timestamps (see https://github.com/cockroachdb/cockroach/issues/41564)")
43+
44+
connTestRunner := defaultConnTestRunner
45+
connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
46+
conn.TypeMap().RegisterType(&pgtype.Type{
47+
Name: "timestamptz",
48+
OID: pgtype.TimestamptzOID,
49+
Codec: &pgtype.TimestamptzCodec{ScanLocation: time.UTC},
50+
})
51+
}
52+
53+
pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, nil, "timestamptz", []pgxtest.ValueRoundTripTest{
54+
{time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEq(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC))},
55+
})
56+
}
57+
58+
func TestTimestamptzCodecWithLocationLocal(t *testing.T) {
59+
skipCockroachDB(t, "Server does not support infinite timestamps (see https://github.com/cockroachdb/cockroach/issues/41564)")
60+
61+
connTestRunner := defaultConnTestRunner
62+
connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
63+
conn.TypeMap().RegisterType(&pgtype.Type{
64+
Name: "timestamptz",
65+
OID: pgtype.TimestamptzOID,
66+
Codec: &pgtype.TimestamptzCodec{ScanLocation: time.Local},
67+
})
68+
}
69+
70+
pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, nil, "timestamptz", []pgxtest.ValueRoundTripTest{
71+
{time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEq(time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local))},
72+
})
73+
}
74+
4175
// https://github.com/jackc/pgx/v4/pgtype/pull/128
4276
func TestTimestamptzTranscodeBigTimeBinary(t *testing.T) {
4377
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {

0 commit comments

Comments
 (0)