Skip to content

Commit

Permalink
change behaviour to keep conn open by default
Browse files Browse the repository at this point in the history
and add a test for control-c
  • Loading branch information
rklaehn committed Dec 5, 2023
1 parent e3903a5 commit c435f97
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 42 deletions.
32 changes: 25 additions & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }

[dev-dependencies]
duct = "0.13.6"
nix = { version = "0.27", features = ["signal", "process"] }
64 changes: 45 additions & 19 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,35 +113,52 @@ pub struct ConnectArgs {
pub secret: Option<iroh_net::key::SecretKey>,
}

/// Copy from a reader to a quinn stream, calling finish() on the stream when
/// the reader is done or the operation is cancelled.
/// Copy from a reader to a quinn stream.
///
/// Will send a reset to the other side if the operation is cancelled, and fail
/// with an error.
///
/// Returns the number of bytes copied in case of success.
async fn copy_to_quinn(
mut from: impl AsyncRead + Unpin,
mut send: quinn::SendStream,
token: CancellationToken,
) -> anyhow::Result<()> {
) -> io::Result<u64> {
tracing::trace!("copying to quinn");
tokio::select! {
_ = tokio::io::copy(&mut from, &mut send) => {}
_ = token.cancelled() => {}
res = tokio::io::copy(&mut from, &mut send) => {
let size = res?;
send.finish().await?;
Ok(size)
}
_ = token.cancelled() => {
// send a reset to the other side immediately
send.reset(0u8.into()).ok();
Err(io::Error::new(io::ErrorKind::Other, "cancelled"))
}
}
send.finish().await?;
Ok(())
}

/// Copy from a quinn stream to a writer, calling stop() on the stream when
/// the writer is done or the operation is cancelled.
/// Copy from a quinn stream to a writer.
///
/// Will send stop to the other side if the operation is cancelled, and fail
/// with an error.
///
/// Returns the number of bytes copied in case of success.
async fn copy_from_quinn(
mut recv: quinn::RecvStream,
mut to: impl AsyncWrite + Unpin,
token: CancellationToken,
) -> anyhow::Result<()> {
) -> io::Result<u64> {
tokio::select! {
_ = tokio::io::copy(&mut recv, &mut to) => {}
res = tokio::io::copy(&mut recv, &mut to) => {
Ok(res?)
},
_ = token.cancelled() => {
recv.stop(0u8.into())?;
recv.stop(0u8.into()).ok();
Err(io::Error::new(io::ErrorKind::Other, "cancelled"))
}
}
Ok(())
}

/// Get the secret key or generate a new one.
Expand All @@ -155,6 +172,13 @@ fn get_or_create_secret(secret: Option<SecretKey>) -> SecretKey {
})
}

fn cancel_token<T>(token: CancellationToken) -> impl Fn(T) -> T {
move |x| {
token.cancel();
x
}
}

/// Bidirectionally forward data from a quinn stream and an arbitrary tokio
/// reader/writer pair, aborting both sides when either one forwarder is done,
/// or when control-c is pressed.
Expand All @@ -168,20 +192,22 @@ async fn forward_bidi(
let token2 = token1.clone();
let token3 = token1.clone();
let forward_from_stdin = tokio::spawn(async move {
copy_to_quinn(from1, to2, token1.clone()).await.ok();
token1.cancel();
copy_to_quinn(from1, to2, token1.clone())
.await
.map_err(cancel_token(token1))
});
let forward_to_stdout = tokio::spawn(async move {
copy_from_quinn(from2, to1, token2.clone()).await.ok();
token2.cancel();
copy_from_quinn(from2, to1, token2.clone())
.await
.map_err(cancel_token(token2))
});
let _control_c = tokio::spawn(async move {
tokio::signal::ctrl_c().await?;
token3.cancel();
io::Result::Ok(())
});
forward_to_stdout.await?;
forward_from_stdin.await?;
forward_to_stdout.await??;
forward_from_stdin.await??;
Ok(())
}

Expand Down
49 changes: 33 additions & 16 deletions tests/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fn dumbpipe_bin() -> &'static str {
/// Read `n` lines from `reader`, returning the bytes read including the newlines.
///
/// This assumes that the header lines are ASCII and can be parsed byte by byte.
fn read_header_lines(mut n: usize, reader: &mut impl Read) -> io::Result<Vec<u8>> {
fn read_ascii_lines(mut n: usize, reader: &mut impl Read) -> io::Result<Vec<u8>> {
let mut buf = [0u8; 1];
let mut res = Vec::new();
loop {
Expand All @@ -35,61 +35,78 @@ fn read_header_lines(mut n: usize, reader: &mut impl Read) -> io::Result<Vec<u8>
Ok(res)
}

/// Tests the basic functionality of the connect and listen pair
///
/// Connect and listen both write a limited amount of data and then EOF.
/// The interaction should stop when both sides have EOF'd.
#[test]
#[ignore]
fn connect_accept_1() {
fn connect_listen_happy() {
// the bytes provided by the listen command
let listen_to_connect = b"hello from listen";
let connect_to_listen = b"hello from connect";
let mut listen = duct::cmd(dumbpipe_bin(), ["listen"])
.env_remove("RUST_LOG") // disable tracing
.stdin_bytes(listen_to_connect)
.stderr_to_stdout() //
.reader()
.unwrap();
// read the first 3 lines of the header, and parse the last token as a ticket
let header = read_header_lines(3, &mut listen).unwrap();
let header = read_ascii_lines(3, &mut listen).unwrap();
let header = String::from_utf8(header).unwrap();
println!("{}", header);
let ticket = header.split_ascii_whitespace().last().unwrap();
let ticket = NodeTicket::from_str(ticket).unwrap();

let connect = duct::cmd(dumbpipe_bin(), ["connect", &ticket.to_string()])
.env_remove("RUST_LOG") // disable tracing
.stdin_bytes(connect_to_listen)
.stderr_null()
.stdout_capture()
.run()
.unwrap();

assert!(connect.status.success());
assert_eq!(&connect.stdout, listen_to_connect);

let mut listen_stdout = Vec::new();
listen.read_to_end(&mut listen_stdout).unwrap();
assert_eq!(&listen_stdout, connect_to_listen);
}

#[cfg(unix)]
#[test]
fn connect_accept_2() {
fn connect_listen_interrupt_connect() {
use nix::{
sys::signal::{self, Signal},
unistd::Pid,
};
// the bytes provided by the listen command
let connect_to_listen = b"hello from connect";
let mut listen = duct::cmd(dumbpipe_bin(), ["listen"])
.env_remove("RUST_LOG") // disable tracing
.stdin_bytes(b"hello from listen\n")
.stderr_to_stdout() //
.reader()
.unwrap();
// read the first 3 lines of the header, and parse the last token as a ticket
let header = read_header_lines(3, &mut listen).unwrap();
let header = read_ascii_lines(3, &mut listen).unwrap();
let header = String::from_utf8(header).unwrap();
let ticket = header.split_ascii_whitespace().last().unwrap();
let ticket = NodeTicket::from_str(ticket).unwrap();

let connect = duct::cmd(dumbpipe_bin(), ["connect", &ticket.to_string()])
let mut connect = duct::cmd(dumbpipe_bin(), ["connect", &ticket.to_string()])
.env_remove("RUST_LOG") // disable tracing
.stdin_bytes(connect_to_listen)
.stderr_null()
.stdout_capture()
.run()
.reader()
.unwrap();
// wait until we get a line from the listen process
read_ascii_lines(1, &mut connect).unwrap();
for pid in connect.pids() {
signal::kill(Pid::from_raw(pid as i32), Signal::SIGINT).unwrap();
}

assert!(connect.status.success());
assert_eq!(&connect.stdout, b"");
let mut listen_stdout = Vec::new();
listen.read_to_end(&mut listen_stdout).unwrap();
assert_eq!(listen_stdout, connect_to_listen);
let mut tmp = Vec::new();
// we don't care about the results. This test is just to make sure that the
// listen command stops when the connect command stops.
listen.read_to_end(&mut tmp).ok();
connect.read_to_end(&mut tmp).ok();
}

0 comments on commit c435f97

Please sign in to comment.