Skip to content

Commit

Permalink
pgconn: add LastError()
Browse files Browse the repository at this point in the history
LastError() returns the last error encountered by the underlying connection
or received from postgres. It is cleared when a new request is initiated.

Fixes #1803
  • Loading branch information
jameshartig committed Dec 4, 2023
1 parent 913e4c8 commit 969b244
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 30 deletions.
101 changes: 74 additions & 27 deletions pgconn/pgconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ type PgConn struct {

peekedMsg pgproto3.BackendMessage

lastError error

// Reusable / preallocated resources
resultReader ResultReader
multiResultReader MultiResultReader
Expand Down Expand Up @@ -460,6 +462,12 @@ func (pgConn *PgConn) signalMessage() chan struct{} {
return ch
}

// setLastError stores the received error for the LastError function to return.
// This MUST be called before the connection is closed.
func (pgConn *PgConn) setLastError(err error) {
pgConn.lastError = err
}

// ReceiveMessage receives one wire protocol message from the PostgreSQL server. It must only be used when the
// connection is not busy. e.g. It is an error to call ReceiveMessage while reading the result of a query. The messages
// are still handled by the core pgconn message handling system so receiving a NotificationResponse will still trigger
Expand All @@ -485,6 +493,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa

msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.setLastError(err)
err = &pgconnError{
msg: "receive message failed",
err: normalizeTimeoutError(ctx, err),
Expand Down Expand Up @@ -523,9 +532,8 @@ func (pgConn *PgConn) peekMessage() (pgproto3.BackendMessage, error) {
var netErr net.Error
isNetErr := errors.As(err, &netErr)
if !(isNetErr && netErr.Timeout()) {
pgConn.asyncClose()
pgConn.asyncClose(err)
}

return nil, err
}

Expand All @@ -547,11 +555,14 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
case *pgproto3.ParameterStatus:
pgConn.parameterStatuses[msg.Name] = msg.Value
case *pgproto3.ErrorResponse:
err := ErrorResponseToPgError(msg)
if msg.Severity == "FATAL" {
// call setLastError before we close but otherwise leave it up to the caller
pgConn.setLastError(err)
pgConn.status = connStatusClosed
pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return.
close(pgConn.cleanupDone)
return nil, ErrorResponseToPgError(msg)
return nil, err
}
case *pgproto3.NoticeResponse:
if pgConn.config.OnNotice != nil {
Expand Down Expand Up @@ -590,6 +601,16 @@ func (pgConn *PgConn) TxStatus() byte {
return pgConn.txStatus
}

// LastError returns the last error caused either by the underlying connection
// or returned from Postgres. If the error was returned from Postgres then the
// error will be *pgconn.PgError.
//
// When a new request is initiated (via Exec, CopyFrom, CopyTo, etc) any previous
// error will be cleared.
func (pgConn *PgConn) LastError() error {
return pgConn.lastError
}

// SecretKey returns the backend secret key used to send a cancel query message to the server.
func (pgConn *PgConn) SecretKey() uint32 {
return pgConn.secretKey
Expand Down Expand Up @@ -637,11 +658,12 @@ func (pgConn *PgConn) Close(ctx context.Context) error {

// asyncClose marks the connection as closed and asynchronously sends a cancel query message and closes the underlying
// connection.
func (pgConn *PgConn) asyncClose() {
func (pgConn *PgConn) asyncClose(err error) {
if pgConn.status == connStatusClosed {
return
}
pgConn.status = connStatusClosed
pgConn.setLastError(err)

go func() {
defer close(pgConn.cleanupDone)
Expand Down Expand Up @@ -832,12 +854,14 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
defer pgConn.contextWatcher.Unwatch()
}

// clear the last error since we are sending a new command
pgConn.setLastError(nil)
pgConn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs})
pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name})
pgConn.frontend.SendSync(&pgproto3.Sync{})
err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
pgConn.asyncClose(err)
return nil, err
}

Expand All @@ -849,7 +873,7 @@ readloop:
for {
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.asyncClose()
pgConn.asyncClose(err)
return nil, normalizeTimeoutError(ctx, err)
}

Expand All @@ -861,6 +885,7 @@ readloop:
psd.Fields = pgConn.convertRowDescription(nil, msg)
case *pgproto3.ErrorResponse:
parseErr = ErrorResponseToPgError(msg)
pgConn.setLastError(parseErr)
case *pgproto3.ReadyForQuery:
break readloop
}
Expand Down Expand Up @@ -1027,7 +1052,9 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
for {
msg, err := pgConn.receiveMessage()
if err != nil {
return normalizeTimeoutError(ctx, err)
err = normalizeTimeoutError(ctx, err)
pgConn.setLastError(err)
return err
}

switch msg.(type) {
Expand Down Expand Up @@ -1067,10 +1094,12 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
pgConn.contextWatcher.Watch(ctx)
}

// clear the last error since we are sending a new command
pgConn.setLastError(nil)
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
pgConn.asyncClose(err)
pgConn.contextWatcher.Unwatch()
multiResult.closed = true
multiResult.err = err
Expand Down Expand Up @@ -1175,13 +1204,15 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
}

func (pgConn *PgConn) execExtendedSuffix(result *ResultReader) {
// clear the last error since we are sending a new command
pgConn.setLastError(nil)
pgConn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
pgConn.frontend.SendExecute(&pgproto3.Execute{})
pgConn.frontend.SendSync(&pgproto3.Sync{})

err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
pgConn.asyncClose(err)
result.concludeCommand(CommandTag{}, err)
pgConn.contextWatcher.Unwatch()
result.closed = true
Expand Down Expand Up @@ -1209,12 +1240,13 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
defer pgConn.contextWatcher.Unwatch()
}

// clear the last error since we are sending a new command
pgConn.setLastError(nil)
// Send copy to command
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})

err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
pgConn.asyncClose(err)
pgConn.unlock()
return CommandTag{}, err
}
Expand All @@ -1225,7 +1257,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
for {
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.asyncClose()
pgConn.asyncClose(err)
return CommandTag{}, normalizeTimeoutError(ctx, err)
}

Expand All @@ -1234,7 +1266,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
case *pgproto3.CopyData:
_, err := w.Write(msg.Data)
if err != nil {
pgConn.asyncClose()
pgConn.asyncClose(err)
return CommandTag{}, err
}
case *pgproto3.ReadyForQuery:
Expand All @@ -1244,6 +1276,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
commandTag = pgConn.makeCommandTag(msg.CommandTag)
case *pgproto3.ErrorResponse:
pgErr = ErrorResponseToPgError(msg)
pgConn.setLastError(pgErr)
}
}
}
Expand All @@ -1268,11 +1301,13 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
defer pgConn.contextWatcher.Unwatch()
}

// clear the last error since we are sending a new command
pgConn.setLastError(nil)
// Send copy from query
pgConn.frontend.SendQuery(&pgproto3.Query{String: sql})
err := pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
pgConn.asyncClose(err)
return CommandTag{}, err
}

Expand All @@ -1297,6 +1332,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co

writeErr := pgConn.frontend.SendUnbufferedEncodedCopyData(*buf)
if writeErr != nil {
pgConn.setLastError(writeErr)
// Write errors are always fatal, but we can't use asyncClose because we are in a different goroutine. Not
// setting pgConn.status or closing pgConn.cleanupDone for the same reason.
pgConn.conn.Close()
Expand All @@ -1306,6 +1342,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
}
}
if readErr != nil {
pgConn.setLastError(readErr)
copyErrChan <- readErr
return
}
Expand All @@ -1328,6 +1365,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
// the goroutine. So instead check pgConn.bufferingReceiveErr which will have been set by the signalMessage. If an
// error is found then forcibly close the connection without sending the Terminate message.
if err := pgConn.bufferingReceiveErr; err != nil {
pgConn.setLastError(err)
pgConn.status = connStatusClosed
pgConn.conn.Close()
close(pgConn.cleanupDone)
Expand All @@ -1338,6 +1376,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
switch msg := msg.(type) {
case *pgproto3.ErrorResponse:
pgErr = ErrorResponseToPgError(msg)
pgConn.setLastError(pgErr)
default:
signalMessageChan = pgConn.signalMessage()
}
Expand All @@ -1354,7 +1393,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
}
err = pgConn.flushWithPotentialWriteReadDeadlock()
if err != nil {
pgConn.asyncClose()
pgConn.asyncClose(err)
return CommandTag{}, err
}

Expand All @@ -1363,7 +1402,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
for {
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.asyncClose()
pgConn.asyncClose(err)
return CommandTag{}, normalizeTimeoutError(ctx, err)
}

Expand All @@ -1374,6 +1413,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
commandTag = pgConn.makeCommandTag(msg.CommandTag)
case *pgproto3.ErrorResponse:
pgErr = ErrorResponseToPgError(msg)
pgConn.setLastError(pgErr)
}
}
}
Expand Down Expand Up @@ -1408,7 +1448,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
mrr.pgConn.contextWatcher.Unwatch()
mrr.err = normalizeTimeoutError(mrr.ctx, err)
mrr.closed = true
mrr.pgConn.asyncClose()
mrr.pgConn.asyncClose(err)
return nil, mrr.err
}

Expand All @@ -1423,6 +1463,7 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
}
case *pgproto3.ErrorResponse:
mrr.err = ErrorResponseToPgError(msg)
mrr.pgConn.setLastError(mrr.err)
}

return msg, nil
Expand Down Expand Up @@ -1584,6 +1625,7 @@ func (rr *ResultReader) Close() (CommandTag, error) {
// Detect a deferred constraint violation where the ErrorResponse is sent after CommandComplete.
case *pgproto3.ErrorResponse:
rr.err = ErrorResponseToPgError(msg)
rr.pgConn.setLastError(rr.err)
case *pgproto3.ReadyForQuery:
rr.pgConn.contextWatcher.Unwatch()
rr.pgConn.unlock()
Expand Down Expand Up @@ -1628,7 +1670,7 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
rr.pgConn.contextWatcher.Unwatch()
rr.closed = true
if rr.multiResultReader == nil {
rr.pgConn.asyncClose()
rr.pgConn.asyncClose(err)
}

return nil, rr.err
Expand All @@ -1652,6 +1694,7 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) {
// Keep the first error that is recorded. Store the error before checking if the command is already concluded to
// allow for receiving an error after CommandComplete but before ReadyForQuery.
if err != nil && rr.err == nil {
rr.pgConn.setLastError(err)
rr.err = err
}

Expand Down Expand Up @@ -1713,10 +1756,13 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR

batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)

// clear the last error since we are sending a new command
pgConn.setLastError(nil)
pgConn.enterPotentialWriteReadDeadlock()
defer pgConn.exitPotentialWriteReadDeadlock()
_, err := pgConn.conn.Write(batch.buf)
if err != nil {
pgConn.setLastError(err)
multiResult.closed = true
multiResult.err = err
pgConn.unlock()
Expand Down Expand Up @@ -2032,7 +2078,7 @@ func (p *Pipeline) Flush() error {
if err != nil {
err = normalizeTimeoutError(p.ctx, err)

p.conn.asyncClose()
p.conn.asyncClose(err)

p.conn.contextWatcher.Unwatch()
p.conn.unlock()
Expand Down Expand Up @@ -2116,7 +2162,7 @@ func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) {
for {
msg, err := p.conn.receiveMessage()
if err != nil {
p.conn.asyncClose()
p.conn.asyncClose(err)
return nil, normalizeTimeoutError(p.ctx, err)
}

Expand All @@ -2138,11 +2184,13 @@ func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) {
pgErr := ErrorResponseToPgError(msg)
return nil, pgErr
case *pgproto3.CommandComplete:
p.conn.asyncClose()
return nil, errors.New("BUG: received CommandComplete while handling Describe")
err := errors.New("BUG: received CommandComplete while handling Describe")
p.conn.asyncClose(err)
return nil, err
case *pgproto3.ReadyForQuery:
p.conn.asyncClose()
return nil, errors.New("BUG: received ReadyForQuery while handling Describe")
err := errors.New("BUG: received ReadyForQuery while handling Describe")
p.conn.asyncClose(err)
return nil, err
}
}
}
Expand All @@ -2155,11 +2203,10 @@ func (p *Pipeline) Close() error {
p.closed = true

if p.pendingSync {
p.conn.asyncClose()
p.err = errors.New("pipeline has unsynced requests")
p.conn.asyncClose(p.err)
p.conn.contextWatcher.Unwatch()
p.conn.unlock()

return p.err
}

Expand All @@ -2169,7 +2216,7 @@ func (p *Pipeline) Close() error {
p.err = err
var pgErr *PgError
if !errors.As(err, &pgErr) {
p.conn.asyncClose()
p.conn.asyncClose(err)
break
}
}
Expand Down
Loading

0 comments on commit 969b244

Please sign in to comment.