Skip to content

Commit

Permalink
no closing connections mode
Browse files Browse the repository at this point in the history
  • Loading branch information
sleygin committed Dec 25, 2023
1 parent 9ab9e3c commit 8e99ffb
Show file tree
Hide file tree
Showing 5 changed files with 462 additions and 31 deletions.
4 changes: 4 additions & 0 deletions pgconn/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ type Config struct {
// that you close on FATAL errors by returning false.
OnPgError PgErrorHandler

// NoClosingConnMode enables mode when connections are not closed when timeout happens but cleaned up
// and returned back to the pool.
NoClosingConnMode bool

createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}

Expand Down
202 changes: 172 additions & 30 deletions pgconn/pgconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ const (
connStatusClosed
connStatusIdle
connStatusBusy
connStatusNeedCleanup
)

// Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from
Expand Down Expand Up @@ -84,7 +85,8 @@ type PgConn struct {

config *Config

status byte // One of connStatus* constants
cleanupWithoutReset bool
status byte // One of connStatus* constants

bufferingReceive bool
bufferingReceiveMux sync.Mutex
Expand Down Expand Up @@ -529,7 +531,7 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
var netErr net.Error
isNetErr := errors.As(err, &netErr)
if !(isNetErr && netErr.Timeout()) {
pgConn.asyncClose()
pgConn.asyncClose(err)
}

return nil, err
Expand Down Expand Up @@ -642,12 +644,150 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
return pgConn.conn.Close()
}

func (pgConn *PgConn) execRoundTrip(ctx context.Context, sql string) error {
rr := pgConn.Exec(ctx, sql)
if rr.err != nil {
return fmt.Errorf("exec: %w", rr.err)
}

// reading all that is left in socket
if err := rr.Close(); err != nil {
return fmt.Errorf("close: %w", err)
}

return nil
}

func (pgConn *PgConn) reset() error {
for {
msg, err := pgConn.receiveMessage()
if err != nil {
return fmt.Errorf("receive message: %w", err)
}

// every request is ended with ReadyForQuery message
// so we just read till it comes or error occurs
switch msg := msg.(type) {
case *pgproto3.ErrorResponse:
return ErrorResponseToPgError(msg)
case *pgproto3.ReadyForQuery:
return nil
}
}
}

func (pgConn *PgConn) cleanSocket() error {
// If this option is set, then there is nothing
// to read from the socket and there is no need
// to hang forever on reading.
if pgConn.cleanupWithoutReset {
return nil
}

// Read all the data left from the previous request.
if err := pgConn.reset(); err != nil {
return fmt.Errorf("reset: %w", err)
}

return nil
}

func (pgConn *PgConn) cleanup(ctx context.Context, resetQuery string) error {
deadline, ok := ctx.Deadline()
if ok {
pgConn.conn.SetDeadline(deadline)
} else {
pgConn.conn.SetDeadline(time.Time{})
}

if err := pgConn.cleanSocket(); err != nil {
return fmt.Errorf("clean socket: %w", err)
}

// Switch status to idle to not receive an error
// while locking connection on exec operation.
pgConn.status = connStatusIdle

// Rollback if there is an active transaction,
// Checking TxStatus to prevent overhead
if pgConn.TxStatus() != 'I' {
if err := pgConn.execRoundTrip(ctx, "ROLLBACK;"); err != nil {
return fmt.Errorf("rollback: %w", err)
}
}

// Full session reset
if resetQuery != "" {
if err := pgConn.execRoundTrip(ctx, resetQuery); err != nil {
return fmt.Errorf("discard all: %w", err)
}
}

// Reset everything.
pgConn.conn.SetDeadline(time.Time{})
pgConn.status = connStatusIdle
pgConn.cleanupWithoutReset = false
pgConn.contextWatcher.Unwatch()

return nil
}

func (pgConn *PgConn) LaunchCleanup(ctx context.Context, resetQuery string, onCleanupSucceeded func(), onCleanupFailed func(error)) (cleanupLaunched bool) {
if pgConn.status != connStatusNeedCleanup {
return false
}

go func() {
if err := pgConn.cleanup(ctx, resetQuery); err != nil {
if onCleanupFailed != nil {
onCleanupFailed(err)
}

return
}

if onCleanupSucceeded != nil {
onCleanupSucceeded()
}
}()

return true
}

func (pgConn *PgConn) setCleanupNeeded(reason error) bool {
if pgConn.status != connStatusBusy {
return false
}

// Set a connStatusNeedCleanup status to reset connection later.
var netErr net.Error
if isNetErr := errors.As(reason, &netErr); isNetErr && netErr.Timeout() {
pgConn.status = connStatusNeedCleanup

// if there was no data send to the server, then there is nothing to read back
// so we dont need to reset
if SafeToRetry(reason) {
pgConn.cleanupWithoutReset = true
}

return true
}

return false
}

// asyncClose marks the connection as closed and asynchronously sends a cancel query message and closes the underlying
// connection.
func (pgConn *PgConn) asyncClose() {
func (pgConn *PgConn) asyncClose(reason error) {
if pgConn.status == connStatusClosed {
return
}

// Set a connStatusNeedCleanup status to reset connection later.
if pgConn.config.NoClosingConnMode && pgConn.setCleanupNeeded(reason) {
return
}

pgConn.status = connStatusClosed

go func() {
Expand Down Expand Up @@ -701,6 +841,8 @@ func (pgConn *PgConn) lock() error {
return &connLockError{status: "conn closed"}
case connStatusUninitialized:
return &connLockError{status: "conn uninitialized"}
case connStatusNeedCleanup:
return &connLockError{status: "conn need to cleanup"}
}
pgConn.status = connStatusBusy
return nil
Expand Down Expand Up @@ -844,7 +986,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
pgConn.frontend.SendSync(&pgproto3.Sync{})
err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
pgConn.asyncClose(err)
return nil, err
}

Expand All @@ -856,7 +998,7 @@ readloop:
for {
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.asyncClose()
pgConn.asyncClose(err)
return nil, normalizeTimeoutError(ctx, err)
}

Expand Down Expand Up @@ -905,14 +1047,14 @@ func (pgConn *PgConn) Deallocate(ctx context.Context, name string) error {
pgConn.frontend.SendSync(&pgproto3.Sync{})
err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
pgConn.asyncClose(err)
return err
}

for {
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.asyncClose()
pgConn.asyncClose(err)
return normalizeTimeoutError(ctx, err)
}

Expand Down Expand Up @@ -1077,7 +1219,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
pgConn.asyncClose(err)
pgConn.contextWatcher.Unwatch()
multiResult.closed = true
multiResult.err = err
Expand Down Expand Up @@ -1188,7 +1330,7 @@ func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) {

err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
pgConn.asyncClose(err)
result.concludeCommand(CommandTag{}, err)
pgConn.contextWatcher.Unwatch()
result.closed = true
Expand Down Expand Up @@ -1221,7 +1363,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm

err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
pgConn.asyncClose(err)
pgConn.unlock()
return CommandTag{}, err
}
Expand All @@ -1232,7 +1374,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
for {
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.asyncClose()
pgConn.asyncClose(err)
return CommandTag{}, normalizeTimeoutError(ctx, err)
}

Expand All @@ -1241,7 +1383,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
case *pgproto3.CopyData:
_, err := w.Write(msg.Data)
if err != nil {
pgConn.asyncClose()
pgConn.asyncClose(err)
return CommandTag{}, err
}
case *pgproto3.ReadyForQuery:
Expand Down Expand Up @@ -1279,7 +1421,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
pgConn.asyncClose(err)
return CommandTag{}, err
}

Expand Down Expand Up @@ -1361,7 +1503,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
}
err = pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
pgConn.asyncClose(err)
return CommandTag{}, err
}

Expand All @@ -1370,7 +1512,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
for {
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.asyncClose()
pgConn.asyncClose(err)
return CommandTag{}, normalizeTimeoutError(ctx, err)
}

Expand Down Expand Up @@ -1415,7 +1557,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
mrr.pgConn.contextWatcher.Unwatch()
mrr.err = normalizeTimeoutError(mrr.ctx, err)
mrr.closed = true
mrr.pgConn.asyncClose()
mrr.pgConn.asyncClose(err)
return nil, mrr.err
}

Expand Down Expand Up @@ -1630,12 +1772,12 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
}

if err != nil {
err = normalizeTimeoutError(rr.ctx, err)
rr.concludeCommand(CommandTag{}, err)
normalizedErr := normalizeTimeoutError(rr.ctx, err)
rr.concludeCommand(CommandTag{}, normalizedErr)
rr.pgConn.contextWatcher.Unwatch()
rr.closed = true
if rr.multiResultReader == nil {
rr.pgConn.asyncClose()
rr.pgConn.asyncClose(err)
}

return nil, rr.err
Expand Down Expand Up @@ -2038,9 +2180,7 @@ func (p *Pipeline) Flush() error {
err := p.conn.flushWithPotentialWriteReadDeadlock()
if err != nil {
err = normalizeTimeoutError(p.ctx, err)

p.conn.asyncClose()

p.conn.asyncClose(err)
p.conn.contextWatcher.Unwatch()
p.conn.unlock()
p.closed = true
Expand Down Expand Up @@ -2076,7 +2216,7 @@ func (p *Pipeline) GetResults() (results any, err error) {
for {
msg, err := p.conn.receiveMessage()
if err != nil {
p.conn.asyncClose()
p.conn.asyncClose(err)
return nil, normalizeTimeoutError(p.ctx, err)
}

Expand Down Expand Up @@ -2123,7 +2263,7 @@ func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) {
for {
msg, err := p.conn.receiveMessage()
if err != nil {
p.conn.asyncClose()
p.conn.asyncClose(err)
return nil, normalizeTimeoutError(p.ctx, err)
}

Expand All @@ -2145,11 +2285,13 @@ func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) {
pgErr := ErrorResponseToPgError(msg)
return nil, pgErr
case *pgproto3.CommandComplete:
p.conn.asyncClose()
return nil, errors.New("BUG: received CommandComplete while handling Describe")
err := errors.New("BUG: received CommandComplete while handling Describe")
p.conn.asyncClose(err)
return nil, err
case *pgproto3.ReadyForQuery:
p.conn.asyncClose()
return nil, errors.New("BUG: received ReadyForQuery while handling Describe")
err := errors.New("BUG: received ReadyForQuery while handling Describe")
p.conn.asyncClose(err)
return nil, err
}
}
}
Expand All @@ -2162,8 +2304,8 @@ func (p *Pipeline) Close() error {
p.closed = true

if p.pendingSync {
p.conn.asyncClose()
p.err = errors.New("pipeline has unsynced requests")
p.conn.asyncClose(p.err)
p.conn.contextWatcher.Unwatch()
p.conn.unlock()

Expand All @@ -2176,7 +2318,7 @@ func (p *Pipeline) Close() error {
p.err = err
var pgErr *PgError
if !errors.As(err, &pgErr) {
p.conn.asyncClose()
p.conn.asyncClose(err)
break
}
}
Expand Down
Loading

0 comments on commit 8e99ffb

Please sign in to comment.