Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pgx pool acquire tracer #2008

Merged
merged 7 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pgxpool/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ func (c *Conn) Release() {
res := c.res
c.res = nil

if c.p.releaseTracer != nil {
c.p.releaseTracer.TraceRelease(c.p, TraceReleaseData{Conn: conn})
}

if conn.IsClosed() || conn.PgConn().IsBusy() || conn.PgConn().TxStatus() != 'I' {
res.Destroy()
// Signal to the health check to run since we just destroyed a connections
Expand Down
24 changes: 23 additions & 1 deletion pgxpool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ type Pool struct {

healthCheckChan chan struct{}

acquireTracer AcquireTracer
releaseTracer ReleaseTracer

closeOnce sync.Once
closeChan chan struct{}
}
Expand Down Expand Up @@ -195,6 +198,14 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
closeChan: make(chan struct{}),
}

if t, ok := config.ConnConfig.Tracer.(AcquireTracer); ok {
p.acquireTracer = t
}

if t, ok := config.ConnConfig.Tracer.(ReleaseTracer); ok {
p.releaseTracer = t
}

var err error
p.p, err = puddle.NewPool(
&puddle.Config[*connResource]{
Expand Down Expand Up @@ -498,7 +509,18 @@ func (p *Pool) createIdleResources(parentCtx context.Context, targetResources in
}

// Acquire returns a connection (*Conn) from the Pool
func (p *Pool) Acquire(ctx context.Context) (*Conn, error) {
func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) {
if p.acquireTracer != nil {
ctx = p.acquireTracer.TraceAcquireStart(ctx, p, TraceAcquireStartData{})
defer func() {
var conn *pgx.Conn
if c != nil {
conn = c.Conn()
}
p.acquireTracer.TraceAcquireEnd(ctx, p, TraceAcquireEndData{Conn: conn, Err: err})
}()
}

for {
res, err := p.p.Acquire(ctx)
if err != nil {
Expand Down
33 changes: 33 additions & 0 deletions pgxpool/tracer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package pgxpool

import (
"context"

"github.com/jackc/pgx/v5"
)

// AcquireTracer traces Acquire.
type AcquireTracer interface {
// TraceAcquireStart is called at the beginning of Acquire.
// The returned context is used for the rest of the call and will be passed to the TraceAcquireEnd.
TraceAcquireStart(ctx context.Context, pool *Pool, data TraceAcquireStartData) context.Context
// TraceAcquireEnd is called when a connection has been acquired.
TraceAcquireEnd(ctx context.Context, pool *Pool, data TraceAcquireEndData)
}

type TraceAcquireStartData struct{}

type TraceAcquireEndData struct {
Conn *pgx.Conn
Err error
}

// ReleaseTracer traces Release.
type ReleaseTracer interface {
// TraceRelease is called at the beginning of Release.
TraceRelease(pool *Pool, data TraceReleaseData)
}

type TraceReleaseData struct {
Conn *pgx.Conn
}
130 changes: 130 additions & 0 deletions pgxpool/tracer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package pgxpool_test

import (
"context"
"os"
"testing"
"time"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/stretchr/testify/require"
)

type testTracer struct {
traceAcquireStart func(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context
traceAcquireEnd func(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData)
traceRelease func(pool *pgxpool.Pool, data pgxpool.TraceReleaseData)
}

type ctxKey string

func (tt *testTracer) TraceAcquireStart(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context {
if tt.traceAcquireStart != nil {
return tt.traceAcquireStart(ctx, pool, data)
}
return ctx
}

func (tt *testTracer) TraceAcquireEnd(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) {
if tt.traceAcquireEnd != nil {
tt.traceAcquireEnd(ctx, pool, data)
}
}

func (tt *testTracer) TraceRelease(pool *pgxpool.Pool, data pgxpool.TraceReleaseData) {
if tt.traceRelease != nil {
tt.traceRelease(pool, data)
}
}

func (tt *testTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
return ctx
}

func (tt *testTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
}

func TestTraceAcquire(t *testing.T) {
t.Parallel()

tracer := &testTracer{}

ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
config.ConnConfig.Tracer = tracer

pool, err := pgxpool.NewWithConfig(ctx, config)
require.NoError(t, err)
defer pool.Close()

traceAcquireStartCalled := false
tracer.traceAcquireStart = func(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context {
traceAcquireStartCalled = true
require.NotNil(t, pool)
return context.WithValue(ctx, ctxKey("fromTraceAcquireStart"), "foo")
}

traceAcquireEndCalled := false
tracer.traceAcquireEnd = func(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) {
traceAcquireEndCalled = true
require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceAcquireStart")))
require.NotNil(t, pool)
require.NotNil(t, data.Conn)
require.NoError(t, data.Err)
}

c, err := pool.Acquire(ctx)
require.NoError(t, err)
defer c.Release()
require.True(t, traceAcquireStartCalled)
require.True(t, traceAcquireEndCalled)

traceAcquireStartCalled = false
traceAcquireEndCalled = false
tracer.traceAcquireEnd = func(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) {
traceAcquireEndCalled = true
require.NotNil(t, pool)
require.Nil(t, data.Conn)
require.Error(t, data.Err)
}

ctx, cancel = context.WithCancel(ctx)
cancel()
_, err = pool.Acquire(ctx)
require.ErrorIs(t, err, context.Canceled)
require.True(t, traceAcquireStartCalled)
require.True(t, traceAcquireEndCalled)
}

func TestTraceRelease(t *testing.T) {
t.Parallel()

tracer := &testTracer{}

ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
config.ConnConfig.Tracer = tracer

pool, err := pgxpool.NewWithConfig(ctx, config)
require.NoError(t, err)
defer pool.Close()

traceReleaseCalled := false
tracer.traceRelease = func(pool *pgxpool.Pool, data pgxpool.TraceReleaseData) {
traceReleaseCalled = true
require.NotNil(t, pool)
require.NotNil(t, data.Conn)
}

c, err := pool.Acquire(ctx)
require.NoError(t, err)
c.Release()
require.True(t, traceReleaseCalled)
}