diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index c2c8755bd..cc3863ff3 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -935,13 +935,18 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { buf := make([]byte, 16) binary.BigEndian.PutUint32(buf[0:4], 16) binary.BigEndian.PutUint32(buf[4:8], 80877102) - binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.pid)) - binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) - // Postgres will process the request and close the connection - // so when don't need to read the reply - // https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.6.7.10 - _, err = cancelConn.Write(buf) - return err + binary.BigEndian.PutUint32(buf[8:12], pgConn.pid) + binary.BigEndian.PutUint32(buf[12:16], pgConn.secretKey) + + if _, err := cancelConn.Write(buf); err != nil { + return fmt.Errorf("write to connection for cancellation: %w", err) + } + + // Wait for the cancel request to be acknowledged by the server. + // It copies the behavior of the libpq: https://github.com/postgres/postgres/blob/REL_16_0/src/interfaces/libpq/fe-connect.c#L4946-L4960 + _, _ = cancelConn.Read(buf) + + return nil } // WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index 57832a5f3..30cf62ffc 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -17,16 +17,19 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/internal/pgmock" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgtype" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) +const pgbouncerConnStringEnvVar = "PGX_TEST_PGBOUNCER_CONN_STRING" + func TestConnect(t *testing.T) { tests := []struct { name string @@ -2256,18 +2259,44 @@ func TestConnCancelRequest(t *testing.T) { func TestConnContextCanceledCancelsRunningQueryOnServer(t *testing.T) { t.Parallel() + t.Run("postgres", func(t *testing.T) { + t.Parallel() + + testConnContextCanceledCancelsRunningQueryOnServer(t, os.Getenv("PGX_TEST_DATABASE"), "postgres") + }) + + t.Run("pgbouncer", func(t *testing.T) { + t.Parallel() + + connString := os.Getenv(pgbouncerConnStringEnvVar) + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", pgbouncerConnStringEnvVar) + } + + testConnContextCanceledCancelsRunningQueryOnServer(t, connString, "pgbouncer") + }) +} + +func testConnContextCanceledCancelsRunningQueryOnServer(t *testing.T, connString, dbType string) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() - pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + pgConn, err := pgconn.Connect(ctx, connString) require.NoError(t, err) defer closeConn(t, pgConn) - pid := pgConn.PID() - ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond) defer cancel() - multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(30)") + + // Getting the actual PostgreSQL server process ID (PID) from a query executed through pgbouncer is not straightforward + // because pgbouncer abstracts the underlying database connections, and it doesn't expose the PID of the PostgreSQL + // server process to clients. However, we can check if the query is running by checking the generated query ID. + queryID := fmt.Sprintf("%s testConnContextCanceled %d", dbType, time.Now().UnixNano()) + + multiResult := pgConn.Exec(ctx, fmt.Sprintf(` + -- %v + select 'Hello, world', pg_sleep(30) + `, queryID)) for multiResult.NextResult() { } @@ -2283,7 +2312,7 @@ func TestConnContextCanceledCancelsRunningQueryOnServer(t *testing.T) { ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - otherConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + otherConn, err := pgconn.Connect(ctx, connString) require.NoError(t, err) defer closeConn(t, otherConn) @@ -2292,8 +2321,8 @@ func TestConnContextCanceledCancelsRunningQueryOnServer(t *testing.T) { for { result := otherConn.ExecParams(ctx, - `select 1 from pg_stat_activity where pid=$1`, - [][]byte{[]byte(strconv.FormatInt(int64(pid), 10))}, + `select 1 from pg_stat_activity where query like $1`, + [][]byte{[]byte("%" + queryID + "%")}, nil, nil, nil,