diff --git a/pgconn/config.go b/pgconn/config.go index db0170e02..157b80984 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -60,6 +60,11 @@ type Config struct { // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. OnNotification NotificationHandler + // OnPGError is a callback function called when a Postgres error is received by the server. The default handler will close + // the connection on any FATAL errors. If you override this handler you should call the previously set handler or ensure + // that you close on FATAL errors by returning false. + OnPGError ErrorPGHandler + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } @@ -261,6 +266,13 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend { return pgproto3.NewFrontend(r, w) }, + OnPGError: func(_ *PgConn, pgErr *PgError) bool { + // we want to automatically close any fatal errors + if strings.EqualFold(pgErr.Severity, "FATAL") { + return false + } + return true + }, } if connectTimeoutSetting, present := settings["connect_timeout"]; present { diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 1ccdc4db9..71d8e50e1 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -52,6 +52,12 @@ type LookupFunc func(ctx context.Context, host string) (addrs []string, err erro // BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. type BuildFrontendFunc func(r io.Reader, w io.Writer) *pgproto3.Frontend +// ErrorPGHandler is a function that handles errors returned from Postgres. This function must return true to keep +// the connection open. Returning false will cause the connection to be closed immediately. You should return +// false on any FATAL-severity errors. This will not receive network errors. The *PgConn is provided so the handler is +// aware of the origin of the error, but it must not invoke any query method. +type ErrorPGHandler func(*PgConn, *PgError) bool + // NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at // any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin // of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY @@ -547,11 +553,12 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) { case *pgproto3.ParameterStatus: pgConn.parameterStatuses[msg.Name] = msg.Value case *pgproto3.ErrorResponse: - if msg.Severity == "FATAL" { + err := ErrorResponseToPgError(msg) + if pgConn.config.OnPGError != nil && !pgConn.config.OnPGError(pgConn, err) { pgConn.status = connStatusClosed pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return. close(pgConn.cleanupDone) - return nil, ErrorResponseToPgError(msg) + return nil, err } case *pgproto3.NoticeResponse: if pgConn.config.OnNotice != nil { diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 806219302..c1d9ae18c 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -3148,6 +3148,41 @@ func TestPipelineCloseDetectsUnsyncedRequests(t *testing.T) { require.EqualError(t, err, "pipeline has unsynced requests") } +func TestConnOnPGError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.OnPGError = func(c *pgconn.PgConn, pgErr *pgconn.PgError) bool { + require.NotNil(t, c) + require.NotNil(t, pgErr) + // close connection on undefined tables only + if pgErr.Code == "42P01" { + return false + } + return true + } + + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() + assert.NoError(t, err) + assert.False(t, pgConn.IsClosed()) + + _, err = pgConn.Exec(ctx, "select 1/0").ReadAll() + assert.Error(t, err) + assert.False(t, pgConn.IsClosed()) + + _, err = pgConn.Exec(ctx, "select * from non_existant_table").ReadAll() + assert.Error(t, err) + assert.True(t, pgConn.IsClosed()) +} + func Example() { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel()