Skip to content

Commit

Permalink
pgconn: add OnError to Config for error handling
Browse files Browse the repository at this point in the history
OnError is called on every error response received from Postgres and can
be used to close connections on specific errors. Defaults to closing on
FATAL-severity errors.

Fixes #1803
  • Loading branch information
jameshartig committed Dec 5, 2023
1 parent 913e4c8 commit 02eb559
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 2 deletions.
11 changes: 11 additions & 0 deletions pgconn/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ type Config struct {
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
OnNotification NotificationHandler

// OnError is a callback function called when an error is received. The default handler will close the connection
// on any FATAL errors.
OnError ErrorHandler

createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}

Expand Down Expand Up @@ -261,6 +265,13 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend {
return pgproto3.NewFrontend(r, w)
},
OnError: func(_ *PgConn, pgErr *PgError) bool {
// we want to automatically close any fatal errors
if strings.EqualFold(pgErr.Severity, "FATAL") {
return false
}
return true
},
}

if connectTimeoutSetting, present := settings["connect_timeout"]; present {
Expand Down
11 changes: 9 additions & 2 deletions pgconn/pgconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ type LookupFunc func(ctx context.Context, host string) (addrs []string, err erro
// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection.
type BuildFrontendFunc func(r io.Reader, w io.Writer) *pgproto3.Frontend

// ErrorHandler is a function that handles errors returned from postgres. This function must return true to keep
// the connection open. Returning false will cause the connection to be closed immediately. You should return
// false on any FATAL-severity errors. This will not receive network errors. The *PgConn is provided so the handler is
// aware of the origin of the error, but it must not invoke any query method.
type ErrorHandler func(*PgConn, *PgError) bool

// NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at
// any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin
// of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY
Expand Down Expand Up @@ -547,11 +553,12 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
case *pgproto3.ParameterStatus:
pgConn.parameterStatuses[msg.Name] = msg.Value
case *pgproto3.ErrorResponse:
if msg.Severity == "FATAL" {
err := ErrorResponseToPgError(msg)
if pgConn.config.OnError != nil && !pgConn.config.OnError(pgConn, err) {
pgConn.status = connStatusClosed
pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return.
close(pgConn.cleanupDone)
return nil, ErrorResponseToPgError(msg)
return nil, err
}
case *pgproto3.NoticeResponse:
if pgConn.config.OnNotice != nil {
Expand Down
35 changes: 35 additions & 0 deletions pgconn/pgconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3148,6 +3148,41 @@ func TestPipelineCloseDetectsUnsyncedRequests(t *testing.T) {
require.EqualError(t, err, "pipeline has unsynced requests")
}

func TestConnOnError(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
config.OnError = func(c *pgconn.PgConn, pgErr *pgconn.PgError) bool {
require.NotNil(t, c)
require.NotNil(t, pgErr)
// close connection on undefined tables only
if pgErr.Code == "42P01" {
return false
}
return true
}

pgConn, err := pgconn.ConnectConfig(ctx, config)
require.NoError(t, err)
defer closeConn(t, pgConn)

_, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
assert.NoError(t, err)
assert.False(t, pgConn.IsClosed())

_, err = pgConn.Exec(ctx, "select 1/0").ReadAll()
assert.Error(t, err)
assert.False(t, pgConn.IsClosed())

_, err = pgConn.Exec(ctx, "select * from non_existant_table").ReadAll()
assert.Error(t, err)
assert.True(t, pgConn.IsClosed())
}

func Example() {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
Expand Down

0 comments on commit 02eb559

Please sign in to comment.