Skip to content

Commit 9e0393f

Browse files
committed
more work
1 parent a2aad25 commit 9e0393f

File tree

14 files changed

+355
-208
lines changed

14 files changed

+355
-208
lines changed

examples/get_metadata/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,5 @@ hex = "0"
2424
futures = "0"
2525
tokio = { version = "1", features = ["full"] }
2626
tokio-util = {version = "0", features = ["codec"]}
27+
tracing = "0"
28+
tracing-subscriber = "0"

examples/get_metadata/src/main.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,22 +118,22 @@ async fn main() {
118118
extensions.add(Extension::ExtensionProtocol);
119119

120120
// Create a handshaker that can initiate connections with peers
121-
let (handshaker_send, mut handshaker_recv) = HandshakerBuilder::new()
121+
let (handshaker, mut tasks) = HandshakerBuilder::new()
122122
.with_extensions(extensions)
123123
.with_config(
124124
// Set a low handshake timeout so we don't wait on peers that aren't listening on tcp
125125
HandshakerConfig::default().with_connect_timeout(Duration::from_millis(500)),
126126
)
127127
.build(TcpTransport)
128128
.await
129-
.expect("it should build a handshaker pair")
130-
.into_parts();
129+
.expect("it should build a handshaker pair");
130+
let (handshaker_send, mut handshaker_recv) = handshaker.into_parts();
131131

132132
// Create a peer manager that will hold our peers and heartbeat/send messages to them
133133
let (mut peer_manager_send, peer_manager_recv) = PeerManagerBuilder::new().build().into_parts();
134134

135135
// Hook up a future that feeds incoming (handshaken) peers over to the peer manager
136-
tokio::spawn(async move {
136+
tasks.spawn(async move {
137137
while let Some(complete_msg) = handshaker_recv.next().await {
138138
let (_, extensions, hash, pid, addr, sock) = complete_msg.unwrap().into_parts();
139139

@@ -185,7 +185,7 @@ async fn main() {
185185
let mut merged_recv = futures::stream::select(peer_manager_recv.map(Either::Left), timer.map(Either::Right)).boxed();
186186

187187
// Hook up a future that receives messages from the peer manager
188-
tokio::spawn(async move {
188+
tasks.spawn(async move {
189189
let mut uber_send = uber_send.clone();
190190

191191
while let Some(item) = merged_recv.next().await {
@@ -268,4 +268,6 @@ async fn main() {
268268
.expect("Failed to create output file")
269269
.write_all(&metainfo.to_bytes())
270270
.expect("Failed to write metainfo to file");
271+
272+
tasks.shutdown().await;
271273
}

examples/simple_torrent/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,5 @@ clap = "4"
2424
futures = "0"
2525
tokio = { version = "1", features = ["full"] }
2626
tokio-util = {version = "0", features = ["codec"]}
27+
tracing = "0"
28+
tracing-subscriber = "0"

examples/simple_torrent/src/main.rs

Lines changed: 96 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::cmp;
22
use std::collections::HashMap;
33
use std::fs::File;
44
use std::io::Read;
5-
use std::sync::Arc;
5+
use std::sync::{Arc, Once};
66

77
use disk::fs::NativeFileSystem;
88
use disk::fs_cache::FileHandleCache;
@@ -16,8 +16,7 @@ use futures::lock::Mutex;
1616
use futures::{stream, SinkExt as _, StreamExt as _};
1717
use handshake::transports::TcpTransport;
1818
use handshake::{
19-
Extensions, Handshaker, HandshakerBuilder, HandshakerConfig, HandshakerSink, HandshakerStream, InitiateMessage, PeerId,
20-
Protocol,
19+
Extensions, Handshaker, HandshakerBuilder, HandshakerConfig, HandshakerStream, InitiateMessage, PeerId, Protocol,
2120
};
2221
use metainfo::{Info, Metainfo};
2322
use peer::messages::{BitFieldMessage, HaveMessage, PeerWireProtocolMessage, PieceMessage, RequestMessage};
@@ -31,10 +30,13 @@ use tokio::signal;
3130
use tokio::task::JoinSet;
3231
use tokio_util::bytes::BytesMut;
3332
use tokio_util::codec::{Decoder, Framed};
33+
use tracing::level_filters::LevelFilter;
3434

3535
// Maximum number of requests that can be in flight at once.
3636
const MAX_PENDING_BLOCKS: usize = 50;
3737

38+
pub static INIT: Once = Once::new();
39+
3840
// Enum to store our selection state updates
3941
#[allow(dead_code)]
4042
#[derive(Debug)]
@@ -59,8 +61,30 @@ enum Downloader {
5961
Interrupted,
6062
}
6163

64+
enum Setup {
65+
Finished((NativeDiskManager, PeerManager, TcpHandshaker), JoinSet<()>),
66+
Interrupted,
67+
}
68+
69+
pub fn tracing_stdout_init(filter: LevelFilter) {
70+
let builder = tracing_subscriber::fmt().with_max_level(filter).with_ansi(true);
71+
72+
builder.pretty().with_file(true).init();
73+
74+
tracing::info!("Logging initialized");
75+
}
76+
77+
async fn ctrl_c() {
78+
signal::ctrl_c().await.expect("failed to listen for event");
79+
println!("Ctrl-C received, shutting down...");
80+
}
81+
6282
#[tokio::main]
6383
async fn main() {
84+
INIT.call_once(|| {
85+
tracing_stdout_init(LevelFilter::TRACE);
86+
});
87+
6488
// Parse command-line arguments
6589
let matched_arguments = parse_arguments();
6690
let (torrent_file_path, download_directory, peer_address) = extract_arguments(&matched_arguments);
@@ -71,28 +95,51 @@ async fn main() {
7195
// Create a JoinSet to manage background tasks
7296
let tasks = Arc::new(Mutex::new(JoinSet::new()));
7397

74-
let ctrl_c = async {
75-
signal::ctrl_c().await.expect("failed to listen for event");
76-
println!("Ctrl-C received, shutting down...");
98+
// Setup the managers.
99+
let setup = setup(download_directory);
100+
101+
// Await either the completion of the setup or the Ctrl-C signal
102+
let setup = tokio::select! {
103+
setup = setup => Setup::Finished(setup.0, setup.1),
104+
() = ctrl_c() => Setup::Interrupted,
105+
};
106+
107+
let (managers, mut handshaker_tasks) = match setup {
108+
Setup::Finished(managers, handshaker_tasks) => (managers, handshaker_tasks),
109+
Setup::Interrupted => {
110+
tracing::warn!("setup was canceled...");
111+
return;
112+
}
77113
};
78114

79-
let downloader = downloader(tasks.clone(), download_directory, peer_address, metainfo, info_hash);
115+
let downloader = downloader(tasks.clone(), managers, peer_address, metainfo, info_hash);
80116

81-
// Await either the completion of all tasks or the Ctrl-C signal
117+
// Await either the completion of the downloader or the Ctrl-C signal
82118
let status = tokio::select! {
83119
() = downloader => Downloader::Finished,
84-
() = ctrl_c => Downloader::Interrupted,
120+
() = ctrl_c() => Downloader::Interrupted,
85121
};
86122

87123
match status {
88124
Downloader::Finished => {
89-
while let Some(result) = tasks.lock().await.join_next().await {
125+
while let Some(result) = handshaker_tasks.try_join_next() {
90126
if let Err(e) = result {
91127
eprintln!("Task failed: {e:?}");
92128
}
93129
}
130+
handshaker_tasks.shutdown().await;
131+
132+
while let Some(result) = tasks.lock().await.try_join_next() {
133+
if let Err(e) = result {
134+
eprintln!("Task failed: {e:?}");
135+
}
136+
}
137+
tasks.lock().await.shutdown().await;
138+
}
139+
Downloader::Interrupted => {
140+
handshaker_tasks.shutdown().await;
141+
tasks.lock().await.shutdown().await;
94142
}
95-
Downloader::Interrupted => tasks.lock().await.shutdown().await,
96143
}
97144
}
98145

@@ -142,21 +189,36 @@ fn load_and_parse_torrent_file(torrent_file_path: &str) -> (Metainfo, InfoHash)
142189
(metainfo, info_hash)
143190
}
144191

192+
async fn setup(download_directory: String) -> ((NativeDiskManager, PeerManager, TcpHandshaker), JoinSet<()>) {
193+
// Setup disk manager for handling file operations
194+
let disk_manager = setup_disk_manager(&download_directory);
195+
196+
// Setup peer manager for managing peer communication
197+
let peer_manager = setup_peer_manager();
198+
199+
// Setup handshaker for managing peer connections
200+
let (handshaker, handshaker_tasks) = setup_handshaker().await;
201+
202+
((disk_manager, peer_manager, handshaker), handshaker_tasks)
203+
}
204+
145205
async fn downloader(
146206
tasks: Arc<Mutex<JoinSet<()>>>,
147-
download_directory: String,
207+
managers: (NativeDiskManager, PeerManager, TcpHandshaker),
148208
peer_address: String,
149209
metainfo: Metainfo,
150210
info_hash: InfoHash,
151211
) {
212+
let (disk_manager, peer_manager, handshaker) = managers;
213+
152214
// Setup disk manager for handling file operations
153-
let (mut disk_manager_sender, disk_manager_receiver) = setup_disk_manager(&download_directory);
215+
let (mut disk_manager_sender, disk_manager_receiver) = disk_manager.into_parts();
154216

155217
// Setup peer manager for managing peer communication
156-
let (peer_manager_sender, peer_manager_receiver) = setup_peer_manager();
218+
let (peer_manager_sender, peer_manager_receiver) = peer_manager.into_parts();
157219

158220
// Setup handshaker for managing peer connections
159-
let (mut handshaker_sender, handshaker_receiver) = setup_handshaker().await;
221+
let (mut handshaker_sender, handshaker_receiver) = handshaker.into_parts();
160222

161223
// Handle new incoming connections
162224
tasks
@@ -231,44 +293,39 @@ async fn downloader(
231293
.await;
232294
}
233295

234-
fn setup_disk_manager(download_directory: &str) -> (DiskManagerSink<FileHandleCache<NativeFileSystem>>, DiskManagerStream) {
296+
type NativeDiskManager = DiskManager<FileHandleCache<NativeFileSystem>>;
297+
298+
fn setup_disk_manager(download_directory: &str) -> DiskManager<FileHandleCache<NativeFileSystem>> {
235299
let filesystem = FileHandleCache::new(NativeFileSystem::with_directory(download_directory), 100);
236-
let disk_manager: DiskManager<FileHandleCache<NativeFileSystem>> = DiskManagerBuilder::new()
300+
301+
DiskManagerBuilder::new()
237302
.with_sink_buffer_capacity(1)
238303
.with_stream_buffer_capacity(0)
239-
.build(Arc::new(filesystem));
240-
241-
disk_manager.into_parts()
304+
.build(Arc::new(filesystem))
242305
}
243306

244-
async fn setup_handshaker() -> (HandshakerSink, HandshakerStream<TcpStream>) {
245-
let handshaker: Handshaker<TcpStream> = HandshakerBuilder::new()
307+
type TcpHandshaker = Handshaker<TcpStream>;
308+
309+
async fn setup_handshaker() -> (TcpHandshaker, JoinSet<()>) {
310+
HandshakerBuilder::new()
246311
.with_peer_id(PeerId::from_hash("-BI0000-000000000000".as_bytes()).unwrap())
247312
.with_config(HandshakerConfig::default().with_wait_buffer_size(0).with_done_buffer_size(0))
248313
.build(TcpTransport)
249314
.await
250-
.unwrap();
251-
252-
handshaker.into_parts()
315+
.unwrap()
253316
}
254317

318+
type PeerManager = peer::PeerManager<
319+
Framed<TcpStream, PeerProtocolCodec<PeerWireProtocol<NullProtocol>>>,
320+
PeerWireProtocolMessage<NullProtocol>,
321+
>;
322+
255323
#[allow(clippy::type_complexity)]
256-
fn setup_peer_manager() -> (
257-
PeerManagerSink<Framed<TcpStream, PeerProtocolCodec<PeerWireProtocol<NullProtocol>>>, PeerWireProtocolMessage<NullProtocol>>,
258-
PeerManagerStream<
259-
Framed<TcpStream, PeerProtocolCodec<PeerWireProtocol<NullProtocol>>>,
260-
PeerWireProtocolMessage<NullProtocol>,
261-
>,
262-
) {
263-
let peer_manager: peer::PeerManager<
264-
Framed<TcpStream, PeerProtocolCodec<PeerWireProtocol<NullProtocol>>>,
265-
PeerWireProtocolMessage<NullProtocol>,
266-
> = PeerManagerBuilder::new()
324+
fn setup_peer_manager() -> PeerManager {
325+
PeerManagerBuilder::new()
267326
.with_sink_buffer_capacity(0)
268327
.with_stream_buffer_capacity(0)
269-
.build();
270-
271-
peer_manager.into_parts()
328+
.build()
272329
}
273330

274331
async fn handle_new_connections(

packages/handshake/examples/handshake_torrent.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ async fn main() -> std::io::Result<()> {
2525

2626
// Show up as a uTorrent client...
2727
let peer_id = (*b"-UT2060-000000000000").into();
28-
let mut handshaker = HandshakerBuilder::new()
28+
let (mut handshaker, mut tasks) = HandshakerBuilder::new()
2929
.with_peer_id(peer_id)
3030
.build(TcpTransport)
3131
.await
@@ -40,6 +40,8 @@ async fn main() -> std::io::Result<()> {
4040

4141
sleep(Duration::from_secs(10)).await;
4242

43+
tasks.shutdown().await;
44+
4345
Ok(())
4446
}
4547

packages/handshake/src/handshake/handshaker.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ impl HandshakerBuilder {
114114
/// # Errors
115115
///
116116
/// Returns a IO error if unable to build.
117-
pub async fn build<T>(&self, transport: T) -> std::io::Result<Handshaker<T::Socket>>
117+
pub async fn build<T>(&self, transport: T) -> std::io::Result<(Handshaker<T::Socket>, JoinSet<()>)>
118118
where
119119
T: Transport + Send + Sync + 'static,
120120
<T as Transport>::Socket: AsyncWrite + AsyncRead + std::fmt::Debug + Send + Sync,
@@ -157,7 +157,7 @@ impl<S> Handshaker<S>
157157
where
158158
S: AsyncRead + AsyncWrite + std::fmt::Debug + Send + Sync + Unpin + 'static,
159159
{
160-
async fn with_builder<T>(builder: &HandshakerBuilder, transport: T) -> std::io::Result<Handshaker<T::Socket>>
160+
async fn with_builder<T>(builder: &HandshakerBuilder, transport: T) -> std::io::Result<(Handshaker<T::Socket>, JoinSet<()>)>
161161
where
162162
T: Transport<Socket = S> + Send + Sync + 'static,
163163
<T as Transport>::Listener: Send,
@@ -208,7 +208,7 @@ where
208208
let sink = HandshakerSink::new(addr_send, open_port, builder.pid, filters);
209209
let stream = HandshakerStream::new(sock_recv);
210210

211-
Ok(Handshaker { sink, stream })
211+
Ok((Handshaker { sink, stream }, tasks))
212212
}
213213
}
214214

packages/handshake/src/message/complete.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::message::protocol::Protocol;
77

88
/// Message containing completed handshaking information.
99
#[allow(clippy::module_name_repetitions)]
10+
#[derive(Debug)]
1011
pub struct CompleteMessage<S> {
1112
prot: Protocol,
1213
ext: Extensions,

packages/handshake/tests/test_byte_after_handshake.rs

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
use std::io::{Read, Write};
22
use std::net::TcpStream;
3-
use std::thread;
43

54
use common::{tracing_stdout_init, INIT};
65
use futures::stream::StreamExt;
@@ -21,7 +20,7 @@ async fn positive_recover_bytes() {
2120
let mut handshaker_one_addr = "127.0.0.1:0".parse().unwrap();
2221
let handshaker_one_pid = [4u8; bt::PEER_ID_LEN].into();
2322

24-
let mut handshaker_one = HandshakerBuilder::new()
23+
let (mut handshaker_one, mut tasks_one) = HandshakerBuilder::new()
2524
.with_bind_addr(handshaker_one_addr)
2625
.with_peer_id(handshaker_one_pid)
2726
.build(TcpTransport)
@@ -30,7 +29,7 @@ async fn positive_recover_bytes() {
3029

3130
handshaker_one_addr.set_port(handshaker_one.port());
3231

33-
thread::spawn(move || {
32+
tasks_one.spawn_blocking(move || {
3433
let mut stream = TcpStream::connect(handshaker_one_addr).unwrap();
3534
let mut write_buffer = Vec::new();
3635

@@ -46,15 +45,21 @@ async fn positive_recover_bytes() {
4645
stream.read_exact(&mut vec![0u8; expect_read_length][..]).unwrap();
4746
});
4847

49-
if let Some(message) = handshaker_one.next().await {
50-
let (_, _, _, _, _, mut sock) = message.unwrap().into_parts();
48+
let test = tokio::spawn(async move {
49+
if let Some(message) = handshaker_one.next().await {
50+
let (_, _, _, _, _, mut sock) = message.unwrap().into_parts();
5151

52-
let mut recv_buffer = vec![0u8; 1];
53-
sock.read_exact(&mut recv_buffer).await.unwrap();
52+
let mut recv_buffer = vec![0u8; 1];
53+
sock.read_exact(&mut recv_buffer).await.unwrap();
5454

55-
// Assert that our buffer contains the bytes after the handshake
56-
assert_eq!(vec![55], recv_buffer);
57-
} else {
58-
panic!("Failed to receive handshake message");
59-
}
55+
// Assert that our buffer contains the bytes after the handshake
56+
assert_eq!(vec![55], recv_buffer);
57+
} else {
58+
panic!("Failed to receive handshake message");
59+
}
60+
});
61+
62+
let res = test.await;
63+
tasks_one.shutdown().await;
64+
res.unwrap();
6065
}

0 commit comments

Comments
 (0)