diff --git a/Cargo.lock b/Cargo.lock index 12400cb7182..ce1298cbc3b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3417,7 +3417,7 @@ checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" [[package]] name = "postgres-native-tls" version = "0.5.0" -source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#c62b9928d402685e152161907e8480603c29ef65" +source = "git+https://github.com/tmm1/rust-postgres?branch=execute-typed#6debe5b6ba7f1e3904eb58f87a829b18e1735003" dependencies = [ "native-tls", "tokio", @@ -3428,7 +3428,7 @@ dependencies = [ [[package]] name = "postgres-protocol" version = "0.6.7" -source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#c62b9928d402685e152161907e8480603c29ef65" +source = "git+https://github.com/tmm1/rust-postgres?branch=execute-typed#6debe5b6ba7f1e3904eb58f87a829b18e1735003" dependencies = [ "base64 0.22.1", "byteorder", @@ -3445,7 +3445,7 @@ dependencies = [ [[package]] name = "postgres-types" version = "0.2.8" -source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#c62b9928d402685e152161907e8480603c29ef65" +source = "git+https://github.com/tmm1/rust-postgres?branch=execute-typed#6debe5b6ba7f1e3904eb58f87a829b18e1735003" dependencies = [ "bit-vec", "bytes", @@ -5764,7 +5764,7 @@ dependencies = [ [[package]] name = "tokio-postgres" version = "0.7.12" -source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#c62b9928d402685e152161907e8480603c29ef65" +source = "git+https://github.com/tmm1/rust-postgres?branch=execute-typed#6debe5b6ba7f1e3904eb58f87a829b18e1735003" dependencies = [ "async-trait", "byteorder", diff --git a/quaint/Cargo.toml b/quaint/Cargo.toml index b1880dcad34..f9e9c74b943 100644 --- a/quaint/Cargo.toml +++ b/quaint/Cargo.toml @@ -168,8 +168,8 @@ features = [ "with-serde_json-1", "with-bit-vec-0_6", ] -git = "https://github.com/prisma/rust-postgres" -branch = "pgbouncer-mode" +git = "https://github.com/tmm1/rust-postgres" +branch = "execute-typed" optional = true [dependencies.postgres-types] @@ -179,13 +179,13 @@ features = [ "with-serde_json-1", "with-bit-vec-0_6", ] -git = "https://github.com/prisma/rust-postgres" -branch = "pgbouncer-mode" +git = "https://github.com/tmm1/rust-postgres" +branch = "execute-typed" optional = true [dependencies.postgres-native-tls] -git = "https://github.com/prisma/rust-postgres" -branch = "pgbouncer-mode" +git = "https://github.com/tmm1/rust-postgres" +branch = "execute-typed" optional = true [dependencies.tokio] diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index c294733fc51..c3c088dec75 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -27,7 +27,7 @@ use futures::{future::FutureExt, lock::Mutex}; use lru_cache::LruCache; use native_tls::{Certificate, Identity, TlsConnector}; use postgres_native_tls::MakeTlsConnector; -use postgres_types::{Kind as PostgresKind, Type as PostgresType}; +use postgres_types::{Kind as PostgresKind, Type as PostgresType, ToSql}; use std::hash::{DefaultHasher, Hash, Hasher}; use std::{ fmt::{Debug, Display}, @@ -540,29 +540,37 @@ impl Queryable for PostgreSql { sql, params, move || async move { - let stmt = self.fetch_cached(sql, &[]).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } + let converted_params = conversion::conv_params(params); + let param_types = conversion::params_to_types(params); + let params_with_types: Vec<(&(dyn ToSql + Sync), PostgresType)> = converted_params + .iter() + .zip(param_types) + .map(|(value, ty)| (*value as &(dyn ToSql + Sync), ty)) + .collect(); + // Execute the query using `query_typed` let rows = self - .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) + .perform_io(self.client.0.query_typed(sql, params_with_types.as_slice())) .await?; - let col_types = stmt - .columns() - .iter() - .map(|c| PGColumnType::from_pg_type(c.type_())) - .map(ColumnType::from) - .collect::>(); - let mut result = ResultSet::new(stmt.to_column_names(), col_types, Vec::new()); + // Extract column information from the first row, if available + let (col_types, column_names) = if let Some(row) = rows.first() { + let columns = row.columns(); + let col_types = columns + .iter() + .map(|c| PGColumnType::from_pg_type(c.type_())) + .map(ColumnType::from) + .collect::>(); + let column_names = columns.iter().map(|c| c.name().to_string()).collect(); + + (col_types, column_names) + } else { + (Vec::new(), Vec::new()) + }; + let mut result = ResultSet::new(column_names, col_types, Vec::new()); + + // Process each row in the result set for row in rows { result.rows.push(row.get_result_row()?); } @@ -582,28 +590,35 @@ impl Queryable for PostgreSql { sql, params, move || async move { - let stmt = self.fetch_cached(sql, params).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } - - let col_types = stmt - .columns() + let converted_params = conversion::conv_params(params); + let param_types = conversion::params_to_types(params); + let params_with_types: Vec<(&(dyn ToSql + Sync), PostgresType)> = converted_params .iter() - .map(|c| PGColumnType::from_pg_type(c.type_())) - .map(ColumnType::from) - .collect::>(); + .zip(param_types) + .map(|(value, ty)| (*value as &(dyn ToSql + Sync), ty)) + .collect(); + + // Execute the query using `query_typed` let rows = self - .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) + .perform_io(self.client.0.query_typed(sql, params_with_types.as_slice())) .await?; - let mut result = ResultSet::new(stmt.to_column_names(), col_types, Vec::new()); + // Extract column information from the first row, if available + let (col_types, column_names) = if let Some(row) = rows.first() { + let columns = row.columns(); + let col_types = columns + .iter() + .map(|c| PGColumnType::from_pg_type(c.type_())) + .map(ColumnType::from) + .collect::>(); + let column_names = columns.iter().map(|c| c.name().to_string()).collect(); + + (col_types, column_names) + } else { + (Vec::new(), Vec::new()) + }; + + let mut result = ResultSet::new(column_names, col_types, Vec::new()); for row in rows { result.rows.push(row.get_result_row()?); @@ -705,19 +720,16 @@ impl Queryable for PostgreSql { sql, params, move || async move { - let stmt = self.fetch_cached(sql, &[]).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } + let converted_params = conversion::conv_params(params); + let param_types = conversion::params_to_types(params); + let params_with_types: Vec<(&(dyn ToSql + Sync), PostgresType)> = converted_params + .iter() + .zip(param_types) + .map(|(value, ty)| (*value as &(dyn ToSql + Sync), ty)) + .collect(); let changes = self - .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) + .perform_io(self.client.0.execute_typed(sql, params_with_types.as_slice())) .await?; Ok(changes) @@ -735,19 +747,16 @@ impl Queryable for PostgreSql { sql, params, move || async move { - let stmt = self.fetch_cached(sql, params).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } + let converted_params = conversion::conv_params(params); + let param_types = conversion::params_to_types(params); + let params_with_types: Vec<(&(dyn ToSql + Sync), PostgresType)> = converted_params + .iter() + .zip(param_types) + .map(|(value, ty)| (*value as &(dyn ToSql + Sync), ty)) + .collect(); let changes = self - .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) + .perform_io(self.client.0.execute_typed(sql, params_with_types.as_slice())) .await?; Ok(changes)