Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added connect_mut for data changing SPI operations #1913

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pgrx-examples/schemas/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ mod tests {

#[pg_test]
fn test_my_some_schema_type() -> Result<(), spi::Error> {
Spi::connect(|mut c| {
Spi::connect_mut(|c| {
// "MySomeSchemaType" is in 'some_schema', so it needs to be discoverable
c.update("SET search_path TO some_schema,public", None, &[])?;
assert_eq!(
Expand Down
4 changes: 2 additions & 2 deletions pgrx-tests/src/tests/bgworker_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub extern "C" fn bgworker(arg: pg_sys::Datum) {
if arg > 0 {
BackgroundWorker::transaction(|| {
Spi::run("CREATE TABLE tests.bgworker_test (v INTEGER);")?;
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
client
.update("INSERT INTO tests.bgworker_test VALUES ($1);", None, &[arg.into()])
.map(|_| ())
Expand Down Expand Up @@ -66,7 +66,7 @@ pub extern "C" fn bgworker_return_value(arg: pg_sys::Datum) {
};
while BackgroundWorker::wait_latch(Some(Duration::from_millis(100))) {}
BackgroundWorker::transaction(|| {
Spi::connect(|mut c| {
Spi::connect_mut(|c| {
c.update("INSERT INTO tests.bgworker_test_return VALUES ($1)", None, &[val.into()])
.map(|_| ())
})
Expand Down
2 changes: 1 addition & 1 deletion pgrx-tests/src/tests/guc_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ mod tests {
Spi::run("SET test.no_show TO false;").expect("SPI failed");
Spi::run("SET test.no_reset_all TO false;").expect("SPI failed");
assert_eq!(GUC_NO_RESET_ALL.get(), false);
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
let r = client.update("SHOW ALL", None, &[]).expect("SPI failed");

let mut no_reset_guc_in_show_all = false;
Expand Down
2 changes: 1 addition & 1 deletion pgrx-tests/src/tests/pg_cast_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ mod tests {

#[pg_test]
fn test_pg_cast_assignment_type_cast() {
let _ = Spi::connect(|mut client| {
let _ = Spi::connect_mut(|client| {
client.update("CREATE TABLE test_table(value int4);", None, &[])?;
client.update("INSERT INTO test_table VALUES('{\"a\": 1}'::json->'a');", None, &[])?;

Expand Down
23 changes: 11 additions & 12 deletions pgrx-tests/src/tests/spi_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ mod tests {

#[pg_test]
fn test_inserting_null() -> Result<(), pgrx::spi::Error> {
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
client.update("CREATE TABLE tests.null_test (id uuid)", None, &[]).map(|_| ())
})?;
assert_eq!(
Expand All @@ -188,7 +188,7 @@ mod tests {

#[pg_test]
fn test_cursor() -> Result<(), spi::Error> {
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
client.update("CREATE TABLE tests.cursor_table (id int)", None, &[])?;
client.update(
"INSERT INTO tests.cursor_table (id) \
Expand All @@ -208,7 +208,7 @@ mod tests {

#[pg_test]
fn test_cursor_prepared_statement() -> Result<(), pgrx::spi::Error> {
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
client.update("CREATE TABLE tests.cursor_table (id int)", None, &[])?;
client.update(
"INSERT INTO tests.cursor_table (id) \
Expand Down Expand Up @@ -245,7 +245,7 @@ mod tests {
fn test_cursor_prepared_statement_panics_impl(
args: &[DatumWithOid],
) -> Result<(), pgrx::spi::Error> {
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
client.update("CREATE TABLE tests.cursor_table (id int)", None, &[])?;
client.update(
"INSERT INTO tests.cursor_table (id) \
Expand All @@ -264,7 +264,7 @@ mod tests {

#[pg_test]
fn test_cursor_by_name() -> Result<(), pgrx::spi::Error> {
let cursor_name = Spi::connect(|mut client| {
let cursor_name = Spi::connect_mut(|client| {
client.update("CREATE TABLE tests.cursor_table (id int)", None, &[])?;
client.update(
"INSERT INTO tests.cursor_table (id) \
Expand Down Expand Up @@ -318,7 +318,7 @@ mod tests {
Ok::<_, spi::Error>(())
})?;

Spi::connect(|mut client| {
Spi::connect_mut(|client| {
let res = client.update("SET TIME ZONE 'PST8PDT'", None, &[])?;

assert_eq!(Err(spi::Error::NoTupleTable), res.columns());
Expand All @@ -334,9 +334,8 @@ mod tests {

#[pg_test]
fn test_spi_non_mut() -> Result<(), pgrx::spi::Error> {
// Ensures update and cursor APIs do not need mutable reference to SpiClient
Spi::connect(|mut client| {
client.update("SELECT 1", None, &[]).expect("SPI failed");
// Ensures cursor APIs do not need mutable reference to SpiClient
Spi::connect(|client| {
let cursor = client.open_cursor("SELECT 1", &[]).detach_into_name();
client.find_cursor(&cursor).map(|_| ())
})
Expand Down Expand Up @@ -428,7 +427,7 @@ mod tests {

#[pg_test]
fn test_readwrite_in_select_readwrite() -> Result<(), spi::Error> {
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
// This is supposed to switch connection to read-write and run it there
client.update("CREATE TABLE a (id INT)", None, &[])?;
// This is supposed to run in read-write
Expand Down Expand Up @@ -459,7 +458,7 @@ mod tests {

#[pg_test]
fn test_spi_select_sees_update() -> spi::Result<()> {
let with_select = Spi::connect(|mut client| {
let with_select = Spi::connect_mut(|client| {
client.update("CREATE TABLE asd(id int)", None, &[])?;
client.update("INSERT INTO asd(id) VALUES (1)", None, &[])?;
client.select("SELECT COUNT(*) FROM asd", None, &[])?.first().get_one::<i64>()
Expand All @@ -485,7 +484,7 @@ mod tests {

#[pg_test]
fn test_spi_select_sees_update_in_other_session() -> spi::Result<()> {
Spi::connect::<spi::Result<()>, _>(|mut client| {
Spi::connect_mut::<spi::Result<()>, _>(|client| {
client.update("CREATE TABLE asd(id int)", None, &[])?;
client.update("INSERT INTO asd(id) VALUES (1)", None, &[])?;
Ok(())
Expand Down
4 changes: 2 additions & 2 deletions pgrx-tests/src/tests/srf_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ mod tests {

#[pg_test]
fn test_srf_setof_datum_detoasting_with_borrow() {
let cnt = Spi::connect(|mut client| {
let cnt = Spi::connect_mut(|client| {
// build up a table with one large column that Postgres will be forced to TOAST
client.update("CREATE TABLE test_srf_datum_detoasting AS SELECT array_to_string(array_agg(g),' ') s FROM (SELECT 'a' g FROM generate_series(1, 1000)) x;", None, &[])?;

Expand All @@ -261,7 +261,7 @@ mod tests {

#[pg_test]
fn test_srf_table_datum_detoasting_with_borrow() {
let cnt = Spi::connect(|mut client| {
let cnt = Spi::connect_mut(|client| {
// build up a table with one large column that Postgres will be forced to TOAST
client.update("CREATE TABLE test_srf_datum_detoasting AS SELECT array_to_string(array_agg(g),' ') s FROM (SELECT 'a' g FROM generate_series(1, 1000)) x;", None, &[])?;

Expand Down
2 changes: 1 addition & 1 deletion pgrx-tests/src/tests/struct_type_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ mod tests {

#[pg_test]
fn test_complex_storage_and_retrieval() -> Result<(), pgrx::spi::Error> {
let complex = Spi::connect(|mut client| {
let complex = Spi::connect_mut(|client| {
client.update(
"CREATE TABLE complex_test AS SELECT s as id, (s || '.0, 2.0' || s)::complex as value FROM generate_series(1, 1000) s;\
SELECT value FROM complex_test ORDER BY id;", None, &[])?.first().get_one::<PgBox<Complex>>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ error: lifetime may not live long enough
8 | let mut res = Spi::connect(|c| {
| -- return type of closure is SpiTupleTable<'2>
| |
| has type `SpiClient<'1>`
| has type `&SpiClient<'1>`
9 | / c.open_cursor("select 'hello world' from generate_series(1, 1000)", &[])
10 | | .fetch(1000)
11 | | .unwrap()
Expand All @@ -31,7 +31,7 @@ error: lifetime may not live long enough
| -- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ returning this value requires that `'1` must outlive `'2`
| ||
| |return type of closure is SpiTupleTable<'2>
| has type `SpiClient<'1>`
| has type `&SpiClient<'1>`

error[E0515]: cannot return value referencing temporary value
--> tests/compile-fail/escaping-spiclient-1209-cursor.rs:16:26
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ error: lifetime may not live long enough
| -- ^^^^^^^^^^^^^^^^^ returning this value requires that `'1` must outlive `'2`
| ||
| |return type of closure is std::result::Result<pgrx::spi::PreparedStatement<'2>, pgrx::spi::SpiError>
| has type `SpiClient<'1>`
| has type `&SpiClient<'1>`
75 changes: 57 additions & 18 deletions pgrx/src/spi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ mod cursor;
mod query;
mod tuple;
pub use client::SpiClient;
use client::SpiConnection;
pub use cursor::SpiCursor;
pub use query::{OwnedPreparedStatement, PreparedStatement, Query};
pub use tuple::{SpiHeapTupleData, SpiHeapTupleDataEntry, SpiTupleTable};
Expand Down Expand Up @@ -237,13 +236,13 @@ impl Spi {
}

pub fn get_one<A: FromDatum + IntoDatum>(query: &str) -> Result<Option<A>> {
Spi::connect(|mut client| client.update(query, Some(1), &[])?.first().get_one())
Spi::connect_mut(|client| client.update(query, Some(1), &[])?.first().get_one())
}

pub fn get_two<A: FromDatum + IntoDatum, B: FromDatum + IntoDatum>(
query: &str,
) -> Result<(Option<A>, Option<B>)> {
Spi::connect(|mut client| client.update(query, Some(1), &[])?.first().get_two::<A, B>())
Spi::connect_mut(|client| client.update(query, Some(1), &[])?.first().get_two::<A, B>())
}

pub fn get_three<
Expand All @@ -253,7 +252,7 @@ impl Spi {
>(
query: &str,
) -> Result<(Option<A>, Option<B>, Option<C>)> {
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
client.update(query, Some(1), &[])?.first().get_three::<A, B, C>()
})
}
Expand All @@ -262,14 +261,14 @@ impl Spi {
query: &str,
args: &[DatumWithOid<'mcx>],
) -> Result<Option<A>> {
Spi::connect(|mut client| client.update(query, Some(1), args)?.first().get_one())
Spi::connect_mut(|client| client.update(query, Some(1), args)?.first().get_one())
}

pub fn get_two_with_args<'mcx, A: FromDatum + IntoDatum, B: FromDatum + IntoDatum>(
query: &str,
args: &[DatumWithOid<'mcx>],
) -> Result<(Option<A>, Option<B>)> {
Spi::connect(|mut client| client.update(query, Some(1), args)?.first().get_two::<A, B>())
Spi::connect_mut(|client| client.update(query, Some(1), args)?.first().get_two::<A, B>())
}

pub fn get_three_with_args<
Expand All @@ -281,12 +280,12 @@ impl Spi {
query: &str,
args: &[DatumWithOid<'mcx>],
) -> Result<(Option<A>, Option<B>, Option<C>)> {
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
client.update(query, Some(1), args)?.first().get_three::<A, B, C>()
})
}

/// just run an arbitrary SQL statement.
/// Just run an arbitrary SQL statement.
///
/// ## Safety
///
Expand All @@ -304,7 +303,7 @@ impl Spi {
query: &str,
args: &[DatumWithOid<'mcx>],
) -> std::result::Result<(), Error> {
Spi::connect(|mut client| client.update(query, None, args).map(|_| ()))
Spi::connect_mut(|client| client.update(query, None, args).map(|_| ()))
}

/// explain a query, returning its result in json form
Expand All @@ -314,7 +313,7 @@ impl Spi {

/// explain a query with args, returning its result in json form
pub fn explain_with_args<'mcx>(query: &str, args: &[DatumWithOid<'mcx>]) -> Result<Json> {
Ok(Spi::connect(|mut client| {
Ok(Spi::connect_mut(|client| {
client
.update(&format!("EXPLAIN (format json) {query}"), None, args)?
.first()
Expand All @@ -323,7 +322,7 @@ impl Spi {
.unwrap())
}

/// Execute SPI commands via the provided `SpiClient`.
/// Execute SPI read-only commands via the provided `SpiClient`.
///
/// While inside the provided closure, code executes under a short-lived "SPI Memory Context",
/// and Postgres will completely free that context when this function is finished.
Expand Down Expand Up @@ -360,10 +359,51 @@ impl Spi {
/// ([`pg_sys::SPI_connect()`]) **always** returns a successful response.
pub fn connect<R, F>(f: F) -> R
where
F: FnOnce(SpiClient<'_>) -> R, /* TODO: redesign this with 2 lifetimes:
- 'conn ~= CurrentMemoryContext after connection
- 'ret ~= SPI_palloc's context
*/
F: FnOnce(&SpiClient<'_>) -> R,
{
Self::connect_mut(|client| f(client))
}

/// Execute SPI mutating commands via the provided `SpiClient`.
///
/// While inside the provided closure, code executes under a short-lived "SPI Memory Context",
/// and Postgres will completely free that context when this function is finished.
///
/// pgrx' SPI API endeavors to return Datum values from functions like `::get_one()` that are
/// automatically copied into the into the `CurrentMemoryContext` at the time of this
/// function call.
///
/// # Examples
///
/// ```rust,no_run
/// use pgrx::prelude::*;
/// # fn foo() -> spi::Result<()> {
/// Spi::connect_mut(|client| {
/// client.update("INSERT INTO users VALUES ('Bob')", None, &[])?;
/// Ok(())
/// })
/// # }
/// ```
///
/// Note that `SpiClient` is scoped to the connection lifetime and cannot be returned. The
/// following code will not compile:
///
/// ```rust,compile_fail
/// use pgrx::prelude::*;
/// let cant_return_client = Spi::connect(|client| client);
/// ```
///
/// # Panics
///
/// This function will panic if for some reason it's unable to "connect" to Postgres' SPI
/// system. At the time of this writing, that's actually impossible as the underlying function
/// ([`pg_sys::SPI_connect()`]) **always** returns a successful response.
pub fn connect_mut<R, F>(f: F) -> R
where
F: FnOnce(&mut SpiClient<'_>) -> R, /* TODO: redesign this with 2 lifetimes:
- 'conn ~= CurrentMemoryContext after connection
- 'ret ~= SPI_palloc's context
*/
{
// connect to SPI
//
Expand All @@ -379,14 +419,13 @@ impl Spi {
// otherwise this function would need to return a `Result<R, spi::Error>` and that's a
// fucking nightmare for users to deal with. There's ample discussion around coming to
// this decision at https://github.com/pgcentralfoundation/pgrx/pull/977
let connection =
SpiConnection::connect().expect("SPI_connect indicated an unexpected failure");
let mut client = SpiClient::connect().expect("SPI_connect indicated an unexpected failure");

// run the provided closure within the memory context that SPI_connect()
// just put us un. We'll disconnect from SPI when the closure is finished.
// If there's a panic or elog(ERROR), we don't care about also disconnecting from
// SPI b/c Postgres will do that for us automatically
f(connection.client())
f(&mut client)
}

#[track_caller]
Expand Down
Loading
Loading