diff --git a/README.md b/README.md index 6878c4a..6c2f8da 100644 --- a/README.md +++ b/README.md @@ -194,6 +194,10 @@ assert_eq!( # Testing +You can set the address and port of the test instance using `TOKIO_ZOOKEEPER_TEST_HOST` and `TOKIO_ZOOKEEPER_TEST_PORT` respectively. + +The the default is `127.0.0.1:2181`. + 1. Start a Zookeeper instance, e.g. using `docker run -p 2181:2181 zookeeper` 2. Run `cargo test` diff --git a/src/lib.rs b/src/lib.rs index 69720c4..d82e7cc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -279,7 +279,18 @@ impl ZooKeeperBuilder { let stream = tokio::net::TcpStream::connect(addr) .await .whatever_context("connect failed")?; - Ok((self.handshake(*addr, stream, tx).await?, rx)) + Ok((self.handshake(*addr, stream, Some(tx)).await?, rx)) + } + + /// Connect to a ZooKeeper server instance at the given address, but without returning a + /// watcher stream. + /// + /// See [`ZooKeeperBuilder::connect_without_watcher`]. + pub async fn connect_without_watcher(self, addr: &SocketAddr) -> Result { + let stream = tokio::net::TcpStream::connect(addr) + .await + .whatever_context("connect failed")?; + self.handshake(*addr, stream, None).await } /// Set the ZooKeeper [session expiry @@ -294,7 +305,7 @@ impl ZooKeeperBuilder { self, addr: SocketAddr, stream: tokio::net::TcpStream, - default_watcher: futures::channel::mpsc::UnboundedSender, + default_watcher: Option>, ) -> Result { let request = proto::Request::Connect { protocol_version: 0, @@ -327,6 +338,16 @@ impl ZooKeeper { ZooKeeperBuilder::default().connect(addr).await } + /// Connect to a ZooKeeper server instance at the given address with default parameters, + /// but without returning a watcher stream. + /// + /// See [`ZooKeeperBuilder::connect_without_watcher`]. + pub async fn connect_without_watcher(addr: &SocketAddr) -> Result { + ZooKeeperBuilder::default() + .connect_without_watcher(addr) + .await + } + /// Create a node with the given `path` with `data` as its contents. /// /// The `mode` argument specifies additional options for the newly created node. @@ -754,6 +775,8 @@ mod tests { use super::*; use futures::StreamExt; + use std::env; + use std::net::ToSocketAddrs; use tracing::Level; fn init_tracing_subscriber() { @@ -762,12 +785,30 @@ mod tests { .try_init(); } + // Use environment variables to override default connection otherwise + // default to localhost:127.0.0.1:2181 + fn get_test_zookeeper_addr() -> SocketAddr { + let host = + env::var("TOKIO_ZOOKEEPER_TEST_HOST").unwrap_or_else(|_| "127.0.0.1".to_string()); + + let port: u16 = env::var("TOKIO_ZOOKEEPER_TEST_PORT") + .unwrap_or_else(|_| "2181".to_string()) + .parse() + .expect("TOKIO_ZOOKEEPER_TEST_PORT must be a valid u16"); + + format!("{host}:{port}") + .to_socket_addrs() + .expect("Invalid host:port") + .next() + .expect("Host resolved but returned no addresses") + } + #[tokio::test] async fn it_works() { init_tracing_subscriber(); let builder = ZooKeeperBuilder::default(); - let connect_addr = "127.0.0.1:2181".parse().unwrap(); + let connect_addr = get_test_zookeeper_addr(); let (zk, w) = builder.connect(&connect_addr).await.unwrap(); let (exists_w, stat) = zk.with_watcher().exists("/foo").await.unwrap(); assert_eq!(stat, None); @@ -873,7 +914,7 @@ mod tests { #[tokio::test] async fn example() { - let connect_addr = "127.0.0.1:2181".parse().unwrap(); + let connect_addr = get_test_zookeeper_addr(); let (zk, default_watcher) = ZooKeeper::connect(&connect_addr).await.unwrap(); // let's first check if /example exists. the .watch() causes us to be notified @@ -958,10 +999,9 @@ mod tests { async fn acl_test() { init_tracing_subscriber(); let builder = ZooKeeperBuilder::default(); + let connect_addr = get_test_zookeeper_addr(); - let (zk, _) = (builder.connect(&"127.0.0.1:2181".parse().unwrap())) - .await - .unwrap(); + let (zk, _) = (builder.connect(&connect_addr)).await.unwrap(); let _ = zk .create( "/acl_test", @@ -1021,11 +1061,9 @@ mod tests { } Result::<_, Error>::Ok(res) } + let connect_addr = get_test_zookeeper_addr(); - let (zk, _) = builder - .connect(&"127.0.0.1:2181".parse().unwrap()) - .await - .unwrap(); + let (zk, _) = builder.connect(&connect_addr).await.unwrap(); let res = zk .multi() @@ -1133,4 +1171,28 @@ mod tests { drop(zk); // make Packetizer idle } + + #[tokio::test] + async fn connect_without_watcher_test() { + init_tracing_subscriber(); + let connect_addr = get_test_zookeeper_addr(); + + let zk = ZooKeeper::connect_without_watcher(&connect_addr).await.unwrap(); + + let path = zk + .create( + "/no_watcher_test", + &b"Hello world"[..], + Acl::open_unsafe(), + CreateMode::Persistent, + ) + .await + .unwrap(); + assert_eq!(path.as_deref(), Ok("/no_watcher_test")); + + let res = zk.delete("/no_watcher_test", None).await.unwrap(); + assert_eq!(res, Ok(())); + + drop(zk); // make Packetizer idle + } } diff --git a/src/proto/active_packetizer.rs b/src/proto/active_packetizer.rs index c34978f..be8daed 100644 --- a/src/proto/active_packetizer.rs +++ b/src/proto/active_packetizer.rs @@ -229,7 +229,7 @@ where fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context, - default_watcher: &mut mpsc::UnboundedSender, + default_watcher: &mut Option>, ) -> Poll> where S: AsyncRead, @@ -362,8 +362,11 @@ where .expect("tried to remove watcher that didn't exist"); } - // NOTE: ignoring error, because the user may not care about events - let _ = default_watcher.unbounded_send(e); + // Handle optional watcher stream for connect_without_watcher + if let Some(w) = &default_watcher { + // NOTE: ignoring error, because the user may not care about events + let _ = w.unbounded_send(e); + } } else if xid == -2 { // response to ping -- empty response trace!("got response to heartbeat"); @@ -445,7 +448,7 @@ where mut self: Pin<&mut Self>, cx: &mut Context, exiting: bool, - default_watcher: &mut mpsc::UnboundedSender, + default_watcher: &mut Option>, ) -> Poll> { let r = self.as_mut().poll_read(cx, default_watcher)?; diff --git a/src/proto/packetizer.rs b/src/proto/packetizer.rs index 6c922df..9094707 100644 --- a/src/proto/packetizer.rs +++ b/src/proto/packetizer.rs @@ -34,7 +34,7 @@ where state: PacketizerState, /// Watcher to send watch events to. - default_watcher: mpsc::UnboundedSender, + default_watcher: Option>, /// Incoming requests rx: mpsc::UnboundedReceiver<(Request, oneshot::Sender>)>, @@ -54,7 +54,7 @@ where pub(crate) fn new( addr: S::Addr, stream: S, - default_watcher: mpsc::UnboundedSender, + default_watcher: Option>, ) -> Enqueuer where S: Send + 'static + AsyncRead + AsyncWrite, @@ -98,7 +98,7 @@ where mut self: Pin<&mut Self>, cx: &mut Context, exiting: bool, - default_watcher: &mut mpsc::UnboundedSender, + default_watcher: &mut Option>, ) -> Poll> { let ap = match self.as_mut().project() { PacketizerStateProj::Connected(ref mut ap) => {