Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ name = "polling"
version = "3.11.0"
authors = ["Stjepan Glavina <[email protected]>", "John Nunley <[email protected]>"]
edition = "2021"
rust-version = "1.71"
rust-version = "1.77"
description = "Portable interface to epoll, kqueue, event ports, and IOCP"
license = "Apache-2.0 OR MIT"
repository = "https://github.com/smol-rs/polling"
Expand All @@ -18,7 +18,10 @@ exclude = ["/.*"]
rustdoc-args = ["--cfg", "docsrs"]

[lints.rust]
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(polling_test_poll_backend)', 'cfg(polling_test_epoll_pipe)'] }
unexpected_cfgs = { level = "warn", check-cfg = [
'cfg(polling_test_poll_backend)',
'cfg(polling_test_epoll_pipe)',
] }

[dependencies]
cfg-if = "1"
Expand Down Expand Up @@ -50,6 +53,7 @@ features = [
"Win32_System_LibraryLoader",
"Win32_System_Threading",
"Win32_System_WindowsProgramming",
"Win32_System_Pipes",
]

[target.'cfg(target_os = "hermit")'.dependencies.hermit-abi]
Expand All @@ -65,3 +69,6 @@ libc = "0.2"

[target.'cfg(all(unix, not(target_os="vita")))'.dev-dependencies]
signal-hook = "0.3.17"

[target.'cfg(windows)'.dev-dependencies]
tempfile = "3.7"
169 changes: 17 additions & 152 deletions src/iocp/afd.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
//! Safe wrapper around \Device\Afd

use crate::iocp::ntdll::NtdllImports;
use crate::iocp::port::FileOverlapped;

use super::port::{Completion, CompletionHandle};

use std::cell::UnsafeCell;
use std::fmt;
use std::io;
use std::marker::{PhantomData, PhantomPinned};
use std::mem::{self, size_of, transmute, MaybeUninit};
use std::mem::{self, size_of, MaybeUninit};
use std::ops;
use std::os::windows::prelude::{AsRawHandle, RawHandle, RawSocket};
use std::pin::Pin;
use std::ptr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::OnceLock;

use windows_sys::Wdk::Foundation::OBJECT_ATTRIBUTES;
use windows_sys::Wdk::Storage::FileSystem::FILE_OPEN;
use windows_sys::Win32::Foundation::{
CloseHandle, HANDLE, HMODULE, NTSTATUS, STATUS_NOT_FOUND, STATUS_PENDING, STATUS_SUCCESS,
UNICODE_STRING,
CloseHandle, HANDLE, NTSTATUS, STATUS_NOT_FOUND, STATUS_PENDING, STATUS_SUCCESS, UNICODE_STRING,
};
use windows_sys::Win32::Networking::WinSock::{
WSAIoctl, SIO_BASE_HANDLE, SIO_BSP_HANDLE_POLL, SOCKET_ERROR,
};
use windows_sys::Win32::Storage::FileSystem::{FILE_SHARE_READ, FILE_SHARE_WRITE, SYNCHRONIZE};
use windows_sys::Win32::System::LibraryLoader::{GetModuleHandleW, GetProcAddress};
use windows_sys::Win32::System::IO::IO_STATUS_BLOCK;

#[derive(Default)]
Expand Down Expand Up @@ -170,153 +170,6 @@ pub(super) trait HasAfdInfo {
fn afd_info(self: Pin<&Self>) -> Pin<&UnsafeCell<AfdPollInfo>>;
}

macro_rules! define_ntdll_import {
(
$(
$(#[$attr:meta])*
fn $name:ident($($arg:ident: $arg_ty:ty),*) -> $ret:ty;
)*
) => {
/// Imported functions from ntdll.dll.
#[allow(non_snake_case)]
pub(super) struct NtdllImports {
$(
$(#[$attr])*
$name: unsafe extern "system" fn($($arg_ty),*) -> $ret,
)*
}

#[allow(non_snake_case)]
impl NtdllImports {
unsafe fn load(ntdll: HMODULE) -> io::Result<Self> {
$(
#[allow(clippy::missing_transmute_annotations)]
let $name = {
const NAME: &str = concat!(stringify!($name), "\0");
let addr = GetProcAddress(ntdll, NAME.as_ptr() as *const _);

let addr = match addr {
Some(addr) => addr,
None => {
#[cfg(feature = "tracing")]
tracing::error!("Failed to load ntdll function {}", NAME);
return Err(io::Error::last_os_error());
},
};

transmute::<_, unsafe extern "system" fn($($arg_ty),*) -> $ret>(addr)
};
)*

Ok(Self {
$(
$name,
)*
})
}

$(
$(#[$attr])*
unsafe fn $name(&self, $($arg: $arg_ty),*) -> $ret {
(self.$name)($($arg),*)
}
)*
}
};
}

define_ntdll_import! {
/// Cancels an ongoing I/O operation.
fn NtCancelIoFileEx(
FileHandle: HANDLE,
IoRequestToCancel: *mut IO_STATUS_BLOCK,
IoStatusBlock: *mut IO_STATUS_BLOCK
) -> NTSTATUS;

/// Opens or creates a file handle.
#[allow(clippy::too_many_arguments)]
fn NtCreateFile(
FileHandle: *mut HANDLE,
DesiredAccess: u32,
ObjectAttributes: *mut OBJECT_ATTRIBUTES,
IoStatusBlock: *mut IO_STATUS_BLOCK,
AllocationSize: *mut i64,
FileAttributes: u32,
ShareAccess: u32,
CreateDisposition: u32,
CreateOptions: u32,
EaBuffer: *mut (),
EaLength: u32
) -> NTSTATUS;

/// Runs an I/O control on a file handle.
///
/// Practically equivalent to `ioctl`.
#[allow(clippy::too_many_arguments)]
fn NtDeviceIoControlFile(
FileHandle: HANDLE,
Event: HANDLE,
ApcRoutine: *mut (),
ApcContext: *mut (),
IoStatusBlock: *mut IO_STATUS_BLOCK,
IoControlCode: u32,
InputBuffer: *mut (),
InputBufferLength: u32,
OutputBuffer: *mut (),
OutputBufferLength: u32
) -> NTSTATUS;

/// Converts `NTSTATUS` to a DOS error code.
fn RtlNtStatusToDosError(
Status: NTSTATUS
) -> u32;
}

impl NtdllImports {
fn get() -> io::Result<&'static Self> {
macro_rules! s {
($e:expr) => {{
$e as u16
}};
}

// ntdll.dll
static NTDLL_NAME: &[u16] = &[
s!('n'),
s!('t'),
s!('d'),
s!('l'),
s!('l'),
s!('.'),
s!('d'),
s!('l'),
s!('l'),
s!('\0'),
];
static NTDLL_IMPORTS: OnceLock<io::Result<NtdllImports>> = OnceLock::new();

NTDLL_IMPORTS
.get_or_init(|| unsafe {
let ntdll = GetModuleHandleW(NTDLL_NAME.as_ptr() as *const _);

if ntdll.is_null() {
#[cfg(feature = "tracing")]
tracing::error!("Failed to load ntdll.dll");
return Err(io::Error::last_os_error());
}

NtdllImports::load(ntdll)
})
.as_ref()
.map_err(|e| io::Error::from(e.kind()))
}

pub(super) fn force_load() -> io::Result<()> {
Self::get()?;
Ok(())
}
}

/// The handle to the AFD device.
pub(super) struct Afd<T> {
/// The handle to the AFD device.
Expand Down Expand Up @@ -614,6 +467,18 @@ unsafe impl<T> Completion for IoStatusBlock<T> {
}
}

impl<T: FileOverlapped> FileOverlapped for IoStatusBlock<T> {
#[inline]
fn file_read_offset() -> usize {
T::file_read_offset() + std::mem::offset_of!(IoStatusBlock<T>, data)
}

#[inline]
fn file_write_offset() -> usize {
T::file_write_offset() + std::mem::offset_of!(IoStatusBlock<T>, data)
}
}

/// Get the base socket associated with a socket.
pub(super) fn base_socket(sock: RawSocket) -> io::Result<RawSocket> {
// First, try the SIO_BASE_HANDLE ioctl.
Expand Down
Loading
Loading