Skip to content

Commit 4d8aa98

Browse files
committed
Added connect_mut for data changing SPI operations
1 parent c1e5dd9 commit 4d8aa98

File tree

12 files changed

+94
-71
lines changed

12 files changed

+94
-71
lines changed

pgrx-examples/schemas/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ mod tests {
101101

102102
#[pg_test]
103103
fn test_my_some_schema_type() -> Result<(), spi::Error> {
104-
Spi::connect(|mut c| {
104+
Spi::connect_mut(|c| {
105105
// "MySomeSchemaType" is in 'some_schema', so it needs to be discoverable
106106
c.update("SET search_path TO some_schema,public", None, &[])?;
107107
assert_eq!(

pgrx-tests/src/tests/bgworker_tests.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ pub extern "C" fn bgworker(arg: pg_sys::Datum) {
2525
if arg > 0 {
2626
BackgroundWorker::transaction(|| {
2727
Spi::run("CREATE TABLE tests.bgworker_test (v INTEGER);")?;
28-
Spi::connect(|mut client| {
28+
Spi::connect_mut(|client| {
2929
client
3030
.update("INSERT INTO tests.bgworker_test VALUES ($1);", None, &[arg.into()])
3131
.map(|_| ())
@@ -66,7 +66,7 @@ pub extern "C" fn bgworker_return_value(arg: pg_sys::Datum) {
6666
};
6767
while BackgroundWorker::wait_latch(Some(Duration::from_millis(100))) {}
6868
BackgroundWorker::transaction(|| {
69-
Spi::connect(|mut c| {
69+
Spi::connect_mut(|c| {
7070
c.update("INSERT INTO tests.bgworker_test_return VALUES ($1)", None, &[val.into()])
7171
.map(|_| ())
7272
})

pgrx-tests/src/tests/guc_tests.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ mod tests {
202202
Spi::run("SET test.no_show TO false;").expect("SPI failed");
203203
Spi::run("SET test.no_reset_all TO false;").expect("SPI failed");
204204
assert_eq!(GUC_NO_RESET_ALL.get(), false);
205-
Spi::connect(|mut client| {
205+
Spi::connect_mut(|client| {
206206
let r = client.update("SHOW ALL", None, &[]).expect("SPI failed");
207207

208208
let mut no_reset_guc_in_show_all = false;

pgrx-tests/src/tests/pg_cast_tests.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ mod tests {
5757

5858
#[pg_test]
5959
fn test_pg_cast_assignment_type_cast() {
60-
let _ = Spi::connect(|mut client| {
60+
let _ = Spi::connect_mut(|client| {
6161
client.update("CREATE TABLE test_table(value int4);", None, &[])?;
6262
client.update("INSERT INTO test_table VALUES('{\"a\": 1}'::json->'a');", None, &[])?;
6363

pgrx-tests/src/tests/spi_tests.rs

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ mod tests {
165165

166166
#[pg_test]
167167
fn test_inserting_null() -> Result<(), pgrx::spi::Error> {
168-
Spi::connect(|mut client| {
168+
Spi::connect_mut(|client| {
169169
client.update("CREATE TABLE tests.null_test (id uuid)", None, &[]).map(|_| ())
170170
})?;
171171
assert_eq!(
@@ -188,7 +188,7 @@ mod tests {
188188

189189
#[pg_test]
190190
fn test_cursor() -> Result<(), spi::Error> {
191-
Spi::connect(|mut client| {
191+
Spi::connect_mut(|client| {
192192
client.update("CREATE TABLE tests.cursor_table (id int)", None, &[])?;
193193
client.update(
194194
"INSERT INTO tests.cursor_table (id) \
@@ -208,7 +208,7 @@ mod tests {
208208

209209
#[pg_test]
210210
fn test_cursor_prepared_statement() -> Result<(), pgrx::spi::Error> {
211-
Spi::connect(|mut client| {
211+
Spi::connect_mut(|client| {
212212
client.update("CREATE TABLE tests.cursor_table (id int)", None, &[])?;
213213
client.update(
214214
"INSERT INTO tests.cursor_table (id) \
@@ -245,7 +245,7 @@ mod tests {
245245
fn test_cursor_prepared_statement_panics_impl(
246246
args: &[DatumWithOid],
247247
) -> Result<(), pgrx::spi::Error> {
248-
Spi::connect(|mut client| {
248+
Spi::connect_mut(|client| {
249249
client.update("CREATE TABLE tests.cursor_table (id int)", None, &[])?;
250250
client.update(
251251
"INSERT INTO tests.cursor_table (id) \
@@ -264,7 +264,7 @@ mod tests {
264264

265265
#[pg_test]
266266
fn test_cursor_by_name() -> Result<(), pgrx::spi::Error> {
267-
let cursor_name = Spi::connect(|mut client| {
267+
let cursor_name = Spi::connect_mut(|client| {
268268
client.update("CREATE TABLE tests.cursor_table (id int)", None, &[])?;
269269
client.update(
270270
"INSERT INTO tests.cursor_table (id) \
@@ -318,7 +318,7 @@ mod tests {
318318
Ok::<_, spi::Error>(())
319319
})?;
320320

321-
Spi::connect(|mut client| {
321+
Spi::connect_mut(|client| {
322322
let res = client.update("SET TIME ZONE 'PST8PDT'", None, &[])?;
323323

324324
assert_eq!(Err(spi::Error::NoTupleTable), res.columns());
@@ -334,9 +334,8 @@ mod tests {
334334

335335
#[pg_test]
336336
fn test_spi_non_mut() -> Result<(), pgrx::spi::Error> {
337-
// Ensures update and cursor APIs do not need mutable reference to SpiClient
338-
Spi::connect(|mut client| {
339-
client.update("SELECT 1", None, &[]).expect("SPI failed");
337+
// Ensures cursor APIs do not need mutable reference to SpiClient
338+
Spi::connect(|client| {
340339
let cursor = client.open_cursor("SELECT 1", &[]).detach_into_name();
341340
client.find_cursor(&cursor).map(|_| ())
342341
})
@@ -428,7 +427,7 @@ mod tests {
428427

429428
#[pg_test]
430429
fn test_readwrite_in_select_readwrite() -> Result<(), spi::Error> {
431-
Spi::connect(|mut client| {
430+
Spi::connect_mut(|client| {
432431
// This is supposed to switch connection to read-write and run it there
433432
client.update("CREATE TABLE a (id INT)", None, &[])?;
434433
// This is supposed to run in read-write
@@ -459,7 +458,7 @@ mod tests {
459458

460459
#[pg_test]
461460
fn test_spi_select_sees_update() -> spi::Result<()> {
462-
let with_select = Spi::connect(|mut client| {
461+
let with_select = Spi::connect_mut(|client| {
463462
client.update("CREATE TABLE asd(id int)", None, &[])?;
464463
client.update("INSERT INTO asd(id) VALUES (1)", None, &[])?;
465464
client.select("SELECT COUNT(*) FROM asd", None, &[])?.first().get_one::<i64>()
@@ -485,7 +484,7 @@ mod tests {
485484

486485
#[pg_test]
487486
fn test_spi_select_sees_update_in_other_session() -> spi::Result<()> {
488-
Spi::connect::<spi::Result<()>, _>(|mut client| {
487+
Spi::connect_mut::<spi::Result<()>, _>(|client| {
489488
client.update("CREATE TABLE asd(id int)", None, &[])?;
490489
client.update("INSERT INTO asd(id) VALUES (1)", None, &[])?;
491490
Ok(())

pgrx-tests/src/tests/srf_tests.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ mod tests {
243243

244244
#[pg_test]
245245
fn test_srf_setof_datum_detoasting_with_borrow() {
246-
let cnt = Spi::connect(|mut client| {
246+
let cnt = Spi::connect_mut(|client| {
247247
// build up a table with one large column that Postgres will be forced to TOAST
248248
client.update("CREATE TABLE test_srf_datum_detoasting AS SELECT array_to_string(array_agg(g),' ') s FROM (SELECT 'a' g FROM generate_series(1, 1000)) x;", None, &[])?;
249249

@@ -261,7 +261,7 @@ mod tests {
261261

262262
#[pg_test]
263263
fn test_srf_table_datum_detoasting_with_borrow() {
264-
let cnt = Spi::connect(|mut client| {
264+
let cnt = Spi::connect_mut(|client| {
265265
// build up a table with one large column that Postgres will be forced to TOAST
266266
client.update("CREATE TABLE test_srf_datum_detoasting AS SELECT array_to_string(array_agg(g),' ') s FROM (SELECT 'a' g FROM generate_series(1, 1000)) x;", None, &[])?;
267267

pgrx-tests/src/tests/struct_type_tests.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ mod tests {
5757

5858
#[pg_test]
5959
fn test_complex_storage_and_retrieval() -> Result<(), pgrx::spi::Error> {
60-
let complex = Spi::connect(|mut client| {
60+
let complex = Spi::connect_mut(|client| {
6161
client.update(
6262
"CREATE TABLE complex_test AS SELECT s as id, (s || '.0, 2.0' || s)::complex as value FROM generate_series(1, 1000) s;\
6363
SELECT value FROM complex_test ORDER BY id;", None, &[])?.first().get_one::<PgBox<Complex>>()

pgrx-tests/tests/compile-fail/escaping-spiclient-1209-cursor.stderr

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ error: lifetime may not live long enough
44
8 | let mut res = Spi::connect(|c| {
55
| -- return type of closure is SpiTupleTable<'2>
66
| |
7-
| has type `SpiClient<'1>`
7+
| has type `&SpiClient<'1>`
88
9 | / c.open_cursor("select 'hello world' from generate_series(1, 1000)", &[])
99
10 | | .fetch(1000)
1010
11 | | .unwrap()
@@ -31,7 +31,7 @@ error: lifetime may not live long enough
3131
| -- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ returning this value requires that `'1` must outlive `'2`
3232
| ||
3333
| |return type of closure is SpiTupleTable<'2>
34-
| has type `SpiClient<'1>`
34+
| has type `&SpiClient<'1>`
3535

3636
error[E0515]: cannot return value referencing temporary value
3737
--> tests/compile-fail/escaping-spiclient-1209-cursor.rs:16:26

pgrx-tests/tests/compile-fail/escaping-spiclient-1209-prep-stmt.stderr

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ error: lifetime may not live long enough
55
| -- ^^^^^^^^^^^^^^^^^ returning this value requires that `'1` must outlive `'2`
66
| ||
77
| |return type of closure is std::result::Result<pgrx::spi::PreparedStatement<'2>, pgrx::spi::SpiError>
8-
| has type `SpiClient<'1>`
8+
| has type `&SpiClient<'1>`

pgrx/src/spi.rs

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ mod cursor;
2121
mod query;
2222
mod tuple;
2323
pub use client::SpiClient;
24-
use client::SpiConnection;
2524
pub use cursor::SpiCursor;
2625
pub use query::{OwnedPreparedStatement, PreparedStatement, Query};
2726
pub use tuple::{SpiHeapTupleData, SpiHeapTupleDataEntry, SpiTupleTable};
@@ -237,13 +236,13 @@ impl Spi {
237236
}
238237

239238
pub fn get_one<A: FromDatum + IntoDatum>(query: &str) -> Result<Option<A>> {
240-
Spi::connect(|mut client| client.update(query, Some(1), &[])?.first().get_one())
239+
Spi::connect_mut(|client| client.update(query, Some(1), &[])?.first().get_one())
241240
}
242241

243242
pub fn get_two<A: FromDatum + IntoDatum, B: FromDatum + IntoDatum>(
244243
query: &str,
245244
) -> Result<(Option<A>, Option<B>)> {
246-
Spi::connect(|mut client| client.update(query, Some(1), &[])?.first().get_two::<A, B>())
245+
Spi::connect_mut(|client| client.update(query, Some(1), &[])?.first().get_two::<A, B>())
247246
}
248247

249248
pub fn get_three<
@@ -253,7 +252,7 @@ impl Spi {
253252
>(
254253
query: &str,
255254
) -> Result<(Option<A>, Option<B>, Option<C>)> {
256-
Spi::connect(|mut client| {
255+
Spi::connect_mut(|client| {
257256
client.update(query, Some(1), &[])?.first().get_three::<A, B, C>()
258257
})
259258
}
@@ -262,14 +261,14 @@ impl Spi {
262261
query: &str,
263262
args: &[DatumWithOid<'mcx>],
264263
) -> Result<Option<A>> {
265-
Spi::connect(|mut client| client.update(query, Some(1), args)?.first().get_one())
264+
Spi::connect_mut(|client| client.update(query, Some(1), args)?.first().get_one())
266265
}
267266

268267
pub fn get_two_with_args<'mcx, A: FromDatum + IntoDatum, B: FromDatum + IntoDatum>(
269268
query: &str,
270269
args: &[DatumWithOid<'mcx>],
271270
) -> Result<(Option<A>, Option<B>)> {
272-
Spi::connect(|mut client| client.update(query, Some(1), args)?.first().get_two::<A, B>())
271+
Spi::connect_mut(|client| client.update(query, Some(1), args)?.first().get_two::<A, B>())
273272
}
274273

275274
pub fn get_three_with_args<
@@ -281,12 +280,12 @@ impl Spi {
281280
query: &str,
282281
args: &[DatumWithOid<'mcx>],
283282
) -> Result<(Option<A>, Option<B>, Option<C>)> {
284-
Spi::connect(|mut client| {
283+
Spi::connect_mut(|client| {
285284
client.update(query, Some(1), args)?.first().get_three::<A, B, C>()
286285
})
287286
}
288287

289-
/// just run an arbitrary SQL statement.
288+
/// Just run an arbitrary SQL statement.
290289
///
291290
/// ## Safety
292291
///
@@ -304,7 +303,7 @@ impl Spi {
304303
query: &str,
305304
args: &[DatumWithOid<'mcx>],
306305
) -> std::result::Result<(), Error> {
307-
Spi::connect(|mut client| client.update(query, None, args).map(|_| ()))
306+
Spi::connect_mut(|client| client.update(query, None, args).map(|_| ()))
308307
}
309308

310309
/// explain a query, returning its result in json form
@@ -314,7 +313,7 @@ impl Spi {
314313

315314
/// explain a query with args, returning its result in json form
316315
pub fn explain_with_args<'mcx>(query: &str, args: &[DatumWithOid<'mcx>]) -> Result<Json> {
317-
Ok(Spi::connect(|mut client| {
316+
Ok(Spi::connect_mut(|client| {
318317
client
319318
.update(&format!("EXPLAIN (format json) {query}"), None, args)?
320319
.first()
@@ -323,7 +322,7 @@ impl Spi {
323322
.unwrap())
324323
}
325324

326-
/// Execute SPI commands via the provided `SpiClient`.
325+
/// Execute SPI read-only commands via the provided `SpiClient`.
327326
///
328327
/// While inside the provided closure, code executes under a short-lived "SPI Memory Context",
329328
/// and Postgres will completely free that context when this function is finished.
@@ -360,10 +359,51 @@ impl Spi {
360359
/// ([`pg_sys::SPI_connect()`]) **always** returns a successful response.
361360
pub fn connect<R, F>(f: F) -> R
362361
where
363-
F: FnOnce(SpiClient<'_>) -> R, /* TODO: redesign this with 2 lifetimes:
364-
- 'conn ~= CurrentMemoryContext after connection
365-
- 'ret ~= SPI_palloc's context
366-
*/
362+
F: FnOnce(&SpiClient<'_>) -> R,
363+
{
364+
Self::connect_mut(|client| f(client))
365+
}
366+
367+
/// Execute SPI mutating commands via the provided `SpiClient`.
368+
///
369+
/// While inside the provided closure, code executes under a short-lived "SPI Memory Context",
370+
/// and Postgres will completely free that context when this function is finished.
371+
///
372+
/// pgrx' SPI API endeavors to return Datum values from functions like `::get_one()` that are
373+
/// automatically copied into the into the `CurrentMemoryContext` at the time of this
374+
/// function call.
375+
///
376+
/// # Examples
377+
///
378+
/// ```rust,no_run
379+
/// use pgrx::prelude::*;
380+
/// # fn foo() -> spi::Result<()> {
381+
/// Spi::connect_mut(|client| {
382+
/// client.update("INSERT INTO users VALUES ('Bob')", None, &[])?;
383+
/// Ok(())
384+
/// })
385+
/// # }
386+
/// ```
387+
///
388+
/// Note that `SpiClient` is scoped to the connection lifetime and cannot be returned. The
389+
/// following code will not compile:
390+
///
391+
/// ```rust,compile_fail
392+
/// use pgrx::prelude::*;
393+
/// let cant_return_client = Spi::connect(|client| client);
394+
/// ```
395+
///
396+
/// # Panics
397+
///
398+
/// This function will panic if for some reason it's unable to "connect" to Postgres' SPI
399+
/// system. At the time of this writing, that's actually impossible as the underlying function
400+
/// ([`pg_sys::SPI_connect()`]) **always** returns a successful response.
401+
pub fn connect_mut<R, F>(f: F) -> R
402+
where
403+
F: FnOnce(&mut SpiClient<'_>) -> R, /* TODO: redesign this with 2 lifetimes:
404+
- 'conn ~= CurrentMemoryContext after connection
405+
- 'ret ~= SPI_palloc's context
406+
*/
367407
{
368408
// connect to SPI
369409
//
@@ -379,14 +419,13 @@ impl Spi {
379419
// otherwise this function would need to return a `Result<R, spi::Error>` and that's a
380420
// fucking nightmare for users to deal with. There's ample discussion around coming to
381421
// this decision at https://github.com/pgcentralfoundation/pgrx/pull/977
382-
let connection =
383-
SpiConnection::connect().expect("SPI_connect indicated an unexpected failure");
422+
let mut client = SpiClient::connect().expect("SPI_connect indicated an unexpected failure");
384423

385424
// run the provided closure within the memory context that SPI_connect()
386425
// just put us un. We'll disconnect from SPI when the closure is finished.
387426
// If there's a panic or elog(ERROR), we don't care about also disconnecting from
388427
// SPI b/c Postgres will do that for us automatically
389-
f(connection.client())
428+
f(&mut client)
390429
}
391430

392431
#[track_caller]

0 commit comments

Comments
 (0)