diff --git a/pgconn/config.go b/pgconn/config.go index 942a864ac..5c55fb244 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -65,6 +65,10 @@ type Config struct { // that you close on FATAL errors by returning false. OnPgError PgErrorHandler + // NoClosingConnMode enables mode when connections are not closed when timeout happens but cleaned up + // and returned back to the pool. + NoClosingConnMode bool + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 803b41d1f..46cc2ff69 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -29,6 +29,7 @@ const ( connStatusClosed connStatusIdle connStatusBusy + connStatusNeedCleanup ) // Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from @@ -84,7 +85,8 @@ type PgConn struct { config *Config - status byte // One of connStatus* constants + cleanupWithoutReset bool + status byte // One of connStatus* constants bufferingReceive bool bufferingReceiveMux sync.Mutex @@ -529,7 +531,7 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) { var netErr net.Error isNetErr := errors.As(err, &netErr) if !(isNetErr && netErr.Timeout()) { - pgConn.asyncClose() + pgConn.asyncClose(err) } return nil, err @@ -642,12 +644,150 @@ func (pgConn *PgConn) Close(ctx context.Context) error { return pgConn.conn.Close() } +func (pgConn *PgConn) execRoundTrip(ctx context.Context, sql string) error { + rr := pgConn.Exec(ctx, sql) + if rr.err != nil { + return fmt.Errorf("exec: %w", rr.err) + } + + // reading all that is left in socket + if err := rr.Close(); err != nil { + return fmt.Errorf("close: %w", err) + } + + return nil +} + +func (pgConn *PgConn) reset() error { + for { + msg, err := pgConn.receiveMessage() + if err != nil { + return fmt.Errorf("receive message: %w", err) + } + + // every request is ended with ReadyForQuery message + // so we just read till it comes or error occurs + switch msg := msg.(type) { + case *pgproto3.ErrorResponse: + return ErrorResponseToPgError(msg) + case *pgproto3.ReadyForQuery: + return nil + } + } +} + +func (pgConn *PgConn) cleanSocket() error { + // If this option is set, then there is nothing + // to read from the socket and there is no need + // to hang forever on reading. + if pgConn.cleanupWithoutReset { + return nil + } + + // Read all the data left from the previous request. + if err := pgConn.reset(); err != nil { + return fmt.Errorf("reset: %w", err) + } + + return nil +} + +func (pgConn *PgConn) cleanup(ctx context.Context, resetQuery string) error { + deadline, ok := ctx.Deadline() + if ok { + pgConn.conn.SetDeadline(deadline) + } else { + pgConn.conn.SetDeadline(time.Time{}) + } + + if err := pgConn.cleanSocket(); err != nil { + return fmt.Errorf("clean socket: %w", err) + } + + // Switch status to idle to not receive an error + // while locking connection on exec operation. + pgConn.status = connStatusIdle + + // Rollback if there is an active transaction, + // Checking TxStatus to prevent overhead + if pgConn.TxStatus() != 'I' { + if err := pgConn.execRoundTrip(ctx, "ROLLBACK;"); err != nil { + return fmt.Errorf("rollback: %w", err) + } + } + + // Full session reset + if resetQuery != "" { + if err := pgConn.execRoundTrip(ctx, resetQuery); err != nil { + return fmt.Errorf("discard all: %w", err) + } + } + + // Reset everything. + pgConn.conn.SetDeadline(time.Time{}) + pgConn.status = connStatusIdle + pgConn.cleanupWithoutReset = false + pgConn.contextWatcher.Unwatch() + + return nil +} + +func (pgConn *PgConn) LaunchCleanup(ctx context.Context, resetQuery string, onCleanupSucceeded func(), onCleanupFailed func(error)) (cleanupLaunched bool) { + if pgConn.status != connStatusNeedCleanup { + return false + } + + go func() { + if err := pgConn.cleanup(ctx, resetQuery); err != nil { + if onCleanupFailed != nil { + onCleanupFailed(err) + } + + return + } + + if onCleanupSucceeded != nil { + onCleanupSucceeded() + } + }() + + return true +} + +func (pgConn *PgConn) setCleanupNeeded(reason error) bool { + if pgConn.status != connStatusBusy { + return false + } + + // Set a connStatusNeedCleanup status to reset connection later. + var netErr net.Error + if isNetErr := errors.As(reason, &netErr); isNetErr && netErr.Timeout() { + pgConn.status = connStatusNeedCleanup + + // if there was no data send to the server, then there is nothing to read back + // so we dont need to reset + if SafeToRetry(reason) { + pgConn.cleanupWithoutReset = true + } + + return true + } + + return false +} + // asyncClose marks the connection as closed and asynchronously sends a cancel query message and closes the underlying // connection. -func (pgConn *PgConn) asyncClose() { +func (pgConn *PgConn) asyncClose(reason error) { if pgConn.status == connStatusClosed { return } + + // Set a connStatusNeedCleanup status to reset connection later. + if pgConn.config.NoClosingConnMode && pgConn.setCleanupNeeded(reason) { + return + } + pgConn.status = connStatusClosed go func() { @@ -701,6 +841,8 @@ func (pgConn *PgConn) lock() error { return &connLockError{status: "conn closed"} case connStatusUninitialized: return &connLockError{status: "conn uninitialized"} + case connStatusNeedCleanup: + return &connLockError{status: "conn need to cleanup"} } pgConn.status = connStatusBusy return nil @@ -844,7 +986,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ pgConn.frontend.SendSync(&pgproto3.Sync{}) err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { - pgConn.asyncClose() + pgConn.asyncClose(err) return nil, err } @@ -856,7 +998,7 @@ readloop: for { msg, err := pgConn.receiveMessage() if err != nil { - pgConn.asyncClose() + pgConn.asyncClose(err) return nil, normalizeTimeoutError(ctx, err) } @@ -905,14 +1047,14 @@ func (pgConn *PgConn) Deallocate(ctx context.Context, name string) error { pgConn.frontend.SendSync(&pgproto3.Sync{}) err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { - pgConn.asyncClose() + pgConn.asyncClose(err) return err } for { msg, err := pgConn.receiveMessage() if err != nil { - pgConn.asyncClose() + pgConn.asyncClose(err) return normalizeTimeoutError(ctx, err) } @@ -1077,7 +1219,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { - pgConn.asyncClose() + pgConn.asyncClose(err) pgConn.contextWatcher.Unwatch() multiResult.closed = true multiResult.err = err @@ -1188,7 +1330,7 @@ func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) { err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { - pgConn.asyncClose() + pgConn.asyncClose(err) result.concludeCommand(CommandTag{}, err) pgConn.contextWatcher.Unwatch() result.closed = true @@ -1221,7 +1363,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { - pgConn.asyncClose() + pgConn.asyncClose(err) pgConn.unlock() return CommandTag{}, err } @@ -1232,7 +1374,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm for { msg, err := pgConn.receiveMessage() if err != nil { - pgConn.asyncClose() + pgConn.asyncClose(err) return CommandTag{}, normalizeTimeoutError(ctx, err) } @@ -1241,7 +1383,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm case *pgproto3.CopyData: _, err := w.Write(msg.Data) if err != nil { - pgConn.asyncClose() + pgConn.asyncClose(err) return CommandTag{}, err } case *pgproto3.ReadyForQuery: @@ -1279,7 +1421,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co pgConn.frontend.SendQuery(&pgproto3.Query{String: sql}) err := pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { - pgConn.asyncClose() + pgConn.asyncClose(err) return CommandTag{}, err } @@ -1361,7 +1503,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co } err = pgConn.flushWithPotentialWriteReadDeadlock() if err != nil { - pgConn.asyncClose() + pgConn.asyncClose(err) return CommandTag{}, err } @@ -1370,7 +1512,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co for { msg, err := pgConn.receiveMessage() if err != nil { - pgConn.asyncClose() + pgConn.asyncClose(err) return CommandTag{}, normalizeTimeoutError(ctx, err) } @@ -1415,7 +1557,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) mrr.pgConn.contextWatcher.Unwatch() mrr.err = normalizeTimeoutError(mrr.ctx, err) mrr.closed = true - mrr.pgConn.asyncClose() + mrr.pgConn.asyncClose(err) return nil, mrr.err } @@ -1630,12 +1772,12 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error } if err != nil { - err = normalizeTimeoutError(rr.ctx, err) - rr.concludeCommand(CommandTag{}, err) + normalizedErr := normalizeTimeoutError(rr.ctx, err) + rr.concludeCommand(CommandTag{}, normalizedErr) rr.pgConn.contextWatcher.Unwatch() rr.closed = true if rr.multiResultReader == nil { - rr.pgConn.asyncClose() + rr.pgConn.asyncClose(err) } return nil, rr.err @@ -2038,9 +2180,7 @@ func (p *Pipeline) Flush() error { err := p.conn.flushWithPotentialWriteReadDeadlock() if err != nil { err = normalizeTimeoutError(p.ctx, err) - - p.conn.asyncClose() - + p.conn.asyncClose(err) p.conn.contextWatcher.Unwatch() p.conn.unlock() p.closed = true @@ -2076,7 +2216,7 @@ func (p *Pipeline) GetResults() (results any, err error) { for { msg, err := p.conn.receiveMessage() if err != nil { - p.conn.asyncClose() + p.conn.asyncClose(err) return nil, normalizeTimeoutError(p.ctx, err) } @@ -2123,7 +2263,7 @@ func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) { for { msg, err := p.conn.receiveMessage() if err != nil { - p.conn.asyncClose() + p.conn.asyncClose(err) return nil, normalizeTimeoutError(p.ctx, err) } @@ -2145,11 +2285,13 @@ func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) { pgErr := ErrorResponseToPgError(msg) return nil, pgErr case *pgproto3.CommandComplete: - p.conn.asyncClose() - return nil, errors.New("BUG: received CommandComplete while handling Describe") + err := errors.New("BUG: received CommandComplete while handling Describe") + p.conn.asyncClose(err) + return nil, err case *pgproto3.ReadyForQuery: - p.conn.asyncClose() - return nil, errors.New("BUG: received ReadyForQuery while handling Describe") + err := errors.New("BUG: received ReadyForQuery while handling Describe") + p.conn.asyncClose(err) + return nil, err } } } @@ -2162,8 +2304,8 @@ func (p *Pipeline) Close() error { p.closed = true if p.pendingSync { - p.conn.asyncClose() p.err = errors.New("pipeline has unsynced requests") + p.conn.asyncClose(p.err) p.conn.contextWatcher.Unwatch() p.conn.unlock() @@ -2176,7 +2318,7 @@ func (p *Pipeline) Close() error { p.err = err var pgErr *PgError if !errors.As(err, &pgErr) { - p.conn.asyncClose() + p.conn.asyncClose(err) break } } diff --git a/pgconn/pgconn_private_test.go b/pgconn/pgconn_private_test.go index 5659bc9ef..646d09732 100644 --- a/pgconn/pgconn_private_test.go +++ b/pgconn/pgconn_private_test.go @@ -1,9 +1,15 @@ package pgconn import ( + "context" + "fmt" + "net" + "os" "testing" + "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCommandTag(t *testing.T) { @@ -39,3 +45,225 @@ func TestCommandTag(t *testing.T) { assert.Equalf(t, tt.isSelect, ct.Select(), "%d. %v", i, tt.commandTag) } } + +func mustConnectWithNoClosingMode(t *testing.T) *PgConn { + t.Helper() + + cfg, err := ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + cfg.NoClosingConnMode = true + + conn, err := ConnectConfig(context.Background(), cfg) + require.NoError(t, err) + + return conn +} + +func TestCleanup(t *testing.T) { + t.Parallel() + + var tests = []struct { + name string + testCase func(t *testing.T, conn *PgConn) + }{ + { + name: "success", + testCase: func(t *testing.T, conn *PgConn) { + ctx := context.Background() + + execCtx, cancel := context.WithTimeout(ctx, time.Millisecond*200) + defer cancel() + + // expecting error because timeout is less than execution time + rr := conn.Exec(execCtx, `select pg_sleep(0.5)`) + err := rr.Close() + require.Error(t, err) + + // enough timeout for cleanup + cleanupCtx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + + var cleanupSucceeded bool + // we expect that connection is in status `need cleanup` because previous request failed + launched := conn.LaunchCleanup(cleanupCtx, `select 'hello wold'`, func() { cleanupSucceeded = true }, nil) + require.True(t, launched) + + // enough time for cleanup + time.Sleep(time.Second) + require.True(t, conn.status == connStatusIdle) + require.True(t, cleanupSucceeded) + + execCtx, cancel = context.WithTimeout(ctx, time.Millisecond*200) + defer cancel() + + // checking that socket is clean, and we are reading data from our request, not from the previous one + rr = conn.Exec(execCtx, `select 'goodbye world'`) + res, err := rr.ReadAll() + require.NoError(t, err) + + require.True(t, string(res[0].Rows[0][0]) == "goodbye world") + }, + }, + { + name: "failed cleanup timepout", + testCase: func(t *testing.T, conn *PgConn) { + ctx := context.Background() + + execCtx, cancel := context.WithTimeout(ctx, time.Millisecond*200) + defer cancel() + + // expecting error because timeout is less than execution time + rr := conn.Exec(execCtx, `select pg_sleep(0.5)`) + err := rr.Close() + require.Error(t, err) + + // enough timeout for cleanup + cleanupCtx, cancel := context.WithTimeout(ctx, time.Millisecond*200) + defer cancel() + + var launchErr error + // we expect that connection is in status `need cleanup` because previous request failed + launched := conn.LaunchCleanup(cleanupCtx, `select pg_sleep(0.5)`, nil, func(err error) { + launchErr = err + }) + require.True(t, launched) + + // enough time for cleanup + time.Sleep(time.Second) + + // we expect error as we failed to cleanup socket + require.NotNil(t, launchErr) + }, + }, + { + name: "failed to launch cleanup", + testCase: func(t *testing.T, conn *PgConn) { + ctx := context.Background() + + execCtx, cancel := context.WithTimeout(ctx, time.Millisecond*200) + defer cancel() + + // expecting error because timeout is less than execution time + rr := conn.Exec(execCtx, `select pg_sleep(0.5)`) + + // we expect not to launch cleanup when connection is not in `need cleanup` state + launched := conn.LaunchCleanup(ctx, "", nil, nil) + require.False(t, launched) + + _ = rr.Close() + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + conn := mustConnectWithNoClosingMode(t) + tt.testCase(t, conn) + }) + } +} + +var _ net.Error = &mockNetError{} + +type mockNetError struct { + safeToRetry bool + timeout bool +} + +func (m *mockNetError) Error() string { + return "mock error" +} + +func (m *mockNetError) Timeout() bool { + return m.timeout +} + +func (m *mockNetError) Temporary() bool { + return true +} + +func (m *mockNetError) SafeToRetry() bool { + return m.safeToRetry +} + +func TestCheckIfCleanupNeeded(t *testing.T) { + t.Parallel() + + var tests = []struct { + name string + reason error + expected bool + expectedStatus byte + expectedCleanupWithoutReset bool + builder func() *PgConn + }{ + { + name: "need to cleanup", + reason: &mockNetError{timeout: true}, + expected: true, + expectedStatus: connStatusNeedCleanup, + expectedCleanupWithoutReset: false, + builder: func() *PgConn { + return &PgConn{status: connStatusBusy} + }, + }, + { + name: "need to cleanup without reset", + reason: &mockNetError{timeout: true, safeToRetry: true}, + expected: true, + expectedStatus: connStatusNeedCleanup, + expectedCleanupWithoutReset: true, + builder: func() *PgConn { + return &PgConn{status: connStatusBusy} + }, + }, + { + name: "wrong status no need to cleanup", + reason: &mockNetError{}, + expected: false, + expectedStatus: connStatusClosed, + expectedCleanupWithoutReset: false, + builder: func() *PgConn { + return &PgConn{status: connStatusClosed} + }, + }, + { + name: "no timeout error", + reason: &mockNetError{timeout: false}, + expected: false, + expectedStatus: connStatusBusy, + expectedCleanupWithoutReset: false, + builder: func() *PgConn { + return &PgConn{status: connStatusBusy} + }, + }, + { + name: "not net error", + reason: fmt.Errorf("some error"), + expected: false, + expectedStatus: connStatusBusy, + expectedCleanupWithoutReset: false, + builder: func() *PgConn { + return &PgConn{status: connStatusBusy} + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + conn := tt.builder() + actual := conn.setCleanupNeeded(tt.reason) + + assert.Equal(t, tt.expected, actual) + assert.Equal(t, tt.expectedStatus, conn.status) + assert.Equal(t, tt.expectedCleanupWithoutReset, conn.cleanupWithoutReset) + }) + } +} diff --git a/pgxpool/conn.go b/pgxpool/conn.go index 36f90969e..34daaa865 100644 --- a/pgxpool/conn.go +++ b/pgxpool/conn.go @@ -15,6 +15,13 @@ type Conn struct { p *Pool } +// some pg cleanups before connection is returned to the pool +const sessionResetQuery = ` +UNLISTEN *; +SELECT pg_advisory_unlock_all(); +DISCARD SEQUENCES; +DISCARD TEMP;` + // Release returns c to the pool it was acquired from. Once Release has been called, other methods must not be called. // However, it is safe to call Release multiple times. Subsequent calls after the first will be ignored. func (c *Conn) Release() { @@ -25,8 +32,27 @@ func (c *Conn) Release() { conn := c.Conn() res := c.res c.res = nil + pgConn := conn.PgConn() + + cleanupCtx, cancel := context.WithTimeout(context.Background(), c.p.connCleanupTimeout) + defer cancel() + + cleanupLaunched := pgConn.LaunchCleanup( + cleanupCtx, + sessionResetQuery, + func() { + res.Release() + }, + func(error) { + res.Destroy() + c.p.triggerHealthCheck() + }, + ) + if cleanupLaunched { + return + } - if conn.IsClosed() || conn.PgConn().IsBusy() || conn.PgConn().TxStatus() != 'I' { + if conn.IsClosed() || pgConn.IsBusy() || pgConn.TxStatus() != 'I' { res.Destroy() // Signal to the health check to run since we just destroyed a connections // and we might be below minConns now diff --git a/pgxpool/pool.go b/pgxpool/pool.go index 9f74805e1..a5c6da0c6 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -20,6 +20,7 @@ var defaultMinConns = int32(0) var defaultMaxConnLifetime = time.Hour var defaultMaxConnIdleTime = time.Minute * 30 var defaultHealthCheckPeriod = time.Minute +var defaultConnCleanupTimeout = time.Millisecond * 100 type connResource struct { conn *pgx.Conn @@ -92,6 +93,7 @@ type Pool struct { maxConnLifetimeJitter time.Duration maxConnIdleTime time.Duration healthCheckPeriod time.Duration + connCleanupTimeout time.Duration healthCheckChan chan struct{} @@ -144,6 +146,12 @@ type Config struct { // HealthCheckPeriod is the duration between checks of the health of idle connections. HealthCheckPeriod time.Duration + // NoClosingConnMode enables mode when connections are not closed when timeout happens but cleaned up + // and returned back to the pool + NoClosingConnMode bool + // ConnCleanupTimeout is the timeout for cleanup pool connection + ConnCleanupTimeout time.Duration + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } @@ -178,6 +186,8 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { panic("config must be created by ParseConfig") } + config.ConnConfig.NoClosingConnMode = config.NoClosingConnMode + p := &Pool{ config: config, beforeConnect: config.BeforeConnect, @@ -191,6 +201,7 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { maxConnLifetimeJitter: config.MaxConnLifetimeJitter, maxConnIdleTime: config.MaxConnIdleTime, healthCheckPeriod: config.HealthCheckPeriod, + connCleanupTimeout: config.ConnCleanupTimeout, healthCheckChan: make(chan struct{}, 1), closeChan: make(chan struct{}), } @@ -365,6 +376,26 @@ func ParseConfig(connString string) (*Config, error) { config.MaxConnLifetimeJitter = d } + if s, ok := config.ConnConfig.Config.RuntimeParams["conn_cleanup_timeout"]; ok { + delete(connConfig.Config.RuntimeParams, "conn_cleanup_timeout") + d, err := time.ParseDuration(s) + if err != nil { + return nil, fmt.Errorf("invalid conn_cleanup_timeout: %w", err) + } + config.ConnCleanupTimeout = d + } else { + config.ConnCleanupTimeout = defaultConnCleanupTimeout + } + + if s, ok := config.ConnConfig.Config.RuntimeParams["no_closing_conn_mode"]; ok { + delete(connConfig.Config.RuntimeParams, "no_closing_conn_mode") + isEnabled, err := strconv.ParseBool(s) + if err != nil { + return nil, fmt.Errorf("invalid conn_cleanup_timeout: %w", err) + } + config.NoClosingConnMode = isEnabled + } + return config, nil }