@@ -104,10 +104,13 @@ type Pool struct {
104
104
closeOnce sync.Once
105
105
closeChan chan struct {}
106
106
107
- autoLoadTypeNames []string
108
- reuseTypeMap bool
109
- autoLoadMutex * sync.Mutex
110
- autoLoadTypes []* pgtype.Type
107
+ autoLoadTypeNames []string
108
+ reuseTypeMap bool
109
+ autoLoadMutex * sync.Mutex
110
+ autoLoadTypes []* pgtype.Type
111
+ customRegistrationMap map [string ]CustomRegistrationFunction
112
+ customRegistrationMutex * sync.Mutex
113
+ customRegistrationOidMap map [string ]uint32
111
114
}
112
115
113
116
// Config is the configuration struct for creating a pool. It must be created by [ParseConfig] and then it can be
@@ -198,6 +201,10 @@ func New(ctx context.Context, connString string) (*Pool, error) {
198
201
return NewWithConfig (ctx , config )
199
202
}
200
203
204
+ // CustomRegistrationFunction is capable of registering whatever is necessary for
205
+ // a custom type. It is provided with the backend's OID for this type.
206
+ type CustomRegistrationFunction func (ctx context.Context , m * pgtype.Map , oid uint32 ) error
207
+
201
208
// NewWithConfig creates a new Pool. config must have been created by [ParseConfig].
202
209
func NewWithConfig (ctx context.Context , config * Config ) (* Pool , error ) {
203
210
// Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from
@@ -207,23 +214,25 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
207
214
}
208
215
209
216
p := & Pool {
210
- config : config ,
211
- beforeConnect : config .BeforeConnect ,
212
- afterConnect : config .AfterConnect ,
213
- autoLoadTypeNames : config .AutoLoadTypes ,
214
- reuseTypeMap : config .ReuseTypeMaps ,
215
- beforeAcquire : config .BeforeAcquire ,
216
- afterRelease : config .AfterRelease ,
217
- beforeClose : config .BeforeClose ,
218
- minConns : config .MinConns ,
219
- maxConns : config .MaxConns ,
220
- maxConnLifetime : config .MaxConnLifetime ,
221
- maxConnLifetimeJitter : config .MaxConnLifetimeJitter ,
222
- maxConnIdleTime : config .MaxConnIdleTime ,
223
- healthCheckPeriod : config .HealthCheckPeriod ,
224
- healthCheckChan : make (chan struct {}, 1 ),
225
- closeChan : make (chan struct {}),
226
- autoLoadMutex : new (sync.Mutex ),
217
+ config : config ,
218
+ beforeConnect : config .BeforeConnect ,
219
+ afterConnect : config .AfterConnect ,
220
+ autoLoadTypeNames : config .AutoLoadTypes ,
221
+ reuseTypeMap : config .ReuseTypeMaps ,
222
+ beforeAcquire : config .BeforeAcquire ,
223
+ afterRelease : config .AfterRelease ,
224
+ beforeClose : config .BeforeClose ,
225
+ minConns : config .MinConns ,
226
+ maxConns : config .MaxConns ,
227
+ maxConnLifetime : config .MaxConnLifetime ,
228
+ maxConnLifetimeJitter : config .MaxConnLifetimeJitter ,
229
+ maxConnIdleTime : config .MaxConnIdleTime ,
230
+ healthCheckPeriod : config .HealthCheckPeriod ,
231
+ healthCheckChan : make (chan struct {}, 1 ),
232
+ closeChan : make (chan struct {}),
233
+ autoLoadMutex : new (sync.Mutex ),
234
+ customRegistrationMap : make (map [string ]CustomRegistrationFunction ),
235
+ customRegistrationMutex : new (sync.Mutex ),
227
236
}
228
237
229
238
if t , ok := config .ConnConfig .Tracer .(AcquireTracer ); ok {
@@ -265,6 +274,24 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
265
274
}
266
275
}
267
276
277
+ if len (p .customRegistrationMap ) > 0 {
278
+ oidMap , err := p .getOidMapForCustomRegistration (ctx , conn )
279
+ if err != nil {
280
+ conn .Close (ctx )
281
+ return nil , fmt .Errorf ("While retrieving OIDs for custom type registration: %w" , err )
282
+ }
283
+ for typeName , f := range p .customRegistrationMap {
284
+ if oid , exists := oidMap [typeName ]; exists {
285
+ if err := f (ctx , conn .TypeMap (), oid ); err != nil {
286
+ return nil , err
287
+ }
288
+ } else {
289
+ return nil , fmt .Errorf ("Type %q does not have an associated OID." , typeName )
290
+ }
291
+ }
292
+
293
+ }
294
+
268
295
if p .autoLoadTypeNames != nil && len (p .autoLoadTypeNames ) > 0 {
269
296
types , err := p .loadTypes (ctx , conn , p .autoLoadTypeNames )
270
297
if err != nil {
@@ -315,6 +342,51 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
315
342
return p , nil
316
343
}
317
344
345
+ func (p * Pool ) getOidMapForCustomRegistration (ctx context.Context , conn * pgx.Conn ) (map [string ]uint32 , error ) {
346
+ if p .reuseTypeMap {
347
+ p .customRegistrationMutex .Lock ()
348
+ defer p .customRegistrationMutex .Unlock ()
349
+ if p .customRegistrationOidMap != nil {
350
+ return p .customRegistrationOidMap , nil
351
+ }
352
+ oidMap , err := p .fetchOidMapForCustomRegistration (ctx , conn )
353
+ if err != nil {
354
+ return nil , err
355
+ }
356
+ p .customRegistrationOidMap = oidMap
357
+ return oidMap , nil
358
+ }
359
+ // Avoid needing to acquire the mutex and allow connections to initialise in parallel
360
+ // if we have chosen to not reuse the type mapping
361
+ return p .fetchOidMapForCustomRegistration (ctx , conn )
362
+ }
363
+
364
+ func (p * Pool ) fetchOidMapForCustomRegistration (ctx context.Context , conn * pgx.Conn ) (map [string ]uint32 , error ) {
365
+ sql := `
366
+ SELECT oid, typname
367
+ FROM pg_type
368
+ WHERE typname = ANY($1)`
369
+ result := make (map [string ]uint32 )
370
+ typeNames := make ([]string , 0 , len (p .customRegistrationMap ))
371
+ for typeName := range p .customRegistrationMap {
372
+ typeNames = append (typeNames , typeName )
373
+ }
374
+ rows , err := conn .Query (ctx , sql , typeNames )
375
+ if err != nil {
376
+ return nil , fmt .Errorf ("While collecting OIDs for custom registrations: %w" , err )
377
+ }
378
+ defer rows .Close ()
379
+ var typeName string
380
+ var oid uint32
381
+ for rows .Next () {
382
+ if err := rows .Scan (& typeName , & oid ); err != nil {
383
+ return nil , fmt .Errorf ("While scanning a row for custom registrations: %w" , err )
384
+ }
385
+ result [typeName ] = oid
386
+ }
387
+ return result , nil
388
+ }
389
+
318
390
// ParseConfig builds a Config from connString. It parses connString with the same behavior as [pgx.ParseConfig] with the
319
391
// addition of the following variables:
320
392
//
@@ -425,6 +497,12 @@ func (p *Pool) Close() {
425
497
})
426
498
}
427
499
500
+ // RegisterCustomType is used to provide a function capable of performing
501
+ // type registration for situations where the autoloader is unable to do so on its own
502
+ func (p * Pool ) RegisterCustomType (typeName string , f CustomRegistrationFunction ) {
503
+ p .customRegistrationMap [typeName ] = f
504
+ }
505
+
428
506
// loadTypes is used internally to autoload the custom types for a connection,
429
507
// potentially reusing previously-loaded typemap information.
430
508
func (p * Pool ) loadTypes (ctx context.Context , conn * pgx.Conn , typeNames []string ) ([]* pgtype.Type , error ) {
0 commit comments