Skip to content

Commit

Permalink
fs: future proof File (#2930)
Browse files Browse the repository at this point in the history
Changes inherent methods to take `&self` instead of `&mut self`. This
brings the API in line with `std`.

This patch is implemented by using a `tokio::sync::Mutex` to guard the
internal `File` state. This is not an ideal implementation strategy
doesn't make a big impact compared to having to dispatch operations to a
background thread followed by a blocking syscall.

In the future, the implementation can be improved as we explore async
file-system APIs provided by the operating-system (iocp / io_uring).

Closes #2927
  • Loading branch information
carllerche authored Oct 9, 2020
1 parent ee59734 commit afe5352
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 72 deletions.
164 changes: 96 additions & 68 deletions tokio/src/fs/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use self::State::*;
use crate::fs::{asyncify, sys};
use crate::io::blocking::Buf;
use crate::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
use crate::sync::Mutex;

use std::fmt;
use std::fs::{Metadata, Permissions};
Expand Down Expand Up @@ -80,6 +81,10 @@ use std::task::Poll::*;
/// ```
pub struct File {
std: Arc<sys::File>,
inner: Mutex<Inner>,
}

struct Inner {
state: State,

/// Errors from writes/flushes are returned in write/flush calls. If a write
Expand Down Expand Up @@ -199,9 +204,11 @@ impl File {
pub fn from_std(std: sys::File) -> File {
File {
std: Arc::new(std),
state: State::Idle(Some(Buf::with_capacity(0))),
last_write_err: None,
pos: 0,
inner: Mutex::new(Inner {
state: State::Idle(Some(Buf::with_capacity(0))),
last_write_err: None,
pos: 0,
}),
}
}

Expand All @@ -228,8 +235,9 @@ impl File {
///
/// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all
/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
pub async fn sync_all(&mut self) -> io::Result<()> {
self.complete_inflight().await;
pub async fn sync_all(&self) -> io::Result<()> {
let mut inner = self.inner.lock().await;
inner.complete_inflight().await;

let std = self.std.clone();
asyncify(move || std.sync_all()).await
Expand Down Expand Up @@ -262,8 +270,9 @@ impl File {
///
/// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all
/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
pub async fn sync_data(&mut self) -> io::Result<()> {
self.complete_inflight().await;
pub async fn sync_data(&self) -> io::Result<()> {
let mut inner = self.inner.lock().await;
inner.complete_inflight().await;

let std = self.std.clone();
asyncify(move || std.sync_data()).await
Expand Down Expand Up @@ -299,10 +308,11 @@ impl File {
///
/// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all
/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
pub async fn set_len(&mut self, size: u64) -> io::Result<()> {
self.complete_inflight().await;
pub async fn set_len(&self, size: u64) -> io::Result<()> {
let mut inner = self.inner.lock().await;
inner.complete_inflight().await;

let mut buf = match self.state {
let mut buf = match inner.state {
Idle(ref mut buf_cell) => buf_cell.take().unwrap(),
_ => unreachable!(),
};
Expand All @@ -315,7 +325,7 @@ impl File {

let std = self.std.clone();

self.state = Busy(sys::run(move || {
inner.state = Busy(sys::run(move || {
let res = if let Some(seek) = seek {
(&*std).seek(seek).and_then(|_| std.set_len(size))
} else {
Expand All @@ -327,16 +337,16 @@ impl File {
(Operation::Seek(res), buf)
}));

let (op, buf) = match self.state {
let (op, buf) = match inner.state {
Idle(_) => unreachable!(),
Busy(ref mut rx) => rx.await?,
};

self.state = Idle(Some(buf));
inner.state = Idle(Some(buf));

match op {
Operation::Seek(res) => res.map(|pos| {
self.pos = pos;
inner.pos = pos;
}),
_ => unreachable!(),
}
Expand Down Expand Up @@ -402,7 +412,7 @@ impl File {
/// # }
/// ```
pub async fn into_std(mut self) -> sys::File {
self.complete_inflight().await;
self.inner.get_mut().complete_inflight().await;
Arc::try_unwrap(self.std).expect("Arc::try_unwrap failed")
}

Expand Down Expand Up @@ -469,24 +479,19 @@ impl File {
let std = self.std.clone();
asyncify(move || std.set_permissions(perm)).await
}

async fn complete_inflight(&mut self) {
use crate::future::poll_fn;

if let Err(e) = poll_fn(|cx| Pin::new(&mut *self).poll_flush(cx)).await {
self.last_write_err = Some(e.kind());
}
}
}

impl AsyncRead for File {
fn poll_read(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut Context<'_>,
dst: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let me = self.get_mut();
let inner = me.inner.get_mut();

loop {
match self.state {
match inner.state {
Idle(ref mut buf_cell) => {
let mut buf = buf_cell.take().unwrap();

Expand All @@ -497,9 +502,9 @@ impl AsyncRead for File {
}

buf.ensure_capacity_for(dst);
let std = self.std.clone();
let std = me.std.clone();

self.state = Busy(sys::run(move || {
inner.state = Busy(sys::run(move || {
let res = buf.read_from(&mut &*std);
(Operation::Read(res), buf)
}));
Expand All @@ -510,30 +515,30 @@ impl AsyncRead for File {
match op {
Operation::Read(Ok(_)) => {
buf.copy_to(dst);
self.state = Idle(Some(buf));
inner.state = Idle(Some(buf));
return Ready(Ok(()));
}
Operation::Read(Err(e)) => {
assert!(buf.is_empty());

self.state = Idle(Some(buf));
inner.state = Idle(Some(buf));
return Ready(Err(e));
}
Operation::Write(Ok(_)) => {
assert!(buf.is_empty());
self.state = Idle(Some(buf));
inner.state = Idle(Some(buf));
continue;
}
Operation::Write(Err(e)) => {
assert!(self.last_write_err.is_none());
self.last_write_err = Some(e.kind());
self.state = Idle(Some(buf));
assert!(inner.last_write_err.is_none());
inner.last_write_err = Some(e.kind());
inner.state = Idle(Some(buf));
}
Operation::Seek(result) => {
assert!(buf.is_empty());
self.state = Idle(Some(buf));
inner.state = Idle(Some(buf));
if let Ok(pos) = result {
self.pos = pos;
inner.pos = pos;
}
continue;
}
Expand All @@ -545,9 +550,12 @@ impl AsyncRead for File {
}

impl AsyncSeek for File {
fn start_seek(mut self: Pin<&mut Self>, mut pos: SeekFrom) -> io::Result<()> {
fn start_seek(self: Pin<&mut Self>, mut pos: SeekFrom) -> io::Result<()> {
let me = self.get_mut();
let inner = me.inner.get_mut();

loop {
match self.state {
match inner.state {
Busy(_) => panic!("must wait for poll_complete before calling start_seek"),
Idle(ref mut buf_cell) => {
let mut buf = buf_cell.take().unwrap();
Expand All @@ -561,9 +569,9 @@ impl AsyncSeek for File {
}
}

let std = self.std.clone();
let std = me.std.clone();

self.state = Busy(sys::run(move || {
inner.state = Busy(sys::run(move || {
let res = (&*std).seek(pos);
(Operation::Seek(res), buf)
}));
Expand All @@ -574,23 +582,25 @@ impl AsyncSeek for File {
}

fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
let inner = self.inner.get_mut();

loop {
match self.state {
Idle(_) => return Poll::Ready(Ok(self.pos)),
match inner.state {
Idle(_) => return Poll::Ready(Ok(inner.pos)),
Busy(ref mut rx) => {
let (op, buf) = ready!(Pin::new(rx).poll(cx))?;
self.state = Idle(Some(buf));
inner.state = Idle(Some(buf));

match op {
Operation::Read(_) => {}
Operation::Write(Err(e)) => {
assert!(self.last_write_err.is_none());
self.last_write_err = Some(e.kind());
assert!(inner.last_write_err.is_none());
inner.last_write_err = Some(e.kind());
}
Operation::Write(_) => {}
Operation::Seek(res) => {
if let Ok(pos) = res {
self.pos = pos;
inner.pos = pos;
}
return Ready(res);
}
Expand All @@ -603,16 +613,19 @@ impl AsyncSeek for File {

impl AsyncWrite for File {
fn poll_write(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut Context<'_>,
src: &[u8],
) -> Poll<io::Result<usize>> {
if let Some(e) = self.last_write_err.take() {
let me = self.get_mut();
let inner = me.inner.get_mut();

if let Some(e) = inner.last_write_err.take() {
return Ready(Err(e.into()));
}

loop {
match self.state {
match inner.state {
Idle(ref mut buf_cell) => {
let mut buf = buf_cell.take().unwrap();

Expand All @@ -623,9 +636,9 @@ impl AsyncWrite for File {
};

let n = buf.copy_from(src);
let std = self.std.clone();
let std = me.std.clone();

self.state = Busy(sys::run(move || {
inner.state = Busy(sys::run(move || {
let res = if let Some(seek) = seek {
(&*std).seek(seek).and_then(|_| buf.write_to(&mut &*std))
} else {
Expand All @@ -639,7 +652,7 @@ impl AsyncWrite for File {
}
Busy(ref mut rx) => {
let (op, buf) = ready!(Pin::new(rx).poll(cx))?;
self.state = Idle(Some(buf));
inner.state = Idle(Some(buf));

match op {
Operation::Read(_) => {
Expand All @@ -665,23 +678,8 @@ impl AsyncWrite for File {
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
if let Some(e) = self.last_write_err.take() {
return Ready(Err(e.into()));
}

let (op, buf) = match self.state {
Idle(_) => return Ready(Ok(())),
Busy(ref mut rx) => ready!(Pin::new(rx).poll(cx))?,
};

// The buffer is not used here
self.state = Idle(Some(buf));

match op {
Operation::Read(_) => Ready(Ok(())),
Operation::Write(res) => Ready(res),
Operation::Seek(_) => Ready(Ok(())),
}
let inner = self.inner.get_mut();
inner.poll_flush(cx)
}

fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Expand Down Expand Up @@ -731,3 +729,33 @@ impl std::os::windows::io::FromRawHandle for File {
sys::File::from_raw_handle(handle).into()
}
}

impl Inner {
async fn complete_inflight(&mut self) {
use crate::future::poll_fn;

if let Err(e) = poll_fn(|cx| Pin::new(&mut *self).poll_flush(cx)).await {
self.last_write_err = Some(e.kind());
}
}

fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
if let Some(e) = self.last_write_err.take() {
return Ready(Err(e.into()));
}

let (op, buf) = match self.state {
Idle(_) => return Ready(Ok(())),
Busy(ref mut rx) => ready!(Pin::new(rx).poll(cx))?,
};

// The buffer is not used here
self.state = Idle(Some(buf));

match op {
Operation::Read(_) => Ready(Ok(())),
Operation::Write(res) => Ready(res),
Operation::Seek(_) => Ready(Ok(())),
}
}
}
1 change: 1 addition & 0 deletions tokio/src/sync/batch_semaphore.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))]
//! # Implementation Details
//!
//! The semaphore is implemented using an intrusive linked list of waiters. An
Expand Down
9 changes: 8 additions & 1 deletion tokio/src/sync/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,14 @@ cfg_sync! {
}

cfg_not_sync! {
#[cfg(any(feature = "fs", feature = "signal", all(unix, feature = "process")))]
pub(crate) mod batch_semaphore;

cfg_fs! {
mod mutex;
pub(crate) use mutex::Mutex;
}

mod notify;
pub(crate) use notify::Notify;

Expand All @@ -472,7 +480,6 @@ cfg_not_sync! {

cfg_signal_internal! {
pub(crate) mod mpsc;
pub(crate) mod batch_semaphore;
}
}

Expand Down
2 changes: 2 additions & 0 deletions tokio/src/sync/mutex.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))]

use crate::sync::batch_semaphore as semaphore;

use std::cell::UnsafeCell;
Expand Down
Loading

0 comments on commit afe5352

Please sign in to comment.