diff --git a/pgrx/src/spi.rs b/pgrx/src/spi.rs index 9aa260d040..c74d279dfd 100644 --- a/pgrx/src/spi.rs +++ b/pgrx/src/spi.rs @@ -21,8 +21,8 @@ mod client; mod cursor; mod query; mod tuple; -pub use client::SpiClient; use client::SpiConnection; +pub use client::{SpiClient, SpiTransaction}; pub use cursor::SpiCursor; pub use query::{OwnedPreparedStatement, PreparedStatement, Query}; pub use tuple::{SpiHeapTupleData, SpiHeapTupleDataEntry, SpiTupleTable}; @@ -394,6 +394,24 @@ impl Spi { f(connection.client()) } + /// Execute SPI commands via the provided `SpiClient` on a non-atomic connection. + /// + /// While inside the provided closure, code executes under a short-lived "SPI Memory Context", + /// and Postgres will completely free that context when this function is finished. + /// + /// pgrx' SPI API endeavors to return Datum values from functions like `::get_one()` that are + /// automatically copied into the into the `CurrentMemoryContext` at the time of this + /// function call. + pub fn connect_non_atomic(f: F) -> R + where + F: FnOnce(SpiClient<'_>, SpiTransaction<'_>) -> R, + { + let connection = SpiConnection::connect_non_atomic() + .expect("SPI_connect_ext indicated an unexpected failure"); + + f(connection.client(), connection.transaction()) + } + #[track_caller] pub fn check_status(status_code: i32) -> std::result::Result { match SpiOkCodes::try_from(status_code) { diff --git a/pgrx/src/spi/client.rs b/pgrx/src/spi/client.rs index e32728d743..6d252cb77c 100644 --- a/pgrx/src/spi/client.rs +++ b/pgrx/src/spi/client.rs @@ -204,6 +204,11 @@ impl SpiConnection { Spi::check_status(unsafe { pg_sys::SPI_connect() })?; Ok(SpiConnection(PhantomData)) } + + pub(super) fn connect_non_atomic() -> SpiResult { + Spi::check_status(unsafe { pg_sys::SPI_connect_ext(pg_sys::SPI_OPT_NONATOMIC as i32) })?; + Ok(SpiConnection(PhantomData)) + } } impl Drop for SpiConnection { @@ -221,4 +226,39 @@ impl SpiConnection { pub(super) fn client(&self) -> SpiClient<'_> { SpiClient { __marker: PhantomData } } + + pub(super) fn transaction(&self) -> SpiTransaction<'_> { + SpiTransaction { _conn: PhantomData } + } +} + +/// Represents an SPI transaction. +pub struct SpiTransaction<'conn> { + _conn: PhantomData<&'conn SpiConnection>, +} + +impl<'conn> SpiTransaction<'conn> { + /// Commits back the transaction and starts a new `SpiTransaction` with default transaction characteristics. + pub fn commit(self) -> Self { + unsafe { pg_sys::SPI_commit() }; + self + } + + /// Commits back the transaction and starts a new `SpiTransaction` with the same characteristics as the just finished one. + pub fn commit_and_chain(self) -> Self { + unsafe { pg_sys::SPI_commit_and_chain() }; + self + } + + /// Rolls back the transaction and starts a new `SpiTransaction` with default transaction characteristics. + pub fn rollback(self) -> Self { + unsafe { pg_sys::SPI_rollback() }; + self + } + + /// Rolls back the transaction and starts a new `SpiTransaction` with the same characteristics as the just finished one. + pub fn rollback_and_chain(self) -> Self { + unsafe { pg_sys::SPI_rollback_and_chain() }; + self + } }