diff --git a/src/request.rs b/src/request.rs index 958208cdb..729059896 100644 --- a/src/request.rs +++ b/src/request.rs @@ -1,5 +1,6 @@ use std::io::Read; use std::sync::{Arc, Mutex}; +use std::time; use lazy_static::lazy_static; use qstring::QString; @@ -46,6 +47,7 @@ pub struct Request { pub(crate) timeout_connect: u64, pub(crate) timeout_read: u64, pub(crate) timeout_write: u64, + pub(crate) timeout: Option, pub(crate) redirects: u32, pub(crate) proxy: Option, #[cfg(feature = "tls")] @@ -336,6 +338,8 @@ impl Request { } /// Timeout for the socket connection to be successful. + /// If both this and .timeout() are both set, .timeout_connect() + /// takes precedence. /// /// The default is `0`, which means a request can block forever. /// @@ -351,6 +355,8 @@ impl Request { } /// Timeout for the individual reads of the socket. + /// If both this and .timeout() are both set, .timeout() + /// takes precedence. /// /// The default is `0`, which means it can block forever. /// @@ -360,12 +366,15 @@ impl Request { /// .call(); /// println!("{:?}", r); /// ``` + #[deprecated(note = "Please use the timeout() function instead")] pub fn timeout_read(&mut self, millis: u64) -> &mut Request { self.timeout_read = millis; self } /// Timeout for the individual writes to the socket. + /// If both this and .timeout() are both set, .timeout() + /// takes precedence. /// /// The default is `0`, which means it can block forever. /// @@ -375,11 +384,32 @@ impl Request { /// .call(); /// println!("{:?}", r); /// ``` + #[deprecated(note = "Please use the timeout() function instead")] pub fn timeout_write(&mut self, millis: u64) -> &mut Request { self.timeout_write = millis; self } + /// Timeout for the overall request, including DNS resolution, connection + /// time, redirects, and reading the response body. Slow DNS resolution + /// may cause a request to exceed the timeout, because the DNS request + /// cannot be interrupted with the available APIs. + /// + /// This takes precedence over .timeout_read() and .timeout_write(), but + /// not .timeout_connect(). + /// + /// ``` + /// // wait max 1 second for whole request to complete. + /// let r = ureq::get("/my_page") + /// .timeout(std::time::Duration::from_secs(1)) + /// .call(); + /// println!("{:?}", r); + /// ``` + pub fn timeout(&mut self, timeout: time::Duration) -> &mut Request { + self.timeout = Some(timeout); + self + } + /// Basic auth. /// /// These are the same diff --git a/src/response.rs b/src/response.rs index 50e2ab279..ba147628c 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,12 +1,13 @@ use std::io::{Cursor, Error as IoError, ErrorKind, Read, Result as IoResult}; use std::str::FromStr; +use std::time::Instant; use chunked_transfer::Decoder as ChunkDecoder; use crate::error::Error; use crate::header::Header; use crate::pool::PoolReturnRead; -use crate::stream::Stream; +use crate::stream::{DeadlineStream, Stream}; use crate::unit::Unit; #[cfg(feature = "json")] @@ -46,6 +47,7 @@ pub struct Response { headers: Vec
, unit: Option, stream: Option, + deadline: Option, } /// index into status_line where we split: HTTP/1.1 200 OK @@ -273,7 +275,6 @@ impl Response { /// ``` pub fn into_reader(self) -> impl Read { // - let is_http10 = self.http_version().eq_ignore_ascii_case("HTTP/1.0"); let is_close = self .header("connection") @@ -306,6 +307,8 @@ impl Response { let stream = self.stream.expect("No reader in response?!"); let unit = self.unit; + let deadline = unit.as_ref().and_then(|u| u.deadline); + let stream = DeadlineStream::new(stream, deadline); match (use_chunked, limit_bytes) { (true, _) => { @@ -472,6 +475,7 @@ impl Response { headers, unit: None, stream: None, + deadline: None, }) } @@ -551,6 +555,9 @@ impl Into for Error { /// *Internal API* pub(crate) fn set_stream(resp: &mut Response, url: String, unit: Option, stream: Stream) { resp.url = Some(url); + if let Some(unit) = &unit { + resp.deadline = unit.deadline; + } resp.unit = unit; resp.stream = Some(stream); } @@ -586,7 +593,7 @@ struct LimitedRead { position: usize, } -impl LimitedRead { +impl LimitedRead { fn new(reader: R, limit: usize) -> Self { LimitedRead { reader, @@ -617,7 +624,7 @@ impl Read for LimitedRead { } } -impl From> for Stream +impl From> for Stream where Stream: From, { diff --git a/src/stream.rs b/src/stream.rs index b6b999e11..bd6c856f4 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -35,6 +35,61 @@ pub enum Stream { Test(Box, Vec), } +// DeadlineStream wraps a stream such that read() will return an error +// after the provided deadline, and sets timeouts on the underlying +// TcpStream to ensure read() doesn't block beyond the deadline. +// When the From trait is used to turn a DeadlineStream back into a +// Stream (by PoolReturningRead), the timeouts are removed. +pub struct DeadlineStream { + stream: Stream, + deadline: Option, +} + +impl DeadlineStream { + pub(crate) fn new(stream: Stream, deadline: Option) -> Self { + DeadlineStream { stream, deadline } + } +} + +impl From for Stream { + fn from(deadline_stream: DeadlineStream) -> Stream { + // Since we are turning this back into a regular, non-deadline Stream, + // remove any timeouts we set. + let stream = deadline_stream.stream; + if let Some(socket) = stream.socket() { + socket.set_read_timeout(None).unwrap(); + socket.set_write_timeout(None).unwrap(); + } + stream + } +} + +impl Read for DeadlineStream { + fn read(&mut self, buf: &mut [u8]) -> IoResult { + if let Some(deadline) = self.deadline { + let timeout = time_until_deadline(deadline)?; + if let Some(socket) = self.stream.socket() { + socket.set_read_timeout(Some(timeout))?; + socket.set_write_timeout(Some(timeout))?; + } + } + self.stream.read(buf) + } +} + +// If the deadline is in the future, return the remaining time until +// then. Otherwise return a TimedOut error. +fn time_until_deadline(deadline: Instant) -> IoResult { + let now = Instant::now(); + match now.checked_duration_since(deadline) { + Some(_) => Err(IoError::new( + ErrorKind::TimedOut, + "timed out reading response", + )), + None => Ok(deadline - now), + } +} + impl ::std::fmt::Debug for Stream { fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::result::Result<(), ::std::fmt::Error> { write!( @@ -77,10 +132,9 @@ impl Stream { } // Return true if the server has closed this connection. pub(crate) fn server_closed(&self) -> IoResult { - match self { - Stream::Http(tcpstream) => Stream::serverclosed_stream(tcpstream), - Stream::Https(rustls_stream) => Stream::serverclosed_stream(&rustls_stream.sock), - _ => Ok(false), + match self.socket() { + Some(socket) => Stream::serverclosed_stream(socket), + None => Ok(false), } } pub fn is_poolable(&self) -> bool { @@ -95,6 +149,15 @@ impl Stream { } } + pub(crate) fn socket(&self) -> Option<&TcpStream> { + match self { + Stream::Http(tcpstream) => Some(tcpstream), + #[cfg(feature = "tls")] + Stream::Https(rustls_stream) => Some(&rustls_stream.sock), + _ => None, + } + } + #[cfg(test)] pub fn to_write_vec(&self) -> Vec { match self { @@ -261,7 +324,13 @@ pub(crate) fn connect_https(unit: &Unit) -> Result { } pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result { - // + let deadline: Option = if unit.timeout_connect > 0 { + Instant::now().checked_add(Duration::from_millis(unit.timeout_connect)) + } else { + unit.deadline + }; + + // TODO: Find a way to apply deadline to DNS lookup. let sock_addrs: Vec = match unit.proxy { Some(ref proxy) => format!("{}:{}", proxy.server, proxy.port), None => format!("{}:{}", hostname, port), @@ -282,34 +351,24 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result 0; - // Find the first sock_addr that accepts a connection for sock_addr in sock_addrs { - // ensure connect timeout isn't hit overall. - if has_timeout { - let lapsed = (Instant::now() - start_time).as_millis() as u64; - if lapsed >= unit.timeout_connect { - any_err = Some(IoError::new(ErrorKind::TimedOut, "Didn't connect in time")); - break; - } else { - timeout_connect = unit.timeout_connect - lapsed; - } - } + // ensure connect timeout or overall timeout aren't yet hit. + let timeout = match deadline { + Some(deadline) => Some(time_until_deadline(deadline)?), + None => None, + }; // connect with a configured timeout. let stream = if Some(Proto::SOCKS5) == proto { connect_socks5( unit.proxy.to_owned().unwrap(), - timeout_connect, + deadline, sock_addr, hostname, port, ) - } else if has_timeout { - let timeout = Duration::from_millis(timeout_connect); + } else if let Some(timeout) = timeout { TcpStream::connect_timeout(&sock_addr, timeout) } else { TcpStream::connect(&sock_addr) @@ -332,7 +391,11 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result 0 { + if let Some(deadline) = deadline { + stream + .set_read_timeout(Some(deadline - Instant::now())) + .ok(); + } else if unit.timeout_read > 0 { stream .set_read_timeout(Some(Duration::from_millis(unit.timeout_read as u64))) .ok(); @@ -340,7 +403,11 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result 0 { + if let Some(deadline) = deadline { + stream + .set_write_timeout(Some(deadline - Instant::now())) + .ok(); + } else if unit.timeout_write > 0 { stream .set_write_timeout(Some(Duration::from_millis(unit.timeout_write as u64))) .ok(); @@ -399,7 +466,7 @@ fn socks5_local_nslookup(hostname: &str, port: u16) -> Result, proxy_addr: SocketAddr, host: &str, port: u16, @@ -430,7 +497,7 @@ fn connect_socks5( // 1) In the event of a timeout, a thread may be left running in the background. // TODO: explore supporting timeouts upstream in Socks5Proxy. #[allow(clippy::mutex_atomic)] - let stream = if timeout_connect > 0 { + let stream = if let Some(deadline) = deadline { use std::sync::mpsc::channel; use std::sync::{Arc, Condvar, Mutex}; use std::thread; @@ -455,9 +522,7 @@ fn connect_socks5( let (lock, cvar) = &*master_signal; let done = lock.lock().unwrap(); - let done_result = cvar - .wait_timeout(done, Duration::from_millis(timeout_connect)) - .unwrap(); + let done_result = cvar.wait_timeout(done, deadline - Instant::now()).unwrap(); let done = done_result.0; if *done { rx.recv().unwrap()? @@ -504,7 +569,7 @@ fn get_socks5_stream( #[cfg(not(feature = "socks-proxy"))] fn connect_socks5( _proxy: Proxy, - _timeout_connect: u64, + _deadline: Option, _proxy_addr: SocketAddr, _hostname: &str, _port: u16, diff --git a/src/test/agent_test.rs b/src/test/agent_test.rs index 35e335a5f..fa70f9ee1 100644 --- a/src/test/agent_test.rs +++ b/src/test/agent_test.rs @@ -1,7 +1,9 @@ use crate::test; +use std::io::{BufRead, BufReader, Read, Write}; +use std::thread; +use std::time::Duration; use super::super::*; -use std::thread; #[test] fn agent_reuse_headers() { @@ -57,8 +59,6 @@ fn agent_cookies() { // Start a test server on an available port, that times out idle connections at 2 seconds. // Return the port this server is listening on. fn start_idle_timeout_server() -> u16 { - use std::io::{BufRead, BufReader, Write}; - use std::time::Duration; let listener = std::net::TcpListener::bind("localhost:0").unwrap(); let port = listener.local_addr().unwrap().port(); thread::spawn(move || { @@ -88,9 +88,6 @@ fn start_idle_timeout_server() -> u16 { #[test] fn connection_reuse() { - use std::io::Read; - use std::time::Duration; - let port = start_idle_timeout_server(); let url = format!("http://localhost:{}", port); let agent = Agent::default().build(); diff --git a/src/test/mod.rs b/src/test/mod.rs index 5d12af72d..2baa08ff8 100644 --- a/src/test/mod.rs +++ b/src/test/mod.rs @@ -14,6 +14,7 @@ mod query_string; mod range; mod redirect; mod simple; +mod timeout; type RequestHandler = dyn Fn(&Unit) -> Result + Send + 'static; diff --git a/src/test/timeout.rs b/src/test/timeout.rs new file mode 100644 index 000000000..1ee33e7f1 --- /dev/null +++ b/src/test/timeout.rs @@ -0,0 +1,125 @@ + +use crate::test; +use std::io::{self, BufRead, BufReader, Read, Write}; +use std::net::TcpStream; +use std::thread; +use std::time::Duration; + +use super::super::*; + +// Send an HTTP response on the TcpStream at a rate of two bytes every 10 +// milliseconds, for a total of 600 bytes. +fn dribble_body_respond(stream: &mut TcpStream) -> io::Result<()> { + let contents = [b'a'; 300]; + let headers = format!( + "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n", + contents.len() * 2 + ); + stream.write_all(headers.as_bytes())?; + for i in 0..contents.len() { + stream.write_all(&contents[i..i + 1])?; + stream.write_all(&[b'\n'; 1])?; + stream.flush()?; + thread::sleep(Duration::from_millis(10)); + } + Ok(()) +} + +// Read a stream until reaching a blank line, in order to consume +// request headers. +fn read_headers(stream: &TcpStream) { + for line in BufReader::new(stream).lines() { + let line = match line { + Ok(x) => x, + Err(_) => return, + }; + if line == "" { + break; + } + } +} + +// Start a test server on an available port, that dribbles out a response at 1 write per 10ms. +// Return the port this server is listening on. +fn start_dribble_body_server() -> u16 { + let listener = std::net::TcpListener::bind("localhost:0").unwrap(); + let port = listener.local_addr().unwrap().port(); + let dribble_handler = |mut stream: TcpStream| { + read_headers(&stream); + if let Err(e) = dribble_body_respond(&mut stream) { + eprintln!("sending dribble repsonse: {}", e); + } + }; + thread::spawn(move || { + for stream in listener.incoming() { + thread::spawn(move || dribble_handler(stream.unwrap())); + } + }); + port +} + +fn get_and_expect_timeout(url: String) { + let agent = Agent::default().build(); + let timeout = Duration::from_millis(500); + let resp = agent.get(&url).timeout(timeout).call(); + + let mut reader = resp.into_reader(); + let mut bytes = vec![]; + let result = reader.read_to_end(&mut bytes); + + match result { + Err(io_error) => match io_error.kind() { + io::ErrorKind::WouldBlock => Ok(()), + io::ErrorKind::TimedOut => Ok(()), + _ => Err(format!("{:?}", io_error)), + }, + Ok(_) => Err("successful response".to_string()), + } + .expect("expected timeout but got something else"); +} + +#[test] +fn overall_timeout_during_body() { + let port = start_dribble_body_server(); + let url = format!("http://localhost:{}/", port); + + get_and_expect_timeout(url); +} + +// Send HTTP headers on the TcpStream at a rate of one header every 100 +// milliseconds, for a total of 30 headers. +fn dribble_headers_respond(stream: &mut TcpStream) -> io::Result<()> { + stream.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n")?; + for _ in 0..30 { + stream.write_all(b"a: b\n")?; + stream.flush()?; + thread::sleep(Duration::from_millis(100)); + } + Ok(()) +} + +// Start a test server on an available port, that dribbles out response *headers* at 1 write per 10ms. +// Return the port this server is listening on. +fn start_dribble_headers_server() -> u16 { + let listener = std::net::TcpListener::bind("localhost:0").unwrap(); + let port = listener.local_addr().unwrap().port(); + let dribble_handler = |mut stream: TcpStream| { + read_headers(&stream); + if let Err(e) = dribble_headers_respond(&mut stream) { + eprintln!("sending dribble repsonse: {}", e); + } + }; + thread::spawn(move || { + for stream in listener.incoming() { + thread::spawn(move || dribble_handler(stream.unwrap())); + } + }); + port +} + +#[test] +fn overall_timeout_during_headers() { + let port = start_dribble_headers_server(); + let url = format!("http://localhost:{}/", port); + get_and_expect_timeout(url); +} diff --git a/src/unit.rs b/src/unit.rs index 4828cdc4f..9a2a331a5 100644 --- a/src/unit.rs +++ b/src/unit.rs @@ -1,5 +1,6 @@ use std::io::{Result as IoResult, Write}; use std::sync::{Arc, Mutex}; +use std::time; use qstring::QString; use url::Url; @@ -38,6 +39,7 @@ pub(crate) struct Unit { pub timeout_connect: u64, pub timeout_read: u64, pub timeout_write: u64, + pub deadline: Option, pub method: String, pub proxy: Option, #[cfg(feature = "tls")] @@ -89,6 +91,14 @@ impl Unit { .cloned() .collect(); + let deadline = match req.timeout { + None => None, + Some(timeout) => { + let now = time::Instant::now(); + Some(now.checked_add(timeout).unwrap()) + } + }; + Unit { agent: Arc::clone(&req.agent), url: url.clone(), @@ -98,6 +108,7 @@ impl Unit { timeout_connect: req.timeout_connect, timeout_read: req.timeout_read, timeout_write: req.timeout_write, + deadline, method: req.method.clone(), proxy: req.proxy.clone(), #[cfg(feature = "tls")] @@ -154,6 +165,7 @@ pub(crate) fn connect( let body_bytes_sent = body::send_body(body, unit.is_chunked, &mut stream)?; // start reading the response to process cookies and redirects. + let mut stream = stream::DeadlineStream::new(stream, unit.deadline); let mut resp = Response::from_read(&mut stream); if let Some(err) = resp.synthetic_error() { @@ -208,8 +220,8 @@ pub(crate) fn connect( } // since it is not a redirect, or we're not following redirects, - // give away the incoming stream to the response object - crate::response::set_stream(&mut resp, unit.url.to_string(), Some(unit), stream); + // give away the incoming stream to the response object. + crate::response::set_stream(&mut resp, unit.url.to_string(), Some(unit), stream.into()); // release the response Ok(resp)