Skip to content

Commit

Permalink
types: add RecvMsgOut helper structure (#192)
Browse files Browse the repository at this point in the history
This introduces a `RecvMsgOut` structure which allows consumers to
safely parse the buffered result from a multishot `IORING_OP_RECVMSG`
completion event.

This also adds a test entry which exercises `IORING_RECV_MULTISHOT` on
`IORING_OP_RECVMSG`, as well as `RecvMsgOut` parsing logic.
  • Loading branch information
lucab authored Feb 16, 2023
1 parent cb3e4c0 commit 6785e66
Show file tree
Hide file tree
Showing 3 changed files with 290 additions and 0 deletions.
1 change: 1 addition & 0 deletions io-uring-test/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ fn test<S: squeue::EntryMarker, C: cqueue::EntryMarker>(
#[cfg(not(feature = "ci"))]
tests::net::test_tcp_recv_multi(&mut ring, &test)?;
tests::net::test_socket(&mut ring, &test)?;
tests::net::test_udp_recvmsg_multishot(&mut ring, &test)?;

// queue
tests::poll::test_eventfd_poll(&mut ring, &test)?;
Expand Down
144 changes: 144 additions & 0 deletions io-uring-test/src/tests/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1235,3 +1235,147 @@ pub fn test_socket<S: squeue::EntryMarker, C: cqueue::EntryMarker>(

Ok(())
}

pub fn test_udp_recvmsg_multishot<S: squeue::EntryMarker, C: cqueue::EntryMarker>(
ring: &mut IoUring<S, C>,
test: &Test,
) -> anyhow::Result<()> {
// Multishot recvmsg was introduced in 6.0, like `SendZc`.
// We cannot probe for the former, so we check for the latter as a proxy instead.
require!(
test;
test.probe.is_supported(opcode::RecvMsg::CODE);
test.probe.is_supported(opcode::ProvideBuffers::CODE);
test.probe.is_supported(opcode::SendZc::CODE);
);

println!("test udp_recvmsg_multishot");

let (socket_slot, socket_addr) = {
// `:0` means "pick up a random available port number", which should
// help avoiding test flakes if a static port is already in use.
let server_sock = std::net::UdpSocket::bind("127.0.0.1:0")?;
ring.submitter()
.register_files(&[server_sock.as_raw_fd()])?;
let addr = server_sock.local_addr().unwrap();
(io_uring::types::Fixed(0), addr)
};

// Provide 2 buffers in buffer group `33`, at index 0 and 1.
// Each one is 512 bytes large.
const BUF_GROUP: u16 = 33;
const SIZE: usize = 512;
let mut buffers = [[0u8; SIZE]; 2];
for (index, buf) in buffers.iter_mut().enumerate() {
let provide_bufs_e = io_uring::opcode::ProvideBuffers::new(
buf.as_mut_ptr(),
SIZE as i32,
1,
BUF_GROUP,
index as u16,
)
.build()
.user_data(11)
.into();
unsafe { ring.submission().push(&provide_bufs_e)? };
ring.submitter().submit_and_wait(1)?;
let cqes: Vec<io_uring::cqueue::Entry> = ring.completion().map(Into::into).collect();
assert_eq!(cqes.len(), 1);
assert_eq!(cqes[0].user_data(), 11);
assert_eq!(cqes[0].result(), 0);
assert_eq!(cqes[0].flags(), 0);
}

// This structure is actually only used for input arguments to the kernel
// (and only name length and control length are actually relevant).
let mut msghdr: libc::msghdr = unsafe { std::mem::zeroed() };
msghdr.msg_namelen = 32;
msghdr.msg_controllen = 0;

// TODO(lucab): make this more ergonomic to use.
const IORING_RECV_MULTISHOT: u16 = 2;

let recvmsg_e = io_uring::opcode::RecvMsg::new(socket_slot, &mut msghdr as *mut _)
.ioprio(IORING_RECV_MULTISHOT)
.buf_group(BUF_GROUP)
.build()
.flags(io_uring::squeue::Flags::BUFFER_SELECT)
.user_data(77)
.into();
unsafe { ring.submission().push(&recvmsg_e)? };
ring.submitter().submit().unwrap();

let client_socket: socket2::Socket = std::net::UdpSocket::bind("127.0.0.1:0").unwrap().into();
let client_addr = client_socket
.local_addr()
.unwrap()
.as_socket_ipv4()
.unwrap();
client_socket
.send_to("testfoo".as_bytes(), &socket_addr.into())
.unwrap();
client_socket
.send_to("testbarbar".as_bytes(), &socket_addr.into())
.unwrap();

// Check the completion events for the two UDP messages, plus a trailing
// CQE signaling that we ran out of buffers.
ring.submitter().submit_and_wait(3).unwrap();
let cqes: Vec<io_uring::cqueue::Entry> = ring.completion().map(Into::into).collect();
assert_eq!(cqes.len(), 3);
assert_eq!(cqes[0].user_data(), 77);
assert!(cqes[0].result() > 0);
assert!(io_uring::cqueue::more(cqes[0].flags()));
assert_eq!(io_uring::cqueue::buffer_select(cqes[0].flags()), Some(0));
assert!(cqes[0].flags() != 0);
assert_eq!(cqes[1].user_data(), 77);
assert!(cqes[1].result() > 0);
assert!(io_uring::cqueue::more(cqes[1].flags()));
assert_eq!(io_uring::cqueue::buffer_select(cqes[1].flags()), Some(1));
assert!(cqes[1].flags() != 0);
assert_eq!(cqes[2].user_data(), 77);
assert_eq!(cqes[2].result(), -libc::ENOBUFS);
assert!(!io_uring::cqueue::more(cqes[2].flags()));
assert_eq!(io_uring::cqueue::buffer_select(cqes[2].flags()), None);
assert_eq!(cqes[2].flags(), 0);

let msg0 = types::RecvMsgOut::parse(buffers[0].as_slice(), &msghdr).unwrap();
assert!(!msg0.is_payload_truncated());
assert_eq!(msg0.payload_data(), b"testfoo".as_slice());
assert!(!msg0.is_control_data_truncated());
assert_eq!(msg0.control_data(), &[]);
assert!(!msg0.is_name_data_truncated());
let (_, addr) = unsafe {
socket2::SockAddr::init(|storage, len| {
*len = msg0.name_data().len() as u32;
std::ptr::copy_nonoverlapping(msg0.name_data().as_ptr() as _, storage, 1);
Ok(())
})
}
.unwrap();
let addr = addr.as_socket_ipv4().unwrap();
assert_eq!(addr.ip(), client_addr.ip());
assert_eq!(addr.port(), client_addr.port());

let msg1 = types::RecvMsgOut::parse(buffers[1].as_slice(), &msghdr).unwrap();
assert!(!msg1.is_payload_truncated());
assert_eq!(msg1.payload_data(), b"testbarbar".as_slice());
assert!(!msg1.is_control_data_truncated());
assert_eq!(msg1.control_data(), &[]);
assert!(!msg1.is_name_data_truncated());
let (_, addr) = unsafe {
socket2::SockAddr::init(|storage, len| {
*len = msg1.name_data().len() as u32;
std::ptr::copy_nonoverlapping(msg1.name_data().as_ptr() as _, storage, 1);
Ok(())
})
}
.unwrap();
let addr = addr.as_socket_ipv4().unwrap();
assert_eq!(addr.ip(), client_addr.ip());
assert_eq!(addr.port(), client_addr.port());

ring.submitter().unregister_files().unwrap();

Ok(())
}
145 changes: 145 additions & 0 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,148 @@ impl DestinationSlot {
self.dest.get()
}
}

/// Helper structure for parsing the result of a multishot [`opcode::RecvMsg`](crate::opcode::RecvMsg).
#[derive(Debug)]
pub struct RecvMsgOut<'buf> {
header: sys::io_uring_recvmsg_out,
/// The fixed length of the name field, in bytes.
///
/// If the incoming name data is larger than this, it gets truncated to this.
/// If it is smaller, it gets 0-padded to fill the whole field. In either case,
/// this fixed amount of space is reserved in the result buffer.
msghdr_name_len: usize,
/// The fixed length of the control field, in bytes.
///
/// This follows the same semantics as the field above, but for control data.
msghdr_control_len: usize,
name_data: &'buf [u8],
control_data: &'buf [u8],
payload_data: &'buf [u8],
}

impl<'buf> RecvMsgOut<'buf> {
const DATA_START: usize = std::mem::size_of::<sys::io_uring_recvmsg_out>();

/// Parse the data buffered upon completion of a `RecvMsg` multishot operation.
///
/// `buffer` is the whole buffer previously provided to the ring, while `msghdr`
/// is the same content provided as input to the corresponding SQE
/// (only `msg_namelen` and `msg_controllen` fields are relevant).
pub fn parse(buffer: &'buf [u8], msghdr: &libc::msghdr) -> Result<Self, ()> {
if buffer.len() < std::mem::size_of::<sys::io_uring_recvmsg_out>() {
return Err(());
}
// SAFETY: buffer (minimum) length is checked here above.
let header: sys::io_uring_recvmsg_out =
unsafe { std::ptr::read_unaligned(buffer.as_ptr() as _) };

let msghdr_name_len = msghdr.msg_namelen as _;
let msghdr_control_len = msghdr.msg_controllen as _;

// Check total length upfront, so that further logic here
// below can safely use unchecked/saturating math.
let length_overflow = Some(Self::DATA_START)
.and_then(|acc| acc.checked_add(msghdr_name_len))
.and_then(|acc| acc.checked_add(msghdr_control_len))
.and_then(|acc| acc.checked_add(header.payloadlen as usize))
.map(|total_len| total_len > buffer.len())
.unwrap_or(true);
if length_overflow {
return Err(());
}

let (name_data, control_start) = {
let name_start = Self::DATA_START;
let name_size = usize::min(header.namelen as usize, msghdr_name_len);
let name_data_end = name_start.saturating_add(name_size);
let name_data = &buffer[name_start..name_data_end];
let name_field_end = name_start.saturating_add(msghdr_name_len);
(name_data, name_field_end)
};
let (control_data, payload_start) = {
let control_size = usize::min(header.controllen as usize, msghdr_control_len);
let control_data_end = control_start.saturating_add(control_size);
let control_data = &buffer[control_start..control_data_end];
let control_field_end = control_start.saturating_add(msghdr_control_len);
(control_data, control_field_end)
};
let payload_data = {
let payload_data_end = payload_start.saturating_add(header.payloadlen as usize);
&buffer[payload_start..payload_data_end]
};

Ok(Self {
header,
msghdr_name_len,
msghdr_control_len,
name_data,
control_data,
payload_data,
})
}

/// Return the length of the incoming `name` data.
///
/// This may be larger than the size of the content returned by
/// `name_data()`, if the kernel could not fit all the incoming
/// data in the provided buffer size. In that case, name data in
/// the result buffer gets truncated.
pub fn incoming_name_len(&self) -> u32 {
self.header.namelen
}

/// Return whether the incoming name data was larger than the provided limit/buffer.
///
/// When `true`, data returned by `name_data()` is truncated and
/// incomplete.
pub fn is_name_data_truncated(&self) -> bool {
self.header.namelen as usize > self.msghdr_name_len
}

/// Message control data, with the same semantics as `msghdr.msg_control`.
pub fn name_data(&self) -> &[u8] {
self.name_data
}

/// Return the length of the incoming `control` data.
///
/// This may be larger than the size of the content returned by
/// `control_data()`, if the kernel could not fit all the incoming
/// data in the provided buffer size. In that case, control data in
/// the result buffer gets truncated.
pub fn incoming_control_len(&self) -> u32 {
self.header.controllen
}

/// Return whether the incoming control data was larger than the provided limit/buffer.
///
/// When `true`, data returned by `control_data()` is truncated and
/// incomplete.
pub fn is_control_data_truncated(&self) -> bool {
self.header.controllen as usize > self.msghdr_control_len
}

/// Message control data, with the same semantics as `msghdr.msg_control`.
pub fn control_data(&self) -> &[u8] {
self.control_data
}

/// Return whether the incoming payload was larger than the provided limit/buffer.
///
/// When `true`, data returned by `payload_data()` is truncated and
/// incomplete.
pub fn is_payload_truncated(&self) -> bool {
self.header.flags & (libc::MSG_TRUNC as u32) != 0
}

/// Message payload, as buffered by the kernel.
pub fn payload_data(&self) -> &[u8] {
self.payload_data
}

/// Message flags, with the same semantics as `msghdr.msg_flags`.
pub fn flags(&self) -> u32 {
self.header.flags
}
}

0 comments on commit 6785e66

Please sign in to comment.