From 02eb559eb64b9b1b3fe6eea2bc4395dd77410b40 Mon Sep 17 00:00:00 2001 From: James Hartig <me@jameshartig.com> Date: Tue, 5 Dec 2023 11:28:06 -0600 Subject: [PATCH] pgconn: add OnError to Config for error handling OnError is called on every error response received from Postgres and can be used to close connections on specific errors. Defaults to closing on FATAL-severity errors. Fixes #1803 --- pgconn/config.go | 11 +++++++++++ pgconn/pgconn.go | 11 +++++++++-- pgconn/pgconn_test.go | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/pgconn/config.go b/pgconn/config.go index db0170e02..1540fb4fc 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -60,6 +60,10 @@ type Config struct { // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. OnNotification NotificationHandler + // OnError is a callback function called when an error is received. The default handler will close the connection + // on any FATAL errors. + OnError ErrorHandler + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } @@ -261,6 +265,13 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend { return pgproto3.NewFrontend(r, w) }, + OnError: func(_ *PgConn, pgErr *PgError) bool { + // we want to automatically close any fatal errors + if strings.EqualFold(pgErr.Severity, "FATAL") { + return false + } + return true + }, } if connectTimeoutSetting, present := settings["connect_timeout"]; present { diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 1ccdc4db9..655447c20 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -52,6 +52,12 @@ type LookupFunc func(ctx context.Context, host string) (addrs []string, err erro // BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. type BuildFrontendFunc func(r io.Reader, w io.Writer) *pgproto3.Frontend +// ErrorHandler is a function that handles errors returned from postgres. This function must return true to keep +// the connection open. Returning false will cause the connection to be closed immediately. You should return +// false on any FATAL-severity errors. This will not receive network errors. The *PgConn is provided so the handler is +// aware of the origin of the error, but it must not invoke any query method. +type ErrorHandler func(*PgConn, *PgError) bool + // NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at // any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin // of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY @@ -547,11 +553,12 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { case *pgproto3.ParameterStatus: pgConn.parameterStatuses[msg.Name] = msg.Value case *pgproto3.ErrorResponse: - if msg.Severity == "FATAL" { + err := ErrorResponseToPgError(msg) + if pgConn.config.OnError != nil && !pgConn.config.OnError(pgConn, err) { pgConn.status = connStatusClosed pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return. close(pgConn.cleanupDone) - return nil, ErrorResponseToPgError(msg) + return nil, err } case *pgproto3.NoticeResponse: if pgConn.config.OnNotice != nil { diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 806219302..d7a3c2608 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -3148,6 +3148,41 @@ func TestPipelineCloseDetectsUnsyncedRequests(t *testing.T) { require.EqualError(t, err, "pipeline has unsynced requests") } +func TestConnOnError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.OnError = func(c *pgconn.PgConn, pgErr *pgconn.PgError) bool { + require.NotNil(t, c) + require.NotNil(t, pgErr) + // close connection on undefined tables only + if pgErr.Code == "42P01" { + return false + } + return true + } + + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() + assert.NoError(t, err) + assert.False(t, pgConn.IsClosed()) + + _, err = pgConn.Exec(ctx, "select 1/0").ReadAll() + assert.Error(t, err) + assert.False(t, pgConn.IsClosed()) + + _, err = pgConn.Exec(ctx, "select * from non_existant_table").ReadAll() + assert.Error(t, err) + assert.True(t, pgConn.IsClosed()) +} + func Example() { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel()