diff --git a/monoio/src/driver/legacy/iocp/mod.rs b/monoio/src/driver/legacy/iocp/mod.rs index 3aad2057..afec11ca 100644 --- a/monoio/src/driver/legacy/iocp/mod.rs +++ b/monoio/src/driver/legacy/iocp/mod.rs @@ -128,7 +128,8 @@ impl Poller { token: mio::Token, interests: mio::Interest, ) -> std::io::Result<()> { - if state.inner.is_none() { + let mut state_inner = state.inner.lock().unwrap(); + if state_inner.inner.is_none() { let flags = interests_to_afd_flags(interests); let inner = { @@ -143,9 +144,9 @@ impl Poller { self.queue_state(inner.clone()); unsafe { self.update_sockets_events_if_polling()? }; - state.inner = Some(inner); - state.token = token; - state.interest = interests; + state_inner.inner = Some(inner); + state_inner.token = token; + state_inner.interest = interests; Ok(()) } else { @@ -155,37 +156,31 @@ impl Poller { pub fn reregister( &self, - state: &mut SocketState, + state: Pin>>, token: mio::Token, interests: mio::Interest, ) -> std::io::Result<()> { - if let Some(inner) = state.inner.as_mut() { - { - let event = Event { - flags: interests_to_afd_flags(interests), - data: token.0 as u64, - }; - - inner.lock().unwrap().set_event(event); - } - - state.token = token; - state.interest = interests; + { + let event = Event { + flags: interests_to_afd_flags(interests), + data: token.0 as u64, + }; - self.queue_state(inner.clone()); - unsafe { self.update_sockets_events_if_polling() } - } else { - Err(std::io::ErrorKind::NotFound.into()) + state.lock().unwrap().set_event(event); } + + self.queue_state(state.clone()); + unsafe { self.update_sockets_events_if_polling() } } pub fn deregister(&mut self, state: &mut SocketState) -> std::io::Result<()> { - if let Some(inner) = state.inner.as_mut() { + let mut state_inner = state.inner.lock().unwrap(); + if let Some(inner) = state_inner.inner.as_mut() { { let mut sock_state = inner.lock().unwrap(); sock_state.mark_delete(); } - state.inner = None; + state_inner.inner = None; Ok(()) } else { Err(std::io::ErrorKind::NotFound.into()) diff --git a/monoio/src/driver/legacy/iocp/state.rs b/monoio/src/driver/legacy/iocp/state.rs index a550eb6e..a1401bfb 100644 --- a/monoio/src/driver/legacy/iocp/state.rs +++ b/monoio/src/driver/legacy/iocp/state.rs @@ -25,20 +25,27 @@ pub enum SockPollStatus { } #[derive(Debug)] -pub struct SocketState { - pub socket: RawSocket, +pub struct SocketStateInner { pub inner: Option>>>, pub token: mio::Token, pub interest: mio::Interest, } +#[derive(Debug)] +pub struct SocketState { + pub socket: RawSocket, + pub inner: Arc>, +} + impl SocketState { pub fn new(socket: RawSocket) -> Self { Self { socket, - inner: None, - token: mio::Token(0), - interest: mio::Interest::READABLE, + inner: Arc::new(Mutex::new(SocketStateInner { + inner: None, + token: mio::Token(0), + interest: mio::Interest::READABLE, + })), } } } diff --git a/monoio/src/driver/legacy/mod.rs b/monoio/src/driver/legacy/mod.rs index 9d0f796e..532c2f5f 100644 --- a/monoio/src/driver/legacy/mod.rs +++ b/monoio/src/driver/legacy/mod.rs @@ -182,7 +182,7 @@ impl LegacyDriver { interest: mio::Interest, ) -> io::Result { let inner = unsafe { &mut *this.get() }; - let io = ScheduledIo::default(); + let io = ScheduledIo::new(state.inner.clone()); let token = inner.io_dispatch.insert(io); match inner.poll.register(state, mio::Token(token), interest) { @@ -303,6 +303,23 @@ impl LegacyInner { flags: 0, }), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + #[cfg(windows)] + { + if let Some((sock_state, token, interest)) = { + let socket_state_lock = ref_mut.state.lock().unwrap(); + socket_state_lock.inner.clone().map(|inner| { + (inner, socket_state_lock.token, socket_state_lock.interest) + }) + } { + if let Err(e) = inner.poll.reregister(sock_state, token, interest) { + return Poll::Ready(CompletionMeta { + result: Err(e), + flags: 0, + }); + } + } + } + ref_mut.clear_readiness(direction.mask()); ref_mut.set_waker(cx, direction); Poll::Pending diff --git a/monoio/src/driver/scheduled_io.rs b/monoio/src/driver/scheduled_io.rs index d164a1d3..f966909a 100644 --- a/monoio/src/driver/scheduled_io.rs +++ b/monoio/src/driver/scheduled_io.rs @@ -9,8 +9,12 @@ pub(crate) struct ScheduledIo { reader: Option, /// Waker used for AsyncWrite. writer: Option, + + #[cfg(windows)] + pub state: std::sync::Arc>, } +#[cfg(not(windows))] impl Default for ScheduledIo { #[inline] fn default() -> Self { @@ -19,11 +23,17 @@ impl Default for ScheduledIo { } impl ScheduledIo { - pub(crate) const fn new() -> Self { + pub(crate) const fn new( + #[cfg(windows)] state: std::sync::Arc< + std::sync::Mutex, + >, + ) -> Self { Self { readiness: Ready::EMPTY, reader: None, writer: None, + #[cfg(windows)] + state, } }