Skip to content

Add support for Packet MMAP #1013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -36,7 +36,7 @@ once_cell = { version = "1.5.2", optional = true }
# libc backend can be selected via adding `--cfg=rustix_use_libc` to
# `RUSTFLAGS` or enabling the `use-libc` cargo feature.
[target.'cfg(all(not(rustix_use_libc), not(miri), target_os = "linux", target_endian = "little", any(target_arch = "arm", all(target_arch = "aarch64", target_pointer_width = "64"), target_arch = "riscv64", all(rustix_use_experimental_asm, target_arch = "powerpc64"), all(rustix_use_experimental_asm, target_arch = "mips"), all(rustix_use_experimental_asm, target_arch = "mips32r6"), all(rustix_use_experimental_asm, target_arch = "mips64"), all(rustix_use_experimental_asm, target_arch = "mips64r6"), target_arch = "x86", all(target_arch = "x86_64", target_pointer_width = "64"))))'.dependencies]
linux-raw-sys = { version = "0.4.12", default-features = false, features = ["general", "errno", "ioctl", "no_std", "elf"] }
linux-raw-sys = { version = "0.6.4", default-features = false, features = ["general", "errno", "ioctl", "no_std", "elf"] }
libc_errno = { package = "errno", version = "0.3.8", default-features = false, optional = true }
libc = { version = "0.2.153", default-features = false, features = ["extra_traits"], optional = true }

@@ -53,7 +53,7 @@ libc = { version = "0.2.153", default-features = false, features = ["extra_trait
# Some syscalls do not have libc wrappers, such as in `io_uring`. For these,
# the libc backend uses the linux-raw-sys ABI and `libc::syscall`.
[target.'cfg(all(any(target_os = "android", target_os = "linux"), any(rustix_use_libc, miri, not(all(target_os = "linux", target_endian = "little", any(target_arch = "arm", all(target_arch = "aarch64", target_pointer_width = "64"), target_arch = "riscv64", all(rustix_use_experimental_asm, target_arch = "powerpc64"), all(rustix_use_experimental_asm, target_arch = "mips"), all(rustix_use_experimental_asm, target_arch = "mips32r6"), all(rustix_use_experimental_asm, target_arch = "mips64"), all(rustix_use_experimental_asm, target_arch = "mips64r6"), target_arch = "x86", all(target_arch = "x86_64", target_pointer_width = "64")))))))'.dependencies]
linux-raw-sys = { version = "0.4.12", default-features = false, features = ["general", "ioctl", "no_std"] }
linux-raw-sys = { version = "0.6.4", default-features = false, features = ["general", "ioctl", "no_std"] }

# For the libc backend on Windows, use the Winsock API in windows-sys.
[target.'cfg(windows)'.dependencies.windows-sys]
@@ -141,7 +141,7 @@ io_uring = ["event", "fs", "net", "linux-raw-sys/io_uring"]
mount = []

# Enable `rustix::net::*`.
net = ["linux-raw-sys/net", "linux-raw-sys/netlink", "linux-raw-sys/if_ether", "linux-raw-sys/xdp"]
net = ["linux-raw-sys/net", "linux-raw-sys/netlink", "linux-raw-sys/if_ether", "linux-raw-sys/if_packet", "linux-raw-sys/xdp"]

# Enable `rustix::thread::*`.
thread = ["linux-raw-sys/prctl"]
366 changes: 366 additions & 0 deletions examples/packet/inner.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,366 @@
use rustix::event::{poll, PollFd, PollFlags};
use rustix::fd::OwnedFd;
use rustix::mm::{mmap, munmap, MapFlags, ProtFlags};
use rustix::net::{
bind_link, eth,
netdevice::name_to_index,
packet::{PacketHeader2, PacketReq, PacketReqAny, PacketStatus, SocketAddrLink},
send, socket_with,
sockopt::{set_packet_rx_ring, set_packet_tx_ring, set_packet_version, PacketVersion},
AddressFamily, SendFlags, SocketFlags, SocketType,
};
use std::{cell::Cell, collections::VecDeque, env, ffi::c_void, io, ptr, slice, str};

#[derive(Debug)]
pub struct Socket {
fd: OwnedFd,
block_size: usize,
block_count: usize,
frame_size: usize,
frame_count: usize,
rx: Cell<*mut c_void>,
tx: Cell<*mut c_void>,
}

impl Socket {
fn new(
name: &str,
block_size: usize,
block_count: usize,
frame_size: usize,
) -> io::Result<Self> {
let family = AddressFamily::PACKET;
let type_ = SocketType::RAW;
let flags = SocketFlags::empty();
let fd = socket_with(family, type_, flags, None)?;

let index = name_to_index(&fd, name)?;

set_packet_version(&fd, PacketVersion::V2)?;

let frame_count = (block_size * block_count) / frame_size;
let req = PacketReq {
block_size: block_size as u32,
block_nr: block_count as u32,
frame_size: frame_size as u32,
frame_nr: frame_count as u32,
};

let req = PacketReqAny::V2(req);
set_packet_rx_ring(&fd, &req)?;
set_packet_tx_ring(&fd, &req)?;

let addr = SocketAddrLink::new(eth::ALL, index);
bind_link(&fd, &addr)?;

let rx = unsafe {
mmap(
ptr::null_mut(),
block_size * block_count * 2,
ProtFlags::READ | ProtFlags::WRITE,
MapFlags::SHARED,
&fd,
0,
)
}?;
let tx = unsafe { rx.add(block_size * block_count) };

Ok(Self {
fd,
block_size,
block_count,
frame_size,
frame_count,
rx: Cell::new(rx),
tx: Cell::new(tx),
})
}

/// Returns a reader object for receiving packets.
pub fn reader(&self) -> Reader<'_> {
assert!(!self.rx.get().is_null());
Reader {
socket: self,
// Take ring pointer.
ring: self.rx.replace(ptr::null_mut()),
}
}

/// Returns a writer object for transmitting packets.
pub fn writer(&self) -> Writer<'_> {
assert!(!self.tx.get().is_null());
Writer {
socket: self,
// Take ring pointer.
ring: self.tx.replace(ptr::null_mut()),
}
}

/// Flushes the transmit buffer.
pub fn flush(&self) -> io::Result<()> {
send(&self.fd, &[], SendFlags::empty())?;
Ok(())
}
}

impl Drop for Socket {
fn drop(&mut self) {
debug_assert!(!self.rx.get().is_null());
debug_assert!(!self.tx.get().is_null());
unsafe {
let _ = munmap(self.rx.get(), self.block_size * self.block_count * 2);
}
}
}

/// TODO
#[derive(Debug)]
pub struct Packet<'r> {
header: &'r mut PacketHeader2,
}

impl<'r> Packet<'r> {
pub fn payload(&self) -> &[u8] {
let ptr = self.header.payload_rx();
let len = self.header.len as usize;
unsafe { slice::from_raw_parts(ptr, len) }
}
}

impl<'r> Drop for Packet<'r> {
fn drop(&mut self) {
self.header.status = PacketStatus::empty();
}
}

/// TODO
#[derive(Debug)]
pub struct Slot<'w> {
header: &'w mut PacketHeader2,
}

impl<'w> Slot<'w> {
pub fn write(&mut self, payload: &[u8]) {
let ptr = self.header.payload_tx();
// TODO verify length
let len = payload.len();
unsafe {
ptr.copy_from_nonoverlapping(payload.as_ptr(), len);
self.header.len = len as u32;
}
}
}

impl<'w> Drop for Slot<'w> {
fn drop(&mut self) {
self.header.status = PacketStatus::SEND_REQUEST;
}
}

/// A reader object for receiving packets.
#[derive(Debug)]
pub struct Reader<'s> {
socket: &'s Socket,
ring: *mut c_void, // Owned
}

impl<'s> Reader<'s> {
/// Returns an iterator over received packets.
/// The iterator blocks until at least one packet is received.
///
/// # Lifetimes
///
/// - `'s`: The lifetime of the socket.
/// - `'r`: The lifetime of the received packets.
pub fn wait<'r>(&'r mut self) -> io::Result<ReadIter<'s, 'r>>
where
's: 'r,
{
let flags = PollFlags::IN | PollFlags::RDNORM | PollFlags::ERR;
let pfd = PollFd::new(&self.socket.fd, flags);
let pfd = &mut [pfd];
let n = poll(pfd, -1)?;
assert_eq!(n, 1);
Ok(ReadIter {
reader: self,
index: 0,
})
}
}

impl<'s> Drop for Reader<'s> {
fn drop(&mut self) {
// Give back ring pointer.
self.socket.rx.set(self.ring);
}
}

/// A writer object for transmitting packets.
#[derive(Debug)]
pub struct Writer<'s> {
socket: &'s Socket,
ring: *mut c_void, // Owned
}

impl<'s> Writer<'s> {
/// Returns an iterator over available slots for transmitting packets.
/// The iterator blocks until at least one slot is available.
///
/// # Lifetimes
///
/// - `'s`: The lifetime of the socket.
/// - `'w`: The lifetime of the slots.
pub fn wait<'w>(&'w mut self) -> io::Result<WriteIter<'s, 'w>>
where
's: 'w,
{
let flags = PollFlags::OUT | PollFlags::WRNORM | PollFlags::ERR;
let pfd = PollFd::new(&self.socket.fd, flags);
let pfd = &mut [pfd];
let n = poll(pfd, -1)?;
assert_eq!(n, 1);
Ok(WriteIter {
writer: self,
index: 0,
})
}
}

impl<'s> Drop for Writer<'s> {
fn drop(&mut self) {
// Give back ring pointer.
self.socket.tx.set(self.ring);
}
}

/// An iterator over received packets.
#[derive(Debug)]
pub struct ReadIter<'s, 'r> {
reader: &'r mut Reader<'s>,
index: usize,
}

impl<'s, 'r> Iterator for ReadIter<'s, 'r> {
type Item = Packet<'r>;

fn next(&mut self) -> Option<Self::Item> {
while self.index < self.reader.socket.frame_count {
let base = unsafe {
self.reader
.ring
.add(self.index * self.reader.socket.frame_size)
};
self.index += 1;

if let Some(header) = unsafe { PacketHeader2::from_rx_ptr(base) } {
return Some(Packet { header });
}
}
None
}
}

/// An iterator over available slots for transmitting packets.
#[derive(Debug)]
pub struct WriteIter<'s, 'w> {
writer: &'w mut Writer<'s>,
index: usize,
}

impl<'s, 'w> Iterator for WriteIter<'s, 'w> {
type Item = Slot<'w>;

fn next(&mut self) -> Option<Self::Item> {
while self.index < self.writer.socket.frame_count {
let base = unsafe {
self.writer
.ring
.add(self.index * self.writer.socket.frame_size)
};
self.index += 1;

if let Some(header) = unsafe { PacketHeader2::from_tx_ptr(base) } {
return Some(Slot { header });
}
}
None
}
}

// ECHO server
fn server(socket: Socket, mut count: usize) -> io::Result<()> {
let mut reader = socket.reader();
let mut writer = socket.writer();

while count > 0 {
let mut queue = VecDeque::new();

for packet in reader.wait()? {
queue.push_back(packet);
}

while let Some(packet) = queue.pop_front() {
let mut iter = writer.wait()?.take(count);
while let Some(mut slot) = iter.next() {
let mut payload = packet.payload().to_vec();
assert_eq!(payload[12..14], [0x08, 0x00]);
payload.swap(14, 15);

slot.write(&payload);
drop(slot);
count -= 1;
}
drop(packet);
}

socket.flush()?;
}

Ok(())
}

// ECHO client
fn client(socket: Socket, mut count: usize) -> io::Result<()> {
let mut reader = socket.reader();
let mut writer = socket.writer();

while count > 0 {
let mut iter = writer.wait()?.take(count);
while let Some(mut slot) = iter.next() {
let payload = &[
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, // Destination
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Source
0x08, 0x00, // Type (IPv4, but not really)
0x13, 0x37, // Payload (some value)
];

slot.write(payload);
drop(slot);
count -= 1;
}

socket.flush()?;

for packet in reader.wait()? {
assert_eq!(packet.payload()[14..16], [0x37, 0x13]);
}
}

Ok(())
}

pub fn main() -> io::Result<()> {
let mut args = env::args().skip(1);
let name = args.next().expect("name");
let mode = args.next().expect("mode");
let count = args.next().expect("count");

let socket = Socket::new(&name, 4096, 4, 2048)?;
let count = count.parse().unwrap();

match mode.as_str() {
"server" => server(socket, count),
"client" => client(socket, count),
_ => panic!("invalid mode"),
}
}
32 changes: 32 additions & 0 deletions examples/packet/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//! Packet MMAP.
#[cfg(all(
feature = "mm",
feature = "net",
feature = "event",
feature = "std",
target_os = "linux"
))]
mod inner;

#[cfg(all(
feature = "mm",
feature = "net",
feature = "event",
feature = "std",
target_os = "linux"
))]
fn main() -> std::io::Result<()> {
inner::main()
}

#[cfg(any(
not(feature = "mm"),
not(feature = "net"),
not(feature = "event"),
not(feature = "std"),
not(target_os = "linux")
))]
fn main() -> Result<(), &'static str> {
Err("This example requires --features=mm,net,event,std and is only supported on Linux.")
}
21 changes: 11 additions & 10 deletions src/backend/linux_raw/c.rs
Original file line number Diff line number Diff line change
@@ -55,13 +55,14 @@ pub(crate) use linux_raw_sys::{
cmsg_macros::*,
general::{O_CLOEXEC as SOCK_CLOEXEC, O_NONBLOCK as SOCK_NONBLOCK},
if_ether::*,
if_packet::*,
net::{
linger, msghdr, sockaddr, sockaddr_in, sockaddr_in6, sockaddr_un, socklen_t, AF_DECnet,
__kernel_sa_family_t as sa_family_t, __kernel_sockaddr_storage as sockaddr_storage,
cmsghdr, in6_addr, in_addr, ip_mreq, ip_mreq_source, ip_mreqn, ipv6_mreq, AF_APPLETALK,
AF_ASH, AF_ATMPVC, AF_ATMSVC, AF_AX25, AF_BLUETOOTH, AF_BRIDGE, AF_CAN, AF_ECONET,
AF_IEEE802154, AF_INET, AF_INET6, AF_IPX, AF_IRDA, AF_ISDN, AF_IUCV, AF_KEY, AF_LLC,
AF_NETBEUI, AF_NETLINK, AF_NETROM, AF_PACKET, AF_PHONET, AF_PPPOX, AF_RDS, AF_ROSE,
cmsghdr, ifreq, in6_addr, in_addr, ip_mreq, ip_mreq_source, ip_mreqn, ipv6_mreq,
AF_APPLETALK, AF_ASH, AF_ATMPVC, AF_ATMSVC, AF_AX25, AF_BLUETOOTH, AF_BRIDGE, AF_CAN,
AF_ECONET, AF_IEEE802154, AF_INET, AF_INET6, AF_IPX, AF_IRDA, AF_ISDN, AF_IUCV, AF_KEY,
AF_LLC, AF_NETBEUI, AF_NETLINK, AF_NETROM, AF_PACKET, AF_PHONET, AF_PPPOX, AF_RDS, AF_ROSE,
AF_RXRPC, AF_SECURITY, AF_SNA, AF_TIPC, AF_UNIX, AF_UNSPEC, AF_WANPIPE, AF_X25, AF_XDP,
IP6T_SO_ORIGINAL_DST, IPPROTO_FRAGMENT, IPPROTO_ICMPV6, IPPROTO_MH, IPPROTO_ROUTING,
IPV6_ADD_MEMBERSHIP, IPV6_DROP_MEMBERSHIP, IPV6_FREEBIND, IPV6_MULTICAST_HOPS,
@@ -71,12 +72,12 @@ pub(crate) use linux_raw_sys::{
MSG_CMSG_CLOEXEC, MSG_CONFIRM, MSG_DONTROUTE, MSG_DONTWAIT, MSG_EOR, MSG_ERRQUEUE,
MSG_MORE, MSG_NOSIGNAL, MSG_OOB, MSG_PEEK, MSG_TRUNC, MSG_WAITALL, SCM_CREDENTIALS,
SCM_RIGHTS, SHUT_RD, SHUT_RDWR, SHUT_WR, SOCK_DGRAM, SOCK_RAW, SOCK_RDM, SOCK_SEQPACKET,
SOCK_STREAM, SOL_SOCKET, SOL_XDP, SO_ACCEPTCONN, SO_BROADCAST, SO_COOKIE, SO_DOMAIN,
SO_ERROR, SO_INCOMING_CPU, SO_KEEPALIVE, SO_LINGER, SO_OOBINLINE, SO_ORIGINAL_DST,
SO_PASSCRED, SO_PROTOCOL, SO_RCVBUF, SO_RCVTIMEO_NEW, SO_RCVTIMEO_NEW as SO_RCVTIMEO,
SO_RCVTIMEO_OLD, SO_REUSEADDR, SO_REUSEPORT, SO_SNDBUF, SO_SNDTIMEO_NEW,
SO_SNDTIMEO_NEW as SO_SNDTIMEO, SO_SNDTIMEO_OLD, SO_TYPE, TCP_CONGESTION, TCP_CORK,
TCP_KEEPCNT, TCP_KEEPIDLE, TCP_KEEPINTVL, TCP_NODELAY, TCP_QUICKACK,
SOCK_STREAM, SOL_PACKET, SOL_SOCKET, SOL_XDP, SO_ACCEPTCONN, SO_BROADCAST, SO_COOKIE,
SO_DOMAIN, SO_ERROR, SO_INCOMING_CPU, SO_KEEPALIVE, SO_LINGER, SO_OOBINLINE,
SO_ORIGINAL_DST, SO_PASSCRED, SO_PROTOCOL, SO_RCVBUF, SO_RCVTIMEO_NEW,
SO_RCVTIMEO_NEW as SO_RCVTIMEO, SO_RCVTIMEO_OLD, SO_REUSEADDR, SO_REUSEPORT, SO_SNDBUF,
SO_SNDTIMEO_NEW, SO_SNDTIMEO_NEW as SO_SNDTIMEO, SO_SNDTIMEO_OLD, SO_TYPE, TCP_CONGESTION,
TCP_CORK, TCP_KEEPCNT, TCP_KEEPIDLE, TCP_KEEPINTVL, TCP_NODELAY, TCP_QUICKACK,
TCP_THIN_LINEAR_TIMEOUTS, TCP_USER_TIMEOUT,
},
netlink::*,
13 changes: 13 additions & 0 deletions src/backend/linux_raw/net/read_sockaddr.rs
Original file line number Diff line number Diff line change
@@ -6,6 +6,8 @@ use crate::backend::c;
use crate::io;
#[cfg(target_os = "linux")]
use crate::net::xdp::{SockaddrXdpFlags, SocketAddrXdp};
#[cfg(target_os = "linux")]
use crate::net::packet::SocketAddrLink;
use crate::net::{Ipv4Addr, Ipv6Addr, SocketAddrAny, SocketAddrUnix, SocketAddrV4, SocketAddrV6};
use core::mem::size_of;
use core::slice;
@@ -127,6 +129,10 @@ pub(crate) unsafe fn read_sockaddr(
u32::from_be(decode.sxdp_shared_umem_fd),
)))
}
#[cfg(target_os = "linux")]
c::AF_PACKET => {
todo!();
}
_ => Err(io::Errno::NOTSUP),
}
}
@@ -216,6 +222,13 @@ pub(crate) unsafe fn read_sockaddr_os(storage: *const c::sockaddr, len: usize) -
u32::from_be(decode.sxdp_shared_umem_fd),
))
}
#[cfg(target_os = "linux")]
c::AF_PACKET => {
assert!(len >= size_of::<c::sockaddr_ll>());
// SocketAddrLink and sockaddr_ll have the same layout.
let addr = &*storage.cast::<SocketAddrLink>();
SocketAddrAny::Link(*addr)
}
other => unimplemented!("{:?}", other),
}
}
60 changes: 59 additions & 1 deletion src/backend/linux_raw/net/sockopt.rs
Original file line number Diff line number Diff line change
@@ -11,7 +11,9 @@ use crate::fd::BorrowedFd;
#[cfg(feature = "alloc")]
use crate::ffi::CStr;
use crate::io;
use crate::net::sockopt::Timeout;
#[cfg(target_os = "linux")]
use crate::net::packet::{PacketReqAny, PacketStats, PacketStats3, PacketStatsAny};
use crate::net::sockopt::{PacketVersion, Timeout};
#[cfg(target_os = "linux")]
use crate::net::xdp::{XdpMmapOffsets, XdpOptionsFlags, XdpRingOffset, XdpStatistics, XdpUmemReg};
use crate::net::{
@@ -969,6 +971,62 @@ pub(crate) fn get_xdp_options(fd: BorrowedFd<'_>) -> io::Result<XdpOptionsFlags>
getsockopt(fd, c::SOL_XDP, c::XDP_OPTIONS)
}

#[cfg(target_os = "linux")]
#[inline]
pub(crate) fn set_packet_rx_ring(fd: BorrowedFd<'_>, value: &PacketReqAny) -> io::Result<()> {
match *value {
PacketReqAny::V1(value) | PacketReqAny::V2(value) => {
setsockopt(fd, c::SOL_PACKET, c::PACKET_RX_RING, value)
}
PacketReqAny::V3(value) => setsockopt(fd, c::SOL_PACKET, c::PACKET_RX_RING, value),
}
}

#[cfg(target_os = "linux")]
#[inline]
pub(crate) fn set_packet_tx_ring(fd: BorrowedFd<'_>, value: &PacketReqAny) -> io::Result<()> {
match *value {
PacketReqAny::V1(value) | PacketReqAny::V2(value) => {
setsockopt(fd, c::SOL_PACKET, c::PACKET_TX_RING, value)
}
PacketReqAny::V3(value) => setsockopt(fd, c::SOL_PACKET, c::PACKET_TX_RING, value),
}
}

#[cfg(target_os = "linux")]
#[inline]
pub(crate) fn set_packet_version(fd: BorrowedFd<'_>, value: PacketVersion) -> io::Result<()> {
setsockopt(fd, c::SOL_PACKET, c::PACKET_VERSION, value)
}

#[cfg(target_os = "linux")]
#[inline]
pub(crate) fn get_packet_version(fd: BorrowedFd<'_>) -> io::Result<PacketVersion> {
getsockopt(fd, c::SOL_PACKET, c::PACKET_VERSION)
}

#[cfg(target_os = "linux")]
#[inline]
pub(crate) fn get_packet_stats(
fd: BorrowedFd<'_>,
version: PacketVersion,
) -> io::Result<PacketStatsAny> {
match version {
PacketVersion::V1 => {
let stats: PacketStats = getsockopt(fd, c::SOL_PACKET, c::PACKET_STATISTICS)?;
Ok(PacketStatsAny::V1(stats))
}
PacketVersion::V2 => {
let stats: PacketStats = getsockopt(fd, c::SOL_PACKET, c::PACKET_STATISTICS)?;
Ok(PacketStatsAny::V2(stats))
}
PacketVersion::V3 => {
let stats: PacketStats3 = getsockopt(fd, c::SOL_PACKET, c::PACKET_STATISTICS)?;
Ok(PacketStatsAny::V3(stats))
}
}
}

#[inline]
fn to_ip_mreq(multiaddr: &Ipv4Addr, interface: &Ipv4Addr) -> c::ip_mreq {
c::ip_mreq {
81 changes: 81 additions & 0 deletions src/backend/linux_raw/net/syscalls.rs
Original file line number Diff line number Diff line change
@@ -13,6 +13,8 @@ use super::msghdr::{
use super::read_sockaddr::{initialize_family_to_unspec, maybe_read_sockaddr_os, read_sockaddr_os};
use super::send_recv::{RecvFlags, SendFlags};
#[cfg(target_os = "linux")]
use super::write_sockaddr::encode_sockaddr_link;
#[cfg(target_os = "linux")]
use super::write_sockaddr::encode_sockaddr_xdp;
use super::write_sockaddr::{encode_sockaddr_v4, encode_sockaddr_v6};
use crate::backend::c;
@@ -23,6 +25,8 @@ use crate::backend::conv::{
use crate::fd::{BorrowedFd, OwnedFd};
use crate::io::{self, IoSlice, IoSliceMut};
#[cfg(target_os = "linux")]
use crate::net::packet::SocketAddrLink;
#[cfg(target_os = "linux")]
use crate::net::xdp::SocketAddrXdp;
use crate::net::{
AddressFamily, Protocol, RecvAncillaryBuffer, RecvMsgReturn, SendAncillaryBuffer, Shutdown,
@@ -439,6 +443,18 @@ pub(crate) fn sendmsg_xdp(
})
}

#[cfg(target_os = "linux")]
#[inline]
pub(crate) fn sendmsg_link(
_sockfd: BorrowedFd<'_>,
_addr: &SocketAddrLink,
_iov: &[IoSlice<'_>],
_control: &mut SendAncillaryBuffer<'_, '_, '_>,
_msg_flags: SendFlags,
) -> io::Result<usize> {
todo!()
}

#[inline]
pub(crate) fn shutdown(fd: BorrowedFd<'_>, how: Shutdown) -> io::Result<()> {
#[cfg(not(target_arch = "x86"))]
@@ -660,6 +676,45 @@ pub(crate) fn sendto_xdp(
}
}

#[cfg(target_os = "linux")]
#[inline]
pub(crate) fn sendto_link(
fd: BorrowedFd<'_>,
buf: &[u8],
flags: SendFlags,
addr: &SocketAddrLink,
) -> io::Result<usize> {
let (buf_addr, buf_len) = slice(buf);

#[cfg(not(target_arch = "x86"))]
unsafe {
ret_usize(syscall_readonly!(
__NR_sendto,
fd,
buf_addr,
buf_len,
flags,
by_ref(&encode_sockaddr_link(addr)),
size_of::<c::sockaddr_ll, _>()
))
}
#[cfg(target_arch = "x86")]
unsafe {
ret_usize(syscall_readonly!(
__NR_socketcall,
x86_sys(SYS_SENDTO),
slice_just_addr::<ArgReg<'_, SocketArg>, _>(&[
fd.into(),
buf_addr,
buf_len,
flags.into(),
by_ref(&encode_sockaddr_link(addr)),
size_of::<c::sockaddr_ll, _>(),
])
))
}
}

#[inline]
pub(crate) unsafe fn recv(
fd: BorrowedFd<'_>,
@@ -931,6 +986,32 @@ pub(crate) fn bind_xdp(fd: BorrowedFd<'_>, addr: &SocketAddrXdp) -> io::Result<(
}
}

#[cfg(target_os = "linux")]
#[inline]
pub(crate) fn bind_link(fd: BorrowedFd<'_>, addr: &SocketAddrLink) -> io::Result<()> {
#[cfg(not(target_arch = "x86"))]
unsafe {
ret(syscall_readonly!(
__NR_bind,
fd,
by_ref(&encode_sockaddr_link(addr)),
size_of::<c::sockaddr_ll, _>()
))
}
#[cfg(target_arch = "x86")]
unsafe {
ret(syscall_readonly!(
__NR_socketcall,
x86_sys(SYS_BIND),
slice_just_addr::<ArgReg<'_, SocketArg>, _>(&[
fd.into(),
by_ref(&encode_sockaddr_link(addr)),
size_of::<c::sockaddr_ll, _>(),
])
))
}
}

#[inline]
pub(crate) fn connect_v4(fd: BorrowedFd<'_>, addr: &SocketAddrV4) -> io::Result<()> {
#[cfg(not(target_arch = "x86"))]
17 changes: 17 additions & 0 deletions src/backend/linux_raw/net/write_sockaddr.rs
Original file line number Diff line number Diff line change
@@ -4,6 +4,8 @@

use crate::backend::c;
#[cfg(target_os = "linux")]
use crate::net::packet::SocketAddrLink;
#[cfg(target_os = "linux")]
use crate::net::xdp::SocketAddrXdp;
use crate::net::{SocketAddrAny, SocketAddrStorage, SocketAddrUnix, SocketAddrV4, SocketAddrV6};
use core::mem::size_of;
@@ -18,6 +20,8 @@ pub(crate) unsafe fn write_sockaddr(
SocketAddrAny::Unix(unix) => write_sockaddr_unix(unix, storage),
#[cfg(target_os = "linux")]
SocketAddrAny::Xdp(xdp) => write_sockaddr_xdp(xdp, storage),
#[cfg(target_os = "linux")]
SocketAddrAny::Link(link) => write_sockaddr_link(link, storage),
}
}

@@ -80,3 +84,16 @@ unsafe fn write_sockaddr_xdp(xdp: &SocketAddrXdp, storage: *mut SocketAddrStorag
core::ptr::write(storage.cast(), encoded);
size_of::<c::sockaddr_xdp>()
}

#[cfg(target_os = "linux")]
pub(crate) fn encode_sockaddr_link(link: &SocketAddrLink) -> c::sockaddr_ll {
// SAFETY: both types have the same memory layout
unsafe { (link as *const _ as *const c::sockaddr_ll).read() }
}

#[cfg(target_os = "linux")]
unsafe fn write_sockaddr_link(link: &SocketAddrLink, storage: *mut SocketAddrStorage) -> usize {
let encoded = encode_sockaddr_link(link);
core::ptr::write(storage.cast(), encoded);
size_of::<c::sockaddr_ll>()
}
2 changes: 2 additions & 0 deletions src/net/mod.rs
Original file line number Diff line number Diff line change
@@ -18,6 +18,8 @@ mod wsa;

#[cfg(linux_kernel)]
pub mod netdevice;
#[cfg(linux_kernel)]
pub mod packet;
pub mod sockopt;

pub use crate::maybe_polyfill::net::{
542 changes: 542 additions & 0 deletions src/net/packet.rs

Large diffs are not rendered by default.

23 changes: 23 additions & 0 deletions src/net/send_recv/mod.rs
Original file line number Diff line number Diff line change
@@ -4,6 +4,8 @@

use crate::buffer::split_init;
#[cfg(target_os = "linux")]
use crate::net::packet::SocketAddrLink;
#[cfg(target_os = "linux")]
use crate::net::xdp::SocketAddrXdp;
#[cfg(unix)]
use crate::net::SocketAddrUnix;
@@ -265,6 +267,8 @@ fn _sendto_any(
SocketAddrAny::Unix(unix) => backend::net::syscalls::sendto_unix(fd, buf, flags, unix),
#[cfg(target_os = "linux")]
SocketAddrAny::Xdp(xdp) => backend::net::syscalls::sendto_xdp(fd, buf, flags, xdp),
#[cfg(target_os = "linux")]
SocketAddrAny::Link(link) => backend::net::syscalls::sendto_link(fd, buf, flags, link),
}
}

@@ -401,3 +405,22 @@ pub fn sendto_xdp<Fd: AsFd>(
) -> io::Result<usize> {
backend::net::syscalls::sendto_xdp(fd.as_fd(), buf, flags, addr)
}

/// `sendto(fd, buf, flags, addr, sizeof(struct sockaddr_ll))`—Writes data
/// to a socket to a specific link-layer address.
///
/// # References
/// - [Linux]
///
/// [Linux]: https://man7.org/linux/man-pages/man2/sendto.2.html
#[cfg(target_os = "linux")]
#[inline]
#[doc(alias = "sendto")]
pub fn sendto_link<Fd: AsFd>(
fd: Fd,
buf: &[u8],
flags: SendFlags,
addr: &SocketAddrLink,
) -> io::Result<usize> {
backend::net::syscalls::sendto_link(fd.as_fd(), buf, flags, addr)
}
4 changes: 4 additions & 0 deletions src/net/send_recv/msg.rs
Original file line number Diff line number Diff line change
@@ -776,6 +776,10 @@ pub fn sendmsg_any(
Some(SocketAddrAny::Xdp(addr)) => {
backend::net::syscalls::sendmsg_xdp(socket.as_fd(), addr, iov, control, flags)
}
#[cfg(target_os = "linux")]
Some(SocketAddrAny::Link(addr)) => {
backend::net::syscalls::sendmsg_link(socket.as_fd(), addr, iov, control, flags)
}
}
}

19 changes: 19 additions & 0 deletions src/net/socket.rs
Original file line number Diff line number Diff line change
@@ -3,6 +3,8 @@ use crate::net::{SocketAddr, SocketAddrAny, SocketAddrV4, SocketAddrV6};
use crate::{backend, io};
use backend::fd::{AsFd, BorrowedFd};

#[cfg(target_os = "linux")]
use crate::net::packet::SocketAddrLink;
#[cfg(target_os = "linux")]
use crate::net::xdp::SocketAddrXdp;
pub use crate::net::{AddressFamily, Protocol, Shutdown, SocketFlags, SocketType};
@@ -172,6 +174,8 @@ fn _bind_any(sockfd: BorrowedFd<'_>, addr: &SocketAddrAny) -> io::Result<()> {
SocketAddrAny::Unix(unix) => backend::net::syscalls::bind_unix(sockfd, unix),
#[cfg(target_os = "linux")]
SocketAddrAny::Xdp(xdp) => backend::net::syscalls::bind_xdp(sockfd, xdp),
#[cfg(target_os = "linux")]
SocketAddrAny::Link(link) => backend::net::syscalls::bind_link(sockfd, link),
}
}

@@ -289,6 +293,19 @@ pub fn bind_xdp<Fd: AsFd>(sockfd: Fd, addr: &SocketAddrXdp) -> io::Result<()> {
backend::net::syscalls::bind_xdp(sockfd.as_fd(), addr)
}

/// `bind(sockfd, addr, sizeof(struct sockaddr_ll))`—Binds a socket to a link-layer address.
///
/// # References
/// - [Linux]
///
/// [Linux]: https://man7.org/linux/man2/bind.2.html
#[cfg(target_os = "linux")]
#[inline]
#[doc(alias = "bind")]
pub fn bind_link<Fd: AsFd>(sockfd: Fd, addr: &SocketAddrLink) -> io::Result<()> {
backend::net::syscalls::bind_link(sockfd.as_fd(), addr)
}

/// `connect(sockfd, addr)`—Initiates a connection to an IP address.
///
/// On Windows, a non-blocking socket returns [`Errno::WOULDBLOCK`] if the
@@ -370,6 +387,8 @@ fn _connect_any(sockfd: BorrowedFd<'_>, addr: &SocketAddrAny) -> io::Result<()>
SocketAddrAny::Unix(unix) => backend::net::syscalls::connect_unix(sockfd, unix),
#[cfg(target_os = "linux")]
SocketAddrAny::Xdp(_) => Err(io::Errno::OPNOTSUPP),
#[cfg(target_os = "linux")]
SocketAddrAny::Link(_) => Err(io::Errno::OPNOTSUPP),
}
}

17 changes: 17 additions & 0 deletions src/net/socket_addr_any.rs
Original file line number Diff line number Diff line change
@@ -9,6 +9,8 @@
//! OS-specific socket address representations in memory.
#![allow(unsafe_code)]

#[cfg(target_os = "linux")]
use crate::net::packet::SocketAddrLink;
#[cfg(target_os = "linux")]
use crate::net::xdp::SocketAddrXdp;
#[cfg(unix)]
@@ -35,6 +37,9 @@ pub enum SocketAddrAny {
/// `struct sockaddr_xdp`
#[cfg(target_os = "linux")]
Xdp(SocketAddrXdp),
/// `struct sockaddr_ll`
#[cfg(target_os = "linux")]
Link(SocketAddrLink),
}

impl From<SocketAddr> for SocketAddrAny {
@@ -69,6 +74,14 @@ impl From<SocketAddrUnix> for SocketAddrAny {
}
}

#[cfg(target_os = "linux")]
impl From<SocketAddrLink> for SocketAddrAny {
#[inline]
fn from(from: SocketAddrLink) -> Self {
Self::Link(from)
}
}

impl SocketAddrAny {
/// Return the address family of this socket address.
#[inline]
@@ -80,6 +93,8 @@ impl SocketAddrAny {
Self::Unix(_) => AddressFamily::UNIX,
#[cfg(target_os = "linux")]
Self::Xdp(_) => AddressFamily::XDP,
#[cfg(target_os = "linux")]
Self::Link(_) => AddressFamily::PACKET,
}
}

@@ -117,6 +132,8 @@ impl fmt::Debug for SocketAddrAny {
Self::Unix(unix) => unix.fmt(fmt),
#[cfg(target_os = "linux")]
Self::Xdp(xdp) => xdp.fmt(fmt),
#[cfg(target_os = "linux")]
Self::Link(link) => link.fmt(fmt),
}
}
}
76 changes: 76 additions & 0 deletions src/net/sockopt.rs
Original file line number Diff line number Diff line change
@@ -143,6 +143,8 @@
#![doc(alias = "getsockopt")]
#![doc(alias = "setsockopt")]

#[cfg(linux_kernel)]
use crate::net::packet::{PacketReqAny, PacketStatsAny};
#[cfg(target_os = "linux")]
use crate::net::xdp::{XdpMmapOffsets, XdpOptionsFlags, XdpStatistics, XdpUmemReg};
#[cfg(not(any(
@@ -1472,11 +1474,85 @@ pub fn get_xdp_options<Fd: AsFd>(fd: Fd) -> io::Result<XdpOptionsFlags> {
backend::net::sockopt::get_xdp_options(fd.as_fd())
}

/// `setsockopt(fd, SOL_SOCKET, PACKET_RX_RING, value)`
///
/// # References
/// - [Linux]
///
/// [Linux]: https://www.kernel.org/doc/html/next/networking/packet_mmap.html#packet-mmap-settings
#[cfg(linux_kernel)]
#[doc(alias = "PACKET_RX_RING")]
pub fn set_packet_rx_ring<Fd: AsFd>(fd: Fd, value: &PacketReqAny) -> io::Result<()> {
backend::net::sockopt::set_packet_rx_ring(fd.as_fd(), value)
}

/// `setsockopt(fd, SOL_SOCKET, PACKET_TX_RING, value)`
///
/// # References
/// - [Linux]
///
/// [Linux]: https://www.kernel.org/doc/html/next/networking/packet_mmap.html#packet-mmap-settings
#[cfg(linux_kernel)]
#[doc(alias = "PACKET_TX_RING")]
pub fn set_packet_tx_ring<Fd: AsFd>(fd: Fd, value: &PacketReqAny) -> io::Result<()> {
backend::net::sockopt::set_packet_tx_ring(fd.as_fd(), value)
}

/// Packet MMAP versions for use with [`set_packet_version`].
#[repr(u32)]
#[non_exhaustive]
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
pub enum PacketVersion {
/// `TPACKET_V1`
V1 = c::tpacket_versions::TPACKET_V1 as _,
/// `TPACKET_V2`
V2 = c::tpacket_versions::TPACKET_V2 as _,
/// `TPACKET_V3`
V3 = c::tpacket_versions::TPACKET_V3 as _,
}

/// `setsockopt(fd, SOL_PACKET, PACKET_VERSION, value)`
///
/// # References
/// - [Linux]
///
/// [Linux]: https://www.kernel.org/doc/html/next/networking/packet_mmap.html#what-tpacket-versions-are-available-and-when-to-use-them
#[cfg(linux_kernel)]
#[doc(alias = "PACKET_VERSION")]
pub fn set_packet_version<Fd: AsFd>(fd: Fd, value: PacketVersion) -> io::Result<()> {
backend::net::sockopt::set_packet_version(fd.as_fd(), value)
}

/// `getsockopt(fd, SOL_PACKET, PACKET_VERSION)`
///
/// # References
/// - [Linux]
///
/// [Linux]: https://www.kernel.org/doc/html/next/networking/packet_mmap.html#what-tpacket-versions-are-available-and-when-to-use-them
#[cfg(linux_kernel)]
#[doc(alias = "PACKET_VERSION")]
pub fn get_packet_version<Fd: AsFd>(fd: Fd) -> io::Result<PacketVersion> {
backend::net::sockopt::get_packet_version(fd.as_fd())
}

/// `getsockopt(fd, SOL_PACKET, PACKET_STATISTICS)`
///
/// # References
/// - [Linux]
///
/// [Linux]: https://www.kernel.org/doc/html/next/networking/packet_mmap.html
#[cfg(linux_kernel)]
#[doc(alias = "PACKET_STATISTICS")]
pub fn get_packet_stats<Fd: AsFd>(fd: Fd, version: PacketVersion) -> io::Result<PacketStatsAny> {
backend::net::sockopt::get_packet_stats(fd.as_fd(), version)
}

#[test]
fn test_sizes() {
use c::c_int;

// Backend code needs to cast these to `c_int` so make sure that cast
// isn't lossy.
assert_eq_size!(Timeout, c_int);
assert_eq_size!(PacketVersion, c_int);
}