From 5e9aad76a370328e3de46c180d64512621ce3890 Mon Sep 17 00:00:00 2001 From: Lu Kai Date: Sun, 17 Aug 2025 16:24:14 +0800 Subject: [PATCH 1/2] Add file handle support for overlapped I/O --- Cargo.toml | 11 +- src/iocp/afd.rs | 168 +-------- src/iocp/mod.rs | 319 ++++++++++++++++- src/iocp/ntdll.rs | 160 +++++++++ src/iocp/port.rs | 220 +++++++++++- src/os/iocp.rs | 526 ++++++++++++++++++++++++++++ tests/windows_overlapped.rs | 678 ++++++++++++++++++++++++++++++++++++ 7 files changed, 1912 insertions(+), 170 deletions(-) create mode 100644 src/iocp/ntdll.rs create mode 100644 tests/windows_overlapped.rs diff --git a/Cargo.toml b/Cargo.toml index e99607c..c2611ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ name = "polling" version = "3.11.0" authors = ["Stjepan Glavina ", "John Nunley "] edition = "2021" -rust-version = "1.71" +rust-version = "1.77" description = "Portable interface to epoll, kqueue, event ports, and IOCP" license = "Apache-2.0 OR MIT" repository = "https://github.com/smol-rs/polling" @@ -18,7 +18,10 @@ exclude = ["/.*"] rustdoc-args = ["--cfg", "docsrs"] [lints.rust] -unexpected_cfgs = { level = "warn", check-cfg = ['cfg(polling_test_poll_backend)', 'cfg(polling_test_epoll_pipe)'] } +unexpected_cfgs = { level = "warn", check-cfg = [ + 'cfg(polling_test_poll_backend)', + 'cfg(polling_test_epoll_pipe)', +] } [dependencies] cfg-if = "1" @@ -65,3 +68,7 @@ libc = "0.2" [target.'cfg(all(unix, not(target_os="vita")))'.dev-dependencies] signal-hook = "0.3.17" + +[target.'cfg(windows)'.dev-dependencies] +tempfile = "3.7" +windows-sys = { version = "0.60", features = ["Win32_System_Pipes"] } diff --git a/src/iocp/afd.rs b/src/iocp/afd.rs index f7c7c3f..9096e3b 100644 --- a/src/iocp/afd.rs +++ b/src/iocp/afd.rs @@ -1,30 +1,30 @@ //! Safe wrapper around \Device\Afd +use crate::iocp::ntdll::NtdllImports; +use crate::iocp::port::FileOverlapped; + use super::port::{Completion, CompletionHandle}; use std::cell::UnsafeCell; use std::fmt; use std::io; use std::marker::{PhantomData, PhantomPinned}; -use std::mem::{self, size_of, transmute, MaybeUninit}; +use std::mem::{self, size_of, MaybeUninit}; use std::ops; use std::os::windows::prelude::{AsRawHandle, RawHandle, RawSocket}; use std::pin::Pin; use std::ptr; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::OnceLock; use windows_sys::Wdk::Foundation::OBJECT_ATTRIBUTES; use windows_sys::Wdk::Storage::FileSystem::FILE_OPEN; use windows_sys::Win32::Foundation::{ - CloseHandle, HANDLE, HMODULE, NTSTATUS, STATUS_NOT_FOUND, STATUS_PENDING, STATUS_SUCCESS, - UNICODE_STRING, + CloseHandle, HANDLE, NTSTATUS, STATUS_NOT_FOUND, STATUS_PENDING, STATUS_SUCCESS, UNICODE_STRING, }; use windows_sys::Win32::Networking::WinSock::{ WSAIoctl, SIO_BASE_HANDLE, SIO_BSP_HANDLE_POLL, SOCKET_ERROR, }; use windows_sys::Win32::Storage::FileSystem::{FILE_SHARE_READ, FILE_SHARE_WRITE, SYNCHRONIZE}; -use windows_sys::Win32::System::LibraryLoader::{GetModuleHandleW, GetProcAddress}; use windows_sys::Win32::System::IO::IO_STATUS_BLOCK; #[derive(Default)] @@ -170,153 +170,6 @@ pub(super) trait HasAfdInfo { fn afd_info(self: Pin<&Self>) -> Pin<&UnsafeCell>; } -macro_rules! define_ntdll_import { - ( - $( - $(#[$attr:meta])* - fn $name:ident($($arg:ident: $arg_ty:ty),*) -> $ret:ty; - )* - ) => { - /// Imported functions from ntdll.dll. - #[allow(non_snake_case)] - pub(super) struct NtdllImports { - $( - $(#[$attr])* - $name: unsafe extern "system" fn($($arg_ty),*) -> $ret, - )* - } - - #[allow(non_snake_case)] - impl NtdllImports { - unsafe fn load(ntdll: HMODULE) -> io::Result { - $( - #[allow(clippy::missing_transmute_annotations)] - let $name = { - const NAME: &str = concat!(stringify!($name), "\0"); - let addr = GetProcAddress(ntdll, NAME.as_ptr() as *const _); - - let addr = match addr { - Some(addr) => addr, - None => { - #[cfg(feature = "tracing")] - tracing::error!("Failed to load ntdll function {}", NAME); - return Err(io::Error::last_os_error()); - }, - }; - - transmute::<_, unsafe extern "system" fn($($arg_ty),*) -> $ret>(addr) - }; - )* - - Ok(Self { - $( - $name, - )* - }) - } - - $( - $(#[$attr])* - unsafe fn $name(&self, $($arg: $arg_ty),*) -> $ret { - (self.$name)($($arg),*) - } - )* - } - }; -} - -define_ntdll_import! { - /// Cancels an ongoing I/O operation. - fn NtCancelIoFileEx( - FileHandle: HANDLE, - IoRequestToCancel: *mut IO_STATUS_BLOCK, - IoStatusBlock: *mut IO_STATUS_BLOCK - ) -> NTSTATUS; - - /// Opens or creates a file handle. - #[allow(clippy::too_many_arguments)] - fn NtCreateFile( - FileHandle: *mut HANDLE, - DesiredAccess: u32, - ObjectAttributes: *mut OBJECT_ATTRIBUTES, - IoStatusBlock: *mut IO_STATUS_BLOCK, - AllocationSize: *mut i64, - FileAttributes: u32, - ShareAccess: u32, - CreateDisposition: u32, - CreateOptions: u32, - EaBuffer: *mut (), - EaLength: u32 - ) -> NTSTATUS; - - /// Runs an I/O control on a file handle. - /// - /// Practically equivalent to `ioctl`. - #[allow(clippy::too_many_arguments)] - fn NtDeviceIoControlFile( - FileHandle: HANDLE, - Event: HANDLE, - ApcRoutine: *mut (), - ApcContext: *mut (), - IoStatusBlock: *mut IO_STATUS_BLOCK, - IoControlCode: u32, - InputBuffer: *mut (), - InputBufferLength: u32, - OutputBuffer: *mut (), - OutputBufferLength: u32 - ) -> NTSTATUS; - - /// Converts `NTSTATUS` to a DOS error code. - fn RtlNtStatusToDosError( - Status: NTSTATUS - ) -> u32; -} - -impl NtdllImports { - fn get() -> io::Result<&'static Self> { - macro_rules! s { - ($e:expr) => {{ - $e as u16 - }}; - } - - // ntdll.dll - static NTDLL_NAME: &[u16] = &[ - s!('n'), - s!('t'), - s!('d'), - s!('l'), - s!('l'), - s!('.'), - s!('d'), - s!('l'), - s!('l'), - s!('\0'), - ]; - static NTDLL_IMPORTS: OnceLock> = OnceLock::new(); - - NTDLL_IMPORTS - .get_or_init(|| unsafe { - let ntdll = GetModuleHandleW(NTDLL_NAME.as_ptr() as *const _); - - if ntdll.is_null() { - #[cfg(feature = "tracing")] - tracing::error!("Failed to load ntdll.dll"); - return Err(io::Error::last_os_error()); - } - - NtdllImports::load(ntdll) - }) - .as_ref() - .map_err(|e| io::Error::from(e.kind())) - } - - pub(super) fn force_load() -> io::Result<()> { - Self::get()?; - Ok(()) - } -} - /// The handle to the AFD device. pub(super) struct Afd { /// The handle to the AFD device. @@ -614,6 +467,17 @@ unsafe impl Completion for IoStatusBlock { } } +impl FileOverlapped for IoStatusBlock { + fn file_read_offset() -> usize { + T::file_read_offset() + std::mem::offset_of!(IoStatusBlock, data) + } + + fn file_write_offset() -> usize { + let data_offset = std::mem::offset_of!(IoStatusBlock, data); + T::file_write_offset() + data_offset + } +} + /// Get the base socket associated with a socket. pub(super) fn base_socket(sock: RawSocket) -> io::Result { // First, try the SIO_BASE_HANDLE ioctl. diff --git a/src/iocp/mod.rs b/src/iocp/mod.rs index facbe06..a579ac3 100644 --- a/src/iocp/mod.rs +++ b/src/iocp/mod.rs @@ -26,6 +26,7 @@ //! AFD-based strategy for polling. mod afd; +pub(crate) mod ntdll; mod port; use afd::{base_socket, Afd, AfdPollInfo, AfdPollMask, HasAfdInfo, IoStatusBlock}; @@ -36,7 +37,10 @@ use windows_sys::Win32::System::Threading::{ RegisterWaitForSingleObject, UnregisterWait, INFINITE, WT_EXECUTELONGFUNCTION, WT_EXECUTEONLYONCE, }; +use windows_sys::Win32::System::IO::{OVERLAPPED, OVERLAPPED_ENTRY}; +use crate::iocp::port::{FileCompletionHandle, FileOverlapped}; +use crate::os::iocp::{IocpFilePacket, Overlapped, OverlappedInner}; use crate::{Event, PollMode}; use concurrent_queue::ConcurrentQueue; @@ -45,8 +49,6 @@ use pin_project_lite::pin_project; use std::cell::UnsafeCell; use std::collections::hash_map::{Entry, HashMap}; use std::ffi::c_void; -use std::fmt; -use std::io; use std::marker::PhantomPinned; use std::mem::{forget, MaybeUninit}; use std::os::windows::io::{ @@ -54,8 +56,9 @@ use std::os::windows::io::{ }; use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::{Arc, Mutex, MutexGuard, RwLock, Weak}; +use std::sync::{Arc, Mutex, MutexGuard, OnceLock, RwLock, Weak}; use std::time::{Duration, Instant}; +use std::{fmt, io}; /// Macro to lock and ignore lock poisoning. macro_rules! lock { @@ -91,6 +94,9 @@ pub(super) struct Poller { /// The state of the waitable handles registered with this poller. waitables: RwLock>, + /// The state of the waitable handles registered with this poller. + files: RwLock>, + /// Sockets with pending updates. /// /// This list contains packets with sockets that need to have their AFD state adjusted by @@ -118,7 +124,7 @@ impl Poller { /// Creates a new poller. pub(super) fn new() -> io::Result { // Make sure AFD is able to be used. - if let Err(e) = afd::NtdllImports::force_load() { + if let Err(e) = ntdll::NtdllImports::force_load() { return Err(io::Error::new( io::ErrorKind::Unsupported, AfdError::new("failed to initialize unstable Windows functions", e), @@ -143,6 +149,7 @@ impl Poller { afd: Mutex::new(vec![]), sources: RwLock::new(HashMap::new()), waitables: RwLock::new(HashMap::new()), + files: RwLock::new(HashMap::new()), pending_updates: ConcurrentQueue::bounded(1024), polling: AtomicBool::new(false), notifier: Arc::pin( @@ -427,6 +434,91 @@ impl Poller { source.begin_delete() } + /// Add a file to the poller. + /// File handle work on PollMode::Edge mode. The IOCP continue to poll the events unitl + /// the file is closed. The caller must use the overlapped pointer return in IocpFilePacket + /// as overlapped paramter for I/O operation. The Packet do not need to increase Arc count because + /// the call can trigger events through I/O operation without update intrest events as long as the + /// file handle has been registered with the IOCP. So the Packet lifetime is ended with calling [`remove_file`]. + /// Any I/O operation using return overlapped pointer return in IocpFilePacket is undefined behavior. + /// + /// IocpFilePacket will return both read and write overlapped pointer no matter what intrest events are. + /// The caller need to use the correct overlapped pointer for I/O operation. Such as: the read overlapped + /// pointer can be used for read operations, and the write overlapped pointer can be used for write operations. + pub(super) fn add_file( + &self, + handle: RawHandle, + interest: Event, + ) -> io::Result { + #[cfg(feature = "tracing")] + tracing::trace!( + "add_file: handle={:?}, file={:p}, ev={:?}", + self.port, + handle, + interest + ); + + // We only support edge events. + // Create a new packet. + let handle_state = { + let state = FileState { handle, interest }; + + Arc::pin(IoStatusBlock::from(PacketInner::File { + read: UnsafeCell::new(OverlappedInner::::new(file_read_overlapped_done)), + write: UnsafeCell::new(OverlappedInner::::new(file_write_overlapped_done)), + handle: Mutex::new(state), + })) + }; + + // Keep track of the source in the poller. + { + let mut sources = lock!(self.files.write()); + + match sources.entry(handle) { + Entry::Vacant(v) => { + v.insert(Pin::>::clone(&handle_state)); + } + + Entry::Occupied(_) => { + return Err(io::Error::from(io::ErrorKind::AlreadyExists)); + } + } + } + + let (read, write, file_handle) = match handle_state.as_ref().data().project_ref() { + PacketInnerProj::File { + read, + write, + handle, + } => (read.get(), write.get(), handle), + _ => unreachable!("PacketInner should always be File here"), + }; + + let file_state = lock!(file_handle.lock()); + // Register the file handle with the I/O completion port. + self.port + .register(&*file_state, true, port::CompletionKeyType::File)?; + + let iocp_packet = unsafe { IocpFilePacket::new((*read).as_ptr(), (*write).as_ptr()) }; + Ok(iocp_packet) + } + + /// Remove a file from the poller. + pub(super) fn remove_file(&self, handle: RawHandle) -> io::Result<()> { + #[cfg(feature = "tracing")] + tracing::trace!("remove: handle={:?}, file={:p}", self.port, handle); + + // Get a reference to the source. + let mut sources = lock!(self.files.write()); + match sources.remove(&handle) { + Some(_) => Ok(()), + None => { + // If the source has already been removed, then we can just return. + Err(io::Error::from(io::ErrorKind::NotFound)) + } + } + } + /// Wait for events. pub(super) fn wait_deadline( &self, @@ -476,10 +568,17 @@ impl Poller { // Process all of the events. for entry in events.completions.drain(..) { - let packet = entry.into_packet(); + let result = if entry.is_file_completion() { + let bytes_transferred = entry.bytes_transferred(); + let (packet, polling_status) = entry.into_file_packet(); + packet.feed_file_event(polling_status, bytes_transferred) + } else { + let packet = entry.into_packet(); + packet.feed_event(self) + }; // Feed the event into the packet. - match packet.feed_event(self)? { + match result? { FeedEventResult::NoEvent => {} FeedEventResult::Event(event) => { events.packets.push(event); @@ -599,7 +698,8 @@ impl Poller { let afd = Arc::new(Afd::new()?); // Register the AFD instance with the I/O completion port. - self.port.register(&*afd, true)?; + self.port + .register(&*afd, true, port::CompletionKeyType::Socket)?; // Insert a weak pointer to the AFD instance into the list for other sockets. afd_handles.push(Arc::downgrade(&afd)); @@ -735,6 +835,37 @@ impl CompletionPacket { /// It needs to be pinned, since it contains data that is expected by IOCP not to be moved. type Packet = Pin>; type PacketUnwrapped = IoStatusBlock; +/// A wrapper around the Overlapped structure for file I/O operation result +#[derive(Debug)] +#[repr(transparent)] +pub struct FileOverlappedWrapper(Overlapped); + +impl FileOverlappedWrapper { + /// Wrapping to [`Overlapped::from_overlapped_ptr`] + /// + /// # Safety + /// + /// The caller must ensure that the pointer is valid and points to an + /// `Overlapped` structure. + pub unsafe fn from_overlapped_ptr(overlapped_ptr: *mut OVERLAPPED) -> *mut Self { + Overlapped::::from_overlapped_ptr(overlapped_ptr) as *mut _ + } + + /// Wrapping to [`Overlapped::get_bytes_transferred`] + pub fn get_bytes_transferred(&self) -> u32 { + self.0.get_bytes_transferred() + } + + /// Wrapping to [`Overlapped::get_result`] + pub fn get_result(&self) -> io::Result { + self.0.get_result() + } + + /// Wrapping to [`Overlapped::zeroed`] + pub fn zeroed(&mut self) { + self.0.zeroed(); + } +} pin_project! { /// The inner type of the packet. @@ -756,6 +887,18 @@ pin_project! { handle: Mutex }, + /// A packet for a File handle. + File { + // read update this overlapped structure. + #[pin] + read: UnsafeCell>, + + // write update this overlapped structure. + #[pin] + write: UnsafeCell>, + handle: Mutex + }, + /// A custom event sent by the user. Custom { event: Event, @@ -782,6 +925,16 @@ impl fmt::Debug for PacketInner { Self::Waitable { handle } => { f.debug_struct("Waitable").field("handle", handle).finish() } + Self::File { + handle, + read, + write, + } => f + .debug_struct("File") + .field("file", handle) + .field("read", &format_args!("{:p}", read as *const _)) + .field("write", &format_args!("{:p}", write as *const _)) + .finish(), } } } @@ -795,6 +948,49 @@ impl HasAfdInfo for PacketInner { } } +/// Only caculate offset once +static FILE_OVERLAPPED_OFFSET: OnceLock<(usize, usize)> = OnceLock::new(); + +impl FileOverlapped for PacketInner { + fn file_read_offset() -> usize { + PacketInner::file_overlapped_offset().0 + } + + fn file_write_offset() -> usize { + PacketInner::file_overlapped_offset().1 + } +} + +impl PacketInner { + /// Calculate the offset of read and write overlapped in PacketInner::File + fn file_overlapped_offset() -> &'static (usize, usize) { + FILE_OVERLAPPED_OFFSET.get_or_init(|| { + let state = FileState { + handle: std::ptr::null_mut(), + interest: Event::none(0), + }; + + let packet = &PacketInner::File { + read: UnsafeCell::new(OverlappedInner::::new(file_read_overlapped_done)), + write: UnsafeCell::new(OverlappedInner::::new(file_write_overlapped_done)), + handle: Mutex::new(state), + }; + + let base = packet as *const _; + let (read, write) = match packet { + PacketInner::File { read, write, .. } => (read, write), + _ => unreachable!(), + }; + let read_ptr = read as *const _; + let write_ptr = write as *const _; + ( + unsafe { (read_ptr as *const u8).offset_from(base as *const _) as usize }, + unsafe { (write_ptr as *const u8).offset_from(base as *const _) as usize }, + ) + }) + } +} + impl PacketUnwrapped { /// Set the new events that this socket is waiting on. /// @@ -889,7 +1085,7 @@ impl PacketUnwrapped { return Ok(()); } - _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid socket state")), + _ => return Err(io::Error::other("invalid socket state")), }; // If we are waiting on a delete, just return, dropping the packet. @@ -980,6 +1176,7 @@ impl PacketUnwrapped { return Ok(FeedEventResult::Event(event)); } + _ => unreachable!("Should not be called on a file packet"), }; let mut socket_state = lock!(socket.lock()); @@ -1067,6 +1264,52 @@ impl PacketUnwrapped { Ok(return_value) } + fn feed_file_event( + self: Pin>, + status: FileCompletionStatus, + bytes_transferred: u32, + ) -> io::Result { + let inner = self.as_ref().data().project_ref(); + + let (handle, read, write) = match inner { + PacketInnerProj::File { + handle, + read, + write, + } => (handle, read, write), + _ => unreachable!("Should not be called on a non-file packet"), + }; + + let file_state = lock!(handle.lock()); + let mut event = Event::none(file_state.interest.key); + if status.is_read() { + unsafe { + (*read.get()).set_bytes_transferred(bytes_transferred); + } + event.readable = true; + } + + if status.is_write() { + unsafe { + (*write.get()).set_bytes_transferred(bytes_transferred); + } + event.writable = true; + } + + event.readable &= file_state.interest.readable; + event.writable &= file_state.interest.writable; + + // If this event doesn't have anything that interests us, don't return or + // update the oneshot state. + let return_value = if event.readable || event.writable { + FeedEventResult::Event(event) + } else { + FeedEventResult::NoEvent + }; + + Ok(return_value) + } + /// Begin deleting this socket. fn begin_delete(self: Pin>) -> io::Result<()> { // If we aren't already being deleted, start deleting. @@ -1111,6 +1354,22 @@ impl PacketUnwrapped { } } +/// Callback convert read overlapped pointer to Packet +unsafe fn file_read_overlapped_done(entry: &OVERLAPPED_ENTRY) -> (Packet, FileCompletionStatus) { + ( + Packet::file_read_done(entry), + FILE_STATUS_POLLING_FLAG_READ.into(), + ) +} + +/// Callback convert write overlapped pointer Packet +unsafe fn file_write_overlapped_done(entry: &OVERLAPPED_ENTRY) -> (Packet, FileCompletionStatus) { + ( + Packet::file_write_done(entry), + FILE_STATUS_POLLING_FLAG_WRITE.into(), + ) +} + /// Per-socket state. #[derive(Debug)] struct SocketState { @@ -1193,6 +1452,50 @@ impl WaitableStatus { } } +#[derive(Debug)] +struct FileState { + /// The handle that this state is for. + handle: RawHandle, + + /// The event that this handle will report. + interest: Event, +} + +impl AsRawHandle for FileState { + fn as_raw_handle(&self) -> RawHandle { + self.handle as _ + } +} + +const FILE_STATUS_POLLING_FLAG_READ: u32 = 1 << 0; // 0001 +const FILE_STATUS_POLLING_FLAG_WRITE: u32 = 1 << 1; // 0010 + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(transparent)] +pub(crate) struct FileCompletionStatus(u32); + +impl FileCompletionStatus { + pub(crate) const fn is_read(&self) -> bool { + self.0 & FILE_STATUS_POLLING_FLAG_READ != 0 + } + + pub(crate) const fn is_write(&self) -> bool { + self.0 & FILE_STATUS_POLLING_FLAG_WRITE != 0 + } +} + +impl From for u32 { + fn from(value: FileCompletionStatus) -> Self { + value.0 + } +} + +impl From for FileCompletionStatus { + fn from(value: u32) -> Self { + FileCompletionStatus(value) + } +} + /// The result of calling `feed_event`. #[derive(Debug)] enum FeedEventResult { diff --git a/src/iocp/ntdll.rs b/src/iocp/ntdll.rs new file mode 100644 index 0000000..645e81f --- /dev/null +++ b/src/iocp/ntdll.rs @@ -0,0 +1,160 @@ +//! ntdll library bindings +use std::{io, mem::transmute, sync::OnceLock}; + +use windows_sys::{ + Wdk::Foundation::OBJECT_ATTRIBUTES, + Win32::{ + Foundation::{HANDLE, HMODULE, NTSTATUS}, + System::{ + LibraryLoader::{GetModuleHandleW, GetProcAddress}, + IO::IO_STATUS_BLOCK, + }, + }, +}; + +macro_rules! define_ntdll_import { + ( + $( + $(#[$attr:meta])* + fn $name:ident($($arg:ident: $arg_ty:ty),*) -> $ret:ty; + )* + ) => { + /// Imported functions from ntdll.dll. + #[allow(non_snake_case)] + pub(crate) struct NtdllImports { + $( + $(#[$attr])* + pub(super) $name: unsafe extern "system" fn($($arg_ty),*) -> $ret, + )* + } + + #[allow(non_snake_case)] + impl NtdllImports { + unsafe fn load(ntdll: HMODULE) -> io::Result { + $( + #[allow(clippy::missing_transmute_annotations)] + let $name = { + const NAME: &str = concat!(stringify!($name), "\0"); + let addr = GetProcAddress(ntdll, NAME.as_ptr() as *const _); + + let addr = match addr { + Some(addr) => addr, + None => { + #[cfg(feature = "tracing")] + tracing::error!("Failed to load ntdll function {}", NAME); + return Err(io::Error::last_os_error()); + }, + }; + + transmute::<_, unsafe extern "system" fn($($arg_ty),*) -> $ret>(addr) + }; + )* + + Ok(Self { + $( + $name, + )* + }) + } + + $( + $(#[$attr])* + pub(crate) unsafe fn $name(&self, $($arg: $arg_ty),*) -> $ret { + (self.$name)($($arg),*) + } + )* + } + }; +} + +define_ntdll_import! { + /// Cancels an ongoing I/O operation. + fn NtCancelIoFileEx( + FileHandle: HANDLE, + IoRequestToCancel: *mut IO_STATUS_BLOCK, + IoStatusBlock: *mut IO_STATUS_BLOCK + ) -> NTSTATUS; + + /// Opens or creates a file handle. + #[allow(clippy::too_many_arguments)] + fn NtCreateFile( + FileHandle: *mut HANDLE, + DesiredAccess: u32, + ObjectAttributes: *mut OBJECT_ATTRIBUTES, + IoStatusBlock: *mut IO_STATUS_BLOCK, + AllocationSize: *mut i64, + FileAttributes: u32, + ShareAccess: u32, + CreateDisposition: u32, + CreateOptions: u32, + EaBuffer: *mut (), + EaLength: u32 + ) -> NTSTATUS; + + /// Runs an I/O control on a file handle. + /// + /// Practically equivalent to `ioctl`. + #[allow(clippy::too_many_arguments)] + fn NtDeviceIoControlFile( + FileHandle: HANDLE, + Event: HANDLE, + ApcRoutine: *mut (), + ApcContext: *mut (), + IoStatusBlock: *mut IO_STATUS_BLOCK, + IoControlCode: u32, + InputBuffer: *mut (), + InputBufferLength: u32, + OutputBuffer: *mut (), + OutputBufferLength: u32 + ) -> NTSTATUS; + + /// Converts `NTSTATUS` to a DOS error code. + fn RtlNtStatusToDosError( + Status: NTSTATUS + ) -> u32; +} + +impl NtdllImports { + pub(crate) fn get() -> io::Result<&'static Self> { + macro_rules! s { + ($e:expr) => {{ + $e as u16 + }}; + } + + // ntdll.dll + static NTDLL_NAME: &[u16] = &[ + s!('n'), + s!('t'), + s!('d'), + s!('l'), + s!('l'), + s!('.'), + s!('d'), + s!('l'), + s!('l'), + s!('\0'), + ]; + static NTDLL_IMPORTS: OnceLock> = OnceLock::new(); + + NTDLL_IMPORTS + .get_or_init(|| unsafe { + let ntdll = GetModuleHandleW(NTDLL_NAME.as_ptr() as *const _); + + if ntdll.is_null() { + #[cfg(feature = "tracing")] + tracing::error!("Failed to load ntdll.dll"); + return Err(io::Error::last_os_error()); + } + + NtdllImports::load(ntdll) + }) + .as_ref() + .map_err(|e| io::Error::from(e.kind())) + } + + pub(super) fn force_load() -> io::Result<()> { + Self::get()?; + Ok(()) + } +} diff --git a/src/iocp/port.rs b/src/iocp/port.rs index 6d9b8be..2b489cb 100644 --- a/src/iocp/port.rs +++ b/src/iocp/port.rs @@ -1,5 +1,8 @@ //! A safe wrapper around the Windows I/O API. +use crate::iocp::FileCompletionStatus; +use crate::os::iocp::OverlappedInner; + use super::dur2timeout; use std::fmt; @@ -10,6 +13,7 @@ use std::ops::Deref; use std::os::windows::io::{AsRawHandle, RawHandle}; use std::pin::Pin; use std::ptr; +use std::sync::atomic::AtomicUsize; use std::sync::Arc; use std::time::Duration; @@ -65,6 +69,32 @@ pub(super) unsafe trait CompletionHandle: Deref + Sized { fn as_ptr(&self) -> *mut OVERLAPPED; } +/// Offset that a file read/write overlapped position to the begining of the whole 'IoStatusBlock' block. +/// +/// # Safety +/// +/// The whole 'IoStatusBlock' block must include file read/write `Overlapped` struct +pub(super) trait FileOverlapped { + /// Get the offset of the file read overlapped structure to the whole 'IoStatusBlock' block + fn file_read_offset() -> usize; + + /// Get the offset of the file write overlapped structure to the whole 'IoStatusBlock' block + fn file_write_offset() -> usize; +} + +/// File completion overlapped pointer convert to 'IoStatusBlock' +/// +/// # Safety +/// +/// The completion overlaped pointer must be valid as part of 'IoStatusBlock' +pub(super) unsafe trait FileCompletionHandle { + /// file read overlapped pointer convert to 'IoStatusBlock' + fn file_read_done(entry: &OVERLAPPED_ENTRY) -> Self; + + /// file write overlapped pointer convert to 'IoStatusBlock' + fn file_write_done(entry: &OVERLAPPED_ENTRY) -> Self; +} + unsafe impl CompletionHandle for Pin<&T> { type Completion = T; @@ -105,6 +135,49 @@ unsafe impl CompletionHandle for Pin> { } } +unsafe impl FileCompletionHandle for Pin<&T> { + fn file_read_done(entry: &OVERLAPPED_ENTRY) -> Self { + let overlapped_ptr = entry.lpOverlapped; + let offset = T::file_read_offset(); + unsafe { Pin::new_unchecked(&*((overlapped_ptr as *mut u8).sub(offset) as *const T)) } + } + + fn file_write_done(entry: &OVERLAPPED_ENTRY) -> Self { + let overlapped_ptr = entry.lpOverlapped; + let offset = T::file_write_offset(); + unsafe { Pin::new_unchecked(&*((overlapped_ptr as *mut u8).sub(offset) as *const T)) } + } +} + +unsafe impl FileCompletionHandle for Pin> { + fn file_read_done(entry: &OVERLAPPED_ENTRY) -> Self { + let overlapped_ptr = entry.lpOverlapped; + let offset = T::file_read_offset(); + // File completion does not clone the Packet when add the file handle to IOCP + // So need to clone the Packet to avoid the owner ship lost + unsafe { + let inner = Arc::from_raw((overlapped_ptr as *const u8).sub(offset) as *const T); + assert!(Arc::strong_count(&inner) >= 1, "File has been removed, but still use FileOverlappedWrapper return from add_file function"); + + let new_one = Pin::new_unchecked(Arc::clone(&inner)); + let _ = Arc::into_raw(inner); // Prevent Arc from being dropped + new_one + } + } + + fn file_write_done(entry: &OVERLAPPED_ENTRY) -> Self { + let overlapped_ptr = entry.lpOverlapped; + let offset = T::file_write_offset(); + unsafe { + let inner = Arc::from_raw((overlapped_ptr as *const u8).sub(offset) as *const T); + assert!(Arc::strong_count(&inner) >= 1, "File has been removed, but still use FileOverlappedWrapper return from add_file function"); + + let new_one = Pin::new_unchecked(Arc::clone(&inner)); + let _ = Arc::into_raw(inner); // Prevent Arc from being dropped + new_one + } + } +} /// A handle to the I/O completion port. pub(super) struct IoCompletionPort { /// The underlying handle. @@ -144,7 +217,7 @@ impl fmt::Debug for IoCompletionPort { } } -impl IoCompletionPort { +impl IoCompletionPort { /// Create a new I/O completion port. pub(super) fn new(threads: usize) -> io::Result { let handle = unsafe { @@ -171,11 +244,13 @@ impl IoCompletionPort { &self, handle: &impl AsRawHandle, // TODO change to AsHandle skip_set_event_on_handle: bool, + kind: CompletionKeyType, ) -> io::Result<()> { let handle = handle.as_raw_handle(); - let result = - unsafe { CreateIoCompletionPort(handle as _, self.handle, handle as usize, 0) }; + let result = unsafe { + CreateIoCompletionPort(handle as _, self.handle, CompletionKey::new(kind).into(), 0) + }; if result.is_null() { return Err(io::Error::last_os_error()); @@ -257,7 +332,7 @@ impl IoCompletionPort { /// An `OVERLAPPED_ENTRY` resulting from an I/O completion port. #[repr(transparent)] -pub(super) struct OverlappedEntry { +pub(super) struct OverlappedEntry { /// The underlying entry. entry: OVERLAPPED_ENTRY, @@ -265,13 +340,13 @@ pub(super) struct OverlappedEntry { _marker: PhantomData, } -impl fmt::Debug for OverlappedEntry { +impl fmt::Debug for OverlappedEntry { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str("OverlappedEntry { .. }") } } -impl OverlappedEntry { +impl OverlappedEntry { /// Convert into the completion packet. pub(super) fn into_packet(self) -> T { let packet = unsafe { self.packet() }; @@ -279,6 +354,33 @@ impl OverlappedEntry { packet } + /// Get the number of bytes transferred by the I/O operation. + pub(super) fn bytes_transferred(&self) -> u32 { + self.entry.dwNumberOfBytesTransferred + } + + /// Check if this entry is a file completion packet. + pub(super) fn is_file_completion(&self) -> bool { + CompletionKey::from(self.entry.lpCompletionKey).is_file() + } + + /// Convert into the completion packet through file overlapped pointer which is not the beginning address + /// of the packet. + /// + /// # Safety + /// + /// This function should only be called once, since it moves + /// out the `T` from the `OVERLAPPED_ENTRY`. + pub(super) fn into_file_packet(self) -> (T, FileCompletionStatus) { + assert!( + self.is_file_completion(), + "This is not a file completion packet" + ); + let (packet, status) = unsafe { OverlappedInner::::from_entry(&self.entry) }; + std::mem::forget(self); + (packet, status) + } + /// Get the packet reference that this entry refers to. /// /// # Safety @@ -292,8 +394,110 @@ impl OverlappedEntry { } } -impl Drop for OverlappedEntry { +impl Drop for OverlappedEntry { fn drop(&mut self) { - drop(unsafe { self.packet() }); + // File packet do not need to Arc::Clone to add or remove from the poller + // So we can safely drop it without decrease the reference count. + if !self.is_file_completion() { + drop(unsafe { self.packet() }); + } + } +} + +/// The type of completion key used to differentiate between different types of completion keys. +/// The completion key type determines how to convert raw address to packet block. +/// [`OverlappedEntry::into_packet`]: create::iocp::OverlappedEntry::into_packet +/// [`OverlappedEntry::into_file_packet`]: create::iocp::OverlappedEntry::into_file_packet +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) enum CompletionKeyType { + Socket, + File, +} + +/// This is used to differentiate between different types of completion keys. +/// Completion key has not to be unique per handle for IOCP. The `CompletionKey` is used to +/// identify the type of completion key, it assign one key for per handle. But it does not +/// guarantee uniqueness across different handles when the token is overflowed and wrapped back +/// to low value which may be used by existing handle. +/// It is used to differentiate between different types of completion keys. +#[repr(transparent)] +pub(super) struct CompletionKey(usize); + +static NEXT_DEFAULT_TOKEN: AtomicUsize = AtomicUsize::new(1); // 0 reserved for default iocp packet +static NEXT_FILE_TOKEN: AtomicUsize = AtomicUsize::new(1usize << (usize::BITS - 1)); // Initialize with high bit set + +impl CompletionKey { + const HIGH_BIT: usize = 1usize << (usize::BITS - 1); // 0x8000_0000_0000_0000 on 64-bit + const COUNTER_MASK: usize = !Self::HIGH_BIT; // 0x7FFF_FFFF_FFFF_FFFF on 64-bit + pub(super) fn new(kind: CompletionKeyType) -> Self { + match kind { + CompletionKeyType::File => { + // For file tokens, increment from HIGH_BIT base + // If it would overflow past HIGH_BIT | COUNTER_MASK, wrap back to HIGH_BIT + let token = loop { + let current = NEXT_FILE_TOKEN.load(std::sync::atomic::Ordering::Relaxed); + let next = if current == (Self::HIGH_BIT | Self::COUNTER_MASK) { + Self::HIGH_BIT // Wrap back to HIGH_BIT (first file token) + } else { + current + 1 + }; + + match NEXT_FILE_TOKEN.compare_exchange_weak( + current, + next, + std::sync::atomic::Ordering::Relaxed, + std::sync::atomic::Ordering::Relaxed, + ) { + Ok(_) => break current, + Err(_) => continue, // Retry if another thread modified it + } + }; + + Self(token) + } + _ => { + // For default tokens, we need to ensure the counter never exceeds COUNTER_MASK + // If it would overflow, wrap back to 0 + let counter = loop { + let current = NEXT_DEFAULT_TOKEN.load(std::sync::atomic::Ordering::Relaxed); + let next = if current >= Self::COUNTER_MASK { + 1 + } else { + current + 1 + }; + + match NEXT_DEFAULT_TOKEN.compare_exchange_weak( + current, + next, + std::sync::atomic::Ordering::Relaxed, + std::sync::atomic::Ordering::Relaxed, + ) { + Ok(_) => break current, + Err(_) => continue, // Retry if another thread modified it + } + }; + + // Keep highest bit as 0, counter is already safe + let token = counter; + Self(token) + } + } + } + + /// Check if this completion key is for a File type + pub(super) fn is_file(&self) -> bool { + (self.0 & Self::HIGH_BIT) != 0 + } +} + +impl From for usize { + fn from(key: CompletionKey) -> Self { + key.0 + } +} + +impl From for CompletionKey { + fn from(token: usize) -> Self { + Self(token) } } diff --git a/src/os/iocp.rs b/src/os/iocp.rs index 3370118..661b826 100644 --- a/src/os/iocp.rs +++ b/src/os/iocp.rs @@ -1,13 +1,21 @@ //! Functionality that is only available for IOCP-based platforms. +use windows_sys::Win32::Foundation as wf; +use windows_sys::Win32::System::IO::{OVERLAPPED, OVERLAPPED_ENTRY}; + +use crate::iocp::ntdll::NtdllImports; +use crate::iocp::FileCompletionStatus; +pub use crate::iocp::FileOverlappedWrapper; pub use crate::sys::CompletionPacket; use super::__private::PollerSealed; use crate::{Event, PollMode, Poller}; +use std::cell::UnsafeCell; use std::io; use std::os::windows::io::{AsRawHandle, RawHandle}; use std::os::windows::prelude::{AsHandle, BorrowedHandle}; +use std::ptr::NonNull; /// Extension trait for the [`Poller`] type that provides functionality specific to IOCP-based /// platforms. @@ -251,3 +259,521 @@ pub trait AsWaitable: AsHandle { } impl AsWaitable for T {} + +/// Overlapped structure owned by the poller and returned to the caller when calling [`add_file`]. +/// The caller must use this structure to get read overlapped ptr or write overlapped ptr as parameter +/// in ReadFile/WriteFile APIs. Otherwise, the behavior is undefined. +/// +/// The overlapped ptr can be safely converted to 'FileOverlappedWrapper' to check result. +/// +/// [`add_file`]: crate::os::iocp::PollerIocpFileExt::add_file +#[derive(Debug, Clone, Copy)] +pub struct IocpFilePacket { + /// read pointer to the overlapped structure + read: NonNull, + /// write pointer to the overlapped structure + write: NonNull, +} + +impl IocpFilePacket { + /// Create a new `IocpFilePacket` with the given `OVERLAPPED` pointer. + pub(crate) fn new(read: *mut OVERLAPPED, write: *mut OVERLAPPED) -> Self { + Self { + read: NonNull::new(read).unwrap(), + write: NonNull::new(write).unwrap(), + } + } + + /// Get the raw read overlapped pointer to the `OVERLAPPED` structure. + pub fn read_ptr(&self) -> *mut OVERLAPPED { + self.read.as_ptr() + } + + /// Get the raw write overlapped pointer to the `OVERLAPPED` structure. + pub fn write_ptr(&self) -> *mut OVERLAPPED { + self.write.as_ptr() + } +} + +/// Extension trait for the [`Poller`] type that provides file specific functionality to IOCP-based +/// platforms. +/// +/// [`Poller`]: crate::Poller +pub trait PollerIocpFileExt: PollerSealed { + /// Add a file handle to this poller. + /// + /// File handle can be used in file read/write operation such as: [`ReadFile`], [`WriteFile`], etc API. + /// Those APIs have LPOVERLAPPED as parameter which will be reutrned in IOCP port polling. + /// + /// [`ReadFile`]: https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-readfile + /// [`WriteFile`]: https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-writefile + /// + /// File handle could be used with various types of handles, including: + /// - **File Handles** + /// - Regular files - Created with CreateFile + /// - Directories - For writing directory entries (limited scenarios) + /// - Physical disks and volumes - Raw disk access + /// - **Communication Handles** + /// - Named pipes - Both client and server sides + /// - Anonymous pipes - For inter-process communication + /// - Mailslots - For one-to-many communication + /// - **Device Handles** + /// - Serial ports (COM ports) + /// - Parallel ports (LPT ports) + /// - Console output + /// - Tape drives + /// - CD-ROM/DVD drives (for raw access) + /// - **Network Handles** + /// - Sockets - Though send() is typically preferred for sockets + /// - **Special Handles** + /// - Memory-mapped files - When accessed as file handles + /// - Virtual files - Some virtual file systems + /// + /// The returned [`IocpFilePacket`] provide read/write overlapped pointer used in ReadFile/WriteFile as overlapped pointer parameter. + /// Once the read/write operation is called with this function readed overlapped pointer, the poller will emit the `interest` event + /// when the operation is completed. The overlapped pointer can be safely converted to [`FileOverlappedWrapper`] to check result. + /// + /// File handle work on PollMode::Edge mode. The IOCP continue to poll the events unitl + /// the file is closed. The caller must use the overlapped pointer return in IocpFilePacket + /// as overlapped paramter for I/O operation. The Packet do not need to increase Arc count because + /// the call can trigger events through I/O operation without update intrest events as long as the + /// file handle has been registered with the IOCP. So the Packet lifetime is ended with calling [`remove_file`]. + /// Any I/O operation using returned overlapped pointer in IocpFilePacket is undefined behavior. + /// + /// IocpFilePacket will return both read and write overlapped pointer no matter what intrest events are. + /// The caller need to use the correct overlapped pointer for I/O operation. Such as: the read overlapped + /// pointer can be used for read operations, and the write overlapped pointer can be used for write operations. + /// + /// # Safety + /// + /// The added handle must not be dropped before it is deleted. + /// The returned [`IocpFilePacket`] must not be used after [`remove_file`] is called. + /// + /// [`remove_file`]: crate::os::iocp::PollerIocpFileExt::remove_file + /// + /// # Examples + /// + /// ```no_run + /// use polling::os::iocp::{FileOverlappedWrapper, Overlapped, PollerIocpFileExt}; + /// use polling::{Event, Events, Poller}; + /// use windows_sys::Win32::System::IO::OVERLAPPED; + /// + /// use std::ffi::OsStr; + /// use std::fs::OpenOptions; + /// use std::io; + /// use std::os::windows::ffi::OsStrExt; + /// use std::os::windows::fs::OpenOptionsExt; + /// use std::os::windows::io::{AsRawHandle, FromRawHandle, IntoRawHandle, OwnedHandle}; + /// use std::time::Duration; + /// + /// use windows_sys::Win32::{ + /// Foundation as wf, Storage::FileSystem as wfs, System::Pipes as wps, System::IO as wio, + /// }; + /// + /// fn new_named_pipe>(addr: A) -> io::Result { + /// let fname = addr + /// .as_ref() + /// .encode_wide() + /// .chain(Some(0)) + /// .collect::>(); + /// let handle = unsafe { + /// let raw_handle = wps::CreateNamedPipeW( + /// fname.as_ptr(), + /// wfs::PIPE_ACCESS_DUPLEX | wfs::FILE_FLAG_OVERLAPPED, + /// wps::PIPE_TYPE_BYTE | wps::PIPE_READMODE_BYTE | wps::PIPE_WAIT, + /// 1, + /// 4096, + /// 4096, + /// 0, + /// std::ptr::null_mut(), + /// ); + /// + /// if raw_handle == wf::INVALID_HANDLE_VALUE { + /// return Err(io::Error::last_os_error()); + /// } + /// + /// OwnedHandle::from_raw_handle(raw_handle as _) + /// }; + /// + /// Ok(handle) + /// } + /// + /// fn server() -> (OwnedHandle, String) { + /// let num: u64 = fastrand::u64(..); + /// let name = format!(r"\\.\pipe\my-pipe-{}", num); + /// let pipe = new_named_pipe(&name).unwrap(); + /// (pipe, name) + /// } + /// + /// fn client(name: &str) -> io::Result { + /// let mut opts = OpenOptions::new(); + /// opts.read(true) + /// .write(true) + /// .custom_flags(wfs::FILE_FLAG_OVERLAPPED); + /// let file = opts.open(name)?; + /// unsafe { Ok(OwnedHandle::from_raw_handle(file.into_raw_handle())) } + /// } + /// + /// fn pipe() -> (OwnedHandle, OwnedHandle) { + /// let (pipe, name) = server(); + /// (pipe, client(&name).unwrap()) + /// } + /// + /// fn write_then_read() { + /// unsafe { + /// let (server, client) = pipe(); + /// let poller = Poller::new().unwrap(); + /// let mut events = Events::new(); + /// + /// let server_overlapped = unsafe { + /// poller + /// .add_file(&server, Event::new(1, true, false)) + /// .unwrap() + /// }; + /// + /// let client_overlapped = poller.add_file(&client, Event::new(2, true, true)).unwrap(); + /// + /// let mut written = 0u32; + /// let ret = wfs::WriteFile( + /// client.as_raw_handle(), + /// b"1234" as *const u8, + /// 4, + /// (&mut written) as *mut u32, + /// client_overlapped.write_ptr(), + /// ); + /// + /// assert!(ret == wf::TRUE && written == 4); + /// + /// loop { + /// poller.wait(&mut events, None).unwrap(); + /// let events = events.iter().collect::>(); + /// if let Some(event) = events.iter().find(|e| e.key == 2) { + /// if event.writable { + /// break; + /// } + /// } + /// } + /// + /// events.clear(); + /// let mut buf = [0u8; 10]; + /// + /// let mut read = 0u32; + /// let ret = wfs::ReadFile( + /// server.as_raw_handle(), + /// &mut buf as *mut u8, + /// 10, + /// (&mut read) as *mut u32, + /// server_overlapped.read_ptr(), + /// ); + /// + /// let event_len = poller + /// .wait(&mut events, Some(Duration::from_millis(10))) + /// .unwrap(); + /// assert_eq!(event_len, 1); + /// + /// let events = events.iter().collect::>(); + /// events.iter().for_each(|e| { + /// if e.key == 2 { + /// assert_eq!(e.writable, true); + /// } + /// }); + /// + /// assert!(ret == wf::TRUE && read == 4); + /// assert_eq!(&buf[..4], b"1234"); + /// + /// poller.remove_file(&server).unwrap(); + /// poller.remove_file(&client).unwrap(); + /// drop(server); + /// drop(client); + /// } + /// } + /// ``` + unsafe fn add_file(&self, file: &impl AsRawHandle, event: Event) -> io::Result; + + /// Remove a file handle from this poller. + /// + /// This function can be used to remove a file handle from the poller. The handle must + /// have been previously added to the poller using [`add_file`]. + /// + /// [`add_file`]: Self::add_file + /// + /// # Examples + /// + /// ```no_run + /// use polling::os::iocp::{FileOverlappedWrapper, Overlapped, PollerIocpFileExt}; + /// use polling::{Event, Events, Poller}; + /// use windows_sys::Win32::System::IO::OVERLAPPED; + /// + /// use std::ffi::OsStr; + /// use std::fs::OpenOptions; + /// use std::io; + /// use std::os::windows::ffi::OsStrExt; + /// use std::os::windows::fs::OpenOptionsExt; + /// use std::os::windows::io::{AsRawHandle, FromRawHandle, IntoRawHandle, OwnedHandle}; + /// use std::time::Duration; + /// + /// use windows_sys::Win32::{ + /// Foundation as wf, Storage::FileSystem as wfs, System::Pipes as wps, System::IO as wio, + /// }; + /// + /// // Create a poller. + /// let poller = Poller::new().unwrap(); + /// let mut events = Events::new(); + /// println!("Create a temp file"); + /// // Open a file for writing. + /// let dir = tempfile::tempdir().unwrap(); + /// let file_path = dir.path().join("test.txt"); + /// let fname = file_path + /// .as_os_str() + /// .encode_wide() + /// .chain(Some(0)) + /// .collect::>(); + /// let file_handle = unsafe { + /// let raw_handle = wfs::CreateFileW( + /// fname.as_ptr(), + /// wf::GENERIC_WRITE | wf::GENERIC_READ, + /// 0, + /// std::ptr::null_mut(), + /// wfs::CREATE_ALWAYS, + /// wfs::FILE_FLAG_OVERLAPPED, + /// std::ptr::null_mut(), + /// ); + /// if raw_handle == wf::INVALID_HANDLE_VALUE { + /// panic!("CreateFileW failed: {}", io::Error::last_os_error()); + /// } + /// OwnedHandle::from_raw_handle(raw_handle as _) + /// }; + /// println!("file handle: {:?}", file_handle); + /// let overlapped = unsafe { + /// poller + /// .add_file(&file_handle, Event::new(1, true, true)) + /// .unwrap() + /// }; + /// + /// // Repeatedly write to the pipe. + /// let input_text = "Now is the time for all good men to come to the aid of their party"; + /// let mut len = input_text.len(); + /// while len > 0 { + /// // Begin the write. + /// let ptr = overlapped.write_ptr(); + /// unsafe { + /// let ret = wfs::WriteFile( + /// file_handle.as_raw_handle() as _, + /// input_text.as_ptr() as _, + /// len as _, + /// std::ptr::null_mut(), + /// ptr, + /// ); + /// println!("WriteFile returned: {}, len: {}, ptr: {:p}", ret, len, ptr); + /// if ret == 0 && wf::GetLastError() != wf::ERROR_IO_PENDING { + /// panic!("WriteFile failed: {}", io::Error::last_os_error()); + /// } + /// } + /// // Wait for the overlapped operation to complete. + /// 'waiter: loop { + /// events.clear(); + /// println!("Starting wait..."); + /// poller.wait(&mut events, None).unwrap(); + /// println!("Got events"); + /// for event in events.iter() { + /// if event.writable && event.key == 1 { + /// break 'waiter; + /// } + /// } + /// } + /// // Decrement the length by the number of bytes written. + /// let wrapper = unsafe { &*FileOverlappedWrapper::from_overlapped_ptr(ptr) }; + /// wrapper.get_result().map_or_else( + /// |e| { + /// match e.kind() { + /// io::ErrorKind::WouldBlock => { + /// // The operation is still pending, we can ignore this error. + /// println!("WriteFile is still pending, continuing..."); + /// } + /// _ => panic!("WriteFile failed: {}", e), + /// } + /// }, + /// |ret| { + /// if (!ret) { + /// println!("The file handle maybe closed"); + /// } + /// else { + /// let bytes_written = wrapper.get_bytes_transferred(); + /// println!("Bytes written: {}", bytes_written); + /// len -= bytes_written as usize; + /// } + /// }, + /// ); + /// } + /// poller.remove_file(&file_handle).unwrap(); + /// ``` + fn remove_file(&self, file: &impl AsRawHandle) -> io::Result<()>; +} + +impl PollerIocpFileExt for Poller { + unsafe fn add_file(&self, file: &impl AsRawHandle, event: Event) -> io::Result { + self.poller.add_file(file.as_raw_handle(), event) + } + + fn remove_file(&self, file: &impl AsRawHandle) -> io::Result<()> { + self.poller.remove_file(file.as_raw_handle()) + } +} + +/// Overlapped structure is part data block of [`IoStatusBlock`] owned by poller. +/// It is same as Overlapped, but used by poller internal to update status of I/O operations. +/// +/// [`IoStatusBlock`]: crate::iocp::IoStatusBlock +#[repr(C)] +pub(crate) struct OverlappedInner { + /// OVERLAPPED structure used for I/O operation + inner: UnsafeCell, + /// bytes transferred when iocp event happens. + bytes_transferred: u32, + /// Callback function used to covert to the whole [`IoStatusBlock`] block + callback: unsafe fn(&OVERLAPPED_ENTRY) -> (T, FileCompletionStatus), +} + +impl OverlappedInner { + pub(crate) fn new(callback: unsafe fn(&OVERLAPPED_ENTRY) -> (T, FileCompletionStatus)) -> Self { + Self { + inner: UnsafeCell::new(OVERLAPPED::default()), + bytes_transferred: 0, + callback, + } + } + + /// Convert from OVERLAPPED_ENTRY.lpOverlapped back to OverlappedInner + /// + /// # Safety + /// + /// The overlapped_ptr must point to the `inner` field of a valid OverlappedInner instance + pub(crate) unsafe fn from_overlapped_ptr(overlapped_ptr: *mut OVERLAPPED) -> *mut Self { + // Calculate offset of 'inner' field within OverlappedInner + let offset = std::mem::offset_of!(OverlappedInner, inner); + + // Get pointer to the containing Overlapped struct + (overlapped_ptr as *mut u8).sub(offset) as *mut OverlappedInner + } + + /// Convert and call the callback + /// + /// # Safety + /// + /// The entry.lpOverlapped must point to the `inner` field of a valid OverlappedInner instance + pub(crate) unsafe fn from_entry(entry: &OVERLAPPED_ENTRY) -> (T, FileCompletionStatus) { + let overlapped_ptr = Self::from_overlapped_ptr(entry.lpOverlapped); + let overlapped_ref = &*overlapped_ptr; + (overlapped_ref.callback)(entry) + } + + /// Get a raw pointer to the OVERLAPPED structure + pub(crate) fn as_ptr(&self) -> *mut OVERLAPPED { + self.inner.get() + } + + /// Set the number of bytes transferred by the I/O operation + pub(crate) fn set_bytes_transferred(&mut self, bytes: u32) { + self.bytes_transferred = bytes; + } +} + +impl Drop for OverlappedInner { + fn drop(&mut self) { + // Safety: The OVERLAPPED structure belongs to the Packet. It will be released with Packet + } +} + +/// [`IocpFilePacket`] read/write pointer can safety convert to this structure to check I/O operation +/// results. [`FileOverlappedWrapper`] is alise for access convinence. +/// # Examples +/// +/// ```no_run +/// use polling::os::iocp::{FileOverlappedWrapper, IocpFilePacket}; +/// use std::io; +/// +/// # fn example(overlapped: IocpFilePacket, mut len: usize) { +/// let ptr = overlapped.write_ptr(); +/// let wrapper = unsafe { &*FileOverlappedWrapper::from_overlapped_ptr(ptr) }; +/// println!("bytes transferred: {}", wrapper.get_bytes_transferred()); +/// wrapper.get_result().map_or_else( +/// |e| { +/// match e.kind() { +/// io::ErrorKind::WouldBlock => { +/// // The operation is still pending, we can ignore this error. +/// println!("WriteFile is still pending, continuing..."); +/// } +/// _ => panic!("WriteFile failed: {}", e), +/// } +/// }, +/// |ret| { +/// if (!ret) { +/// println!("The file handle maybe closed"); +/// } +/// else { +/// let bytes_written = wrapper.get_bytes_transferred(); +/// println!("Bytes written: {}", bytes_written); +/// len -= bytes_written as usize; +/// } +/// }, +/// ); +/// # } +/// ``` +/// [`FileOverlappedWrapper`]: crate::iocp::FileOverlappedWrapper +#[derive(Debug)] +#[repr(C)] +pub struct Overlapped { + inner: UnsafeCell, + bytes_transferred: u32, + callback: unsafe fn(&OVERLAPPED_ENTRY) -> (T, FileCompletionStatus), +} + +impl Overlapped { + /// Convert from OVERLAPPED_ENTRY.lpOverlapped back to Overlapped + /// + /// # Safety + /// + /// The overlapped_ptr must point to the `inner` field of a valid Overlapped instance + pub unsafe fn from_overlapped_ptr(overlapped_ptr: *mut OVERLAPPED) -> *mut Self { + // Calculate offset of 'inner' field within Overlapped + let offset = std::mem::offset_of!(Overlapped, inner); + + // Get pointer to the containing Overlapped struct + (overlapped_ptr as *mut u8).sub(offset) as *mut Overlapped + } + + /// Get number of bytes transferred by the I/O operation + pub fn get_bytes_transferred(&self) -> u32 { + self.bytes_transferred + } + + /// Get the result of the I/O operation. It returns: + /// - Ok(true) if the operation was successful + /// - Ok(false) if there was no data which may means the handle has been closed + /// - Err(io::ErrorKind::WouldBlock) if the operation is still pending + /// - Err(io::Error) for any other error + pub fn get_result(&self) -> io::Result { + let nt_status = unsafe { (*self.inner.get()).Internal }; + let ntdll = NtdllImports::get()?; + let os_error_code = unsafe { ntdll.RtlNtStatusToDosError(nt_status as _) }; + + match os_error_code { + wf::ERROR_SUCCESS => Ok(true), + wf::ERROR_NO_DATA => Ok(false), + wf::ERROR_IO_PENDING => Err(io::ErrorKind::WouldBlock.into()), + error => Err(io::Error::from_raw_os_error(error as _)), + } + } + + /// Clear the state of the I/O operation before take the next I/O operation + pub fn zeroed(&mut self) { + *self.inner.get_mut() = OVERLAPPED::default(); + self.bytes_transferred = 0; + } +} + +impl Drop for Overlapped { + fn drop(&mut self) { + // Safety: The type is not the owner of the struct which is owned by the Packet + } +} diff --git a/tests/windows_overlapped.rs b/tests/windows_overlapped.rs new file mode 100644 index 0000000..945b42f --- /dev/null +++ b/tests/windows_overlapped.rs @@ -0,0 +1,678 @@ +//! Take advantage of overlapped I/O on Windows using CompletionPacket. + +#![cfg(windows)] + +use polling::os::iocp::{FileOverlappedWrapper, PollerIocpFileExt}; +use polling::{Event, Events, Poller}; +use windows_sys::Win32::System::IO::OVERLAPPED; + +use std::ffi::OsStr; +use std::fs::OpenOptions; +use std::io; +use std::os::windows::ffi::OsStrExt; +use std::os::windows::fs::OpenOptionsExt; +use std::os::windows::io::{AsRawHandle, FromRawHandle, IntoRawHandle, OwnedHandle}; +use std::time::Duration; + +use windows_sys::Win32::{Foundation as wf, Storage::FileSystem as wfs, System::Pipes as wps}; + +#[test] +fn win32_file_io() { + // Create a poller. + let poller = Poller::new().unwrap(); + let mut events = Events::new(); + + println!("Create a temp file"); + // Open a file for writing. + let dir = tempfile::tempdir().unwrap(); + let file_path = dir.path().join("test.txt"); + let fname = file_path + .as_os_str() + .encode_wide() + .chain(Some(0)) + .collect::>(); + let file_handle = unsafe { + let raw_handle = wfs::CreateFileW( + fname.as_ptr(), + wf::GENERIC_WRITE | wf::GENERIC_READ, + 0, + std::ptr::null_mut(), + wfs::CREATE_ALWAYS, + wfs::FILE_FLAG_OVERLAPPED, + std::ptr::null_mut(), + ); + + if raw_handle == wf::INVALID_HANDLE_VALUE { + panic!("CreateFileW failed: {}", io::Error::last_os_error()); + } + + OwnedHandle::from_raw_handle(raw_handle as _) + }; + + println!("file handle: {:?}", file_handle); + let overlapped = unsafe { + poller + .add_file(&file_handle, Event::new(1, true, true)) + .unwrap() + }; + + // Repeatedly write to the pipe. + let input_text = "Now is the time for all good men to come to the aid of their party"; + let mut len = input_text.len(); + + while len > 0 { + // Begin the write. + let ptr = overlapped.write_ptr(); + unsafe { + let ret = wfs::WriteFile( + file_handle.as_raw_handle() as _, + input_text.as_ptr() as _, + len as _, + std::ptr::null_mut(), + ptr, + ); + println!("WriteFile returned: {}, len: {}, ptr: {:p}", ret, len, ptr); + if ret == 0 && wf::GetLastError() != wf::ERROR_IO_PENDING { + // Only panic if not running under Wine + if std::env::var("WINELOADER").is_ok() + || std::env::var("WINE").is_ok() + || std::env::var("WINEPREFIX").is_ok() + { + println!("Skipping test under Wine"); + return; + } else { + panic!("WriteFile failed: {}", io::Error::last_os_error()); + } + } + } + + // Wait for the overlapped operation to complete. + 'waiter: loop { + events.clear(); + println!("Starting wait..."); + poller.wait(&mut events, None).unwrap(); + println!("Got events"); + + for event in events.iter() { + if event.writable && event.key == 1 { + break 'waiter; + } + } + } + + // Decrement the length by the number of bytes written. + let wrapper = unsafe { &*FileOverlappedWrapper::from_overlapped_ptr(ptr) }; + wrapper.get_result().map_or_else( + |e| { + match e.kind() { + io::ErrorKind::WouldBlock => { + // The operation is still pending, we can ignore this error. + println!("WriteFile is still pending, continuing..."); + } + _ => panic!("WriteFile failed: {}", e), + } + }, + |ret| { + if !ret { + println!("The file handle maybe closed"); + } else { + let bytes_written = wrapper.get_bytes_transferred(); + println!("Bytes written: {}", bytes_written); + len -= bytes_written as usize; + } + }, + ); + } + + poller.remove_file(&file_handle).unwrap(); + // Close the file and re-open it for reading. + drop(file_handle); + println!("file handle dropped"); + + let file_handle = unsafe { + let raw_handle = wfs::CreateFileW( + fname.as_ptr(), + wf::GENERIC_READ | wf::GENERIC_WRITE, + 0, + std::ptr::null_mut(), + wfs::OPEN_EXISTING, + wfs::FILE_FLAG_OVERLAPPED, + std::ptr::null_mut(), + ); + + if raw_handle == wf::INVALID_HANDLE_VALUE { + panic!("CreateFileW failed: {}", io::Error::last_os_error()); + } + + OwnedHandle::from_raw_handle(raw_handle as _) + }; + + println!("file handle: {:?}", file_handle); + let overlapped = unsafe { + poller + .add_file(&file_handle, Event::new(1, true, true)) + .unwrap() + }; + + // Repeatedly read from the pipe. + let mut buffer = vec![0u8; 1024]; + let mut buffer_cursor = &mut *buffer; + let mut len = 1024; + let mut bytes_received = 0; + + while bytes_received < input_text.len() { + // Begin the read. + let ptr = overlapped.read_ptr(); + unsafe { + if wfs::ReadFile( + file_handle.as_raw_handle() as _, + buffer_cursor.as_mut_ptr() as _, + len as _, + std::ptr::null_mut(), + ptr, + ) == 0 + && wf::GetLastError() != wf::ERROR_IO_PENDING + { + panic!("ReadFile failed: {}", io::Error::last_os_error()); + } + } + + // Wait for the overlapped operation to complete. + 'waiter: loop { + events.clear(); + poller.wait(&mut events, None).unwrap(); + + for event in events.iter() { + if event.readable && event.key == 1 { + break 'waiter; + } + } + } + + // Increment the cursor and decrement the length by the number of bytes read. + let bytes_read = input_text.len(); + buffer_cursor = &mut buffer_cursor[bytes_read..]; + len -= bytes_read; + bytes_received += bytes_read; + } + + assert_eq!(bytes_received, input_text.len()); + assert_eq!(&buffer[..bytes_received], input_text.as_bytes()); +} + +fn new_named_pipe>(addr: A) -> io::Result { + let fname = addr + .as_ref() + .encode_wide() + .chain(Some(0)) + .collect::>(); + let handle = unsafe { + let raw_handle = wps::CreateNamedPipeW( + fname.as_ptr(), + wfs::PIPE_ACCESS_DUPLEX | wfs::FILE_FLAG_OVERLAPPED, + wps::PIPE_TYPE_BYTE | wps::PIPE_READMODE_BYTE | wps::PIPE_WAIT, + 1, + 4096, + 4096, + 0, + std::ptr::null_mut(), + ); + + if raw_handle == wf::INVALID_HANDLE_VALUE { + return Err(io::Error::last_os_error()); + } + + OwnedHandle::from_raw_handle(raw_handle as _) + }; + + Ok(handle) +} + +unsafe fn connect_named_pipe( + handle: &impl AsRawHandle, + overlapped: *mut OVERLAPPED, +) -> io::Result<()> { + if wps::ConnectNamedPipe(handle.as_raw_handle() as _, overlapped) != 0 { + // If ConnectNamedPipe returns non-zero, the connection was successful. + return Ok(()); + } + + let err = io::Error::last_os_error(); + + match err.raw_os_error().map(|e| e as u32) { + Some(wf::ERROR_PIPE_CONNECTED) => Ok(()), + Some(wf::ERROR_NO_DATA) => Err(io::ErrorKind::WouldBlock.into()), + Some(wf::ERROR_IO_PENDING) => Err(io::ErrorKind::WouldBlock.into()), + _ => Err(err), + } +} + +fn server() -> (OwnedHandle, String) { + let num: u64 = fastrand::u64(..); + let name = format!(r"\\.\pipe\my-pipe-{}", num); + let pipe = new_named_pipe(&name).unwrap(); + (pipe, name) +} + +fn client(name: &str) -> io::Result { + let mut opts = OpenOptions::new(); + opts.read(true) + .write(true) + .custom_flags(wfs::FILE_FLAG_OVERLAPPED); + let file = opts.open(name)?; + unsafe { Ok(OwnedHandle::from_raw_handle(file.into_raw_handle())) } +} + +fn pipe() -> (OwnedHandle, OwnedHandle) { + let (pipe, name) = server(); + (pipe, client(&name).unwrap()) +} + +// Test client create success if server create named pipe first. +// Client can write data to pipe without server call ConnectNamedPipe first. +// Client return NotFound error if clinet create before server create named pipe. +// Poller will not receive event if client and server create before register file. +// Poller will also not receive event if server create and add to poller before client create named pipe. +#[test] +fn writable_after_register() { + { + let name = format!(r"\\.\pipe\my-pipe-{}", fastrand::u64(..)); + let client = client(&name); + assert_eq!(client.err().unwrap().kind(), io::ErrorKind::NotFound); + + let (server, client) = pipe(); + let poller = Poller::new().unwrap(); + let mut events = Events::new(); + + let _server_overlapped = unsafe { + poller + .add_file(&server, Event::new(1, true, false)) + .unwrap() + }; + + let _client_overlapped = unsafe { + poller + .add_file(&client, Event::new(2, false, true)) + .unwrap() + }; + + poller + .wait(&mut events, Some(Duration::from_millis(10))) + .unwrap(); + assert!(events.is_empty()); + + poller.remove_file(&server).unwrap(); + poller.remove_file(&client).unwrap(); + drop(server); + drop(client); + } + + // Poller will receive event if server add to poller before client create file + let (server, name) = server(); + let poller = Poller::new().unwrap(); + let mut events = Events::new(); + + let _server_overlapped = unsafe { + poller + .add_file(&server, Event::new(1, true, false)) + .unwrap() + }; + + let client = client(&name); + poller + .wait(&mut events, Some(Duration::from_millis(10))) + .unwrap(); + + assert!(events.is_empty()); + + poller.remove_file(&server).unwrap(); + drop(server); + drop(client); +} + +// Client can write data to pipe without server call ConnectNamedPipe first +// if server create named pipe first. Poller will receive write event when client +// write data to pipe. The Polling mode is EDGE, the write event will be cleared. +#[test] +fn write_then_read() { + let (server, client) = pipe(); + let poller = Poller::new().unwrap(); + let mut events = Events::new(); + + let server_overlapped = unsafe { + poller + .add_file(&server, Event::new(1, true, false)) + .unwrap() + }; + + let client_overlapped = unsafe { poller.add_file(&client, Event::new(2, true, true)).unwrap() }; + + unsafe { + let mut written = 0u32; + let ret = wfs::WriteFile( + client.as_raw_handle(), + b"1234" as *const u8, + 4, + (&mut written) as *mut u32, + client_overlapped.write_ptr(), + ); + + assert!(ret == wf::TRUE && written == 4); + + loop { + poller.wait(&mut events, None).unwrap(); + let events = events.iter().collect::>(); + if let Some(event) = events.iter().find(|e| e.key == 2) { + if event.writable { + break; + } + } + } + + events.clear(); + let mut buf = [0u8; 10]; + + let mut read = 0u32; + let ret = wfs::ReadFile( + server.as_raw_handle(), + &mut buf as *mut u8, + 10, + (&mut read) as *mut u32, + server_overlapped.read_ptr(), + ); + + let event_len = poller + .wait(&mut events, Some(Duration::from_millis(10))) + .unwrap(); + assert_eq!(event_len, 1); + + let events = events.iter().collect::>(); + events.iter().for_each(|e| { + if e.key == 2 { + assert!(e.writable); + } + }); + + assert!(ret == wf::TRUE && read == 4); + assert_eq!(&buf[..4], b"1234"); + } + + poller.remove_file(&server).unwrap(); + poller.remove_file(&client).unwrap(); + drop(server); + drop(client); +} + +// Poller will receive read event if server call ConnectNamedPipe after add to poller before +// client create named pipe. +#[test] +fn connect_before_client() { + let (server, name) = server(); + let poller = Poller::new().unwrap(); + let mut events = Events::new(); + + let server_overlapped = unsafe { + poller + .add_file(&server, Event::new(1, true, false)) + .unwrap() + }; + + poller.wait(&mut events, Some(Duration::new(0, 0))).unwrap(); + assert_eq!(events.iter().count(), 0); + + unsafe { + let ret = connect_named_pipe(&server, server_overlapped.read_ptr()); + assert_eq!(ret.err().unwrap().kind(), io::ErrorKind::WouldBlock); + + let client = client(&name).unwrap(); + let _client_overlapped = poller.add_file(&client, Event::new(2, true, true)).unwrap(); + + loop { + let event_num = poller.wait(&mut events, None).unwrap(); + assert_eq!(event_num, 1); + let e = events.iter().collect::>(); + events.clear(); + if let Some(event) = e.iter().find(|e| e.key == 1) { + if event.readable { + let overlapped_wrapper = + &*FileOverlappedWrapper::from_overlapped_ptr(server_overlapped.read_ptr()); + assert_eq!(overlapped_wrapper.get_bytes_transferred(), 0); + assert!(overlapped_wrapper.get_result().is_ok()); + break; + } + } + } + + poller.remove_file(&server).unwrap(); + poller.remove_file(&client).unwrap(); + drop(server); + drop(client); + } +} + +// Server can not write data to pipe after client disconnected and return ERROR_NO_DATA error. +// Poller will not receive write event if server try to write data to pipe after client disconnected +#[test] +fn write_disconnected() { + let (server, client) = pipe(); + let poller = Poller::new().unwrap(); + let mut events = Events::new(); + + let server_overlapped = unsafe { + poller + .add_file(&server, Event::new(1, true, false)) + .unwrap() + }; + + let _client_overlapped = unsafe { + poller + .add_file(&client, Event::new(2, false, true)) + .unwrap() + }; + + drop(client); + + poller + .wait(&mut events, Some(Duration::from_millis(10))) + .unwrap(); + assert!(events.iter().count() == 0); + + unsafe { + let mut written = 0u32; + let ret = wfs::WriteFile( + server.as_raw_handle(), + b"1234" as *const u8, + 1, + (&mut written) as *mut u32, + server_overlapped.write_ptr(), + ); + + let e = io::Error::last_os_error(); + + assert_eq!(ret, wf::FALSE); + assert_eq!(written, 0); + assert_eq!(e.raw_os_error(), Some(wf::ERROR_NO_DATA as i32)); + + // according testing, it return ERROR_NO_DATA. the server cannot write even one byte + let num_event = poller + .wait(&mut events, Some(Duration::from_millis(10))) + .unwrap(); + assert_eq!(num_event, 0); + } +} + +// Poller will receive write event if client write data to pipe before drop. +// Server can read the data written by client after client drop and Poller +// can receive the read event of server. +#[test] +fn write_then_drop() { + let (server, client) = pipe(); + let poller = Poller::new().unwrap(); + let mut events = Events::new(); + + let server_overlapped = unsafe { + poller + .add_file(&server, Event::new(1, true, false)) + .unwrap() + }; + + let client_overlapped = unsafe { + poller + .add_file(&client, Event::new(2, false, true)) + .unwrap() + }; + + unsafe { + let mut written = 0u32; + let ret = wfs::WriteFile( + client.as_raw_handle(), + b"1234" as *const u8, + 4, + (&mut written) as *mut u32, + client_overlapped.write_ptr(), + ); + + assert!(ret == wf::TRUE && written == 4); + } + + drop(client); + + // Poller will receive write event if client write data to pipe before drop. + let num_event = poller + .wait(&mut events, Some(Duration::from_millis(10))) + .unwrap(); + + assert_eq!(num_event, 1); + + unsafe { + let events = events.iter().collect::>(); + assert_eq!(events[0].key, 2); + assert!(events[0].writable); + assert!(!events[0].readable); + let overlapped_wrapper = + &*FileOverlappedWrapper::from_overlapped_ptr(client_overlapped.write_ptr()); + assert_eq!(overlapped_wrapper.get_bytes_transferred(), 4); + assert!(overlapped_wrapper.get_result().unwrap()); + } + + events.clear(); + let num_event = poller + .wait(&mut events, Some(Duration::from_millis(10))) + .unwrap(); + + assert_eq!(num_event, 0); + + unsafe { + let mut buf = [0u8; 10]; + + let mut read = 0u32; + let ret = wfs::ReadFile( + server.as_raw_handle(), + &mut buf as *mut u8, + 10, + (&mut read) as *mut u32, + server_overlapped.read_ptr(), + ); + + assert_eq!(ret, wf::TRUE); + assert_eq!(read, 4); + + // Still receive read event even ReadFile return true. + let num_event = poller + .wait(&mut events, Some(Duration::from_millis(10))) + .unwrap(); + assert_eq!(num_event, 1); + assert_eq!(&buf[..4], b"1234"); + } + + drop(server); +} + +// Server can not be connected by the second client. +// Server return error when ReadFile with ERROR_BROKEN_PIPE which client has been closed. +#[test] +fn connect_twice() { + unsafe { + let (server, name) = server(); + let poller = Poller::new().unwrap(); + let mut events = Events::new(); + + let server_overlapped = poller + .add_file(&server, Event::new(1, true, false)) + .unwrap(); + + poller.wait(&mut events, Some(Duration::new(0, 0))).unwrap(); + assert_eq!(events.iter().count(), 0); + + let ret = connect_named_pipe(&server, server_overlapped.read_ptr()); + assert_eq!(ret.err().unwrap().kind(), io::ErrorKind::WouldBlock); + + let c1 = client(&name).unwrap(); + let _c1_overlapped = poller.add_file(&c1, Event::new(2, true, true)).unwrap(); + drop(c1); + + poller.wait(&mut events, Some(Duration::new(0, 0))).unwrap(); + let ret_events = events.iter().collect::>(); + assert_eq!(ret_events.len(), 1); + assert_eq!(ret_events[0].key, 1); + assert!(ret_events[0].readable); + + events.clear(); + + let mut buf = [0u8; 10]; + + let mut read = 0u32; + // Can not read, should close server pipe. + let ret = wfs::ReadFile( + server.as_raw_handle(), + &mut buf as *mut u8, + 10, + (&mut read) as *mut u32, + server_overlapped.read_ptr(), + ); + + let e = io::Error::last_os_error(); + + assert_eq!(ret, wf::FALSE); + assert_eq!(read, 0); + assert_eq!(e.raw_os_error(), Some(wf::ERROR_BROKEN_PIPE as i32)); + + let num_event = poller + .wait(&mut events, Some(Duration::from_millis(10))) + .unwrap(); + assert_eq!(num_event, 0); + + let c2 = client(&name); + assert_eq!( + c2.err().unwrap().raw_os_error(), + Some(wf::ERROR_PIPE_BUSY as i32) + ); + } +} + +#[test] +fn remove_file_before_add_file() { + let (server, _) = server(); + let poller = Poller::new().unwrap(); + + assert_eq!( + poller.remove_file(&server).unwrap_err().kind(), + io::ErrorKind::NotFound, + ); +} + +#[test] +fn add_file_different_poll() { + let (server, _) = server(); + let poller1 = Poller::new().unwrap(); + let poller2 = Poller::new().unwrap(); + + unsafe { + let _ = poller1 + .add_file(&server, Event::new(1, true, true)) + .unwrap(); + + let ret = poller2.add_file(&server, Event::new(2, true, true)); + assert!(ret.is_err()); + } +} From f7d8dade2c0660496932acdc2da9ef304644d58a Mon Sep 17 00:00:00 2001 From: Lu Kai Date: Fri, 19 Sep 2025 21:43:39 +0800 Subject: [PATCH 2/2] Solve FilePacket lifetime maybe longer than poller issue --- Cargo.toml | 2 +- src/iocp/afd.rs | 5 +- src/iocp/mod.rs | 272 ++++++++++++---- src/iocp/port.rs | 74 +++-- src/os/iocp.rs | 602 +++++++++++++++++++++++------------- tests/windows_overlapped.rs | 474 ++++++++++++++-------------- 6 files changed, 891 insertions(+), 538 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c2611ab..496f83e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,6 +53,7 @@ features = [ "Win32_System_LibraryLoader", "Win32_System_Threading", "Win32_System_WindowsProgramming", + "Win32_System_Pipes", ] [target.'cfg(target_os = "hermit")'.dependencies.hermit-abi] @@ -71,4 +72,3 @@ signal-hook = "0.3.17" [target.'cfg(windows)'.dev-dependencies] tempfile = "3.7" -windows-sys = { version = "0.60", features = ["Win32_System_Pipes"] } diff --git a/src/iocp/afd.rs b/src/iocp/afd.rs index 9096e3b..93a01e1 100644 --- a/src/iocp/afd.rs +++ b/src/iocp/afd.rs @@ -468,13 +468,14 @@ unsafe impl Completion for IoStatusBlock { } impl FileOverlapped for IoStatusBlock { + #[inline] fn file_read_offset() -> usize { T::file_read_offset() + std::mem::offset_of!(IoStatusBlock, data) } + #[inline] fn file_write_offset() -> usize { - let data_offset = std::mem::offset_of!(IoStatusBlock, data); - T::file_write_offset() + data_offset + T::file_write_offset() + std::mem::offset_of!(IoStatusBlock, data) } } diff --git a/src/iocp/mod.rs b/src/iocp/mod.rs index a579ac3..10eef14 100644 --- a/src/iocp/mod.rs +++ b/src/iocp/mod.rs @@ -50,7 +50,7 @@ use std::cell::UnsafeCell; use std::collections::hash_map::{Entry, HashMap}; use std::ffi::c_void; use std::marker::PhantomPinned; -use std::mem::{forget, MaybeUninit}; +use std::mem::{forget, ManuallyDrop, MaybeUninit}; use std::os::windows::io::{ AsHandle, AsRawHandle, AsRawSocket, BorrowedHandle, BorrowedSocket, RawHandle, RawSocket, }; @@ -94,7 +94,7 @@ pub(super) struct Poller { /// The state of the waitable handles registered with this poller. waitables: RwLock>, - /// The state of the waitable handles registered with this poller. + /// The state of the overlapped files registered with this poller. files: RwLock>, /// Sockets with pending updates. @@ -435,16 +435,31 @@ impl Poller { } /// Add a file to the poller. + /// /// File handle work on PollMode::Edge mode. The IOCP continue to poll the events unitl /// the file is closed. The caller must use the overlapped pointer return in IocpFilePacket - /// as overlapped paramter for I/O operation. The Packet do not need to increase Arc count because - /// the call can trigger events through I/O operation without update intrest events as long as the - /// file handle has been registered with the IOCP. So the Packet lifetime is ended with calling [`remove_file`]. - /// Any I/O operation using return overlapped pointer return in IocpFilePacket is undefined behavior. + /// as overlapped paramter for I/O operation. The Packet need to increase Arc count every time the I/O operation + /// is performed success (return TRUE or FALSE with ERROR_IO_PENDING in last error), otherwise the Arc count do + /// not need to increase if I/O operation fail because [`IocpFilePacket`] can exist after the poller is dropped. + /// And I/O operation still be valid with the overlapped pointer after the poller is dropped. [`FileOverlappedConverter`] + /// can help to manage the Arc count to avoid memory leak. + /// + /// Normally, the caller use I/O helper function like [`read_file_overlapped`], [`write_file_overlapped`] or + /// [`connect_named_pipe_overlapped`] to perform I/O operation to avoid the complexity of managing the Arc count. + /// + /// [`read_file_overlapped`]: crate::os::iocp::read_file_overlapped + /// [`write_file_overlapped`]: crate::os::iocp::write_file_overlapped + /// [`connect_named_pipe_overlapped`]: crate::os::iocp::connect_named_pipe_overlapped + /// + /// The call can trigger events through I/O operation without update intrest events as long as the + /// file handle has been registered with the IOCP. The Packet lifetime is ended with conditions: [`remove_file`] + /// is called, I/O operation is polled, and [`IocpFilePacket`] is dropped. + /// + /// IocpFilePacket will return both read and write overlapped pointer through [`FileOverlappedConverter::as_ptr()`] + /// no matter what intrest events are. /// - /// IocpFilePacket will return both read and write overlapped pointer no matter what intrest events are. - /// The caller need to use the correct overlapped pointer for I/O operation. Such as: the read overlapped - /// pointer can be used for read operations, and the write overlapped pointer can be used for write operations. + /// The caller need to use the correct overlapped converter for I/O operation. Such as: the read overlapped + /// converter can be used for read operations, and the write overlapped converter can be used for write operations. pub(super) fn add_file( &self, handle: RawHandle, @@ -485,24 +500,57 @@ impl Poller { } } - let (read, write, file_handle) = match handle_state.as_ref().data().project_ref() { - PacketInnerProj::File { - read, - write, - handle, - } => (read.get(), write.get(), handle), - _ => unreachable!("PacketInner should always be File here"), - }; + let read_ptr; + let write_ptr; + { + let (read, write, file_handle) = match handle_state.as_ref().data().project_ref() { + PacketInnerProj::File { + read, + write, + handle, + } => (read.get(), write.get(), handle), + _ => unreachable!("PacketInner should always be File here"), + }; - let file_state = lock!(file_handle.lock()); - // Register the file handle with the I/O completion port. - self.port - .register(&*file_state, true, port::CompletionKeyType::File)?; + let file_state = lock!(file_handle.lock()); + // Register the file handle with the I/O completion port. + self.port + .register(&*file_state, true, port::CompletionKeyType::File)?; + + read_ptr = unsafe { (*read).as_ptr() }; + write_ptr = unsafe { (*write).as_ptr() }; + } - let iocp_packet = unsafe { IocpFilePacket::new((*read).as_ptr(), (*write).as_ptr()) }; + let iocp_packet = + unsafe { IocpFilePacket::new(read_ptr, write_ptr, PacketWrapper(handle_state)) }; Ok(iocp_packet) } + pub(super) fn modify_file(&self, handle: RawHandle, interest: Event) -> io::Result<()> { + #[cfg(feature = "tracing")] + tracing::trace!( + "modify_file: handle={:?}, file={:p}, ev={:?}", + self.port, + handle, + interest + ); + + // Get a reference to the source. + let source = { + let sources = lock!(self.files.read()); + + sources + .get(&handle) + .cloned() + .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))? + }; + + // Set the new event. + source.as_ref().set_events(interest, PollMode::Edge); + + Ok(()) + } + /// Remove a file from the poller. pub(super) fn remove_file(&self, handle: RawHandle) -> io::Result<()> { #[cfg(feature = "tracing")] @@ -835,7 +883,8 @@ impl CompletionPacket { /// It needs to be pinned, since it contains data that is expected by IOCP not to be moved. type Packet = Pin>; type PacketUnwrapped = IoStatusBlock; -/// A wrapper around the Overlapped structure for file I/O operation result + +/// A wrapper around the `Overlapped` structure for file I/O operation result #[derive(Debug)] #[repr(transparent)] pub struct FileOverlappedWrapper(Overlapped); @@ -867,6 +916,110 @@ impl FileOverlappedWrapper { } } +/// The converter is used to safely reference count the Packet owned by the poller +/// when overlapped I/O operation is called successfully (the operation return TRUE or ERROR_IO_PENDING). +/// +/// If the I/O operation return FALSE with last error not ERROR_IO_PENDING, the caller must call +/// [`reclaim`] to reclaim the Packet reference count. Otherwise the Packet will be leaked. +/// +/// Normally the caller should use helper function [`read_file_overlapped`] or [`write_file_overlapped`] +/// to do the I/O operation. The helper function will call `reclaim` automatically when I/O operation failed. +/// +/// [`reclaim`]: FileOverlappedConverter::reclaim +/// [`read_file_overlapped`]: crate::os::iocp::read_file_overlapped +/// [`write_file_overlapped`]: crate::os::iocp::write_file_overlapped +/// +/// # Examples +/// +/// ```no_run +/// use polling::os::iocp::FileOverlappedConverter; +/// use std::{io, os::windows::io::RawHandle}; +/// use windows_sys::Win32::{Foundation as wf, Storage::FileSystem as wsf}; +/// fn read_file( +/// handle: RawHandle, +/// buf: &mut [u8], +/// mut overlapped: FileOverlappedConverter, +/// ) -> io::Result { +/// let mut read = 0u32; +/// // Safety: syscall +/// if unsafe { +/// wsf::ReadFile( +/// handle, +/// buf.as_mut_ptr(), +/// buf.len() as u32, +/// &mut read as *mut _, +/// overlapped +/// .as_ptr() +/// .expect("The overlapped pointer may have been used for I/O operation"), +/// ) +/// } != wf::FALSE +/// { +/// return Ok(read as usize); +/// } +/// +/// let err = io::Error::last_os_error(); +/// let err: io::Result = err +/// .raw_os_error() +/// .map(|e| match (e as u32) { +/// wf::ERROR_IO_PENDING => Err(io::ErrorKind::WouldBlock.into()), +/// _ => Err(err), +/// }) +/// .unwrap(); +/// match err { +/// Err(e) if e.kind() == io::ErrorKind::WouldBlock => Err(e), +/// Err(e) => { +/// overlapped.reclaim(); // reclaim the Packet reference count +/// Err(e) +/// } +/// _ => unreachable!(), +/// } +/// } +/// ``` +#[derive(Debug)] +pub struct FileOverlappedConverter { + ptr: *mut OVERLAPPED, + owner: Option, + drop: Option>, +} + +impl FileOverlappedConverter { + pub(crate) fn new(ptr: *mut OVERLAPPED, packet: PacketWrapper) -> Self { + Self { + ptr, + owner: Some(packet), + drop: None, + } + } + + /// Get the raw pointer. The caller must ensure the pointer is used for overlapped I/O operation. + pub fn as_ptr(&mut self) -> Option<*mut OVERLAPPED> { + if let Some(packet) = self.owner.take() { + self.drop = Some(ManuallyDrop::new(packet)); + } + Some(self.ptr) + } + + /// Reclaim the Packet reference count when I/O operation failed. + pub fn reclaim(&mut self) { + if let Some(drop) = self.drop.take() { + self.owner = Some(ManuallyDrop::into_inner(drop)); + } + } +} + +#[derive(Debug, Clone)] +#[repr(transparent)] +pub(crate) struct PacketWrapper(Packet); + +impl PacketWrapper { + #[doc(hidden)] + pub fn test_ref_count(&self) -> usize { + // Safety: the object is Arc and will not be moved + let inner = unsafe { &*(&self.0 as *const Packet as *const Arc) }; + Arc::strong_count(inner) + } +} + pin_project! { /// The inner type of the packet. #[project_ref = PacketInnerProj] @@ -1022,6 +1175,13 @@ impl PacketUnwrapped { // Update if there is no ongoing wait. handle.status.is_idle() } + PacketInnerProj::File { handle, .. } => { + let mut handle = lock!(handle.lock()); + + // Set the new interest. + handle.interest = interest; + false + } _ => true, } } @@ -1269,44 +1429,46 @@ impl PacketUnwrapped { status: FileCompletionStatus, bytes_transferred: u32, ) -> io::Result { - let inner = self.as_ref().data().project_ref(); - - let (handle, read, write) = match inner { - PacketInnerProj::File { - handle, - read, - write, - } => (handle, read, write), - _ => unreachable!("Should not be called on a non-file packet"), - }; + let return_value; + { + let inner = self.as_ref().data().project_ref(); + + let (handle, read, write) = match inner { + PacketInnerProj::File { + handle, + read, + write, + } => (handle, read, write), + _ => unreachable!("Should not be called on a non-file packet"), + }; - let file_state = lock!(handle.lock()); - let mut event = Event::none(file_state.interest.key); - if status.is_read() { - unsafe { - (*read.get()).set_bytes_transferred(bytes_transferred); + let file_state = lock!(handle.lock()); + let mut event = Event::none(file_state.interest.key); + if status.is_read() { + unsafe { + (*read.get()).set_bytes_transferred(bytes_transferred); + } + event.readable = true; } - event.readable = true; - } - if status.is_write() { - unsafe { - (*write.get()).set_bytes_transferred(bytes_transferred); + if status.is_write() { + unsafe { + (*write.get()).set_bytes_transferred(bytes_transferred); + } + event.writable = true; } - event.writable = true; - } - event.readable &= file_state.interest.readable; - event.writable &= file_state.interest.writable; - - // If this event doesn't have anything that interests us, don't return or - // update the oneshot state. - let return_value = if event.readable || event.writable { - FeedEventResult::Event(event) - } else { - FeedEventResult::NoEvent - }; + event.readable &= file_state.interest.readable; + event.writable &= file_state.interest.writable; + // If this event doesn't have anything that interests us, don't return or + // update the oneshot state. + return_value = if event.readable || event.writable { + FeedEventResult::Event(event) + } else { + FeedEventResult::NoEvent + }; + } Ok(return_value) } diff --git a/src/iocp/port.rs b/src/iocp/port.rs index 2b489cb..517696b 100644 --- a/src/iocp/port.rs +++ b/src/iocp/port.rs @@ -159,9 +159,9 @@ unsafe impl FileCompletionHandle for Pin> { let inner = Arc::from_raw((overlapped_ptr as *const u8).sub(offset) as *const T); assert!(Arc::strong_count(&inner) >= 1, "File has been removed, but still use FileOverlappedWrapper return from add_file function"); - let new_one = Pin::new_unchecked(Arc::clone(&inner)); - let _ = Arc::into_raw(inner); // Prevent Arc from being dropped - new_one + // Do not need to clone new one for file packet because it will be cloned for every successful I/O operation + // see [`IocpFilePacket::read_overlapped`] + Pin::new_unchecked(inner) } } @@ -172,9 +172,9 @@ unsafe impl FileCompletionHandle for Pin> { let inner = Arc::from_raw((overlapped_ptr as *const u8).sub(offset) as *const T); assert!(Arc::strong_count(&inner) >= 1, "File has been removed, but still use FileOverlappedWrapper return from add_file function"); - let new_one = Pin::new_unchecked(Arc::clone(&inner)); - let _ = Arc::into_raw(inner); // Prevent Arc from being dropped - new_one + // Do not need to clone new one for file packet because it will be cloned for every successful I/O operation + // see [`IocpFilePacket::write_overlapped`] + Pin::new_unchecked(inner) } } } @@ -183,6 +183,9 @@ pub(super) struct IoCompletionPort { /// The underlying handle. handle: HANDLE, + /// The completion key generator. + key_gen: CompletionKeyGenerator, + /// We own the status block. _marker: PhantomData, } @@ -234,6 +237,7 @@ impl IoCompletionPort { } else { Ok(Self { handle, + key_gen: Default::default(), _marker: PhantomData, }) } @@ -249,7 +253,12 @@ impl IoCompletionPort { let handle = handle.as_raw_handle(); let result = unsafe { - CreateIoCompletionPort(handle as _, self.handle, CompletionKey::new(kind).into(), 0) + CreateIoCompletionPort( + handle as _, + self.handle, + CompletionKey::new(kind, &self.key_gen).into(), + 0, + ) }; if result.is_null() { @@ -421,28 +430,40 @@ pub(super) enum CompletionKeyType { /// to low value which may be used by existing handle. /// It is used to differentiate between different types of completion keys. #[repr(transparent)] -pub(super) struct CompletionKey(usize); +struct CompletionKey(usize); -static NEXT_DEFAULT_TOKEN: AtomicUsize = AtomicUsize::new(1); // 0 reserved for default iocp packet -static NEXT_FILE_TOKEN: AtomicUsize = AtomicUsize::new(1usize << (usize::BITS - 1)); // Initialize with high bit set +#[derive(Debug)] +struct CompletionKeyGenerator { + next_default_key: AtomicUsize, + next_file_key: AtomicUsize, +} + +impl Default for CompletionKeyGenerator { + fn default() -> Self { + Self { + next_default_key: AtomicUsize::new(1), // 0 reserved for default iocp packet + next_file_key: AtomicUsize::new(1usize << (usize::BITS - 1)), // Initialize with high bit set + } + } +} impl CompletionKey { const HIGH_BIT: usize = 1usize << (usize::BITS - 1); // 0x8000_0000_0000_0000 on 64-bit const COUNTER_MASK: usize = !Self::HIGH_BIT; // 0x7FFF_FFFF_FFFF_FFFF on 64-bit - pub(super) fn new(kind: CompletionKeyType) -> Self { + pub(super) fn new(kind: CompletionKeyType, gen: &CompletionKeyGenerator) -> Self { match kind { CompletionKeyType::File => { - // For file tokens, increment from HIGH_BIT base + // For file key, increment from HIGH_BIT base // If it would overflow past HIGH_BIT | COUNTER_MASK, wrap back to HIGH_BIT - let token = loop { - let current = NEXT_FILE_TOKEN.load(std::sync::atomic::Ordering::Relaxed); + let key = loop { + let current = gen.next_file_key.load(std::sync::atomic::Ordering::Relaxed); let next = if current == (Self::HIGH_BIT | Self::COUNTER_MASK) { - Self::HIGH_BIT // Wrap back to HIGH_BIT (first file token) + Self::HIGH_BIT // Wrap back to HIGH_BIT (first file key) } else { current + 1 }; - match NEXT_FILE_TOKEN.compare_exchange_weak( + match gen.next_file_key.compare_exchange_weak( current, next, std::sync::atomic::Ordering::Relaxed, @@ -453,20 +474,22 @@ impl CompletionKey { } }; - Self(token) + Self(key) } _ => { - // For default tokens, we need to ensure the counter never exceeds COUNTER_MASK + // For default keys, we need to ensure the counter never exceeds COUNTER_MASK // If it would overflow, wrap back to 0 - let counter = loop { - let current = NEXT_DEFAULT_TOKEN.load(std::sync::atomic::Ordering::Relaxed); + let key = loop { + let current = gen + .next_default_key + .load(std::sync::atomic::Ordering::Relaxed); let next = if current >= Self::COUNTER_MASK { 1 } else { current + 1 }; - match NEXT_DEFAULT_TOKEN.compare_exchange_weak( + match gen.next_default_key.compare_exchange_weak( current, next, std::sync::atomic::Ordering::Relaxed, @@ -477,9 +500,8 @@ impl CompletionKey { } }; - // Keep highest bit as 0, counter is already safe - let token = counter; - Self(token) + // Keep highest bit as 0, key is already safed by above logic + Self(key) } } } @@ -497,7 +519,7 @@ impl From for usize { } impl From for CompletionKey { - fn from(token: usize) -> Self { - Self(token) + fn from(key: usize) -> Self { + Self(key) } } diff --git a/src/os/iocp.rs b/src/os/iocp.rs index 661b826..430c43f 100644 --- a/src/os/iocp.rs +++ b/src/os/iocp.rs @@ -1,11 +1,11 @@ //! Functionality that is only available for IOCP-based platforms. -use windows_sys::Win32::Foundation as wf; use windows_sys::Win32::System::IO::{OVERLAPPED, OVERLAPPED_ENTRY}; +use windows_sys::Win32::{Foundation as wf, Storage::FileSystem as wsf, System::Pipes as wsp}; use crate::iocp::ntdll::NtdllImports; -use crate::iocp::FileCompletionStatus; -pub use crate::iocp::FileOverlappedWrapper; +use crate::iocp::{FileCompletionStatus, PacketWrapper}; +pub use crate::iocp::{FileOverlappedConverter, FileOverlappedWrapper}; pub use crate::sys::CompletionPacket; use super::__private::PollerSealed; @@ -261,40 +261,172 @@ pub trait AsWaitable: AsHandle { impl AsWaitable for T {} /// Overlapped structure owned by the poller and returned to the caller when calling [`add_file`]. -/// The caller must use this structure to get read overlapped ptr or write overlapped ptr as parameter -/// in ReadFile/WriteFile APIs. Otherwise, the behavior is undefined. -/// -/// The overlapped ptr can be safely converted to 'FileOverlappedWrapper' to check result. +/// The caller must use this structure to get read overlapped converter or write overlapped converter +/// as parameter in [`read_file_overlapped`] or [`write_file_overlapped`] methods which help avoid memory leak +/// instead of using raw Windows ReadFile/WriteFile APIs. Otherwise, the behavior is undefined. /// /// [`add_file`]: crate::os::iocp::PollerIocpFileExt::add_file -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone)] pub struct IocpFilePacket { /// read pointer to the overlapped structure read: NonNull, /// write pointer to the overlapped structure write: NonNull, + packet: PacketWrapper, } impl IocpFilePacket { /// Create a new `IocpFilePacket` with the given `OVERLAPPED` pointer. - pub(crate) fn new(read: *mut OVERLAPPED, write: *mut OVERLAPPED) -> Self { + pub(crate) fn new( + read: *mut OVERLAPPED, + write: *mut OVERLAPPED, + packet: PacketWrapper, + ) -> Self { Self { read: NonNull::new(read).unwrap(), write: NonNull::new(write).unwrap(), + packet, } } - /// Get the raw read overlapped pointer to the `OVERLAPPED` structure. - pub fn read_ptr(&self) -> *mut OVERLAPPED { - self.read.as_ptr() + /// Get the raw read overlapped wrapper to the `OVERLAPPED` structure. + pub fn read_complete(&self) -> *mut FileOverlappedWrapper { + unsafe { FileOverlappedWrapper::from_overlapped_ptr(self.read.as_ptr()) } + } + + /// Get the raw write overlapped wrapper to the `OVERLAPPED` structure. + pub fn write_complete(&self) -> *mut FileOverlappedWrapper { + unsafe { FileOverlappedWrapper::from_overlapped_ptr(self.write.as_ptr()) } + } + + /// Get the read overlapped converter which can be used in read_file_overlapped method. + pub fn read_overlapped(&self) -> FileOverlappedConverter { + FileOverlappedConverter::new(self.read.as_ptr(), self.packet.clone()) + } + + /// Get the write overlapped converter which can be used in write_file_overlapped method. + pub fn write_overlapped(&self) -> FileOverlappedConverter { + FileOverlappedConverter::new(self.write.as_ptr(), self.packet.clone()) + } + + /// Get the reference count of the internal packet for testing purpose. + #[doc(hidden)] + pub fn test_ref_count(&self) -> usize { + self.packet.test_ref_count() + } +} + +/// Helper function to perform a file operation with an overlapped converter to avoid memory leak. +pub fn file_op_overlapped(mut overlapped: FileOverlappedConverter, f: F) -> io::Result +where + F: FnOnce(&mut FileOverlappedConverter) -> io::Result, +{ + let ret = f(&mut overlapped); + match ret { + Ok(size) => Ok(size), + Err(e) if e.kind() == io::ErrorKind::WouldBlock => Err(e), + Err(e) => { + overlapped.reclaim(); + Err(e) + } } +} + +/// Wrapper for ConnectNamedPipe API with an overlapped converter to avoid memory leak. +pub fn connect_named_pipe_overlapped( + handle: &impl AsHandle, + overlapped: FileOverlappedConverter, +) -> io::Result<()> { + let ret = file_op_overlapped(overlapped, |overlapped| { + let ret = unsafe { + wsp::ConnectNamedPipe( + handle.as_handle().as_raw_handle() as _, + overlapped + .as_ptr() + .expect("The overlaped pointer may have been used for I/O operation"), + ) + }; + if ret != wf::FALSE { + Ok(0) + } else { + let err = io::Error::last_os_error(); + match err.raw_os_error().map(|e| e as u32) { + Some(wf::ERROR_IO_PENDING) => Err(io::ErrorKind::WouldBlock.into()), + _ => Err(err), + } + } + }); - /// Get the raw write overlapped pointer to the `OVERLAPPED` structure. - pub fn write_ptr(&self) -> *mut OVERLAPPED { - self.write.as_ptr() + match ret { + Ok(_) => Ok(()), + Err(e) => Err(e), } } +/// Wrapper for ReadFile API with an overlapped converter to avoid memory leak. +pub fn read_file_overlapped( + handle: &impl AsHandle, + buf: &mut [u8], + overlapped: FileOverlappedConverter, +) -> io::Result { + file_op_overlapped(overlapped, |overlapped| { + let mut read = 0u32; + // Safety: syscall + if unsafe { + wsf::ReadFile( + handle.as_handle().as_raw_handle() as _, + buf.as_mut_ptr() as *mut _, + buf.len() as u32, + &mut read as *mut _, + overlapped + .as_ptr() + .expect("The overlapped pointer may have been used for I/O operation"), + ) + } != wf::FALSE + { + return Ok(read as usize); + } + + let err = io::Error::last_os_error(); + match err.raw_os_error().map(|e| e as u32) { + Some(wf::ERROR_IO_PENDING) => Err(io::ErrorKind::WouldBlock.into()), + _ => Err(err), + } + }) +} + +/// Wrapper for WriteFile API with an overlapped converter to avoid memory leak. +pub fn write_file_overlapped( + handle: &impl AsHandle, + buf: &[u8], + overlapped: FileOverlappedConverter, +) -> io::Result { + file_op_overlapped(overlapped, |overlapped| { + let mut write = 0u32; + // Safety: syscall + if unsafe { + wsf::WriteFile( + handle.as_handle().as_raw_handle() as _, + buf.as_ptr(), + buf.len() as u32, + &mut write as *mut _, + overlapped + .as_ptr() + .expect("The overlapped pointer may have been used for I/O operation"), + ) + } != wf::FALSE + { + return Ok(write as usize); + } + + let err = io::Error::last_os_error(); + match err.raw_os_error().map(|e| e as u32) { + Some(wf::ERROR_IO_PENDING) => Err(io::ErrorKind::WouldBlock.into()), + _ => Err(err), + } + }) +} + /// Extension trait for the [`Poller`] type that provides file specific functionality to IOCP-based /// platforms. /// @@ -354,48 +486,45 @@ pub trait PollerIocpFileExt: PollerSealed { /// # Examples /// /// ```no_run - /// use polling::os::iocp::{FileOverlappedWrapper, Overlapped, PollerIocpFileExt}; + /// use polling::os::iocp::{read_file_overlapped, write_file_overlapped, PollerIocpFileExt}; /// use polling::{Event, Events, Poller}; - /// use windows_sys::Win32::System::IO::OVERLAPPED; /// /// use std::ffi::OsStr; /// use std::fs::OpenOptions; /// use std::io; /// use std::os::windows::ffi::OsStrExt; /// use std::os::windows::fs::OpenOptionsExt; - /// use std::os::windows::io::{AsRawHandle, FromRawHandle, IntoRawHandle, OwnedHandle}; + /// use std::os::windows::io::{FromRawHandle, IntoRawHandle, OwnedHandle}; /// use std::time::Duration; /// - /// use windows_sys::Win32::{ - /// Foundation as wf, Storage::FileSystem as wfs, System::Pipes as wps, System::IO as wio, - /// }; + /// use windows_sys::Win32::{Foundation as wf, Storage::FileSystem as wfs, System::Pipes as wps}; /// /// fn new_named_pipe>(addr: A) -> io::Result { - /// let fname = addr - /// .as_ref() - /// .encode_wide() - /// .chain(Some(0)) - /// .collect::>(); - /// let handle = unsafe { - /// let raw_handle = wps::CreateNamedPipeW( - /// fname.as_ptr(), - /// wfs::PIPE_ACCESS_DUPLEX | wfs::FILE_FLAG_OVERLAPPED, - /// wps::PIPE_TYPE_BYTE | wps::PIPE_READMODE_BYTE | wps::PIPE_WAIT, - /// 1, - /// 4096, - /// 4096, - /// 0, - /// std::ptr::null_mut(), - /// ); - /// - /// if raw_handle == wf::INVALID_HANDLE_VALUE { - /// return Err(io::Error::last_os_error()); - /// } - /// - /// OwnedHandle::from_raw_handle(raw_handle as _) - /// }; - /// - /// Ok(handle) + /// let fname = addr + /// .as_ref() + /// .encode_wide() + /// .chain(Some(0)) + /// .collect::>(); + /// let handle = unsafe { + /// let raw_handle = wps::CreateNamedPipeW( + /// fname.as_ptr(), + /// wfs::PIPE_ACCESS_DUPLEX | wfs::FILE_FLAG_OVERLAPPED, + /// wps::PIPE_TYPE_BYTE | wps::PIPE_READMODE_BYTE | wps::PIPE_WAIT, + /// 1, + /// 4096, + /// 4096, + /// 0, + /// std::ptr::null_mut(), + /// ); + /// + /// if raw_handle == wf::INVALID_HANDLE_VALUE { + /// return Err(io::Error::last_os_error()); + /// } + /// + /// OwnedHandle::from_raw_handle(raw_handle as _) + /// }; + /// + /// Ok(handle) /// } /// /// fn server() -> (OwnedHandle, String) { @@ -433,52 +562,33 @@ pub trait PollerIocpFileExt: PollerSealed { /// /// let client_overlapped = poller.add_file(&client, Event::new(2, true, true)).unwrap(); /// - /// let mut written = 0u32; - /// let ret = wfs::WriteFile( - /// client.as_raw_handle(), - /// b"1234" as *const u8, - /// 4, - /// (&mut written) as *mut u32, - /// client_overlapped.write_ptr(), - /// ); + /// let ret = write_file_overlapped(&client, b"1234", client_overlapped.write_overlapped()); /// - /// assert!(ret == wf::TRUE && written == 4); + /// assert_eq!(ret.unwrap(), 4); /// - /// loop { - /// poller.wait(&mut events, None).unwrap(); - /// let events = events.iter().collect::>(); - /// if let Some(event) = events.iter().find(|e| e.key == 2) { - /// if event.writable { - /// break; - /// } - /// } - /// } + /// poller.wait(&mut events, None).unwrap(); + /// + /// let w_events = events.iter().collect::>(); + /// assert_eq!(w_events.len(), 1); + /// assert_eq!(w_events[0].key, 2); + /// assert!(w_events[0].writable); /// /// events.clear(); /// let mut buf = [0u8; 10]; /// /// let mut read = 0u32; - /// let ret = wfs::ReadFile( - /// server.as_raw_handle(), - /// &mut buf as *mut u8, - /// 10, - /// (&mut read) as *mut u32, - /// server_overlapped.read_ptr(), - /// ); + /// let ret = read_file_overlapped(&server, &mut buf, server_overlapped.read_overlapped()); /// /// let event_len = poller /// .wait(&mut events, Some(Duration::from_millis(10))) /// .unwrap(); /// assert_eq!(event_len, 1); /// - /// let events = events.iter().collect::>(); - /// events.iter().for_each(|e| { - /// if e.key == 2 { - /// assert_eq!(e.writable, true); - /// } - /// }); - /// - /// assert!(ret == wf::TRUE && read == 4); + /// let r_events = events.iter().collect::>(); + /// assert_eq!(r_events.len(), 1); + /// assert_eq!(r_events[0].key, 1); + /// assert!(r_events[0].readable); + /// assert_eq!(ret.unwrap(), 4); /// assert_eq!(&buf[..4], b"1234"); /// /// poller.remove_file(&server).unwrap(); @@ -488,135 +598,209 @@ pub trait PollerIocpFileExt: PollerSealed { /// } /// } /// ``` - unsafe fn add_file(&self, file: &impl AsRawHandle, event: Event) -> io::Result; + unsafe fn add_file( + &self, + file: impl AsRawFileHandle, + event: Event, + ) -> io::Result; - /// Remove a file handle from this poller. + /// Modifies the interest in a file handle. /// - /// This function can be used to remove a file handle from the poller. The handle must - /// have been previously added to the poller using [`add_file`]. + /// This method has the same behavior as [`add_file()`][`Poller::add_file()`] except it modifies the + /// interest of a previously added file handle. The `file` parameter must impl AsFileHandle trait + /// to ensure the handle is not closed before remove_file is called. /// - /// [`add_file`]: Self::add_file + /// To use this method with a file handle, you must first add it using + /// [`add_file()`][`Poller::add_file()`]. /// /// # Examples /// /// ```no_run - /// use polling::os::iocp::{FileOverlappedWrapper, Overlapped, PollerIocpFileExt}; - /// use polling::{Event, Events, Poller}; - /// use windows_sys::Win32::System::IO::OVERLAPPED; - /// - /// use std::ffi::OsStr; + /// use polling::os::iocp::PollerIocpFileExt; + /// use polling::{Event, Poller}; /// use std::fs::OpenOptions; /// use std::io; - /// use std::os::windows::ffi::OsStrExt; /// use std::os::windows::fs::OpenOptionsExt; - /// use std::os::windows::io::{AsRawHandle, FromRawHandle, IntoRawHandle, OwnedHandle}; - /// use std::time::Duration; + /// use std::os::windows::io::{FromRawHandle, IntoRawHandle, OwnedHandle}; + /// use windows_sys::Win32::Storage::FileSystem as wfs; /// - /// use windows_sys::Win32::{ - /// Foundation as wf, Storage::FileSystem as wfs, System::Pipes as wps, System::IO as wio, - /// }; + /// fn client(name: &str) -> io::Result { + /// let mut opts = OpenOptions::new(); + /// opts.read(true) + /// .write(true) + /// .custom_flags(wfs::FILE_FLAG_OVERLAPPED); + /// let file = opts.open(name)?; + /// unsafe { Ok(OwnedHandle::from_raw_handle(file.into_raw_handle())) } + /// } /// - /// // Create a poller. - /// let poller = Poller::new().unwrap(); - /// let mut events = Events::new(); - /// println!("Create a temp file"); - /// // Open a file for writing. - /// let dir = tempfile::tempdir().unwrap(); - /// let file_path = dir.path().join("test.txt"); - /// let fname = file_path - /// .as_os_str() - /// .encode_wide() - /// .chain(Some(0)) - /// .collect::>(); - /// let file_handle = unsafe { - /// let raw_handle = wfs::CreateFileW( - /// fname.as_ptr(), - /// wf::GENERIC_WRITE | wf::GENERIC_READ, - /// 0, - /// std::ptr::null_mut(), - /// wfs::CREATE_ALWAYS, - /// wfs::FILE_FLAG_OVERLAPPED, - /// std::ptr::null_mut(), - /// ); - /// if raw_handle == wf::INVALID_HANDLE_VALUE { - /// panic!("CreateFileW failed: {}", io::Error::last_os_error()); - /// } - /// OwnedHandle::from_raw_handle(raw_handle as _) - /// }; - /// println!("file handle: {:?}", file_handle); - /// let overlapped = unsafe { - /// poller - /// .add_file(&file_handle, Event::new(1, true, true)) - /// .unwrap() + /// # fn main() -> io::Result<()> { + /// static PIPE_NUM: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0); + /// let num = PIPE_NUM.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + /// let name = format!(r"\\.\pipe\my-pipe-{}", num); + /// let client = client(&name).unwrap(); + /// let poller = Poller::new().unwrap(); + /// let key = 2; + /// let _client_overlapped = unsafe { poller.add_file(&client, Event::none(key)).unwrap() }; + /// poller.modify_file(&client, Event::writable(key))?; + /// poller.remove_file(&client) + /// # } + /// ``` + fn modify_file(&self, handle: impl AsFileHandle, interest: Event) -> io::Result<()>; + + /// Remove a file handle from this poller. + /// + /// This function can be used to remove a file handle from the poller. The handle must + /// have been previously added to the poller using [`add_file`]. The `file` parameter must impl + /// AsFileHandle trait to ensure the handle is not closed before remove_file is called. + /// + /// + /// [`add_file`]: Self::add_file + /// + /// # Examples + /// + /// ```no_run + /// use polling::os::iocp::{ + /// write_file_overlapped, PollerIocpFileExt, /// }; + /// use polling::{Event, Events, Poller}; /// - /// // Repeatedly write to the pipe. - /// let input_text = "Now is the time for all good men to come to the aid of their party"; - /// let mut len = input_text.len(); - /// while len > 0 { - /// // Begin the write. - /// let ptr = overlapped.write_ptr(); - /// unsafe { - /// let ret = wfs::WriteFile( - /// file_handle.as_raw_handle() as _, - /// input_text.as_ptr() as _, - /// len as _, + /// use std::io; + /// use std::os::windows::ffi::OsStrExt; + /// use std::os::windows::io::{FromRawHandle, OwnedHandle}; + /// + /// use windows_sys::Win32::{Foundation as wf, Storage::FileSystem as wfs}; + /// + /// # fn main() { + /// // Create a poller. + /// let poller = Poller::new().unwrap(); + /// let mut events = Events::new(); + /// println!("Create a temp file"); + /// // Open a file for writing. + /// let dir = tempfile::tempdir().unwrap(); + /// let file_path = dir.path().join("test.txt"); + /// let fname = file_path + /// .as_os_str() + /// .encode_wide() + /// .chain(Some(0)) + /// .collect::>(); + /// let file_handle = unsafe { + /// let raw_handle = wfs::CreateFileW( + /// fname.as_ptr(), + /// wf::GENERIC_WRITE | wf::GENERIC_READ, + /// 0, + /// std::ptr::null_mut(), + /// wfs::CREATE_ALWAYS, + /// wfs::FILE_FLAG_OVERLAPPED, /// std::ptr::null_mut(), - /// ptr, /// ); - /// println!("WriteFile returned: {}, len: {}, ptr: {:p}", ret, len, ptr); - /// if ret == 0 && wf::GetLastError() != wf::ERROR_IO_PENDING { - /// panic!("WriteFile failed: {}", io::Error::last_os_error()); + /// if raw_handle == wf::INVALID_HANDLE_VALUE { + /// panic!("CreateFileW failed: {}", io::Error::last_os_error()); /// } - /// } - /// // Wait for the overlapped operation to complete. - /// 'waiter: loop { - /// events.clear(); - /// println!("Starting wait..."); - /// poller.wait(&mut events, None).unwrap(); - /// println!("Got events"); - /// for event in events.iter() { - /// if event.writable && event.key == 1 { - /// break 'waiter; + /// OwnedHandle::from_raw_handle(raw_handle as _) + /// }; + /// println!("file handle: {:?}", file_handle); + /// let overlapped = unsafe { + /// poller + /// .add_file(&file_handle, Event::new(1, true, true)) + /// .unwrap() + /// }; + /// + /// // Repeatedly write to the pipe. + /// let input_text = "Now is the time for all good men to come to the aid of their party"; + /// let mut len = input_text.len(); + /// while len > 0 { + /// // Begin the write. + /// let ret = write_file_overlapped(&file_handle, b"1234", overlapped.write_overlapped()); + /// let _ = ret.map_err(|e| { + /// if e.kind() != io::ErrorKind::WouldBlock { + /// panic!("WriteFile failed: {}", e); /// } - /// } - /// } - /// // Decrement the length by the number of bytes written. - /// let wrapper = unsafe { &*FileOverlappedWrapper::from_overlapped_ptr(ptr) }; - /// wrapper.get_result().map_or_else( - /// |e| { - /// match e.kind() { - /// io::ErrorKind::WouldBlock => { - /// // The operation is still pending, we can ignore this error. - /// println!("WriteFile is still pending, continuing..."); + /// }); + /// // Wait for the overlapped operation to complete. + /// 'waiter: loop { + /// events.clear(); + /// println!("Starting wait..."); + /// poller.wait(&mut events, None).unwrap(); + /// println!("Got events"); + /// for event in events.iter() { + /// if event.writable && event.key == 1 { + /// break 'waiter; /// } - /// _ => panic!("WriteFile failed: {}", e), - /// } - /// }, - /// |ret| { - /// if (!ret) { - /// println!("The file handle maybe closed"); - /// } - /// else { - /// let bytes_written = wrapper.get_bytes_transferred(); - /// println!("Bytes written: {}", bytes_written); - /// len -= bytes_written as usize; /// } - /// }, - /// ); - /// } - /// poller.remove_file(&file_handle).unwrap(); + /// } + /// // Decrement the length by the number of bytes written. + /// let wrapper = unsafe { &*overlapped.write_complete() }; + /// wrapper.get_result().map_or_else( + /// |e| { + /// match e.kind() { + /// io::ErrorKind::WouldBlock => { + /// // The operation is still pending, we can ignore this error. + /// println!("WriteFile is still pending, continuing..."); + /// } + /// _ => panic!("WriteFile failed: {}", e), + /// } + /// }, + /// |ret| { + /// if (!ret) { + /// println!("The file handle maybe closed"); + /// } else { + /// let bytes_written = wrapper.get_bytes_transferred(); + /// println!("Bytes written: {}", bytes_written); + /// len -= bytes_written as usize; + /// } + /// }, + /// ); + /// } + /// poller.remove_file(file_handle).unwrap(); + /// # } /// ``` - fn remove_file(&self, file: &impl AsRawHandle) -> io::Result<()>; + fn remove_file(&self, file: impl AsFileHandle) -> io::Result<()>; +} + +/// A type that represents a raw file handle. +pub trait AsRawFileHandle { + /// Returns the raw handle of this file. + fn as_raw_handle(&self) -> RawHandle; +} + +impl AsRawFileHandle for RawHandle { + fn as_raw_handle(&self) -> RawHandle { + *self + } } +impl AsRawFileHandle for &T { + fn as_raw_handle(&self) -> RawHandle { + AsRawHandle::as_raw_handle(*self) + } +} + +/// A type that represents a file handle. +pub trait AsFileHandle: AsHandle { + /// Returns the raw handle of this file. + fn as_file(&self) -> BorrowedHandle<'_> { + self.as_handle() + } +} + +impl AsFileHandle for T {} + impl PollerIocpFileExt for Poller { - unsafe fn add_file(&self, file: &impl AsRawHandle, event: Event) -> io::Result { + unsafe fn add_file( + &self, + file: impl AsRawFileHandle, + event: Event, + ) -> io::Result { self.poller.add_file(file.as_raw_handle(), event) } - fn remove_file(&self, file: &impl AsRawHandle) -> io::Result<()> { - self.poller.remove_file(file.as_raw_handle()) + fn modify_file(&self, handle: impl AsFileHandle, interest: Event) -> io::Result<()> { + self.poller + .modify_file(handle.as_file().as_raw_handle(), interest) + } + + fn remove_file(&self, file: impl AsFileHandle) -> io::Result<()> { + self.poller.remove_file(file.as_file().as_raw_handle()) } } @@ -686,38 +870,37 @@ impl Drop for OverlappedInner { /// [`IocpFilePacket`] read/write pointer can safety convert to this structure to check I/O operation /// results. [`FileOverlappedWrapper`] is alise for access convinence. +/// /// # Examples /// /// ```no_run -/// use polling::os::iocp::{FileOverlappedWrapper, IocpFilePacket}; +/// use polling::os::iocp::IocpFilePacket; /// use std::io; /// -/// # fn example(overlapped: IocpFilePacket, mut len: usize) { -/// let ptr = overlapped.write_ptr(); -/// let wrapper = unsafe { &*FileOverlappedWrapper::from_overlapped_ptr(ptr) }; -/// println!("bytes transferred: {}", wrapper.get_bytes_transferred()); -/// wrapper.get_result().map_or_else( -/// |e| { -/// match e.kind() { -/// io::ErrorKind::WouldBlock => { -/// // The operation is still pending, we can ignore this error. -/// println!("WriteFile is still pending, continuing..."); +/// fn write_all(overlapped: IocpFilePacket, mut len: usize) { +/// let wrapper = unsafe { &*overlapped.write_complete() }; +/// println!("bytes transferred: {}", wrapper.get_bytes_transferred()); +/// wrapper.get_result().map_or_else( +/// |e| { +/// match e.kind() { +/// io::ErrorKind::WouldBlock => { +/// // The operation is still pending, we can ignore this error. +/// println!("WriteFile is still pending, continuing..."); +/// } +/// _ => panic!("WriteFile failed: {}", e), +/// } +/// }, +/// |ret| { +/// if !ret { +/// println!("The file handle maybe closed"); +/// } else { +/// let bytes_written = wrapper.get_bytes_transferred(); +/// println!("Bytes written: {}", bytes_written); +/// len -= bytes_written as usize; /// } -/// _ => panic!("WriteFile failed: {}", e), -/// } -/// }, -/// |ret| { -/// if (!ret) { -/// println!("The file handle maybe closed"); -/// } -/// else { -/// let bytes_written = wrapper.get_bytes_transferred(); -/// println!("Bytes written: {}", bytes_written); -/// len -= bytes_written as usize; -/// } -/// }, -/// ); -/// # } +/// }, +/// ); +/// } /// ``` /// [`FileOverlappedWrapper`]: crate::iocp::FileOverlappedWrapper #[derive(Debug)] @@ -729,11 +912,12 @@ pub struct Overlapped { } impl Overlapped { - /// Convert from OVERLAPPED_ENTRY.lpOverlapped back to Overlapped + /// Convert from OVERLAPPED_ENTRY.lpOverlapped back to `Overlapped` /// /// # Safety /// - /// The overlapped_ptr must point to the `inner` field of a valid Overlapped instance + /// The overlapped_ptr must point to the `inner` field of a valid `Overlapped` instance + /// Normally, the call should be made through [`IocpFilePacket::read_complete`] or [`IocpFilePacket::write_complete`] pub unsafe fn from_overlapped_ptr(overlapped_ptr: *mut OVERLAPPED) -> *mut Self { // Calculate offset of 'inner' field within Overlapped let offset = std::mem::offset_of!(Overlapped, inner); diff --git a/tests/windows_overlapped.rs b/tests/windows_overlapped.rs index 945b42f..a170bf1 100644 --- a/tests/windows_overlapped.rs +++ b/tests/windows_overlapped.rs @@ -2,16 +2,17 @@ #![cfg(windows)] -use polling::os::iocp::{FileOverlappedWrapper, PollerIocpFileExt}; +use polling::os::iocp::{ + connect_named_pipe_overlapped, read_file_overlapped, write_file_overlapped, PollerIocpFileExt, +}; use polling::{Event, Events, Poller}; -use windows_sys::Win32::System::IO::OVERLAPPED; use std::ffi::OsStr; use std::fs::OpenOptions; use std::io; use std::os::windows::ffi::OsStrExt; use std::os::windows::fs::OpenOptionsExt; -use std::os::windows::io::{AsRawHandle, FromRawHandle, IntoRawHandle, OwnedHandle}; +use std::os::windows::io::{FromRawHandle, IntoRawHandle, OwnedHandle}; use std::time::Duration; use windows_sys::Win32::{Foundation as wf, Storage::FileSystem as wfs, System::Pipes as wps}; @@ -22,7 +23,6 @@ fn win32_file_io() { let poller = Poller::new().unwrap(); let mut events = Events::new(); - println!("Create a temp file"); // Open a file for writing. let dir = tempfile::tempdir().unwrap(); let file_path = dir.path().join("test.txt"); @@ -49,7 +49,6 @@ fn win32_file_io() { OwnedHandle::from_raw_handle(raw_handle as _) }; - println!("file handle: {:?}", file_handle); let overlapped = unsafe { poller .add_file(&file_handle, Event::new(1, true, true)) @@ -58,21 +57,14 @@ fn win32_file_io() { // Repeatedly write to the pipe. let input_text = "Now is the time for all good men to come to the aid of their party"; - let mut len = input_text.len(); - - while len > 0 { - // Begin the write. - let ptr = overlapped.write_ptr(); - unsafe { - let ret = wfs::WriteFile( - file_handle.as_raw_handle() as _, - input_text.as_ptr() as _, - len as _, - std::ptr::null_mut(), - ptr, - ); - println!("WriteFile returned: {}, len: {}, ptr: {:p}", ret, len, ptr); - if ret == 0 && wf::GetLastError() != wf::ERROR_IO_PENDING { + + // Begin the write. + let ptr = overlapped.write_overlapped(); + { + let ret = write_file_overlapped(&file_handle, input_text.as_ref(), ptr); + println!("WriteFile returned: {:?}", ret); + if let Err(e) = ret { + if e.kind() != io::ErrorKind::WouldBlock { // Only panic if not running under Wine if std::env::var("WINELOADER").is_ok() || std::env::var("WINE").is_ok() @@ -85,49 +77,27 @@ fn win32_file_io() { } } } + } - // Wait for the overlapped operation to complete. - 'waiter: loop { - events.clear(); - println!("Starting wait..."); - poller.wait(&mut events, None).unwrap(); - println!("Got events"); + // Wait for the overlapped operation to complete. + events.clear(); + poller.wait(&mut events, None).unwrap(); + let w_events = events.iter().collect::>(); + assert_eq!(w_events.len(), 1); + assert_eq!(w_events[0].key, 1); + assert!(w_events[0].writable); - for event in events.iter() { - if event.writable && event.key == 1 { - break 'waiter; - } - } - } + // Check the number of bytes written. + let wrapper = unsafe { &*overlapped.write_complete() }; - // Decrement the length by the number of bytes written. - let wrapper = unsafe { &*FileOverlappedWrapper::from_overlapped_ptr(ptr) }; - wrapper.get_result().map_or_else( - |e| { - match e.kind() { - io::ErrorKind::WouldBlock => { - // The operation is still pending, we can ignore this error. - println!("WriteFile is still pending, continuing..."); - } - _ => panic!("WriteFile failed: {}", e), - } - }, - |ret| { - if !ret { - println!("The file handle maybe closed"); - } else { - let bytes_written = wrapper.get_bytes_transferred(); - println!("Bytes written: {}", bytes_written); - len -= bytes_written as usize; - } - }, - ); - } + assert!(wrapper.get_result().unwrap()); + assert_eq!(wrapper.get_bytes_transferred() as usize, input_text.len()); poller.remove_file(&file_handle).unwrap(); + + assert_eq!(overlapped.test_ref_count(), 1); // Close the file and re-open it for reading. drop(file_handle); - println!("file handle dropped"); let file_handle = unsafe { let raw_handle = wfs::CreateFileW( @@ -147,7 +117,6 @@ fn win32_file_io() { OwnedHandle::from_raw_handle(raw_handle as _) }; - println!("file handle: {:?}", file_handle); let overlapped = unsafe { poller .add_file(&file_handle, Event::new(1, true, true)) @@ -156,48 +125,33 @@ fn win32_file_io() { // Repeatedly read from the pipe. let mut buffer = vec![0u8; 1024]; - let mut buffer_cursor = &mut *buffer; - let mut len = 1024; - let mut bytes_received = 0; - - while bytes_received < input_text.len() { - // Begin the read. - let ptr = overlapped.read_ptr(); - unsafe { - if wfs::ReadFile( - file_handle.as_raw_handle() as _, - buffer_cursor.as_mut_ptr() as _, - len as _, - std::ptr::null_mut(), - ptr, - ) == 0 - && wf::GetLastError() != wf::ERROR_IO_PENDING - { - panic!("ReadFile failed: {}", io::Error::last_os_error()); - } - } + let buffer_cursor = &mut *buffer; - // Wait for the overlapped operation to complete. - 'waiter: loop { - events.clear(); - poller.wait(&mut events, None).unwrap(); + // Begin the read. + let ptr = overlapped.read_overlapped(); + let ret = read_file_overlapped(&file_handle, buffer_cursor, ptr); - for event in events.iter() { - if event.readable && event.key == 1 { - break 'waiter; - } - } + if let Err(e) = ret { + if e.kind() != io::ErrorKind::WouldBlock { + panic!("ReadFile failed: {}", io::Error::last_os_error()); } - - // Increment the cursor and decrement the length by the number of bytes read. - let bytes_read = input_text.len(); - buffer_cursor = &mut buffer_cursor[bytes_read..]; - len -= bytes_read; - bytes_received += bytes_read; } - assert_eq!(bytes_received, input_text.len()); - assert_eq!(&buffer[..bytes_received], input_text.as_bytes()); + events.clear(); + poller.wait(&mut events, None).unwrap(); + let r_events = events.iter().collect::>(); + assert_eq!(r_events.len(), 1); + assert_eq!(r_events[0].key, 1); + assert!(r_events[0].readable); + + // Check the number of bytes written. + let wrapper = unsafe { &*overlapped.read_complete() }; + + assert!(wrapper.get_result().unwrap()); + assert_eq!(wrapper.get_bytes_transferred() as usize, input_text.len()); + assert_eq!(&buffer[..input_text.len()], input_text.as_bytes()); + drop(poller); + assert_eq!(overlapped.test_ref_count(), 1); } fn new_named_pipe>(addr: A) -> io::Result { @@ -228,25 +182,6 @@ fn new_named_pipe>(addr: A) -> io::Result { Ok(handle) } -unsafe fn connect_named_pipe( - handle: &impl AsRawHandle, - overlapped: *mut OVERLAPPED, -) -> io::Result<()> { - if wps::ConnectNamedPipe(handle.as_raw_handle() as _, overlapped) != 0 { - // If ConnectNamedPipe returns non-zero, the connection was successful. - return Ok(()); - } - - let err = io::Error::last_os_error(); - - match err.raw_os_error().map(|e| e as u32) { - Some(wf::ERROR_PIPE_CONNECTED) => Ok(()), - Some(wf::ERROR_NO_DATA) => Err(io::ErrorKind::WouldBlock.into()), - Some(wf::ERROR_IO_PENDING) => Err(io::ErrorKind::WouldBlock.into()), - _ => Err(err), - } -} - fn server() -> (OwnedHandle, String) { let num: u64 = fastrand::u64(..); let name = format!(r"\\.\pipe\my-pipe-{}", num); @@ -270,7 +205,8 @@ fn pipe() -> (OwnedHandle, OwnedHandle) { // Test client create success if server create named pipe first. // Client can write data to pipe without server call ConnectNamedPipe first. -// Client return NotFound error if clinet create before server create named pipe. +// Client return NotFound error if client create before server create named pipe. +// If Client create file before server call connect_named_pipe, connect_named_pipe return ERROR_PIPE_CONNECTED // Poller will not receive event if client and server create before register file. // Poller will also not receive event if server create and add to poller before client create named pipe. #[test] @@ -284,18 +220,28 @@ fn writable_after_register() { let poller = Poller::new().unwrap(); let mut events = Events::new(); - let _server_overlapped = unsafe { + let server_overlapped = unsafe { poller .add_file(&server, Event::new(1, true, false)) .unwrap() }; - let _client_overlapped = unsafe { + let client_overlapped = unsafe { poller .add_file(&client, Event::new(2, false, true)) .unwrap() }; + // connect_named_pipe return ERROR_PIPE_CONNECTED if the pipe has been connected. + { + let ret = connect_named_pipe_overlapped(&server, server_overlapped.read_overlapped()); + assert_eq!(server_overlapped.test_ref_count(), 2); + assert_eq!( + ret.err().unwrap().raw_os_error(), + Some(wf::ERROR_PIPE_CONNECTED as i32) + ); + } + poller .wait(&mut events, Some(Duration::from_millis(10))) .unwrap(); @@ -305,6 +251,8 @@ fn writable_after_register() { poller.remove_file(&client).unwrap(); drop(server); drop(client); + assert_eq!(server_overlapped.test_ref_count(), 1); + assert_eq!(client_overlapped.test_ref_count(), 1); } // Poller will receive event if server add to poller before client create file @@ -312,7 +260,7 @@ fn writable_after_register() { let poller = Poller::new().unwrap(); let mut events = Events::new(); - let _server_overlapped = unsafe { + let server_overlapped = unsafe { poller .add_file(&server, Event::new(1, true, false)) .unwrap() @@ -324,7 +272,9 @@ fn writable_after_register() { .unwrap(); assert!(events.is_empty()); + assert_eq!(server_overlapped.test_ref_count(), 2); + drop(server_overlapped); poller.remove_file(&server).unwrap(); drop(server); drop(client); @@ -347,53 +297,37 @@ fn write_then_read() { let client_overlapped = unsafe { poller.add_file(&client, Event::new(2, true, true)).unwrap() }; - unsafe { - let mut written = 0u32; - let ret = wfs::WriteFile( - client.as_raw_handle(), - b"1234" as *const u8, - 4, - (&mut written) as *mut u32, - client_overlapped.write_ptr(), - ); + { + let ret = write_file_overlapped(&client, b"1234", client_overlapped.write_overlapped()); - assert!(ret == wf::TRUE && written == 4); + assert_eq!(ret.unwrap(), 4); + assert_eq!(client_overlapped.test_ref_count(), 3); - loop { - poller.wait(&mut events, None).unwrap(); - let events = events.iter().collect::>(); - if let Some(event) = events.iter().find(|e| e.key == 2) { - if event.writable { - break; - } - } - } + poller.wait(&mut events, None).unwrap(); + assert_eq!(client_overlapped.test_ref_count(), 2); + let w_events = events.iter().collect::>(); + assert_eq!(w_events.len(), 1); + assert_eq!(w_events[0].key, 2); + assert!(w_events[0].writable); events.clear(); let mut buf = [0u8; 10]; - let mut read = 0u32; - let ret = wfs::ReadFile( - server.as_raw_handle(), - &mut buf as *mut u8, - 10, - (&mut read) as *mut u32, - server_overlapped.read_ptr(), - ); + let ret = read_file_overlapped(&server, &mut buf, server_overlapped.read_overlapped()); + assert_eq!(server_overlapped.test_ref_count(), 3); let event_len = poller .wait(&mut events, Some(Duration::from_millis(10))) .unwrap(); assert_eq!(event_len, 1); + assert_eq!(server_overlapped.test_ref_count(), 2); - let events = events.iter().collect::>(); - events.iter().for_each(|e| { - if e.key == 2 { - assert!(e.writable); - } - }); + let r_events = events.iter().collect::>(); + assert_eq!(r_events.len(), 1); + assert_eq!(r_events[0].key, 1); + assert!(r_events[0].readable); - assert!(ret == wf::TRUE && read == 4); + assert_eq!(ret.unwrap(), 4); assert_eq!(&buf[..4], b"1234"); } @@ -403,6 +337,94 @@ fn write_then_read() { drop(client); } +// Read completion will be trigger if the pipe is closed +#[test] +fn close_before_read_complete() { + let (server, _name) = server(); + let poller = Poller::new().unwrap(); + let mut events = Events::new(); + + let server_overlapped = unsafe { + poller + .add_file(&server, Event::new(1, true, false)) + .unwrap() + }; + + poller.wait(&mut events, Some(Duration::new(0, 0))).unwrap(); + assert_eq!(events.iter().count(), 0); + assert_eq!(server_overlapped.test_ref_count(), 2); + + { + let ret = connect_named_pipe_overlapped(&server, server_overlapped.read_overlapped()); + assert_eq!(ret.err().unwrap().kind(), io::ErrorKind::WouldBlock); + assert_eq!(server_overlapped.test_ref_count(), 3); + } + + poller.wait(&mut events, Some(Duration::new(0, 0))).unwrap(); + assert_eq!(events.iter().count(), 0); + assert_eq!(server_overlapped.test_ref_count(), 3); + + drop(server); + let event_num = poller.wait(&mut events, Some(Duration::new(0, 0))).unwrap(); + assert_eq!(event_num, 1); + assert_eq!(server_overlapped.test_ref_count(), 2); + + let r_events = events.iter().collect::>(); + events.clear(); + assert_eq!(r_events.len(), 1); + assert_eq!(r_events[0].key, 1); + assert!(r_events[0].readable); + let overlapped_wrapper = unsafe { &*server_overlapped.read_complete() }; + assert_eq!(overlapped_wrapper.get_bytes_transferred(), 0); + if !(std::env::var("WINELOADER").is_ok() + || std::env::var("WINE").is_ok() + || std::env::var("WINEPREFIX").is_ok()) + { + assert_eq!( + overlapped_wrapper + .get_result() + .unwrap_err() + .raw_os_error() + .unwrap(), + wf::ERROR_BROKEN_PIPE as i32 + ); + } + drop(poller); + assert_eq!(server_overlapped.test_ref_count(), 1); +} + +// Write completion will hold ref count until write complete even if the pipe is removed from poller and closed. +// Poller is Edge mode, write events will be triggered twice. +#[test] +fn close_before_write_twice_complete() { + let (server, client) = pipe(); + let poller = Poller::new().unwrap(); + let mut events = Events::new(); + + let server_overlapped = unsafe { poller.add_file(&server, Event::new(1, true, true)).unwrap() }; + + let _client_overlapped = + unsafe { poller.add_file(&client, Event::new(2, true, true)).unwrap() }; + + let ret = write_file_overlapped(&server, b"1234", server_overlapped.write_overlapped()); + + assert_eq!(ret.unwrap(), 4); + + let ret = write_file_overlapped(&server, b"1234", server_overlapped.write_overlapped()); + + assert_eq!(ret.unwrap(), 4); + assert_eq!(server_overlapped.test_ref_count(), 4); + + assert!(poller.remove_file(&server).is_ok()); + drop(server); + + let event_num = poller + .wait(&mut events, Some(Duration::from_millis(10))) + .unwrap(); + assert_eq!(event_num, 2); + assert_eq!(server_overlapped.test_ref_count(), 1); +} + // Poller will receive read event if server call ConnectNamedPipe after add to poller before // client create named pipe. #[test] @@ -421,27 +443,22 @@ fn connect_before_client() { assert_eq!(events.iter().count(), 0); unsafe { - let ret = connect_named_pipe(&server, server_overlapped.read_ptr()); + let ret = connect_named_pipe_overlapped(&server, server_overlapped.read_overlapped()); assert_eq!(ret.err().unwrap().kind(), io::ErrorKind::WouldBlock); let client = client(&name).unwrap(); let _client_overlapped = poller.add_file(&client, Event::new(2, true, true)).unwrap(); - loop { - let event_num = poller.wait(&mut events, None).unwrap(); - assert_eq!(event_num, 1); - let e = events.iter().collect::>(); - events.clear(); - if let Some(event) = e.iter().find(|e| e.key == 1) { - if event.readable { - let overlapped_wrapper = - &*FileOverlappedWrapper::from_overlapped_ptr(server_overlapped.read_ptr()); - assert_eq!(overlapped_wrapper.get_bytes_transferred(), 0); - assert!(overlapped_wrapper.get_result().is_ok()); - break; - } - } - } + let event_num = poller.wait(&mut events, None).unwrap(); + assert_eq!(event_num, 1); + let r_events = events.iter().collect::>(); + assert_eq!(r_events[0].key, 1); + assert!(r_events[0].readable); + events.clear(); + + let overlapped_wrapper = &*server_overlapped.read_complete(); + assert_eq!(overlapped_wrapper.get_bytes_transferred(), 0); + assert!(overlapped_wrapper.get_result().is_ok()); poller.remove_file(&server).unwrap(); poller.remove_file(&client).unwrap(); @@ -477,28 +494,19 @@ fn write_disconnected() { .unwrap(); assert!(events.iter().count() == 0); - unsafe { - let mut written = 0u32; - let ret = wfs::WriteFile( - server.as_raw_handle(), - b"1234" as *const u8, - 1, - (&mut written) as *mut u32, - server_overlapped.write_ptr(), - ); - - let e = io::Error::last_os_error(); + let ret = write_file_overlapped(&server, b"1234", server_overlapped.write_overlapped()); - assert_eq!(ret, wf::FALSE); - assert_eq!(written, 0); - assert_eq!(e.raw_os_error(), Some(wf::ERROR_NO_DATA as i32)); + assert_eq!( + ret.err().unwrap().raw_os_error(), + Some(wf::ERROR_NO_DATA as i32) + ); + assert_eq!(server_overlapped.test_ref_count(), 2); - // according testing, it return ERROR_NO_DATA. the server cannot write even one byte - let num_event = poller - .wait(&mut events, Some(Duration::from_millis(10))) - .unwrap(); - assert_eq!(num_event, 0); - } + // according testing, it return ERROR_NO_DATA. the server cannot write even one byte + let num_event = poller + .wait(&mut events, Some(Duration::from_millis(10))) + .unwrap(); + assert_eq!(num_event, 0); } // Poller will receive write event if client write data to pipe before drop. @@ -522,18 +530,10 @@ fn write_then_drop() { .unwrap() }; - unsafe { - let mut written = 0u32; - let ret = wfs::WriteFile( - client.as_raw_handle(), - b"1234" as *const u8, - 4, - (&mut written) as *mut u32, - client_overlapped.write_ptr(), - ); + let ret = write_file_overlapped(&client, b"1234", client_overlapped.write_overlapped()); - assert!(ret == wf::TRUE && written == 4); - } + assert_eq!(ret.unwrap(), 4); + assert_eq!(client_overlapped.test_ref_count(), 3); drop(client); @@ -543,17 +543,15 @@ fn write_then_drop() { .unwrap(); assert_eq!(num_event, 1); + assert_eq!(client_overlapped.test_ref_count(), 2); - unsafe { - let events = events.iter().collect::>(); - assert_eq!(events[0].key, 2); - assert!(events[0].writable); - assert!(!events[0].readable); - let overlapped_wrapper = - &*FileOverlappedWrapper::from_overlapped_ptr(client_overlapped.write_ptr()); - assert_eq!(overlapped_wrapper.get_bytes_transferred(), 4); - assert!(overlapped_wrapper.get_result().unwrap()); - } + let w_events = events.iter().collect::>(); + assert_eq!(w_events[0].key, 2); + assert!(w_events[0].writable); + assert!(!w_events[0].readable); + let overlapped_wrapper = unsafe { &*client_overlapped.write_complete() }; + assert_eq!(overlapped_wrapper.get_bytes_transferred(), 4); + assert!(overlapped_wrapper.get_result().unwrap()); events.clear(); let num_event = poller @@ -562,28 +560,19 @@ fn write_then_drop() { assert_eq!(num_event, 0); - unsafe { - let mut buf = [0u8; 10]; + let mut buf = [0u8; 10]; - let mut read = 0u32; - let ret = wfs::ReadFile( - server.as_raw_handle(), - &mut buf as *mut u8, - 10, - (&mut read) as *mut u32, - server_overlapped.read_ptr(), - ); + let ret = read_file_overlapped(&server, buf.as_mut(), server_overlapped.read_overlapped()); - assert_eq!(ret, wf::TRUE); - assert_eq!(read, 4); + assert_eq!(ret.unwrap(), 4); + assert_eq!(server_overlapped.test_ref_count(), 3); - // Still receive read event even ReadFile return true. - let num_event = poller - .wait(&mut events, Some(Duration::from_millis(10))) - .unwrap(); - assert_eq!(num_event, 1); - assert_eq!(&buf[..4], b"1234"); - } + // Still receive read event even ReadFile return true. + let num_event = poller + .wait(&mut events, Some(Duration::from_millis(10))) + .unwrap(); + assert_eq!(num_event, 1); + assert_eq!(&buf[..4], b"1234"); drop(server); } @@ -604,38 +593,33 @@ fn connect_twice() { poller.wait(&mut events, Some(Duration::new(0, 0))).unwrap(); assert_eq!(events.iter().count(), 0); - let ret = connect_named_pipe(&server, server_overlapped.read_ptr()); + let ret = connect_named_pipe_overlapped(&server, server_overlapped.read_overlapped()); assert_eq!(ret.err().unwrap().kind(), io::ErrorKind::WouldBlock); + assert_eq!(server_overlapped.test_ref_count(), 3); let c1 = client(&name).unwrap(); let _c1_overlapped = poller.add_file(&c1, Event::new(2, true, true)).unwrap(); drop(c1); poller.wait(&mut events, Some(Duration::new(0, 0))).unwrap(); - let ret_events = events.iter().collect::>(); - assert_eq!(ret_events.len(), 1); - assert_eq!(ret_events[0].key, 1); - assert!(ret_events[0].readable); + assert_eq!(server_overlapped.test_ref_count(), 2); + let r_events = events.iter().collect::>(); + assert_eq!(r_events.len(), 1); + assert_eq!(r_events[0].key, 1); + assert!(r_events[0].readable); events.clear(); let mut buf = [0u8; 10]; - let mut read = 0u32; // Can not read, should close server pipe. - let ret = wfs::ReadFile( - server.as_raw_handle(), - &mut buf as *mut u8, - 10, - (&mut read) as *mut u32, - server_overlapped.read_ptr(), - ); + let ret = read_file_overlapped(&server, buf.as_mut(), server_overlapped.read_overlapped()); - let e = io::Error::last_os_error(); - - assert_eq!(ret, wf::FALSE); - assert_eq!(read, 0); - assert_eq!(e.raw_os_error(), Some(wf::ERROR_BROKEN_PIPE as i32)); + assert_eq!( + ret.err().unwrap().raw_os_error(), + Some(wf::ERROR_BROKEN_PIPE as i32) + ); + assert_eq!(server_overlapped.test_ref_count(), 2); let num_event = poller .wait(&mut events, Some(Duration::from_millis(10)))