From 6785e666869993f81414a5cfad501ef714c742e9 Mon Sep 17 00:00:00 2001 From: Luca Bruno Date: Thu, 16 Feb 2023 02:55:34 +0000 Subject: [PATCH] types: add RecvMsgOut helper structure (#192) This introduces a `RecvMsgOut` structure which allows consumers to safely parse the buffered result from a multishot `IORING_OP_RECVMSG` completion event. This also adds a test entry which exercises `IORING_RECV_MULTISHOT` on `IORING_OP_RECVMSG`, as well as `RecvMsgOut` parsing logic. --- io-uring-test/src/main.rs | 1 + io-uring-test/src/tests/net.rs | 144 ++++++++++++++++++++++++++++++++ src/types.rs | 145 +++++++++++++++++++++++++++++++++ 3 files changed, 290 insertions(+) diff --git a/io-uring-test/src/main.rs b/io-uring-test/src/main.rs index e830853f..f1eacfe3 100644 --- a/io-uring-test/src/main.rs +++ b/io-uring-test/src/main.rs @@ -127,6 +127,7 @@ fn test( #[cfg(not(feature = "ci"))] tests::net::test_tcp_recv_multi(&mut ring, &test)?; tests::net::test_socket(&mut ring, &test)?; + tests::net::test_udp_recvmsg_multishot(&mut ring, &test)?; // queue tests::poll::test_eventfd_poll(&mut ring, &test)?; diff --git a/io-uring-test/src/tests/net.rs b/io-uring-test/src/tests/net.rs index 21ff4fd8..07314fa2 100644 --- a/io-uring-test/src/tests/net.rs +++ b/io-uring-test/src/tests/net.rs @@ -1235,3 +1235,147 @@ pub fn test_socket( Ok(()) } + +pub fn test_udp_recvmsg_multishot( + ring: &mut IoUring, + test: &Test, +) -> anyhow::Result<()> { + // Multishot recvmsg was introduced in 6.0, like `SendZc`. + // We cannot probe for the former, so we check for the latter as a proxy instead. + require!( + test; + test.probe.is_supported(opcode::RecvMsg::CODE); + test.probe.is_supported(opcode::ProvideBuffers::CODE); + test.probe.is_supported(opcode::SendZc::CODE); + ); + + println!("test udp_recvmsg_multishot"); + + let (socket_slot, socket_addr) = { + // `:0` means "pick up a random available port number", which should + // help avoiding test flakes if a static port is already in use. + let server_sock = std::net::UdpSocket::bind("127.0.0.1:0")?; + ring.submitter() + .register_files(&[server_sock.as_raw_fd()])?; + let addr = server_sock.local_addr().unwrap(); + (io_uring::types::Fixed(0), addr) + }; + + // Provide 2 buffers in buffer group `33`, at index 0 and 1. + // Each one is 512 bytes large. + const BUF_GROUP: u16 = 33; + const SIZE: usize = 512; + let mut buffers = [[0u8; SIZE]; 2]; + for (index, buf) in buffers.iter_mut().enumerate() { + let provide_bufs_e = io_uring::opcode::ProvideBuffers::new( + buf.as_mut_ptr(), + SIZE as i32, + 1, + BUF_GROUP, + index as u16, + ) + .build() + .user_data(11) + .into(); + unsafe { ring.submission().push(&provide_bufs_e)? }; + ring.submitter().submit_and_wait(1)?; + let cqes: Vec = ring.completion().map(Into::into).collect(); + assert_eq!(cqes.len(), 1); + assert_eq!(cqes[0].user_data(), 11); + assert_eq!(cqes[0].result(), 0); + assert_eq!(cqes[0].flags(), 0); + } + + // This structure is actually only used for input arguments to the kernel + // (and only name length and control length are actually relevant). + let mut msghdr: libc::msghdr = unsafe { std::mem::zeroed() }; + msghdr.msg_namelen = 32; + msghdr.msg_controllen = 0; + + // TODO(lucab): make this more ergonomic to use. + const IORING_RECV_MULTISHOT: u16 = 2; + + let recvmsg_e = io_uring::opcode::RecvMsg::new(socket_slot, &mut msghdr as *mut _) + .ioprio(IORING_RECV_MULTISHOT) + .buf_group(BUF_GROUP) + .build() + .flags(io_uring::squeue::Flags::BUFFER_SELECT) + .user_data(77) + .into(); + unsafe { ring.submission().push(&recvmsg_e)? }; + ring.submitter().submit().unwrap(); + + let client_socket: socket2::Socket = std::net::UdpSocket::bind("127.0.0.1:0").unwrap().into(); + let client_addr = client_socket + .local_addr() + .unwrap() + .as_socket_ipv4() + .unwrap(); + client_socket + .send_to("testfoo".as_bytes(), &socket_addr.into()) + .unwrap(); + client_socket + .send_to("testbarbar".as_bytes(), &socket_addr.into()) + .unwrap(); + + // Check the completion events for the two UDP messages, plus a trailing + // CQE signaling that we ran out of buffers. + ring.submitter().submit_and_wait(3).unwrap(); + let cqes: Vec = ring.completion().map(Into::into).collect(); + assert_eq!(cqes.len(), 3); + assert_eq!(cqes[0].user_data(), 77); + assert!(cqes[0].result() > 0); + assert!(io_uring::cqueue::more(cqes[0].flags())); + assert_eq!(io_uring::cqueue::buffer_select(cqes[0].flags()), Some(0)); + assert!(cqes[0].flags() != 0); + assert_eq!(cqes[1].user_data(), 77); + assert!(cqes[1].result() > 0); + assert!(io_uring::cqueue::more(cqes[1].flags())); + assert_eq!(io_uring::cqueue::buffer_select(cqes[1].flags()), Some(1)); + assert!(cqes[1].flags() != 0); + assert_eq!(cqes[2].user_data(), 77); + assert_eq!(cqes[2].result(), -libc::ENOBUFS); + assert!(!io_uring::cqueue::more(cqes[2].flags())); + assert_eq!(io_uring::cqueue::buffer_select(cqes[2].flags()), None); + assert_eq!(cqes[2].flags(), 0); + + let msg0 = types::RecvMsgOut::parse(buffers[0].as_slice(), &msghdr).unwrap(); + assert!(!msg0.is_payload_truncated()); + assert_eq!(msg0.payload_data(), b"testfoo".as_slice()); + assert!(!msg0.is_control_data_truncated()); + assert_eq!(msg0.control_data(), &[]); + assert!(!msg0.is_name_data_truncated()); + let (_, addr) = unsafe { + socket2::SockAddr::init(|storage, len| { + *len = msg0.name_data().len() as u32; + std::ptr::copy_nonoverlapping(msg0.name_data().as_ptr() as _, storage, 1); + Ok(()) + }) + } + .unwrap(); + let addr = addr.as_socket_ipv4().unwrap(); + assert_eq!(addr.ip(), client_addr.ip()); + assert_eq!(addr.port(), client_addr.port()); + + let msg1 = types::RecvMsgOut::parse(buffers[1].as_slice(), &msghdr).unwrap(); + assert!(!msg1.is_payload_truncated()); + assert_eq!(msg1.payload_data(), b"testbarbar".as_slice()); + assert!(!msg1.is_control_data_truncated()); + assert_eq!(msg1.control_data(), &[]); + assert!(!msg1.is_name_data_truncated()); + let (_, addr) = unsafe { + socket2::SockAddr::init(|storage, len| { + *len = msg1.name_data().len() as u32; + std::ptr::copy_nonoverlapping(msg1.name_data().as_ptr() as _, storage, 1); + Ok(()) + }) + } + .unwrap(); + let addr = addr.as_socket_ipv4().unwrap(); + assert_eq!(addr.ip(), client_addr.ip()); + assert_eq!(addr.port(), client_addr.port()); + + ring.submitter().unregister_files().unwrap(); + + Ok(()) +} diff --git a/src/types.rs b/src/types.rs index 22a850d1..76955091 100644 --- a/src/types.rs +++ b/src/types.rs @@ -330,3 +330,148 @@ impl DestinationSlot { self.dest.get() } } + +/// Helper structure for parsing the result of a multishot [`opcode::RecvMsg`](crate::opcode::RecvMsg). +#[derive(Debug)] +pub struct RecvMsgOut<'buf> { + header: sys::io_uring_recvmsg_out, + /// The fixed length of the name field, in bytes. + /// + /// If the incoming name data is larger than this, it gets truncated to this. + /// If it is smaller, it gets 0-padded to fill the whole field. In either case, + /// this fixed amount of space is reserved in the result buffer. + msghdr_name_len: usize, + /// The fixed length of the control field, in bytes. + /// + /// This follows the same semantics as the field above, but for control data. + msghdr_control_len: usize, + name_data: &'buf [u8], + control_data: &'buf [u8], + payload_data: &'buf [u8], +} + +impl<'buf> RecvMsgOut<'buf> { + const DATA_START: usize = std::mem::size_of::(); + + /// Parse the data buffered upon completion of a `RecvMsg` multishot operation. + /// + /// `buffer` is the whole buffer previously provided to the ring, while `msghdr` + /// is the same content provided as input to the corresponding SQE + /// (only `msg_namelen` and `msg_controllen` fields are relevant). + pub fn parse(buffer: &'buf [u8], msghdr: &libc::msghdr) -> Result { + if buffer.len() < std::mem::size_of::() { + return Err(()); + } + // SAFETY: buffer (minimum) length is checked here above. + let header: sys::io_uring_recvmsg_out = + unsafe { std::ptr::read_unaligned(buffer.as_ptr() as _) }; + + let msghdr_name_len = msghdr.msg_namelen as _; + let msghdr_control_len = msghdr.msg_controllen as _; + + // Check total length upfront, so that further logic here + // below can safely use unchecked/saturating math. + let length_overflow = Some(Self::DATA_START) + .and_then(|acc| acc.checked_add(msghdr_name_len)) + .and_then(|acc| acc.checked_add(msghdr_control_len)) + .and_then(|acc| acc.checked_add(header.payloadlen as usize)) + .map(|total_len| total_len > buffer.len()) + .unwrap_or(true); + if length_overflow { + return Err(()); + } + + let (name_data, control_start) = { + let name_start = Self::DATA_START; + let name_size = usize::min(header.namelen as usize, msghdr_name_len); + let name_data_end = name_start.saturating_add(name_size); + let name_data = &buffer[name_start..name_data_end]; + let name_field_end = name_start.saturating_add(msghdr_name_len); + (name_data, name_field_end) + }; + let (control_data, payload_start) = { + let control_size = usize::min(header.controllen as usize, msghdr_control_len); + let control_data_end = control_start.saturating_add(control_size); + let control_data = &buffer[control_start..control_data_end]; + let control_field_end = control_start.saturating_add(msghdr_control_len); + (control_data, control_field_end) + }; + let payload_data = { + let payload_data_end = payload_start.saturating_add(header.payloadlen as usize); + &buffer[payload_start..payload_data_end] + }; + + Ok(Self { + header, + msghdr_name_len, + msghdr_control_len, + name_data, + control_data, + payload_data, + }) + } + + /// Return the length of the incoming `name` data. + /// + /// This may be larger than the size of the content returned by + /// `name_data()`, if the kernel could not fit all the incoming + /// data in the provided buffer size. In that case, name data in + /// the result buffer gets truncated. + pub fn incoming_name_len(&self) -> u32 { + self.header.namelen + } + + /// Return whether the incoming name data was larger than the provided limit/buffer. + /// + /// When `true`, data returned by `name_data()` is truncated and + /// incomplete. + pub fn is_name_data_truncated(&self) -> bool { + self.header.namelen as usize > self.msghdr_name_len + } + + /// Message control data, with the same semantics as `msghdr.msg_control`. + pub fn name_data(&self) -> &[u8] { + self.name_data + } + + /// Return the length of the incoming `control` data. + /// + /// This may be larger than the size of the content returned by + /// `control_data()`, if the kernel could not fit all the incoming + /// data in the provided buffer size. In that case, control data in + /// the result buffer gets truncated. + pub fn incoming_control_len(&self) -> u32 { + self.header.controllen + } + + /// Return whether the incoming control data was larger than the provided limit/buffer. + /// + /// When `true`, data returned by `control_data()` is truncated and + /// incomplete. + pub fn is_control_data_truncated(&self) -> bool { + self.header.controllen as usize > self.msghdr_control_len + } + + /// Message control data, with the same semantics as `msghdr.msg_control`. + pub fn control_data(&self) -> &[u8] { + self.control_data + } + + /// Return whether the incoming payload was larger than the provided limit/buffer. + /// + /// When `true`, data returned by `payload_data()` is truncated and + /// incomplete. + pub fn is_payload_truncated(&self) -> bool { + self.header.flags & (libc::MSG_TRUNC as u32) != 0 + } + + /// Message payload, as buffered by the kernel. + pub fn payload_data(&self) -> &[u8] { + self.payload_data + } + + /// Message flags, with the same semantics as `msghdr.msg_flags`. + pub fn flags(&self) -> u32 { + self.header.flags + } +}