From b1631e8e35b681452a9a8787757e5dbeebf987b7 Mon Sep 17 00:00:00 2001 From: James Hartig <me@jameshartig.com> Date: Tue, 5 Dec 2023 11:28:06 -0600 Subject: [PATCH] pgconn: add OnPGError to Config for error handling OnPGError is called on every error response received from Postgres and can be used to close connections on specific errors. Defaults to closing on FATAL-severity errors. Fixes #1803 --- pgconn/config.go | 12 ++++++++++++ pgconn/pgconn.go | 11 +++++++++-- pgconn/pgconn_test.go | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 2 deletions(-) diff --git a/pgconn/config.go b/pgconn/config.go index db0170e02..157b80984 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -60,6 +60,11 @@ type Config struct { // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. OnNotification NotificationHandler + // OnPGError is a callback function called when a Postgres error is received by the server. The default handler will close + // the connection on any FATAL errors. If you override this handler you should call the previously set handler or ensure + // that you close on FATAL errors by returning false. + OnPGError ErrorPGHandler + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } @@ -261,6 +266,13 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend { return pgproto3.NewFrontend(r, w) }, + OnPGError: func(_ *PgConn, pgErr *PgError) bool { + // we want to automatically close any fatal errors + if strings.EqualFold(pgErr.Severity, "FATAL") { + return false + } + return true + }, } if connectTimeoutSetting, present := settings["connect_timeout"]; present { diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 1ccdc4db9..71d8e50e1 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -52,6 +52,12 @@ type LookupFunc func(ctx context.Context, host string) (addrs []string, err erro // BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. type BuildFrontendFunc func(r io.Reader, w io.Writer) *pgproto3.Frontend +// ErrorPGHandler is a function that handles errors returned from Postgres. This function must return true to keep +// the connection open. Returning false will cause the connection to be closed immediately. You should return +// false on any FATAL-severity errors. This will not receive network errors. The *PgConn is provided so the handler is +// aware of the origin of the error, but it must not invoke any query method. +type ErrorPGHandler func(*PgConn, *PgError) bool + // NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at // any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin // of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY @@ -547,11 +553,12 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { case *pgproto3.ParameterStatus: pgConn.parameterStatuses[msg.Name] = msg.Value case *pgproto3.ErrorResponse: - if msg.Severity == "FATAL" { + err := ErrorResponseToPgError(msg) + if pgConn.config.OnPGError != nil && !pgConn.config.OnPGError(pgConn, err) { pgConn.status = connStatusClosed pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return. close(pgConn.cleanupDone) - return nil, ErrorResponseToPgError(msg) + return nil, err } case *pgproto3.NoticeResponse: if pgConn.config.OnNotice != nil { diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 806219302..c1d9ae18c 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -3148,6 +3148,41 @@ func TestPipelineCloseDetectsUnsyncedRequests(t *testing.T) { require.EqualError(t, err, "pipeline has unsynced requests") } +func TestConnOnPGError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.OnPGError = func(c *pgconn.PgConn, pgErr *pgconn.PgError) bool { + require.NotNil(t, c) + require.NotNil(t, pgErr) + // close connection on undefined tables only + if pgErr.Code == "42P01" { + return false + } + return true + } + + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() + assert.NoError(t, err) + assert.False(t, pgConn.IsClosed()) + + _, err = pgConn.Exec(ctx, "select 1/0").ReadAll() + assert.Error(t, err) + assert.False(t, pgConn.IsClosed()) + + _, err = pgConn.Exec(ctx, "select * from non_existant_table").ReadAll() + assert.Error(t, err) + assert.True(t, pgConn.IsClosed()) +} + func Example() { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel()