From e32cb13db4395656bd526b9fc381d180177b90c5 Mon Sep 17 00:00:00 2001 From: Boris S Date: Tue, 28 May 2024 16:29:22 +0300 Subject: [PATCH] Eliminate backpressure It turned out to be not straightforward to implement backpressure of TCP connections. Reverting TCP backpressure logic to ease debuggability of existing logic. --- src/main.rs | 2 +- src/net/tcp.rs | 203 +++++++++++++++++++++-------------------- src/server/handlers.rs | 5 +- src/server/mod.rs | 9 +- tests/common.rs | 4 +- tests/integration.rs | 4 +- 6 files changed, 115 insertions(+), 112 deletions(-) diff --git a/src/main.rs b/src/main.rs index 9ebc31f..7d3725e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -17,7 +17,7 @@ async fn main() -> Result<()> { // Create proxy server instance. It will handle incoming connection in async. fashion. let server_addr = SocketAddr::new(IpAddr::V4(config.bind_ipv4()), config.bind_port()); - let server = LurkServer::new(server_addr, config.tcp_conn_limit()); + let server = LurkServer::new(server_addr); // Bind and serve clients "forever" server.run().await?; diff --git a/src/net/tcp.rs b/src/net/tcp.rs index 9f346a1..41d329f 100644 --- a/src/net/tcp.rs +++ b/src/net/tcp.rs @@ -60,71 +60,90 @@ pub mod listener { use super::connection::{LurkTcpConnection, LurkTcpConnectionFactory, LurkTcpConnectionLabel}; use anyhow::Result; - use async_listen::{backpressure, backpressure::Backpressure, ListenExt}; - use std::net::SocketAddr; - use tokio::net::{TcpListener, ToSocketAddrs}; - use tokio_stream::{wrappers::TcpListenerStream, StreamExt}; + use socket2::{Domain, Socket, Type}; + use std::{ + future, + net::{SocketAddr, ToSocketAddrs}, + task::{self, ready, Poll}, + }; + use tokio::net::{TcpListener, TcpStream}; + + const TCP_LISTEN_BACKLOG: i32 = 1024; /// Custom implementation of TCP listener. #[allow(dead_code)] pub struct LurkTcpListener { - incoming: Backpressure, - factory: LurkTcpConnectionFactory, - local_addr: SocketAddr, + inner: TcpListener, + conn_factory: LurkTcpConnectionFactory, } impl LurkTcpListener { /// Binds TCP listener to passed `addr`. - /// - /// Argument `conn_limit` sets the limit of open TCP connections. Thus accepting of new connections - /// on returned `LurkTcpListener` will be paused, when number of open TCP connections will reach - /// the `conn_limit`. - pub async fn bind(addr: impl ToSocketAddrs, conn_limit: usize) -> Result { - // Bind TCP listener. - let listener = TcpListener::bind(addr).await?; - let local_addr = listener.local_addr()?; - - // Create backpressure limit and supply the receiver to the created stream. - let (bp_tx, bp_rx) = backpressure::new(conn_limit); - let incoming = TcpListenerStream::new(listener).apply_backpressure(bp_rx); - - Ok(LurkTcpListener { - incoming, - factory: LurkTcpConnectionFactory::new(bp_tx), - local_addr, - }) + pub async fn bind(addr: impl ToSocketAddrs) -> Result { + let bind_addr = LurkTcpListener::resolve_bind_addr(addr); + + // Create TCP socket + let socket = Socket::new(Domain::for_address(bind_addr), Type::STREAM, None)?; + + // Bind TCP socket and mark it ready to accept incoming connections + socket.bind(&bind_addr.into())?; + socket.listen(TCP_LISTEN_BACKLOG)?; + + // Set TCP options + socket.set_nonblocking(true)?; + + // Create tokio TCP listener from TCP socket + let inner: TcpListener = TcpListener::from_std(socket.into())?; + + // Create TCP connections factory + let conn_factory = LurkTcpConnectionFactory::new(); + + Ok(LurkTcpListener { inner, conn_factory }) } /// Accept incoming TCP connection. pub async fn accept(&mut self) -> Result { - let err_msg: &str = "Incoming TCP listener should never return empty option"; - let tcp_stream = self.incoming.next().await.expect(err_msg)?; + let (tcp_stream, _) = future::poll_fn(|cx| self.poll_inner_accept(cx)).await?; let tcp_label = LurkTcpConnectionLabel::from_tcp_stream(&tcp_stream).await?; - self.factory.create_connection(tcp_stream, tcp_label) + self.conn_factory.create_connection(tcp_stream, tcp_label) } /// Returns local address that this listener is binded to. #[allow(dead_code)] pub fn local_addr(&self) -> SocketAddr { - self.local_addr + self.inner.local_addr().expect("Expect inbound TCP address") + } + + /// Polls inner TCP listener to accept new connection + fn poll_inner_accept(&self, cx: &mut task::Context<'_>) -> Poll> { + let (tcp_stream, peer_addr) = ready!(self.inner.poll_accept(cx))?; + + Poll::Ready(Ok((tcp_stream, peer_addr))) + } + + fn resolve_bind_addr(addr: impl ToSocketAddrs) -> SocketAddr { + let mut bind_addr = addr.to_socket_addrs().unwrap(); + + // Return first resolved socket address + bind_addr.next().expect("Expect benign address to bind") } } #[cfg(test)] mod tests { - use super::*; - use futures::{stream::FuturesUnordered, TryFutureExt}; - use std::time::Duration; - use tokio::{ - io::AsyncWriteExt, - net::TcpStream, - time::{sleep, timeout}, - }; + // use super::*; + // use futures::{stream::FuturesUnordered, StreamExt, TryFutureExt}; + // use std::time::Duration; + // use tokio::{ + // io::AsyncWriteExt, + // net::TcpStream, + // time::{sleep, timeout}, + // }; - // :0 tells the OS to pick an open port. - const TEST_BIND_IPV4: &str = "127.0.0.1:0"; + // // :0 tells the OS to pick an open port. + // const TEST_BIND_IPV4: &str = "127.0.0.1:0"; /// This tests backpressure limit set on listener. /// Number of connections intentionally exceeds the limit. Thus listener @@ -133,46 +152,46 @@ pub mod listener { #[tokio::test] async fn limit_tcp_connections() { - let conn_limit = 5; - let num_clients = 20; - - let mut listener = LurkTcpListener::bind(TEST_BIND_IPV4, 5).await.expect("Expect binded listener"); - let listener_addr = listener.local_addr(); - - let client_tasks: FuturesUnordered<_> = (0..num_clients) - .map(|_| async move { - TcpStream::connect(listener_addr) - .and_then(|mut s| async move { s.write_all(&[0x05]).await }) - .await - .unwrap() - }) - .collect(); - - // Await all clients to complete. - client_tasks.collect::<()>().await; - - // We have to handle all clients, but only `conn_limit` - // could be handled in parallel. - for _ in 0..num_clients { - let conn = timeout(Duration::from_secs(2), listener.accept()) - .await - .expect("Expect acceptied connection before expired timeout") - .expect("Expect accepted TCP connection"); - - assert_eq!(LurkTcpConnectionLabel::SOCKS5, conn.label()); - assert!( - listener.factory.get_active_tokens() <= conn_limit, - "Number of opened connections must not exceed the limit" - ); - - tokio::spawn(async move { - // Some client handling ... - sleep(Duration::from_millis(300)).await; - // Drop the connection after sleep, hence one - // slot should become available for the next client - drop(conn) - }); - } + // let conn_limit = 5; + // let num_clients = 20; + + // let mut listener = LurkTcpListener::bind(TEST_BIND_IPV4, 5).await.expect("Expect binded listener"); + // let listener_addr = listener.local_addr(); + + // let client_tasks: FuturesUnordered<_> = (0..num_clients) + // .map(|_| async move { + // TcpStream::connect(listener_addr) + // .and_then(|mut s| async move { s.write_all(&[0x05]).await }) + // .await + // .unwrap() + // }) + // .collect(); + + // // Await all clients to complete. + // client_tasks.collect::<()>().await; + + // // We have to handle all clients, but only `conn_limit` + // // could be handled in parallel. + // for _ in 0..num_clients { + // let conn = timeout(Duration::from_secs(2), listener.accept()) + // .await + // .expect("Expect acceptied connection before expired timeout") + // .expect("Expect accepted TCP connection"); + + // assert_eq!(LurkTcpConnectionLabel::SOCKS5, conn.label()); + // assert!( + // listener.factory.get_active_tokens() <= conn_limit, + // "Number of opened connections must not exceed the limit" + // ); + + // tokio::spawn(async move { + // // Some client handling ... + // sleep(Duration::from_millis(300)).await; + // // Drop the connection after sleep, hence one + // // slot should become available for the next client + // drop(conn) + // }); + // } } } } @@ -184,7 +203,6 @@ pub mod connection { io::stream::{LurkStream, LurkTcpStream}, }; use anyhow::{bail, Result}; - use async_listen::backpressure::{Sender, Token}; use std::{fmt::Display, io, net::SocketAddr}; use tokio::net::TcpStream; @@ -229,39 +247,27 @@ pub mod connection { } /// Factory that produces new TCP connection instances. - /// - /// For each new instance, factory uses backpressure 'sender' to create the token that - /// should be destroyed on TCP connection drop. - /// - pub struct LurkTcpConnectionFactory { - /// Backpressure sender instance. - /// This will produce tokens for created TCP connections. - bp_tx: Sender, - } + pub struct LurkTcpConnectionFactory {} impl LurkTcpConnectionFactory { - pub fn new(bp_tx: Sender) -> LurkTcpConnectionFactory { - LurkTcpConnectionFactory { bp_tx } + pub fn new() -> LurkTcpConnectionFactory { + LurkTcpConnectionFactory {} } /// Returns the number of currently active tokens. #[allow(dead_code)] pub fn get_active_tokens(&self) -> usize { - self.bp_tx.get_active_tokens() + 0 } pub fn create_connection(&self, tcp_stream: TcpStream, label: LurkTcpConnectionLabel) -> Result { - // Wrap raw TcpStream to the stream wrapper and generate new backpressure token - // that must be dropped on connection destruction. - LurkTcpConnection::new(tcp_stream, label, self.bp_tx.token()) + LurkTcpConnection::new(tcp_stream, label) } } pub struct LurkTcpConnection { /// Lurk wrapper of TcpStream stream: LurkTcpStream, - /// Backpressure token - _token: Token, /// Label describing traffic in this TCP connection label: LurkTcpConnectionLabel, /// Remote address that this connection is connected to @@ -271,12 +277,11 @@ pub mod connection { } impl LurkTcpConnection { - fn new(tcp_stream: TcpStream, label: LurkTcpConnectionLabel, token: Token) -> Result { + fn new(tcp_stream: TcpStream, label: LurkTcpConnectionLabel) -> Result { Ok(LurkTcpConnection { peer_addr: tcp_stream.peer_addr()?, local_addr: tcp_stream.local_addr()?, stream: LurkStream::new(tcp_stream), - _token: token, label, }) } diff --git a/src/server/handlers.rs b/src/server/handlers.rs index cf88073..273246e 100644 --- a/src/server/handlers.rs +++ b/src/server/handlers.rs @@ -161,11 +161,10 @@ mod tests { // :0 tells the OS to pick an open port. const TEST_BIND_IPV4: &str = "127.0.0.1:0"; - const TEST_CONN_LIMIT: usize = 1024; #[tokio::test] async fn socks5_handshake_with_auth_method() { - let mut listener = LurkTcpListener::bind(TEST_BIND_IPV4, TEST_CONN_LIMIT) + let mut listener = LurkTcpListener::bind(TEST_BIND_IPV4) .await .expect("Expect binded listener"); @@ -206,7 +205,7 @@ mod tests { #[tokio::test] async fn socks5_handshake_with_non_accepatable_method() { - let mut listener = LurkTcpListener::bind(TEST_BIND_IPV4, TEST_CONN_LIMIT) + let mut listener = LurkTcpListener::bind(TEST_BIND_IPV4) .await .expect("Expect binded listener"); diff --git a/src/server/mod.rs b/src/server/mod.rs index 591de1d..40e1f85 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -16,7 +16,6 @@ mod handlers; pub struct LurkServer { bind_addr: SocketAddr, - conn_limit: usize, } impl LurkServer { @@ -24,13 +23,13 @@ impl LurkServer { /// handle resource exhaustion errors. const DELAY_AFTER_ERROR_MILLIS: u64 = 500; - pub fn new(bind_addr: SocketAddr, conn_limit: usize) -> LurkServer { - LurkServer { bind_addr, conn_limit } + pub fn new(bind_addr: SocketAddr) -> LurkServer { + LurkServer { bind_addr } } pub async fn run(&self) -> Result<()> { - let mut tcp_listener = LurkTcpListener::bind(self.bind_addr, self.conn_limit).await?; - info!("Listening on {} (TCP connections limit {})", self.bind_addr, self.conn_limit); + let mut tcp_listener = LurkTcpListener::bind(self.bind_addr).await?; + info!("Listening on {}", self.bind_addr); loop { match tcp_listener.accept().await { diff --git a/tests/common.rs b/tests/common.rs index a87220d..18319d0 100644 --- a/tests/common.rs +++ b/tests/common.rs @@ -14,10 +14,10 @@ pub fn init_logging() { } /// Spawn Lurk proxy instance. -pub async fn spawn_lurk_server(addr: SocketAddr, tcp_conn_limit: usize) -> tokio::task::JoinHandle<()> { +pub async fn spawn_lurk_server(addr: SocketAddr) -> tokio::task::JoinHandle<()> { // Run proxy let handle = tokio::spawn(async move { - LurkServer::new(addr, tcp_conn_limit) + LurkServer::new(addr) .run() .await .expect("Error during proxy server run") diff --git a/tests/integration.rs b/tests/integration.rs index a152ca5..6bbedad 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -15,7 +15,7 @@ async fn http_server_single_client() { let http_server_addr = "127.0.0.1:32002".parse::().unwrap(); // Run proxy - let lurk_handle = common::spawn_lurk_server(lurk_server_addr, 1024).await; + let lurk_handle = common::spawn_lurk_server(lurk_server_addr).await; // Run HTTP server in the background let http_server = ServerBuilder::new() @@ -55,7 +55,7 @@ async fn echo_server_multiple_clients() { let echo_server_addr = "127.0.0.1:32004".parse::().unwrap(); // Run Lurk proxy. - let lurk_handle = common::spawn_lurk_server(lurk_server_addr, 1024).await; + let lurk_handle = common::spawn_lurk_server(lurk_server_addr).await; // Run echo server. Data sent to this server will be proxied through Lurk // instance spawned above.