Skip to content

Commit

Permalink
pgconn: add OnPGError to Config for error handling
Browse files Browse the repository at this point in the history
OnPGError is called on every error response received from Postgres and can
be used to close connections on specific errors. Defaults to closing on
FATAL-severity errors.

Fixes #1803
  • Loading branch information
jameshartig committed Dec 11, 2023
1 parent 913e4c8 commit c39842d
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 2 deletions.
12 changes: 12 additions & 0 deletions pgconn/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ type Config struct {
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
OnNotification NotificationHandler

// OnPGError is a callback function called when a Postgres error is received by the server. The default handler will close
// the connection on any FATAL errors. If you override this handler you should call the previously set handler or ensure
// that you close on FATAL errors by returning false.
OnPGError ErrorPGHandler

createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}

Expand Down Expand Up @@ -261,6 +266,13 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend {
return pgproto3.NewFrontend(r, w)
},
OnPGError: func(_ *PgConn, pgErr *PgError) bool {
// we want to automatically close any fatal errors
if strings.EqualFold(pgErr.Severity, "FATAL") {
return false
}
return true
},
}

if connectTimeoutSetting, present := settings["connect_timeout"]; present {
Expand Down
11 changes: 9 additions & 2 deletions pgconn/pgconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ type LookupFunc func(ctx context.Context, host string) (addrs []string, err erro
// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection.
type BuildFrontendFunc func(r io.Reader, w io.Writer) *pgproto3.Frontend

// ErrorPGHandler is a function that handles errors returned from Postgres. This function must return true to keep
// the connection open. Returning false will cause the connection to be closed immediately. You should return
// false on any FATAL-severity errors. This will not receive network errors. The *PgConn is provided so the handler is
// aware of the origin of the error, but it must not invoke any query method.
type ErrorPGHandler func(*PgConn, *PgError) bool

// NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at
// any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin
// of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY
Expand Down Expand Up @@ -547,11 +553,12 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
case *pgproto3.ParameterStatus:
pgConn.parameterStatuses[msg.Name] = msg.Value
case *pgproto3.ErrorResponse:
if msg.Severity == "FATAL" {
err := ErrorResponseToPgError(msg)
if pgConn.config.OnPGError != nil && !pgConn.config.OnPGError(pgConn, err) {
pgConn.status = connStatusClosed
pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return.
close(pgConn.cleanupDone)
return nil, ErrorResponseToPgError(msg)
return nil, err
}
case *pgproto3.NoticeResponse:
if pgConn.config.OnNotice != nil {
Expand Down
35 changes: 35 additions & 0 deletions pgconn/pgconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3148,6 +3148,41 @@ func TestPipelineCloseDetectsUnsyncedRequests(t *testing.T) {
require.EqualError(t, err, "pipeline has unsynced requests")
}

func TestConnOnPGError(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
config.OnPGError = func(c *pgconn.PgConn, pgErr *pgconn.PgError) bool {
require.NotNil(t, c)
require.NotNil(t, pgErr)
// close connection on undefined tables only
if pgErr.Code == "42P01" {
return false
}
return true
}

pgConn, err := pgconn.ConnectConfig(ctx, config)
require.NoError(t, err)
defer closeConn(t, pgConn)

_, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
assert.NoError(t, err)
assert.False(t, pgConn.IsClosed())

_, err = pgConn.Exec(ctx, "select 1/0").ReadAll()
assert.Error(t, err)
assert.False(t, pgConn.IsClosed())

_, err = pgConn.Exec(ctx, "select * from non_existant_table").ReadAll()
assert.Error(t, err)
assert.True(t, pgConn.IsClosed())
}

func Example() {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
Expand Down

0 comments on commit c39842d

Please sign in to comment.