Skip to content

Commit

Permalink
dev: use os-selected port for utracker tests
Browse files Browse the repository at this point in the history
  • Loading branch information
da2ce7 committed Aug 22, 2024
1 parent cb2e84d commit a1b157e
Show file tree
Hide file tree
Showing 11 changed files with 127 additions and 82 deletions.
9 changes: 7 additions & 2 deletions packages/utracker/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ pub struct TrackerClient {
}

impl TrackerClient {
/// Create a new `TrackerClient` with the given message capacity.
/// Run a new `TrackerClient` with the given message capacity.
///
/// Panics if capacity == `usize::max_value`().
///
Expand All @@ -140,7 +140,7 @@ impl TrackerClient {
///
/// It would panic if the desired capacity is too large.
#[instrument(skip())]
pub fn new<H>(bind: SocketAddr, handshaker: H, capacity_or_default: Option<usize>) -> std::io::Result<TrackerClient>
pub fn run<H>(bind: SocketAddr, handshaker: H, capacity_or_default: Option<usize>) -> std::io::Result<TrackerClient>
where
H: Sink<std::io::Result<HandshakerMessage>> + std::fmt::Debug + DiscoveryInfo + Send + Unpin + 'static,
H::Error: std::fmt::Display,
Expand Down Expand Up @@ -206,6 +206,11 @@ impl TrackerClient {
None
}
}

#[must_use]
pub fn local_addr(&self) -> SocketAddr {
self.bound_socket
}
}

impl Drop for TrackerClient {
Expand Down
5 changes: 5 additions & 0 deletions packages/utracker/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ impl TrackerServer {
shutdown_handle,
})
}

#[must_use]
pub fn local_addr(&self) -> SocketAddr {
self.bound_socket
}
}

impl Drop for TrackerServer {
Expand Down
5 changes: 4 additions & 1 deletion packages/utracker/tests/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::collections::{HashMap, HashSet};
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::sync::{Arc, Mutex, Once};
use std::time::Duration;

Expand All @@ -20,6 +20,9 @@ use utracker::{HandshakerMessage, ServerHandler, ServerResult};
#[allow(dead_code)]
pub const DEFAULT_TIMEOUT: Duration = Duration::from_millis(1000);

#[allow(dead_code)]
pub const LOOPBACK_IPV4: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0));

const NUM_PEERS_RETURNED: usize = 20;

#[allow(dead_code)]
Expand Down
11 changes: 5 additions & 6 deletions packages/utracker/tests/test_announce_start.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::net::SocketAddr;
use std::time::Duration;

use common::{handshaker, tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT};
use common::{handshaker, tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT, LOOPBACK_IPV4};
use futures::StreamExt as _;
use handshake::Protocol;
use tracing::level_filters::LevelFilter;
Expand All @@ -14,25 +14,24 @@ mod common;
#[tokio::test]
async fn positive_announce_started() {
INIT.call_once(|| {
tracing_stderr_init(LevelFilter::INFO);
tracing_stderr_init(LevelFilter::ERROR);
});

let (handshaker_sender, mut handshaker_receiver) = handshaker();

let server_addr = "127.0.0.1:3501".parse().unwrap();
let mock_handler = MockTrackerHandler::new();
let _server = TrackerServer::run(server_addr, mock_handler).unwrap();
let server = TrackerServer::run(LOOPBACK_IPV4, mock_handler).unwrap();

std::thread::sleep(Duration::from_millis(100));

let mut client = TrackerClient::new("127.0.0.1:4501".parse().unwrap(), handshaker_sender, None).unwrap();
let mut client = TrackerClient::run(LOOPBACK_IPV4, handshaker_sender, None).unwrap();

let hash = [0u8; bt::INFO_HASH_LEN].into();

tracing::debug!("sending announce");
let _send_token = client
.request(
server_addr,
server.local_addr(),
ClientRequest::Announce(hash, ClientState::new(0, 0, 0, AnnounceEvent::Started)),
)
.unwrap();
Expand Down
11 changes: 5 additions & 6 deletions packages/utracker/tests/test_announce_stop.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::time::Duration;

use common::{handshaker, tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT};
use common::{handshaker, tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT, LOOPBACK_IPV4};
use futures::StreamExt as _;
use tracing::level_filters::LevelFilter;
use util::bt::{self};
Expand All @@ -17,21 +17,20 @@ async fn positive_announce_stopped() {

let (sink, mut stream) = handshaker();

let server_addr = "127.0.0.1:3502".parse().unwrap();
let mock_handler = MockTrackerHandler::new();
let _server = TrackerServer::run(server_addr, mock_handler).unwrap();
let server = TrackerServer::run(LOOPBACK_IPV4, mock_handler).unwrap();

std::thread::sleep(Duration::from_millis(100));

let mut client = TrackerClient::new("127.0.0.1:4502".parse().unwrap(), sink, None).unwrap();
let mut client = TrackerClient::run(LOOPBACK_IPV4, sink, None).unwrap();

let info_hash = [0u8; bt::INFO_HASH_LEN].into();

// Started
{
let _send_token = client
.request(
server_addr,
server.local_addr(),
ClientRequest::Announce(info_hash, ClientState::new(0, 0, 0, AnnounceEvent::Started)),
)
.unwrap();
Expand Down Expand Up @@ -66,7 +65,7 @@ async fn positive_announce_stopped() {
{
let _send_token = client
.request(
server_addr,
server.local_addr(),
ClientRequest::Announce(info_hash, ClientState::new(0, 0, 0, AnnounceEvent::Stopped)),
)
.unwrap();
Expand Down
56 changes: 35 additions & 21 deletions packages/utracker/tests/test_client_drop.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::net::SocketAddr;

use common::{handshaker, tracing_stderr_init, DEFAULT_TIMEOUT, INIT};
use common::{handshaker, tracing_stderr_init, DEFAULT_TIMEOUT, INIT, LOOPBACK_IPV4};
use futures::StreamExt as _;
use tokio::net::UdpSocket;
use tracing::level_filters::LevelFilter;
use util::bt::{self};
use utracker::announce::{AnnounceEvent, ClientState};
Expand All @@ -12,43 +13,56 @@ mod common;
#[tokio::test]
async fn positive_client_request_failed() {
INIT.call_once(|| {
tracing_stderr_init(LevelFilter::INFO);
tracing_stderr_init(LevelFilter::ERROR);
});

let (sink, mut stream) = handshaker();
let (sink, stream) = handshaker();

let server_addr: SocketAddr = "127.0.0.1:3503".parse().unwrap();
// Don't actually create the server since we want the request to wait for a little bit until we drop
// We bind a temp socket, then drop it...
let disconnected_addr: SocketAddr = {
let socket = UdpSocket::bind(LOOPBACK_IPV4).await.unwrap();
socket.local_addr().unwrap()
};

tokio::task::yield_now().await;

let send_token = {
let mut client = TrackerClient::new("127.0.0.1:4503".parse().unwrap(), sink, None).unwrap();
let mut client = TrackerClient::run(LOOPBACK_IPV4, sink, None).unwrap();

client
let token = client
.request(
server_addr,
disconnected_addr,
ClientRequest::Announce(
[0u8; bt::INFO_HASH_LEN].into(),
ClientState::new(0, 0, 0, AnnounceEvent::None),
),
)
.unwrap()
.unwrap();

// yield to allow for the request to be sent before the client is shutdown
for _ in 0..100 {
tokio::task::yield_now().await;
}

token
};
// Client is now dropped

let metadata = match tokio::time::timeout(DEFAULT_TIMEOUT, stream.next())
let mut messages: Vec<_> = tokio::time::timeout(DEFAULT_TIMEOUT, stream.collect())
.await
.unwrap()
.unwrap()
.unwrap()
{
HandshakerMessage::InitiateMessage(_) => unreachable!(),
HandshakerMessage::ClientMetadata(metadata) => metadata,
};
.expect("it should not time out");

while let Some(message) = messages.pop() {
let metadata = match message.expect("it should be a handshake message") {
HandshakerMessage::InitiateMessage(_) => unreachable!(),
HandshakerMessage::ClientMetadata(metadata) => metadata,
};

assert_eq!(send_token, metadata.token());
assert_eq!(send_token, metadata.token());

match metadata.result() {
Err(ClientError::ClientShutdown) => (),
_ => panic!("Did Not Receive ClientShutdown..."),
match metadata.result() {
Err(ClientError::ClientShutdown) => (),
_ => panic!("Did Not Receive ClientShutdown..."),
}
}
}
65 changes: 42 additions & 23 deletions packages/utracker/tests/test_client_full.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
use common::{handshaker, tracing_stderr_init, DEFAULT_TIMEOUT, INIT};
use std::net::SocketAddr;

use common::{handshaker, tracing_stderr_init, DEFAULT_TIMEOUT, INIT, LOOPBACK_IPV4};
use futures::StreamExt as _;
use tokio::net::UdpSocket;
use tracing::level_filters::LevelFilter;
use util::bt::{self};
use utracker::announce::{AnnounceEvent, ClientState};
use utracker::{ClientRequest, TrackerClient};

mod common;

#[ignore = "race condition with shutdown of client"]
#[tokio::test]
async fn positive_client_request_dropped() {
INIT.call_once(|| {
Expand All @@ -15,39 +19,54 @@ async fn positive_client_request_dropped() {

let (sink, stream) = handshaker();

let server_addr = "127.0.0.1:3504".parse().unwrap();
// We bind a temp socket, then drop it...
let disconnected_addr: SocketAddr = {
let socket = UdpSocket::bind(LOOPBACK_IPV4).await.unwrap();
socket.local_addr().unwrap()
};

tokio::task::yield_now().await;

let request_capacity = 10;

let mut client = TrackerClient::new("127.0.0.1:4504".parse().unwrap(), sink, Some(request_capacity)).unwrap();
{
let mut client = TrackerClient::run(LOOPBACK_IPV4, sink, Some(request_capacity)).unwrap();

tracing::warn!("sending announce requests to fill buffer");
for i in 1..=request_capacity {
tracing::warn!("request {i} of {request_capacity}");
tracing::warn!("sending announce requests to fill buffer");
for i in 1..=request_capacity {
tracing::warn!("request {i} of {request_capacity}");

client
client
.request(
disconnected_addr,
ClientRequest::Announce(
[0u8; bt::INFO_HASH_LEN].into(),
ClientState::new(0, 0, 0, AnnounceEvent::Started),
),
)
.unwrap();
}

tracing::warn!("sending one more announce request, it should fail");
assert!(client
.request(
server_addr,
disconnected_addr,
ClientRequest::Announce(
[0u8; bt::INFO_HASH_LEN].into(),
ClientState::new(0, 0, 0, AnnounceEvent::Started),
),
ClientState::new(0, 0, 0, AnnounceEvent::Started)
)
)
.unwrap();
}
.is_none());

tracing::warn!("sending one more announce request, it should fail");
assert!(client
.request(
server_addr,
ClientRequest::Announce(
[0u8; bt::INFO_HASH_LEN].into(),
ClientState::new(0, 0, 0, AnnounceEvent::Started)
)
)
.is_none());
// yield to allow for the request to be sent before the client is shutdown

// todo: somehow there is a race-condition here, if the number of yields is too large, the test fails.
// that is very strange...

std::mem::drop(client);
for _ in 0..100 {
tokio::task::yield_now().await;
}
}

let buffer: Vec<_> = tokio::time::timeout(DEFAULT_TIMEOUT, stream.collect()).await.unwrap();
assert_eq!(request_capacity, buffer.len());
Expand Down
9 changes: 4 additions & 5 deletions packages/utracker/tests/test_connect.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::time::Duration;

use common::{handshaker, tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT};
use common::{handshaker, tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT, LOOPBACK_IPV4};
use futures::StreamExt as _;
use tracing::level_filters::LevelFilter;
use util::bt::{self};
Expand All @@ -17,17 +17,16 @@ async fn positive_receive_connect_id() {

let (sink, mut stream) = handshaker();

let server_addr = "127.0.0.1:3505".parse().unwrap();
let mock_handler = MockTrackerHandler::new();
let _server = TrackerServer::run(server_addr, mock_handler).unwrap();
let server = TrackerServer::run(LOOPBACK_IPV4, mock_handler).unwrap();

std::thread::sleep(Duration::from_millis(100));

let mut client = TrackerClient::new("127.0.0.1:4505".parse().unwrap(), sink, None).unwrap();
let mut client = TrackerClient::run(LOOPBACK_IPV4, sink, None).unwrap();

let send_token = client
.request(
server_addr,
server.local_addr(),
ClientRequest::Announce(
[0u8; bt::INFO_HASH_LEN].into(),
ClientState::new(0, 0, 0, AnnounceEvent::None),
Expand Down
15 changes: 9 additions & 6 deletions packages/utracker/tests/test_connect_cache.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::time::Duration;

use common::{tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT};
use common::{tracing_stderr_init, MockTrackerHandler, DEFAULT_TIMEOUT, INIT, LOOPBACK_IPV4};
use futures::StreamExt as _;
use tracing::level_filters::LevelFilter;
use util::bt::{self};
Expand All @@ -16,18 +16,19 @@ async fn positive_connection_id_cache() {

let (sink, mut stream) = common::handshaker();

let server_addr = "127.0.0.1:3506".parse().unwrap();
let mock_handler = MockTrackerHandler::new();
let _server = TrackerServer::run(server_addr, mock_handler.clone()).unwrap();
let server = TrackerServer::run(LOOPBACK_IPV4, mock_handler.clone()).unwrap();

std::thread::sleep(Duration::from_millis(100));

let mut client = TrackerClient::new("127.0.0.1:4506".parse().unwrap(), sink, None).unwrap();
let mut client = TrackerClient::run(LOOPBACK_IPV4, sink, None).unwrap();

let first_hash = [0u8; bt::INFO_HASH_LEN].into();
let second_hash = [1u8; bt::INFO_HASH_LEN].into();

client.request(server_addr, ClientRequest::Scrape(first_hash)).unwrap();
client
.request(server.local_addr(), ClientRequest::Scrape(first_hash))
.unwrap();
tokio::time::timeout(DEFAULT_TIMEOUT, stream.next())
.await
.unwrap()
Expand All @@ -37,7 +38,9 @@ async fn positive_connection_id_cache() {
assert_eq!(mock_handler.num_active_connect_ids(), 1);

for _ in 0..10 {
client.request(server_addr, ClientRequest::Scrape(second_hash)).unwrap();
client
.request(server.local_addr(), ClientRequest::Scrape(second_hash))
.unwrap();
}

for _ in 0..10 {
Expand Down
Loading

0 comments on commit a1b157e

Please sign in to comment.