Skip to content

Commit

Permalink
feat: pgx pool acquire tracer
Browse files Browse the repository at this point in the history
  • Loading branch information
ngavinsir committed May 10, 2024
1 parent 579a320 commit 5c4fedf
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 1 deletion.
15 changes: 14 additions & 1 deletion pgxpool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ type Pool struct {

healthCheckChan chan struct{}

acquireTracer AcquireTracer

closeOnce sync.Once
closeChan chan struct{}
}
Expand Down Expand Up @@ -195,6 +197,10 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
closeChan: make(chan struct{}),
}

if t, ok := config.ConnConfig.Tracer.(AcquireTracer); ok {
p.acquireTracer = t
}

var err error
p.p, err = puddle.NewPool(
&puddle.Config[*connResource]{
Expand Down Expand Up @@ -498,7 +504,14 @@ func (p *Pool) createIdleResources(parentCtx context.Context, targetResources in
}

// Acquire returns a connection (*Conn) from the Pool
func (p *Pool) Acquire(ctx context.Context) (*Conn, error) {
func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) {
if p.acquireTracer != nil {
ctx = p.acquireTracer.TraceAcquireStart(ctx, TraceAcquireStartData{ConnConfig: p.config.ConnConfig})
defer func() {
p.acquireTracer.TraceAcquireEnd(ctx, TraceAcquireEndData{Err: err})
}()
}

for {
res, err := p.p.Acquire(ctx)
if err != nil {
Expand Down
23 changes: 23 additions & 0 deletions pgxpool/tracer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package pgxpool

import (
"context"

"github.com/jackc/pgx/v5"
)

// AcquireTracer traces Acquire.
type AcquireTracer interface {
// TraceAcquireStart is called at the beginning of Acquire.
// The returned context is used for the rest of the call and will be passed to the TraceAcquireEnd.
TraceAcquireStart(ctx context.Context, data TraceAcquireStartData) context.Context
TraceAcquireEnd(ctx context.Context, data TraceAcquireEndData)
}

type TraceAcquireStartData struct {
ConnConfig *pgx.ConnConfig
}

type TraceAcquireEndData struct {
Err error
}
90 changes: 90 additions & 0 deletions pgxpool/tracer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package pgxpool_test

import (
"context"
"os"
"testing"
"time"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/stretchr/testify/require"
)

type testTracer struct {
traceAcquireStart func(ctx context.Context, data pgxpool.TraceAcquireStartData) context.Context
traceAcquireEnd func(ctx context.Context, data pgxpool.TraceAcquireEndData)
}

type ctxKey string

func (tt *testTracer) TraceAcquireStart(ctx context.Context, data pgxpool.TraceAcquireStartData) context.Context {
if tt.traceAcquireStart != nil {
return tt.traceAcquireStart(ctx, data)
}
return ctx
}

func (tt *testTracer) TraceAcquireEnd(ctx context.Context, data pgxpool.TraceAcquireEndData) {
if tt.traceAcquireEnd != nil {
tt.traceAcquireEnd(ctx, data)
}
}

func (tt *testTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
return ctx
}

func (tt *testTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
}

func TestTraceAcquire(t *testing.T) {
t.Parallel()

tracer := &testTracer{}

ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
config.ConnConfig.Tracer = tracer

pool, err := pgxpool.NewWithConfig(ctx, config)
require.NoError(t, err)
defer pool.Close()

traceAcquireStartCalled := false
tracer.traceAcquireStart = func(ctx context.Context, data pgxpool.TraceAcquireStartData) context.Context {
traceAcquireStartCalled = true
require.NotNil(t, data.ConnConfig)
return context.WithValue(ctx, ctxKey("fromTraceAcquireStart"), "foo")
}

traceAcquireEndCalled := false
tracer.traceAcquireEnd = func(ctx context.Context, data pgxpool.TraceAcquireEndData) {
traceAcquireEndCalled = true
require.Equal(t, "foo", ctx.Value(ctxKey(ctxKey("fromTraceAcquireStart"))))
require.NoError(t, data.Err)
}

c, err := pool.Acquire(ctx)
require.NoError(t, err)
defer c.Release()
require.True(t, traceAcquireStartCalled)
require.True(t, traceAcquireEndCalled)

traceAcquireStartCalled = false
traceAcquireEndCalled = false
tracer.traceAcquireEnd = func(ctx context.Context, data pgxpool.TraceAcquireEndData) {
traceAcquireEndCalled = true
require.Error(t, data.Err)
}

ctx, cancel = context.WithCancel(ctx)
cancel()
c, err = pool.Acquire(ctx)
require.ErrorIs(t, err, context.Canceled)
require.True(t, traceAcquireStartCalled)
require.True(t, traceAcquireEndCalled)
}

0 comments on commit 5c4fedf

Please sign in to comment.