diff --git a/src/query.rs b/src/query.rs index 39d0299..198f015 100644 --- a/src/query.rs +++ b/src/query.rs @@ -28,7 +28,7 @@ impl Query { } } - /// Bind `value` to the first `?` in the query. + /// Binds `value` to the next `?` in the query. /// /// The `value`, which must either implement [`Serialize`](serde::Serialize) /// or be an [`Identifier`], will be appropriately escaped. @@ -40,15 +40,17 @@ impl Query { self } - /// Execute the query. + /// Executes the query. pub async fn execute(self) -> Result<()> { self.do_execute(false)?.finish().await } - /// Execute the query, returning a [`RowCursor`] to obtain results. + /// Executes the query, returning a [`RowCursor`] to obtain results. /// /// # Example - /// ```norun + /// + /// ``` + /// # async fn example() -> clickhouse::error::Result<()> { /// #[derive(clickhouse::Row, serde::Deserialize)] /// struct MyRow<'a> { /// no: u32, @@ -62,6 +64,7 @@ impl Query { /// while let Some(MyRow { name, no }) = cursor.next().await? { /// println!("{name}: {no}"); /// } + /// # Ok(()) } /// ``` pub fn fetch(mut self) -> Result> { self.sql.bind_fields::(); @@ -85,6 +88,16 @@ impl Query { } } + /// Executes the query and returns at most one row. + /// + /// Note that `T` must be owned. + pub async fn fetch_optional(self) -> Result> + where + T: Row + for<'b> Deserialize<'b>, + { + self.fetch()?.next().await + } + /// Executes the query and returns all the generated results, collected into a Vec. /// /// Note that `T` must be owned. diff --git a/tests/test_query.rs b/tests/test_query.rs index be9c6b1..04d566f 100644 --- a/tests/test_query.rs +++ b/tests/test_query.rs @@ -1,6 +1,6 @@ use serde::{Deserialize, Serialize}; -use clickhouse::Row; +use clickhouse::{error::Error, Row}; mod common; @@ -54,6 +54,41 @@ async fn smoke() { } } +#[common::named] +#[tokio::test] +async fn fetch_one_and_optional() { + let client = common::prepare_database!(); + + client + .query("CREATE TABLE test(n String) ENGINE = MergeTree ORDER BY n") + .execute() + .await + .unwrap(); + + let q = "SELECT * FROM test"; + let got_string = client.query(q).fetch_optional::().await.unwrap(); + assert_eq!(got_string, None); + + let got_string = client.query(q).fetch_one::().await; + assert!(matches!(got_string, Err(Error::RowNotFound))); + + #[derive(Serialize, Row)] + struct Row { + n: String, + } + + let mut insert = client.insert("test").unwrap(); + insert.write(&Row { n: "foo".into() }).await.unwrap(); + insert.write(&Row { n: "bar".into() }).await.unwrap(); + insert.end().await.unwrap(); + + let got_string = client.query(q).fetch_optional::().await.unwrap(); + assert_eq!(got_string, Some("bar".into())); + + let got_string = client.query(q).fetch_one::().await.unwrap(); + assert_eq!(got_string, "bar"); +} + // See #19. #[common::named] #[tokio::test]