Skip to content

Commit

Permalink
Added connect_mut for data changing SPI operations
Browse files Browse the repository at this point in the history
  • Loading branch information
YohDeadfall committed Nov 21, 2024
1 parent c1e5dd9 commit 309ab16
Show file tree
Hide file tree
Showing 11 changed files with 93 additions and 70 deletions.
4 changes: 2 additions & 2 deletions pgrx-tests/src/tests/bgworker_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub extern "C" fn bgworker(arg: pg_sys::Datum) {
if arg > 0 {
BackgroundWorker::transaction(|| {
Spi::run("CREATE TABLE tests.bgworker_test (v INTEGER);")?;
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
client
.update("INSERT INTO tests.bgworker_test VALUES ($1);", None, &[arg.into()])
.map(|_| ())
Expand Down Expand Up @@ -66,7 +66,7 @@ pub extern "C" fn bgworker_return_value(arg: pg_sys::Datum) {
};
while BackgroundWorker::wait_latch(Some(Duration::from_millis(100))) {}
BackgroundWorker::transaction(|| {
Spi::connect(|mut c| {
Spi::connect_mut(|c| {
c.update("INSERT INTO tests.bgworker_test_return VALUES ($1)", None, &[val.into()])
.map(|_| ())
})
Expand Down
2 changes: 1 addition & 1 deletion pgrx-tests/src/tests/guc_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ mod tests {
Spi::run("SET test.no_show TO false;").expect("SPI failed");
Spi::run("SET test.no_reset_all TO false;").expect("SPI failed");
assert_eq!(GUC_NO_RESET_ALL.get(), false);
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
let r = client.update("SHOW ALL", None, &[]).expect("SPI failed");

let mut no_reset_guc_in_show_all = false;
Expand Down
2 changes: 1 addition & 1 deletion pgrx-tests/src/tests/pg_cast_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ mod tests {

#[pg_test]
fn test_pg_cast_assignment_type_cast() {
let _ = Spi::connect(|mut client| {
let _ = Spi::connect_mut(|client| {
client.update("CREATE TABLE test_table(value int4);", None, &[])?;
client.update("INSERT INTO test_table VALUES('{\"a\": 1}'::json->'a');", None, &[])?;

Expand Down
23 changes: 11 additions & 12 deletions pgrx-tests/src/tests/spi_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ mod tests {

#[pg_test]
fn test_inserting_null() -> Result<(), pgrx::spi::Error> {
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
client.update("CREATE TABLE tests.null_test (id uuid)", None, &[]).map(|_| ())
})?;
assert_eq!(
Expand All @@ -188,7 +188,7 @@ mod tests {

#[pg_test]
fn test_cursor() -> Result<(), spi::Error> {
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
client.update("CREATE TABLE tests.cursor_table (id int)", None, &[])?;
client.update(
"INSERT INTO tests.cursor_table (id) \
Expand All @@ -208,7 +208,7 @@ mod tests {

#[pg_test]
fn test_cursor_prepared_statement() -> Result<(), pgrx::spi::Error> {
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
client.update("CREATE TABLE tests.cursor_table (id int)", None, &[])?;
client.update(
"INSERT INTO tests.cursor_table (id) \
Expand Down Expand Up @@ -245,7 +245,7 @@ mod tests {
fn test_cursor_prepared_statement_panics_impl(
args: &[DatumWithOid],
) -> Result<(), pgrx::spi::Error> {
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
client.update("CREATE TABLE tests.cursor_table (id int)", None, &[])?;
client.update(
"INSERT INTO tests.cursor_table (id) \
Expand All @@ -264,7 +264,7 @@ mod tests {

#[pg_test]
fn test_cursor_by_name() -> Result<(), pgrx::spi::Error> {
let cursor_name = Spi::connect(|mut client| {
let cursor_name = Spi::connect_mut(|client| {
client.update("CREATE TABLE tests.cursor_table (id int)", None, &[])?;
client.update(
"INSERT INTO tests.cursor_table (id) \
Expand Down Expand Up @@ -318,7 +318,7 @@ mod tests {
Ok::<_, spi::Error>(())
})?;

Spi::connect(|mut client| {
Spi::connect_mut(|client| {
let res = client.update("SET TIME ZONE 'PST8PDT'", None, &[])?;

assert_eq!(Err(spi::Error::NoTupleTable), res.columns());
Expand All @@ -334,9 +334,8 @@ mod tests {

#[pg_test]
fn test_spi_non_mut() -> Result<(), pgrx::spi::Error> {
// Ensures update and cursor APIs do not need mutable reference to SpiClient
Spi::connect(|mut client| {
client.update("SELECT 1", None, &[]).expect("SPI failed");
// Ensures cursor APIs do not need mutable reference to SpiClient
Spi::connect(|client| {
let cursor = client.open_cursor("SELECT 1", &[]).detach_into_name();
client.find_cursor(&cursor).map(|_| ())
})
Expand Down Expand Up @@ -428,7 +427,7 @@ mod tests {

#[pg_test]
fn test_readwrite_in_select_readwrite() -> Result<(), spi::Error> {
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
// This is supposed to switch connection to read-write and run it there
client.update("CREATE TABLE a (id INT)", None, &[])?;
// This is supposed to run in read-write
Expand Down Expand Up @@ -459,7 +458,7 @@ mod tests {

#[pg_test]
fn test_spi_select_sees_update() -> spi::Result<()> {
let with_select = Spi::connect(|mut client| {
let with_select = Spi::connect_mut(|client| {
client.update("CREATE TABLE asd(id int)", None, &[])?;
client.update("INSERT INTO asd(id) VALUES (1)", None, &[])?;
client.select("SELECT COUNT(*) FROM asd", None, &[])?.first().get_one::<i64>()
Expand All @@ -485,7 +484,7 @@ mod tests {

#[pg_test]
fn test_spi_select_sees_update_in_other_session() -> spi::Result<()> {
Spi::connect::<spi::Result<()>, _>(|mut client| {
Spi::connect_mut::<spi::Result<()>, _>(|client| {
client.update("CREATE TABLE asd(id int)", None, &[])?;
client.update("INSERT INTO asd(id) VALUES (1)", None, &[])?;
Ok(())
Expand Down
4 changes: 2 additions & 2 deletions pgrx-tests/src/tests/srf_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ mod tests {

#[pg_test]
fn test_srf_setof_datum_detoasting_with_borrow() {
let cnt = Spi::connect(|mut client| {
let cnt = Spi::connect_mut(|client| {
// build up a table with one large column that Postgres will be forced to TOAST
client.update("CREATE TABLE test_srf_datum_detoasting AS SELECT array_to_string(array_agg(g),' ') s FROM (SELECT 'a' g FROM generate_series(1, 1000)) x;", None, &[])?;

Expand All @@ -261,7 +261,7 @@ mod tests {

#[pg_test]
fn test_srf_table_datum_detoasting_with_borrow() {
let cnt = Spi::connect(|mut client| {
let cnt = Spi::connect_mut(|client| {
// build up a table with one large column that Postgres will be forced to TOAST
client.update("CREATE TABLE test_srf_datum_detoasting AS SELECT array_to_string(array_agg(g),' ') s FROM (SELECT 'a' g FROM generate_series(1, 1000)) x;", None, &[])?;

Expand Down
2 changes: 1 addition & 1 deletion pgrx-tests/src/tests/struct_type_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ mod tests {

#[pg_test]
fn test_complex_storage_and_retrieval() -> Result<(), pgrx::spi::Error> {
let complex = Spi::connect(|mut client| {
let complex = Spi::connect_mut(|client| {
client.update(
"CREATE TABLE complex_test AS SELECT s as id, (s || '.0, 2.0' || s)::complex as value FROM generate_series(1, 1000) s;\
SELECT value FROM complex_test ORDER BY id;", None, &[])?.first().get_one::<PgBox<Complex>>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ error: lifetime may not live long enough
8 | let mut res = Spi::connect(|c| {
| -- return type of closure is SpiTupleTable<'2>
| |
| has type `SpiClient<'1>`
| has type `&SpiClient<'1>`
9 | / c.open_cursor("select 'hello world' from generate_series(1, 1000)", &[])
10 | | .fetch(1000)
11 | | .unwrap()
Expand All @@ -31,7 +31,7 @@ error: lifetime may not live long enough
| -- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ returning this value requires that `'1` must outlive `'2`
| ||
| |return type of closure is SpiTupleTable<'2>
| has type `SpiClient<'1>`
| has type `&SpiClient<'1>`

error[E0515]: cannot return value referencing temporary value
--> tests/compile-fail/escaping-spiclient-1209-cursor.rs:16:26
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ error: lifetime may not live long enough
| -- ^^^^^^^^^^^^^^^^^ returning this value requires that `'1` must outlive `'2`
| ||
| |return type of closure is std::result::Result<pgrx::spi::PreparedStatement<'2>, pgrx::spi::SpiError>
| has type `SpiClient<'1>`
| has type `&SpiClient<'1>`
75 changes: 57 additions & 18 deletions pgrx/src/spi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ mod cursor;
mod query;
mod tuple;
pub use client::SpiClient;
use client::SpiConnection;
pub use cursor::SpiCursor;
pub use query::{OwnedPreparedStatement, PreparedStatement, Query};
pub use tuple::{SpiHeapTupleData, SpiHeapTupleDataEntry, SpiTupleTable};
Expand Down Expand Up @@ -237,13 +236,13 @@ impl Spi {
}

pub fn get_one<A: FromDatum + IntoDatum>(query: &str) -> Result<Option<A>> {
Spi::connect(|mut client| client.update(query, Some(1), &[])?.first().get_one())
Spi::connect_mut(|client| client.update(query, Some(1), &[])?.first().get_one())
}

pub fn get_two<A: FromDatum + IntoDatum, B: FromDatum + IntoDatum>(
query: &str,
) -> Result<(Option<A>, Option<B>)> {
Spi::connect(|mut client| client.update(query, Some(1), &[])?.first().get_two::<A, B>())
Spi::connect_mut(|client| client.update(query, Some(1), &[])?.first().get_two::<A, B>())
}

pub fn get_three<
Expand All @@ -253,7 +252,7 @@ impl Spi {
>(
query: &str,
) -> Result<(Option<A>, Option<B>, Option<C>)> {
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
client.update(query, Some(1), &[])?.first().get_three::<A, B, C>()
})
}
Expand All @@ -262,14 +261,14 @@ impl Spi {
query: &str,
args: &[DatumWithOid<'mcx>],
) -> Result<Option<A>> {
Spi::connect(|mut client| client.update(query, Some(1), args)?.first().get_one())
Spi::connect_mut(|client| client.update(query, Some(1), args)?.first().get_one())
}

pub fn get_two_with_args<'mcx, A: FromDatum + IntoDatum, B: FromDatum + IntoDatum>(
query: &str,
args: &[DatumWithOid<'mcx>],
) -> Result<(Option<A>, Option<B>)> {
Spi::connect(|mut client| client.update(query, Some(1), args)?.first().get_two::<A, B>())
Spi::connect_mut(|client| client.update(query, Some(1), args)?.first().get_two::<A, B>())
}

pub fn get_three_with_args<
Expand All @@ -281,12 +280,12 @@ impl Spi {
query: &str,
args: &[DatumWithOid<'mcx>],
) -> Result<(Option<A>, Option<B>, Option<C>)> {
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
client.update(query, Some(1), args)?.first().get_three::<A, B, C>()
})
}

/// just run an arbitrary SQL statement.
/// Just run an arbitrary SQL statement.
///
/// ## Safety
///
Expand All @@ -304,7 +303,7 @@ impl Spi {
query: &str,
args: &[DatumWithOid<'mcx>],
) -> std::result::Result<(), Error> {
Spi::connect(|mut client| client.update(query, None, args).map(|_| ()))
Spi::connect_mut(|client| client.update(query, None, args).map(|_| ()))
}

/// explain a query, returning its result in json form
Expand All @@ -314,7 +313,7 @@ impl Spi {

/// explain a query with args, returning its result in json form
pub fn explain_with_args<'mcx>(query: &str, args: &[DatumWithOid<'mcx>]) -> Result<Json> {
Ok(Spi::connect(|mut client| {
Ok(Spi::connect_mut(|client| {
client
.update(&format!("EXPLAIN (format json) {query}"), None, args)?
.first()
Expand All @@ -323,7 +322,7 @@ impl Spi {
.unwrap())
}

/// Execute SPI commands via the provided `SpiClient`.
/// Execute SPI read-only commands via the provided `SpiClient`.
///
/// While inside the provided closure, code executes under a short-lived "SPI Memory Context",
/// and Postgres will completely free that context when this function is finished.
Expand Down Expand Up @@ -360,10 +359,51 @@ impl Spi {
/// ([`pg_sys::SPI_connect()`]) **always** returns a successful response.
pub fn connect<R, F>(f: F) -> R
where
F: FnOnce(SpiClient<'_>) -> R, /* TODO: redesign this with 2 lifetimes:
- 'conn ~= CurrentMemoryContext after connection
- 'ret ~= SPI_palloc's context
*/
F: FnOnce(&SpiClient<'_>) -> R,
{
Self::connect_mut(|client| f(client))
}

/// Execute SPI mutating commands via the provided `SpiClient`.
///
/// While inside the provided closure, code executes under a short-lived "SPI Memory Context",
/// and Postgres will completely free that context when this function is finished.
///
/// pgrx' SPI API endeavors to return Datum values from functions like `::get_one()` that are
/// automatically copied into the into the `CurrentMemoryContext` at the time of this
/// function call.
///
/// # Examples
///
/// ```rust,no_run
/// use pgrx::prelude::*;
/// # fn foo() -> spi::Result<()> {
/// Spi::connect_mut(|client| {
/// client.update("INSERT INTO users VALUES ('Bob')", None, &[])?;
/// Ok(())
/// })
/// # }
/// ```
///
/// Note that `SpiClient` is scoped to the connection lifetime and cannot be returned. The
/// following code will not compile:
///
/// ```rust,compile_fail
/// use pgrx::prelude::*;
/// let cant_return_client = Spi::connect(|client| client);
/// ```
///
/// # Panics
///
/// This function will panic if for some reason it's unable to "connect" to Postgres' SPI
/// system. At the time of this writing, that's actually impossible as the underlying function
/// ([`pg_sys::SPI_connect()`]) **always** returns a successful response.
pub fn connect_mut<R, F>(f: F) -> R
where
F: FnOnce(&mut SpiClient<'_>) -> R, /* TODO: redesign this with 2 lifetimes:
- 'conn ~= CurrentMemoryContext after connection
- 'ret ~= SPI_palloc's context
*/
{
// connect to SPI
//
Expand All @@ -379,14 +419,13 @@ impl Spi {
// otherwise this function would need to return a `Result<R, spi::Error>` and that's a
// fucking nightmare for users to deal with. There's ample discussion around coming to
// this decision at https://github.com/pgcentralfoundation/pgrx/pull/977
let connection =
SpiConnection::connect().expect("SPI_connect indicated an unexpected failure");
let mut client = SpiClient::connect().expect("SPI_connect indicated an unexpected failure");

// run the provided closure within the memory context that SPI_connect()
// just put us un. We'll disconnect from SPI when the closure is finished.
// If there's a panic or elog(ERROR), we don't care about also disconnecting from
// SPI b/c Postgres will do that for us automatically
f(connection.client())
f(&mut client)
}

#[track_caller]
Expand Down
Loading

0 comments on commit 309ab16

Please sign in to comment.