Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

types: add RecvMsgOut helper structure #192

Merged
merged 2 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions io-uring-test/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ fn test<S: squeue::EntryMarker, C: cqueue::EntryMarker>(
#[cfg(not(feature = "ci"))]
tests::net::test_tcp_recv_multi(&mut ring, &test)?;
tests::net::test_socket(&mut ring, &test)?;
tests::net::test_udp_recvmsg_multishot(&mut ring, &test)?;

// queue
tests::poll::test_eventfd_poll(&mut ring, &test)?;
Expand Down
144 changes: 144 additions & 0 deletions io-uring-test/src/tests/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1235,3 +1235,147 @@ pub fn test_socket<S: squeue::EntryMarker, C: cqueue::EntryMarker>(

Ok(())
}

pub fn test_udp_recvmsg_multishot<S: squeue::EntryMarker, C: cqueue::EntryMarker>(
ring: &mut IoUring<S, C>,
test: &Test,
) -> anyhow::Result<()> {
// Multishot recvmsg was introduced in 6.0, like `SendZc`.
// We cannot probe for the former, so we check for the latter as a proxy instead.
require!(
test;
test.probe.is_supported(opcode::RecvMsg::CODE);
test.probe.is_supported(opcode::ProvideBuffers::CODE);
test.probe.is_supported(opcode::SendZc::CODE);
);

println!("test udp_recvmsg_multishot");

let (socket_slot, socket_addr) = {
// `:0` means "pick up a random available port number", which should
// help avoiding test flakes if a static port is already in use.
let server_sock = std::net::UdpSocket::bind("127.0.0.1:0")?;
ring.submitter()
.register_files(&[server_sock.as_raw_fd()])?;
let addr = server_sock.local_addr().unwrap();
(io_uring::types::Fixed(0), addr)
};

// Provide 2 buffers in buffer group `33`, at index 0 and 1.
// Each one is 512 bytes large.
const BUF_GROUP: u16 = 33;
const SIZE: usize = 512;
let mut buffers = [[0u8; SIZE]; 2];
for (index, buf) in buffers.iter_mut().enumerate() {
let provide_bufs_e = io_uring::opcode::ProvideBuffers::new(
buf.as_mut_ptr(),
SIZE as i32,
1,
BUF_GROUP,
index as u16,
)
.build()
.user_data(11)
.into();
unsafe { ring.submission().push(&provide_bufs_e)? };
ring.submitter().submit_and_wait(1)?;
let cqes: Vec<io_uring::cqueue::Entry> = ring.completion().map(Into::into).collect();
assert_eq!(cqes.len(), 1);
assert_eq!(cqes[0].user_data(), 11);
assert_eq!(cqes[0].result(), 0);
assert_eq!(cqes[0].flags(), 0);
}

// This structure is actually only used for input arguments to the kernel
// (and only name length and control length are actually relevant).
let mut msghdr: libc::msghdr = unsafe { std::mem::zeroed() };
msghdr.msg_namelen = 32;
msghdr.msg_controllen = 0;

// TODO(lucab): make this more ergonomic to use.
const IORING_RECV_MULTISHOT: u16 = 2;
FrankReh marked this conversation as resolved.
Show resolved Hide resolved

let recvmsg_e = io_uring::opcode::RecvMsg::new(socket_slot, &mut msghdr as *mut _)
.ioprio(IORING_RECV_MULTISHOT)
.buf_group(BUF_GROUP)
.build()
.flags(io_uring::squeue::Flags::BUFFER_SELECT)
.user_data(77)
.into();
unsafe { ring.submission().push(&recvmsg_e)? };
ring.submitter().submit().unwrap();
FrankReh marked this conversation as resolved.
Show resolved Hide resolved

let client_socket: socket2::Socket = std::net::UdpSocket::bind("127.0.0.1:0").unwrap().into();
let client_addr = client_socket
.local_addr()
.unwrap()
.as_socket_ipv4()
.unwrap();
client_socket
.send_to("testfoo".as_bytes(), &socket_addr.into())
.unwrap();
client_socket
.send_to("testbarbar".as_bytes(), &socket_addr.into())
.unwrap();

// Check the completion events for the two UDP messages, plus a trailing
// CQE signaling that we ran out of buffers.
ring.submitter().submit_and_wait(3).unwrap();
let cqes: Vec<io_uring::cqueue::Entry> = ring.completion().map(Into::into).collect();
assert_eq!(cqes.len(), 3);
assert_eq!(cqes[0].user_data(), 77);
assert!(cqes[0].result() > 0);
assert!(io_uring::cqueue::more(cqes[0].flags()));
assert_eq!(io_uring::cqueue::buffer_select(cqes[0].flags()), Some(0));
assert!(cqes[0].flags() != 0);
assert_eq!(cqes[1].user_data(), 77);
assert!(cqes[1].result() > 0);
assert!(io_uring::cqueue::more(cqes[1].flags()));
assert_eq!(io_uring::cqueue::buffer_select(cqes[1].flags()), Some(1));
assert!(cqes[1].flags() != 0);
assert_eq!(cqes[2].user_data(), 77);
assert_eq!(cqes[2].result(), -libc::ENOBUFS);
assert!(!io_uring::cqueue::more(cqes[2].flags()));
assert_eq!(io_uring::cqueue::buffer_select(cqes[2].flags()), None);
assert_eq!(cqes[2].flags(), 0);

let msg0 = types::RecvMsgOut::parse(buffers[0].as_slice(), &msghdr).unwrap();
assert!(!msg0.is_payload_truncated());
assert_eq!(msg0.payload_data(), b"testfoo".as_slice());
assert!(!msg0.is_control_data_truncated());
assert_eq!(msg0.control_data(), &[]);
assert!(!msg0.is_name_data_truncated());
let (_, addr) = unsafe {
socket2::SockAddr::init(|storage, len| {
*len = msg0.name_data().len() as u32;
std::ptr::copy_nonoverlapping(msg0.name_data().as_ptr() as _, storage, 1);
Ok(())
})
}
.unwrap();
let addr = addr.as_socket_ipv4().unwrap();
assert_eq!(addr.ip(), client_addr.ip());
assert_eq!(addr.port(), client_addr.port());

let msg1 = types::RecvMsgOut::parse(buffers[1].as_slice(), &msghdr).unwrap();
assert!(!msg1.is_payload_truncated());
assert_eq!(msg1.payload_data(), b"testbarbar".as_slice());
assert!(!msg1.is_control_data_truncated());
assert_eq!(msg1.control_data(), &[]);
assert!(!msg1.is_name_data_truncated());
let (_, addr) = unsafe {
socket2::SockAddr::init(|storage, len| {
*len = msg1.name_data().len() as u32;
std::ptr::copy_nonoverlapping(msg1.name_data().as_ptr() as _, storage, 1);
Ok(())
})
}
.unwrap();
let addr = addr.as_socket_ipv4().unwrap();
assert_eq!(addr.ip(), client_addr.ip());
assert_eq!(addr.port(), client_addr.port());

ring.submitter().unregister_files().unwrap();

Ok(())
}
145 changes: 145 additions & 0 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,148 @@ impl DestinationSlot {
self.dest.get()
}
}

/// Helper structure for parsing the result of a multishot [`opcode::RecvMsg`](crate::opcode::RecvMsg).
#[derive(Debug)]
pub struct RecvMsgOut<'buf> {
header: sys::io_uring_recvmsg_out,
/// The fixed length of the name field, in bytes.
///
/// If the incoming name data is larger than this, it gets truncated to this.
/// If it is smaller, it gets 0-padded to fill the whole field. In either case,
/// this fixed amount of space is reserved in the result buffer.
msghdr_name_len: usize,
/// The fixed length of the control field, in bytes.
///
/// This follows the same semantics as the field above, but for control data.
msghdr_control_len: usize,
name_data: &'buf [u8],
control_data: &'buf [u8],
payload_data: &'buf [u8],
}

impl<'buf> RecvMsgOut<'buf> {
const DATA_START: usize = std::mem::size_of::<sys::io_uring_recvmsg_out>();

/// Parse the data buffered upon completion of a `RecvMsg` multishot operation.
///
/// `buffer` is the whole buffer previously provided to the ring, while `msghdr`
/// is the same content provided as input to the corresponding SQE
/// (only `msg_namelen` and `msg_controllen` fields are relevant).
pub fn parse(buffer: &'buf [u8], msghdr: &libc::msghdr) -> Result<Self, ()> {
if buffer.len() < std::mem::size_of::<sys::io_uring_recvmsg_out>() {
return Err(());
}
// SAFETY: buffer (minimum) length is checked here above.
let header: sys::io_uring_recvmsg_out =
unsafe { std::ptr::read_unaligned(buffer.as_ptr() as _) };

let msghdr_name_len = msghdr.msg_namelen as _;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to use .into() whenever possible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a u32 -> usize conversion, which is always safe to perform on Linux but unfortunately not yet covered in Rust by Into (see rust-lang/rust#70460).

Additionally there are some mismatches across different libc, so that the source fields are u32 on some libc and usize on other libc (so actually not really casting). This is all very unfortunate, but as _ is pretty much a required hack to appease rustc and clippy in a portable way across libc's.

let msghdr_control_len = msghdr.msg_controllen as _;

// Check total length upfront, so that further logic here
// below can safely use unchecked/saturating math.
let length_overflow = Some(Self::DATA_START)
.and_then(|acc| acc.checked_add(msghdr_name_len))
.and_then(|acc| acc.checked_add(msghdr_control_len))
.and_then(|acc| acc.checked_add(header.payloadlen as usize))
.map(|total_len| total_len > buffer.len())
.unwrap_or(true);
if length_overflow {
return Err(());
}

let (name_data, control_start) = {
let name_start = Self::DATA_START;
let name_size = usize::min(header.namelen as usize, msghdr_name_len);
let name_data_end = name_start.saturating_add(name_size);
let name_data = &buffer[name_start..name_data_end];
let name_field_end = name_start.saturating_add(msghdr_name_len);
(name_data, name_field_end)
};
let (control_data, payload_start) = {
let control_size = usize::min(header.controllen as usize, msghdr_control_len);
let control_data_end = control_start.saturating_add(control_size);
let control_data = &buffer[control_start..control_data_end];
let control_field_end = control_start.saturating_add(msghdr_control_len);
(control_data, control_field_end)
};
let payload_data = {
let payload_data_end = payload_start.saturating_add(header.payloadlen as usize);
&buffer[payload_start..payload_data_end]
};

Ok(Self {
header,
msghdr_name_len,
msghdr_control_len,
name_data,
control_data,
payload_data,
})
}

/// Return the length of the incoming `name` data.
///
/// This may be larger than the size of the content returned by
/// `name_data()`, if the kernel could not fit all the incoming
/// data in the provided buffer size. In that case, name data in
/// the result buffer gets truncated.
pub fn incoming_name_len(&self) -> u32 {
self.header.namelen
}

/// Return whether the incoming name data was larger than the provided limit/buffer.
///
/// When `true`, data returned by `name_data()` is truncated and
/// incomplete.
pub fn is_name_data_truncated(&self) -> bool {
self.header.namelen as usize > self.msghdr_name_len
}

/// Message control data, with the same semantics as `msghdr.msg_control`.
pub fn name_data(&self) -> &[u8] {
self.name_data
}

/// Return the length of the incoming `control` data.
///
/// This may be larger than the size of the content returned by
/// `control_data()`, if the kernel could not fit all the incoming
/// data in the provided buffer size. In that case, control data in
/// the result buffer gets truncated.
pub fn incoming_control_len(&self) -> u32 {
self.header.controllen
}

/// Return whether the incoming control data was larger than the provided limit/buffer.
///
/// When `true`, data returned by `control_data()` is truncated and
/// incomplete.
pub fn is_control_data_truncated(&self) -> bool {
self.header.controllen as usize > self.msghdr_control_len
}

/// Message control data, with the same semantics as `msghdr.msg_control`.
pub fn control_data(&self) -> &[u8] {
self.control_data
}

/// Return whether the incoming payload was larger than the provided limit/buffer.
///
/// When `true`, data returned by `payload_data()` is truncated and
/// incomplete.
pub fn is_payload_truncated(&self) -> bool {
self.header.flags & (libc::MSG_TRUNC as u32) != 0
}

/// Message payload, as buffered by the kernel.
pub fn payload_data(&self) -> &[u8] {
self.payload_data
}

/// Message flags, with the same semantics as `msghdr.msg_flags`.
pub fn flags(&self) -> u32 {
self.header.flags
}
}