Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No closing connections mode #1845

Closed
wants to merge 22 commits into from
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@ _testmain.go
/.testdb

.DS_Store

.idea
63 changes: 45 additions & 18 deletions pgconn/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ type Config struct {
// that you close on FATAL errors by returning false.
OnPgError PgErrorHandler

// OnRecover is a callback function called when recover of connection is started. By default OnRecover sends cancel request
// to cancel running query. It may be set to nil to completely disable this functionю.
OnRecover RecoverHandler

// RecoverTimeout is a timeout for connection to recover. If connection wasnt able to recover during this time
// it will be closed. By default it is 500 ms.
RecoverTimeout time.Duration

createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}

Expand Down Expand Up @@ -273,6 +281,14 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
}
return true
},
OnRecover: func(ctx context.Context, conn *PgConn) error {
if err := conn.CancelRequest(ctx); err != nil {
return err
}

time.Sleep(time.Millisecond * 100)
return nil
},
}

if connectTimeoutSetting, present := settings["connect_timeout"]; present {
Expand All @@ -287,27 +303,38 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
config.DialFunc = defaultDialer.DialContext
}

if connectRecoverTimeout, present := settings["connect_recover_timeout"]; present {
d, err := time.ParseDuration(connectRecoverTimeout)
if err != nil {
return nil, &ParseConfigError{ConnString: connString, msg: "invalid connect_recover_timeout", err: err}
}
config.RecoverTimeout = d
} else {
config.RecoverTimeout = time.Millisecond * 500
}

config.LookupFunc = makeDefaultResolver().LookupHost

notRuntimeParams := map[string]struct{}{
"host": {},
"port": {},
"database": {},
"user": {},
"password": {},
"passfile": {},
"connect_timeout": {},
"sslmode": {},
"sslkey": {},
"sslcert": {},
"sslrootcert": {},
"sslpassword": {},
"sslsni": {},
"krbspn": {},
"krbsrvname": {},
"target_session_attrs": {},
"service": {},
"servicefile": {},
"host": {},
"port": {},
"database": {},
"user": {},
"password": {},
"passfile": {},
"connect_timeout": {},
"connect_recover_timeout": {},
"sslmode": {},
"sslkey": {},
"sslcert": {},
"sslrootcert": {},
"sslpassword": {},
"sslsni": {},
"krbspn": {},
"krbsrvname": {},
"target_session_attrs": {},
"service": {},
"servicefile": {},
}

// Adding kerberos configuration
Expand Down
75 changes: 52 additions & 23 deletions pgconn/pgconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/jackc/pgx/v5/internal/iobufpool"
Expand All @@ -29,6 +30,7 @@ const (
connStatusClosed
connStatusIdle
connStatusBusy
connStatusRecovering
)

// Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from
Expand Down Expand Up @@ -70,6 +72,9 @@ type NoticeHandler func(*PgConn, *Notice)
// notice event.
type NotificationHandler func(*PgConn, *Notification)

// RecoverHandler is a function that may be used when connection is at recover stage.
type RecoverHandler func(context.Context, *PgConn) error

// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
type PgConn struct {
conn net.Conn
Expand All @@ -84,7 +89,7 @@ type PgConn struct {

config *Config

status byte // One of connStatus* constants
status atomic.Uint32 // One of connStatus* constants

bufferingReceive bool
bufferingReceiveMux sync.Mutex
Expand All @@ -101,6 +106,7 @@ type PgConn struct {
fieldDescriptions [16]FieldDescription

cleanupDone chan struct{}
recoverWg sync.WaitGroup
}

// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format)
Expand Down Expand Up @@ -306,7 +312,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
defer pgConn.contextWatcher.Unwatch()

pgConn.parameterStatuses = make(map[string]string)
pgConn.status = connStatusConnecting
pgConn.status.Store(connStatusConnecting)
pgConn.bgReader = bgreader.New(pgConn.conn)
pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64),
func() {
Expand Down Expand Up @@ -381,7 +387,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
return nil, &ConnectError{Config: config, msg: "failed GSS auth", err: err}
}
case *pgproto3.ReadyForQuery:
pgConn.status = connStatusIdle
pgConn.status.Store(connStatusIdle)
if config.ValidateConnect != nil {
// ValidateConnect may execute commands that cause the context to be watched again. Unwatch first to avoid
// the watch already in progress panic. This is that last thing done by this method so there is no need to
Expand Down Expand Up @@ -555,7 +561,7 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
case *pgproto3.ErrorResponse:
err := ErrorResponseToPgError(msg)
if pgConn.config.OnPgError != nil && !pgConn.config.OnPgError(pgConn, err) {
pgConn.status = connStatusClosed
pgConn.status.Store(connStatusClosed)
pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return.
close(pgConn.cleanupDone)
return nil, err
Expand Down Expand Up @@ -594,6 +600,12 @@ func (pgConn *PgConn) PID() uint32 {
//
// See https://www.postgresql.org/docs/current/protocol-message-formats.html.
func (pgConn *PgConn) TxStatus() byte {
// if we are in recovering state transaction status is undefined
// so we have to wait until we recover
if pgConn.IsRecovering() {
pgConn.WaitForRecover()
}

return pgConn.txStatus
}

Expand All @@ -611,10 +623,10 @@ func (pgConn *PgConn) Frontend() *pgproto3.Frontend {
// sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The
// underlying net.Conn.Close() will always be called regardless of any other errors.
func (pgConn *PgConn) Close(ctx context.Context) error {
if pgConn.status == connStatusClosed {
if pgConn.status.Load() == connStatusClosed {
return nil
}
pgConn.status = connStatusClosed
pgConn.status.Store(connStatusClosed)

defer close(pgConn.cleanupDone)
defer pgConn.conn.Close()
Expand Down Expand Up @@ -645,10 +657,10 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
// asyncClose marks the connection as closed and asynchronously sends a cancel query message and closes the underlying
// connection.
func (pgConn *PgConn) asyncClose() {
if pgConn.status == connStatusClosed {
if pgConn.status.Load() == connStatusClosed {
return
}
pgConn.status = connStatusClosed
pgConn.status.Store(connStatusClosed)

go func() {
defer close(pgConn.cleanupDone)
Expand Down Expand Up @@ -684,32 +696,42 @@ func (pgConn *PgConn) CleanupDone() chan (struct{}) {
//
// CleanupDone() can be used to determine if all cleanup has been completed.
func (pgConn *PgConn) IsClosed() bool {
return pgConn.status < connStatusIdle
return pgConn.status.Load() < connStatusIdle
}

// IsBusy reports if the connection is busy.
func (pgConn *PgConn) IsBusy() bool {
return pgConn.status == connStatusBusy
return pgConn.status.Load() == connStatusBusy
}

// IsRecovering reports if the connection is in recovering state.
func (pgConn *PgConn) IsRecovering() bool {
return pgConn.status.Load() == connStatusRecovering
}

// lock locks the connection.
func (pgConn *PgConn) lock() error {
switch pgConn.status {
func (pgConn *PgConn) lock() (err error) {
if pgConn.status.Load() == connStatusRecovering {
pgConn.recoverWg.Wait()
}

switch pgConn.status.Load() {
case connStatusBusy:
return &connLockError{status: "conn busy"} // This only should be possible in case of an application bug.
case connStatusClosed:
return &connLockError{status: "conn closed"}
case connStatusUninitialized:
return &connLockError{status: "conn uninitialized"}
}
pgConn.status = connStatusBusy

pgConn.status.Store(connStatusBusy)
return nil
}

func (pgConn *PgConn) unlock() {
switch pgConn.status {
case connStatusBusy:
pgConn.status = connStatusIdle
switch pgConn.status.Load() {
case connStatusBusy, connStatusRecovering:
pgConn.status.Store(connStatusIdle)
case connStatusClosed:
default:
panic("BUG: cannot unlock unlocked connection") // This should only be possible if there is a bug in this package.
Expand All @@ -719,6 +741,12 @@ func (pgConn *PgConn) unlock() {
// ParameterStatus returns the value of a parameter reported by the server (e.g.
// server_version). Returns an empty string for unknown parameters.
func (pgConn *PgConn) ParameterStatus(key string) string {
// if we are in recovering state parameter statuses are undefined
// so we have to wait until we recover
if pgConn.IsRecovering() {
pgConn.WaitForRecover()
}

return pgConn.parameterStatuses[key]
}

Expand Down Expand Up @@ -1335,7 +1363,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
// the goroutine. So instead check pgConn.bufferingReceiveErr which will have been set by the signalMessage. If an
// error is found then forcibly close the connection without sending the Terminate message.
if err := pgConn.bufferingReceiveErr; err != nil {
pgConn.status = connStatusClosed
pgConn.status.Store(connStatusClosed)
pgConn.conn.Close()
close(pgConn.cleanupDone)
return CommandTag{}, normalizeTimeoutError(ctx, err)
Expand Down Expand Up @@ -1415,7 +1443,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
mrr.pgConn.contextWatcher.Unwatch()
mrr.err = normalizeTimeoutError(mrr.ctx, err)
mrr.closed = true
mrr.pgConn.asyncClose()
mrr.pgConn.handleConnectionError(err)
return nil, mrr.err
}

Expand Down Expand Up @@ -1630,12 +1658,12 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
}

if err != nil {
err = normalizeTimeoutError(rr.ctx, err)
rr.concludeCommand(CommandTag{}, err)
normalizedErr := normalizeTimeoutError(rr.ctx, err)
rr.concludeCommand(CommandTag{}, normalizedErr)
rr.pgConn.contextWatcher.Unwatch()
rr.closed = true
if rr.multiResultReader == nil {
rr.pgConn.asyncClose()
rr.pgConn.handleConnectionError(err)
}

return nil, rr.err
Expand Down Expand Up @@ -1867,7 +1895,7 @@ func (pgConn *PgConn) Hijack() (*HijackedConn, error) {
if err := pgConn.lock(); err != nil {
return nil, err
}
pgConn.status = connStatusClosed
pgConn.status.Store(connStatusClosed)

return &HijackedConn{
Conn: pgConn.conn,
Expand Down Expand Up @@ -1897,10 +1925,11 @@ func Construct(hc *HijackedConn) (*PgConn, error) {
frontend: hc.Frontend,
config: hc.Config,

status: connStatusIdle,
status: atomic.Uint32{},

cleanupDone: make(chan struct{}),
}
pgConn.status.Store(connStatusIdle)

pgConn.contextWatcher = newContextWatcher(pgConn.conn)
pgConn.bgReader = bgreader.New(pgConn.conn)
Expand Down
Loading