From afe535283c68dec40c5fc7a810445e8c2380880f Mon Sep 17 00:00:00 2001 From: Carl Lerche Date: Fri, 9 Oct 2020 10:02:55 -0700 Subject: [PATCH] fs: future proof `File` (#2930) Changes inherent methods to take `&self` instead of `&mut self`. This brings the API in line with `std`. This patch is implemented by using a `tokio::sync::Mutex` to guard the internal `File` state. This is not an ideal implementation strategy doesn't make a big impact compared to having to dispatch operations to a background thread followed by a blocking syscall. In the future, the implementation can be improved as we explore async file-system APIs provided by the operating-system (iocp / io_uring). Closes #2927 --- tokio/src/fs/file.rs | 164 +++++++++++++++++------------- tokio/src/sync/batch_semaphore.rs | 1 + tokio/src/sync/mod.rs | 9 +- tokio/src/sync/mutex.rs | 2 + tokio/src/util/linked_list.rs | 7 +- tokio/tests/fs_file_mocked.rs | 7 +- 6 files changed, 118 insertions(+), 72 deletions(-) diff --git a/tokio/src/fs/file.rs b/tokio/src/fs/file.rs index 9556a22fe40..8b385117a3d 100644 --- a/tokio/src/fs/file.rs +++ b/tokio/src/fs/file.rs @@ -6,6 +6,7 @@ use self::State::*; use crate::fs::{asyncify, sys}; use crate::io::blocking::Buf; use crate::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; +use crate::sync::Mutex; use std::fmt; use std::fs::{Metadata, Permissions}; @@ -80,6 +81,10 @@ use std::task::Poll::*; /// ``` pub struct File { std: Arc, + inner: Mutex, +} + +struct Inner { state: State, /// Errors from writes/flushes are returned in write/flush calls. If a write @@ -199,9 +204,11 @@ impl File { pub fn from_std(std: sys::File) -> File { File { std: Arc::new(std), - state: State::Idle(Some(Buf::with_capacity(0))), - last_write_err: None, - pos: 0, + inner: Mutex::new(Inner { + state: State::Idle(Some(Buf::with_capacity(0))), + last_write_err: None, + pos: 0, + }), } } @@ -228,8 +235,9 @@ impl File { /// /// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt - pub async fn sync_all(&mut self) -> io::Result<()> { - self.complete_inflight().await; + pub async fn sync_all(&self) -> io::Result<()> { + let mut inner = self.inner.lock().await; + inner.complete_inflight().await; let std = self.std.clone(); asyncify(move || std.sync_all()).await @@ -262,8 +270,9 @@ impl File { /// /// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt - pub async fn sync_data(&mut self) -> io::Result<()> { - self.complete_inflight().await; + pub async fn sync_data(&self) -> io::Result<()> { + let mut inner = self.inner.lock().await; + inner.complete_inflight().await; let std = self.std.clone(); asyncify(move || std.sync_data()).await @@ -299,10 +308,11 @@ impl File { /// /// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt - pub async fn set_len(&mut self, size: u64) -> io::Result<()> { - self.complete_inflight().await; + pub async fn set_len(&self, size: u64) -> io::Result<()> { + let mut inner = self.inner.lock().await; + inner.complete_inflight().await; - let mut buf = match self.state { + let mut buf = match inner.state { Idle(ref mut buf_cell) => buf_cell.take().unwrap(), _ => unreachable!(), }; @@ -315,7 +325,7 @@ impl File { let std = self.std.clone(); - self.state = Busy(sys::run(move || { + inner.state = Busy(sys::run(move || { let res = if let Some(seek) = seek { (&*std).seek(seek).and_then(|_| std.set_len(size)) } else { @@ -327,16 +337,16 @@ impl File { (Operation::Seek(res), buf) })); - let (op, buf) = match self.state { + let (op, buf) = match inner.state { Idle(_) => unreachable!(), Busy(ref mut rx) => rx.await?, }; - self.state = Idle(Some(buf)); + inner.state = Idle(Some(buf)); match op { Operation::Seek(res) => res.map(|pos| { - self.pos = pos; + inner.pos = pos; }), _ => unreachable!(), } @@ -402,7 +412,7 @@ impl File { /// # } /// ``` pub async fn into_std(mut self) -> sys::File { - self.complete_inflight().await; + self.inner.get_mut().complete_inflight().await; Arc::try_unwrap(self.std).expect("Arc::try_unwrap failed") } @@ -469,24 +479,19 @@ impl File { let std = self.std.clone(); asyncify(move || std.set_permissions(perm)).await } - - async fn complete_inflight(&mut self) { - use crate::future::poll_fn; - - if let Err(e) = poll_fn(|cx| Pin::new(&mut *self).poll_flush(cx)).await { - self.last_write_err = Some(e.kind()); - } - } } impl AsyncRead for File { fn poll_read( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, dst: &mut ReadBuf<'_>, ) -> Poll> { + let me = self.get_mut(); + let inner = me.inner.get_mut(); + loop { - match self.state { + match inner.state { Idle(ref mut buf_cell) => { let mut buf = buf_cell.take().unwrap(); @@ -497,9 +502,9 @@ impl AsyncRead for File { } buf.ensure_capacity_for(dst); - let std = self.std.clone(); + let std = me.std.clone(); - self.state = Busy(sys::run(move || { + inner.state = Busy(sys::run(move || { let res = buf.read_from(&mut &*std); (Operation::Read(res), buf) })); @@ -510,30 +515,30 @@ impl AsyncRead for File { match op { Operation::Read(Ok(_)) => { buf.copy_to(dst); - self.state = Idle(Some(buf)); + inner.state = Idle(Some(buf)); return Ready(Ok(())); } Operation::Read(Err(e)) => { assert!(buf.is_empty()); - self.state = Idle(Some(buf)); + inner.state = Idle(Some(buf)); return Ready(Err(e)); } Operation::Write(Ok(_)) => { assert!(buf.is_empty()); - self.state = Idle(Some(buf)); + inner.state = Idle(Some(buf)); continue; } Operation::Write(Err(e)) => { - assert!(self.last_write_err.is_none()); - self.last_write_err = Some(e.kind()); - self.state = Idle(Some(buf)); + assert!(inner.last_write_err.is_none()); + inner.last_write_err = Some(e.kind()); + inner.state = Idle(Some(buf)); } Operation::Seek(result) => { assert!(buf.is_empty()); - self.state = Idle(Some(buf)); + inner.state = Idle(Some(buf)); if let Ok(pos) = result { - self.pos = pos; + inner.pos = pos; } continue; } @@ -545,9 +550,12 @@ impl AsyncRead for File { } impl AsyncSeek for File { - fn start_seek(mut self: Pin<&mut Self>, mut pos: SeekFrom) -> io::Result<()> { + fn start_seek(self: Pin<&mut Self>, mut pos: SeekFrom) -> io::Result<()> { + let me = self.get_mut(); + let inner = me.inner.get_mut(); + loop { - match self.state { + match inner.state { Busy(_) => panic!("must wait for poll_complete before calling start_seek"), Idle(ref mut buf_cell) => { let mut buf = buf_cell.take().unwrap(); @@ -561,9 +569,9 @@ impl AsyncSeek for File { } } - let std = self.std.clone(); + let std = me.std.clone(); - self.state = Busy(sys::run(move || { + inner.state = Busy(sys::run(move || { let res = (&*std).seek(pos); (Operation::Seek(res), buf) })); @@ -574,23 +582,25 @@ impl AsyncSeek for File { } fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let inner = self.inner.get_mut(); + loop { - match self.state { - Idle(_) => return Poll::Ready(Ok(self.pos)), + match inner.state { + Idle(_) => return Poll::Ready(Ok(inner.pos)), Busy(ref mut rx) => { let (op, buf) = ready!(Pin::new(rx).poll(cx))?; - self.state = Idle(Some(buf)); + inner.state = Idle(Some(buf)); match op { Operation::Read(_) => {} Operation::Write(Err(e)) => { - assert!(self.last_write_err.is_none()); - self.last_write_err = Some(e.kind()); + assert!(inner.last_write_err.is_none()); + inner.last_write_err = Some(e.kind()); } Operation::Write(_) => {} Operation::Seek(res) => { if let Ok(pos) = res { - self.pos = pos; + inner.pos = pos; } return Ready(res); } @@ -603,16 +613,19 @@ impl AsyncSeek for File { impl AsyncWrite for File { fn poll_write( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, src: &[u8], ) -> Poll> { - if let Some(e) = self.last_write_err.take() { + let me = self.get_mut(); + let inner = me.inner.get_mut(); + + if let Some(e) = inner.last_write_err.take() { return Ready(Err(e.into())); } loop { - match self.state { + match inner.state { Idle(ref mut buf_cell) => { let mut buf = buf_cell.take().unwrap(); @@ -623,9 +636,9 @@ impl AsyncWrite for File { }; let n = buf.copy_from(src); - let std = self.std.clone(); + let std = me.std.clone(); - self.state = Busy(sys::run(move || { + inner.state = Busy(sys::run(move || { let res = if let Some(seek) = seek { (&*std).seek(seek).and_then(|_| buf.write_to(&mut &*std)) } else { @@ -639,7 +652,7 @@ impl AsyncWrite for File { } Busy(ref mut rx) => { let (op, buf) = ready!(Pin::new(rx).poll(cx))?; - self.state = Idle(Some(buf)); + inner.state = Idle(Some(buf)); match op { Operation::Read(_) => { @@ -665,23 +678,8 @@ impl AsyncWrite for File { } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if let Some(e) = self.last_write_err.take() { - return Ready(Err(e.into())); - } - - let (op, buf) = match self.state { - Idle(_) => return Ready(Ok(())), - Busy(ref mut rx) => ready!(Pin::new(rx).poll(cx))?, - }; - - // The buffer is not used here - self.state = Idle(Some(buf)); - - match op { - Operation::Read(_) => Ready(Ok(())), - Operation::Write(res) => Ready(res), - Operation::Seek(_) => Ready(Ok(())), - } + let inner = self.inner.get_mut(); + inner.poll_flush(cx) } fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { @@ -731,3 +729,33 @@ impl std::os::windows::io::FromRawHandle for File { sys::File::from_raw_handle(handle).into() } } + +impl Inner { + async fn complete_inflight(&mut self) { + use crate::future::poll_fn; + + if let Err(e) = poll_fn(|cx| Pin::new(&mut *self).poll_flush(cx)).await { + self.last_write_err = Some(e.kind()); + } + } + + fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll> { + if let Some(e) = self.last_write_err.take() { + return Ready(Err(e.into())); + } + + let (op, buf) = match self.state { + Idle(_) => return Ready(Ok(())), + Busy(ref mut rx) => ready!(Pin::new(rx).poll(cx))?, + }; + + // The buffer is not used here + self.state = Idle(Some(buf)); + + match op { + Operation::Read(_) => Ready(Ok(())), + Operation::Write(res) => Ready(res), + Operation::Seek(_) => Ready(Ok(())), + } + } +} diff --git a/tokio/src/sync/batch_semaphore.rs b/tokio/src/sync/batch_semaphore.rs index d09528beffd..0b50e4f7ad8 100644 --- a/tokio/src/sync/batch_semaphore.rs +++ b/tokio/src/sync/batch_semaphore.rs @@ -1,3 +1,4 @@ +#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))] //! # Implementation Details //! //! The semaphore is implemented using an intrusive linked list of waiters. An diff --git a/tokio/src/sync/mod.rs b/tokio/src/sync/mod.rs index 4919ad8e212..ed9f07a0e76 100644 --- a/tokio/src/sync/mod.rs +++ b/tokio/src/sync/mod.rs @@ -456,6 +456,14 @@ cfg_sync! { } cfg_not_sync! { + #[cfg(any(feature = "fs", feature = "signal", all(unix, feature = "process")))] + pub(crate) mod batch_semaphore; + + cfg_fs! { + mod mutex; + pub(crate) use mutex::Mutex; + } + mod notify; pub(crate) use notify::Notify; @@ -472,7 +480,6 @@ cfg_not_sync! { cfg_signal_internal! { pub(crate) mod mpsc; - pub(crate) mod batch_semaphore; } } diff --git a/tokio/src/sync/mutex.rs b/tokio/src/sync/mutex.rs index b2cf64d3607..21e44ca932c 100644 --- a/tokio/src/sync/mutex.rs +++ b/tokio/src/sync/mutex.rs @@ -1,3 +1,5 @@ +#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))] + use crate::sync::batch_semaphore as semaphore; use std::cell::UnsafeCell; diff --git a/tokio/src/util/linked_list.rs b/tokio/src/util/linked_list.rs index 5073855e8a8..5692743f4e5 100644 --- a/tokio/src/util/linked_list.rs +++ b/tokio/src/util/linked_list.rs @@ -181,7 +181,12 @@ impl fmt::Debug for LinkedList { } } -#[cfg(any(feature = "sync", feature = "signal", feature = "process"))] +#[cfg(any( + feature = "fs", + feature = "sync", + feature = "signal", + feature = "process" +))] impl LinkedList { pub(crate) fn last(&self) -> Option<&L::Target> { let tail = self.tail.as_ref()?; diff --git a/tokio/tests/fs_file_mocked.rs b/tokio/tests/fs_file_mocked.rs index 2e7e8b7cf48..edb74a7324e 100644 --- a/tokio/tests/fs_file_mocked.rs +++ b/tokio/tests/fs_file_mocked.rs @@ -57,6 +57,9 @@ pub(crate) mod fs { pub(crate) use crate::support::mock_pool::asyncify; } +pub(crate) mod sync { + pub(crate) use tokio::sync::Mutex; +} use fs::sys; use tokio::prelude::*; @@ -710,7 +713,7 @@ fn open_set_len_ok() { let (mock, file) = sys::File::mock(); mock.set_len(123); - let mut file = File::from_std(file); + let file = File::from_std(file); let mut t = task::spawn(file.set_len(123)); assert_pending!(t.poll()); @@ -728,7 +731,7 @@ fn open_set_len_err() { let (mock, file) = sys::File::mock(); mock.set_len_err(123); - let mut file = File::from_std(file); + let file = File::from_std(file); let mut t = task::spawn(file.set_len(123)); assert_pending!(t.poll());