Skip to content

Commit

Permalink
Eliminate backpressure
Browse files Browse the repository at this point in the history
It turned out to be not straightforward to implement backpressure of TCP connections.
Reverting TCP backpressure logic to ease debuggability of existing logic.
  • Loading branch information
boris-sinyapkin committed May 29, 2024
1 parent b520d07 commit e32cb13
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 112 deletions.
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async fn main() -> Result<()> {

// Create proxy server instance. It will handle incoming connection in async. fashion.
let server_addr = SocketAddr::new(IpAddr::V4(config.bind_ipv4()), config.bind_port());
let server = LurkServer::new(server_addr, config.tcp_conn_limit());
let server = LurkServer::new(server_addr);

// Bind and serve clients "forever"
server.run().await?;
Expand Down
203 changes: 104 additions & 99 deletions src/net/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,71 +60,90 @@ pub mod listener {

use super::connection::{LurkTcpConnection, LurkTcpConnectionFactory, LurkTcpConnectionLabel};
use anyhow::Result;
use async_listen::{backpressure, backpressure::Backpressure, ListenExt};
use std::net::SocketAddr;
use tokio::net::{TcpListener, ToSocketAddrs};
use tokio_stream::{wrappers::TcpListenerStream, StreamExt};
use socket2::{Domain, Socket, Type};
use std::{
future,
net::{SocketAddr, ToSocketAddrs},
task::{self, ready, Poll},
};
use tokio::net::{TcpListener, TcpStream};

const TCP_LISTEN_BACKLOG: i32 = 1024;

/// Custom implementation of TCP listener.
#[allow(dead_code)]
pub struct LurkTcpListener {
incoming: Backpressure<TcpListenerStream>,
factory: LurkTcpConnectionFactory,
local_addr: SocketAddr,
inner: TcpListener,
conn_factory: LurkTcpConnectionFactory,
}

impl LurkTcpListener {
/// Binds TCP listener to passed `addr`.
///
/// Argument `conn_limit` sets the limit of open TCP connections. Thus accepting of new connections
/// on returned `LurkTcpListener` will be paused, when number of open TCP connections will reach
/// the `conn_limit`.
pub async fn bind(addr: impl ToSocketAddrs, conn_limit: usize) -> Result<LurkTcpListener> {
// Bind TCP listener.
let listener = TcpListener::bind(addr).await?;
let local_addr = listener.local_addr()?;

// Create backpressure limit and supply the receiver to the created stream.
let (bp_tx, bp_rx) = backpressure::new(conn_limit);
let incoming = TcpListenerStream::new(listener).apply_backpressure(bp_rx);

Ok(LurkTcpListener {
incoming,
factory: LurkTcpConnectionFactory::new(bp_tx),
local_addr,
})
pub async fn bind(addr: impl ToSocketAddrs) -> Result<LurkTcpListener> {
let bind_addr = LurkTcpListener::resolve_bind_addr(addr);

// Create TCP socket
let socket = Socket::new(Domain::for_address(bind_addr), Type::STREAM, None)?;

// Bind TCP socket and mark it ready to accept incoming connections
socket.bind(&bind_addr.into())?;
socket.listen(TCP_LISTEN_BACKLOG)?;

// Set TCP options
socket.set_nonblocking(true)?;

// Create tokio TCP listener from TCP socket
let inner: TcpListener = TcpListener::from_std(socket.into())?;

// Create TCP connections factory
let conn_factory = LurkTcpConnectionFactory::new();

Ok(LurkTcpListener { inner, conn_factory })
}

/// Accept incoming TCP connection.
pub async fn accept(&mut self) -> Result<LurkTcpConnection> {
let err_msg: &str = "Incoming TCP listener should never return empty option";
let tcp_stream = self.incoming.next().await.expect(err_msg)?;
let (tcp_stream, _) = future::poll_fn(|cx| self.poll_inner_accept(cx)).await?;
let tcp_label = LurkTcpConnectionLabel::from_tcp_stream(&tcp_stream).await?;

self.factory.create_connection(tcp_stream, tcp_label)
self.conn_factory.create_connection(tcp_stream, tcp_label)
}

/// Returns local address that this listener is binded to.
#[allow(dead_code)]
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
self.inner.local_addr().expect("Expect inbound TCP address")
}

/// Polls inner TCP listener to accept new connection
fn poll_inner_accept(&self, cx: &mut task::Context<'_>) -> Poll<Result<(TcpStream, SocketAddr)>> {
let (tcp_stream, peer_addr) = ready!(self.inner.poll_accept(cx))?;

Poll::Ready(Ok((tcp_stream, peer_addr)))
}

fn resolve_bind_addr(addr: impl ToSocketAddrs) -> SocketAddr {
let mut bind_addr = addr.to_socket_addrs().unwrap();

// Return first resolved socket address
bind_addr.next().expect("Expect benign address to bind")
}
}

#[cfg(test)]
mod tests {

use super::*;
use futures::{stream::FuturesUnordered, TryFutureExt};
use std::time::Duration;
use tokio::{
io::AsyncWriteExt,
net::TcpStream,
time::{sleep, timeout},
};
// use super::*;
// use futures::{stream::FuturesUnordered, StreamExt, TryFutureExt};
// use std::time::Duration;
// use tokio::{
// io::AsyncWriteExt,
// net::TcpStream,
// time::{sleep, timeout},
// };

// :0 tells the OS to pick an open port.
const TEST_BIND_IPV4: &str = "127.0.0.1:0";
// // :0 tells the OS to pick an open port.
// const TEST_BIND_IPV4: &str = "127.0.0.1:0";

/// This tests backpressure limit set on listener.
/// Number of connections intentionally exceeds the limit. Thus listener
Expand All @@ -133,46 +152,46 @@ pub mod listener {
#[tokio::test]

async fn limit_tcp_connections() {
let conn_limit = 5;
let num_clients = 20;

let mut listener = LurkTcpListener::bind(TEST_BIND_IPV4, 5).await.expect("Expect binded listener");
let listener_addr = listener.local_addr();

let client_tasks: FuturesUnordered<_> = (0..num_clients)
.map(|_| async move {
TcpStream::connect(listener_addr)
.and_then(|mut s| async move { s.write_all(&[0x05]).await })
.await
.unwrap()
})
.collect();

// Await all clients to complete.
client_tasks.collect::<()>().await;

// We have to handle all clients, but only `conn_limit`
// could be handled in parallel.
for _ in 0..num_clients {
let conn = timeout(Duration::from_secs(2), listener.accept())
.await
.expect("Expect acceptied connection before expired timeout")
.expect("Expect accepted TCP connection");

assert_eq!(LurkTcpConnectionLabel::SOCKS5, conn.label());
assert!(
listener.factory.get_active_tokens() <= conn_limit,
"Number of opened connections must not exceed the limit"
);

tokio::spawn(async move {
// Some client handling ...
sleep(Duration::from_millis(300)).await;
// Drop the connection after sleep, hence one
// slot should become available for the next client
drop(conn)
});
}
// let conn_limit = 5;
// let num_clients = 20;

// let mut listener = LurkTcpListener::bind(TEST_BIND_IPV4, 5).await.expect("Expect binded listener");
// let listener_addr = listener.local_addr();

// let client_tasks: FuturesUnordered<_> = (0..num_clients)
// .map(|_| async move {
// TcpStream::connect(listener_addr)
// .and_then(|mut s| async move { s.write_all(&[0x05]).await })
// .await
// .unwrap()
// })
// .collect();

// // Await all clients to complete.
// client_tasks.collect::<()>().await;

// // We have to handle all clients, but only `conn_limit`
// // could be handled in parallel.
// for _ in 0..num_clients {
// let conn = timeout(Duration::from_secs(2), listener.accept())
// .await
// .expect("Expect acceptied connection before expired timeout")
// .expect("Expect accepted TCP connection");

// assert_eq!(LurkTcpConnectionLabel::SOCKS5, conn.label());
// assert!(
// listener.factory.get_active_tokens() <= conn_limit,
// "Number of opened connections must not exceed the limit"
// );

// tokio::spawn(async move {
// // Some client handling ...
// sleep(Duration::from_millis(300)).await;
// // Drop the connection after sleep, hence one
// // slot should become available for the next client
// drop(conn)
// });
// }
}
}
}
Expand All @@ -184,7 +203,6 @@ pub mod connection {
io::stream::{LurkStream, LurkTcpStream},
};
use anyhow::{bail, Result};
use async_listen::backpressure::{Sender, Token};
use std::{fmt::Display, io, net::SocketAddr};
use tokio::net::TcpStream;

Expand Down Expand Up @@ -229,39 +247,27 @@ pub mod connection {
}

/// Factory that produces new TCP connection instances.
///
/// For each new instance, factory uses backpressure 'sender' to create the token that
/// should be destroyed on TCP connection drop.
///
pub struct LurkTcpConnectionFactory {
/// Backpressure sender instance.
/// This will produce tokens for created TCP connections.
bp_tx: Sender,
}
pub struct LurkTcpConnectionFactory {}

impl LurkTcpConnectionFactory {
pub fn new(bp_tx: Sender) -> LurkTcpConnectionFactory {
LurkTcpConnectionFactory { bp_tx }
pub fn new() -> LurkTcpConnectionFactory {
LurkTcpConnectionFactory {}
}

/// Returns the number of currently active tokens.
#[allow(dead_code)]
pub fn get_active_tokens(&self) -> usize {
self.bp_tx.get_active_tokens()
0
}

pub fn create_connection(&self, tcp_stream: TcpStream, label: LurkTcpConnectionLabel) -> Result<LurkTcpConnection> {
// Wrap raw TcpStream to the stream wrapper and generate new backpressure token
// that must be dropped on connection destruction.
LurkTcpConnection::new(tcp_stream, label, self.bp_tx.token())
LurkTcpConnection::new(tcp_stream, label)
}
}

pub struct LurkTcpConnection {
/// Lurk wrapper of TcpStream
stream: LurkTcpStream,
/// Backpressure token
_token: Token,
/// Label describing traffic in this TCP connection
label: LurkTcpConnectionLabel,
/// Remote address that this connection is connected to
Expand All @@ -271,12 +277,11 @@ pub mod connection {
}

impl LurkTcpConnection {
fn new(tcp_stream: TcpStream, label: LurkTcpConnectionLabel, token: Token) -> Result<LurkTcpConnection> {
fn new(tcp_stream: TcpStream, label: LurkTcpConnectionLabel) -> Result<LurkTcpConnection> {
Ok(LurkTcpConnection {
peer_addr: tcp_stream.peer_addr()?,
local_addr: tcp_stream.local_addr()?,
stream: LurkStream::new(tcp_stream),
_token: token,
label,
})
}
Expand Down
5 changes: 2 additions & 3 deletions src/server/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,10 @@ mod tests {

// :0 tells the OS to pick an open port.
const TEST_BIND_IPV4: &str = "127.0.0.1:0";
const TEST_CONN_LIMIT: usize = 1024;

#[tokio::test]
async fn socks5_handshake_with_auth_method() {
let mut listener = LurkTcpListener::bind(TEST_BIND_IPV4, TEST_CONN_LIMIT)
let mut listener = LurkTcpListener::bind(TEST_BIND_IPV4)
.await
.expect("Expect binded listener");

Expand Down Expand Up @@ -206,7 +205,7 @@ mod tests {

#[tokio::test]
async fn socks5_handshake_with_non_accepatable_method() {
let mut listener = LurkTcpListener::bind(TEST_BIND_IPV4, TEST_CONN_LIMIT)
let mut listener = LurkTcpListener::bind(TEST_BIND_IPV4)
.await
.expect("Expect binded listener");

Expand Down
9 changes: 4 additions & 5 deletions src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,20 @@ mod handlers;

pub struct LurkServer {
bind_addr: SocketAddr,
conn_limit: usize,
}

impl LurkServer {
/// Delay after non-transient TCP acception failure, e.g.
/// handle resource exhaustion errors.
const DELAY_AFTER_ERROR_MILLIS: u64 = 500;

pub fn new(bind_addr: SocketAddr, conn_limit: usize) -> LurkServer {
LurkServer { bind_addr, conn_limit }
pub fn new(bind_addr: SocketAddr) -> LurkServer {
LurkServer { bind_addr }
}

pub async fn run(&self) -> Result<()> {
let mut tcp_listener = LurkTcpListener::bind(self.bind_addr, self.conn_limit).await?;
info!("Listening on {} (TCP connections limit {})", self.bind_addr, self.conn_limit);
let mut tcp_listener = LurkTcpListener::bind(self.bind_addr).await?;
info!("Listening on {}", self.bind_addr);

loop {
match tcp_listener.accept().await {
Expand Down
4 changes: 2 additions & 2 deletions tests/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ pub fn init_logging() {
}

/// Spawn Lurk proxy instance.
pub async fn spawn_lurk_server(addr: SocketAddr, tcp_conn_limit: usize) -> tokio::task::JoinHandle<()> {
pub async fn spawn_lurk_server(addr: SocketAddr) -> tokio::task::JoinHandle<()> {
// Run proxy
let handle = tokio::spawn(async move {
LurkServer::new(addr, tcp_conn_limit)
LurkServer::new(addr)
.run()
.await
.expect("Error during proxy server run")
Expand Down
4 changes: 2 additions & 2 deletions tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ async fn http_server_single_client() {
let http_server_addr = "127.0.0.1:32002".parse::<SocketAddr>().unwrap();

// Run proxy
let lurk_handle = common::spawn_lurk_server(lurk_server_addr, 1024).await;
let lurk_handle = common::spawn_lurk_server(lurk_server_addr).await;

// Run HTTP server in the background
let http_server = ServerBuilder::new()
Expand Down Expand Up @@ -55,7 +55,7 @@ async fn echo_server_multiple_clients() {
let echo_server_addr = "127.0.0.1:32004".parse::<SocketAddr>().unwrap();

// Run Lurk proxy.
let lurk_handle = common::spawn_lurk_server(lurk_server_addr, 1024).await;
let lurk_handle = common::spawn_lurk_server(lurk_server_addr).await;

// Run echo server. Data sent to this server will be proxied through Lurk
// instance spawned above.
Expand Down

0 comments on commit e32cb13

Please sign in to comment.