From 4a230bba746bafb0c152ee08a759990e6ed6a008 Mon Sep 17 00:00:00 2001 From: Chayim Refael Friedman Date: Sun, 17 Nov 2024 21:26:15 +0200 Subject: [PATCH] Support ranges in `<[T]>::get_many_mut()` I implemented that with a separate trait and not within `SliceIndex`, because doing that via `SliceIndex` requires adding support for range types that are (almost) always overlapping e.g. `RangeFrom`, and also adding fake support code for `impl SliceIndex`. An inconvenience that I ran into was that slice indexing takes the index by value, but I only have it by reference. I could change slice indexing to take by ref, but this is pretty much the hottest code ever so I'm afraid to touch it. Instead I added a requirement for `Clone` (which all index types implement anyway) and cloned. This is an internal requirement the user won't see and the clone should always be optimized away. I also implemented `Clone`, `PartialEq` and `Eq` for the error type, since I noticed it does not do that when writing the tests and other errors in std seem to implement them. I didn't implement `Copy` because maybe we will want to put something non-`Copy` there. --- library/core/src/slice/mod.rs | 198 ++++++++++++++++++++++++++++++---- library/core/tests/slice.rs | 70 +++++++++++- 2 files changed, 249 insertions(+), 19 deletions(-) diff --git a/library/core/src/slice/mod.rs b/library/core/src/slice/mod.rs index ef52cc441268a..2fca2ca05efa2 100644 --- a/library/core/src/slice/mod.rs +++ b/library/core/src/slice/mod.rs @@ -10,10 +10,10 @@ use crate::cmp::Ordering::{self, Equal, Greater, Less}; use crate::intrinsics::{exact_div, select_unpredictable, unchecked_sub}; use crate::mem::{self, SizedTypeProperties}; use crate::num::NonZero; -use crate::ops::{Bound, OneSidedRange, Range, RangeBounds}; +use crate::ops::{Bound, OneSidedRange, Range, RangeBounds, RangeInclusive}; use crate::simd::{self, Simd}; use crate::ub_checks::assert_unsafe_precondition; -use crate::{fmt, hint, ptr, slice}; +use crate::{fmt, hint, ptr, range, slice}; #[unstable( feature = "slice_internals", @@ -4467,6 +4467,12 @@ impl [T] { /// Returns mutable references to many indices at once, without doing any checks. /// + /// An index can be either a `usize`, a [`Range`] or a [`RangeInclusive`]. Note + /// that this method takes an array, so all indices must be of the same type. + /// If passed an array of `usize`s this method gives back an array of mutable references + /// to single elements, while if passed an array of ranges it gives back an array of + /// mutable references to slices. + /// /// For a safe alternative see [`get_many_mut`]. /// /// # Safety @@ -4487,30 +4493,49 @@ impl [T] { /// *b *= 100; /// } /// assert_eq!(x, &[10, 2, 400]); + /// + /// unsafe { + /// let [a, b] = x.get_many_unchecked_mut([0..1, 1..3]); + /// a[0] = 8; + /// b[0] = 88; + /// b[1] = 888; + /// } + /// assert_eq!(x, &[8, 88, 888]); + /// + /// unsafe { + /// let [a, b] = x.get_many_unchecked_mut([1..=2, 0..=0]); + /// a[0] = 11; + /// a[1] = 111; + /// b[0] = 1; + /// } + /// assert_eq!(x, &[1, 11, 111]); /// ``` /// /// [`get_many_mut`]: slice::get_many_mut /// [undefined behavior]: https://doc.rust-lang.org/reference/behavior-considered-undefined.html #[unstable(feature = "get_many_mut", issue = "104642")] #[inline] - pub unsafe fn get_many_unchecked_mut( + pub unsafe fn get_many_unchecked_mut( &mut self, - indices: [usize; N], - ) -> [&mut T; N] { + indices: [I; N], + ) -> [&mut I::Output; N] + where + I: GetManyMutIndex + SliceIndex, + { // NB: This implementation is written as it is because any variation of // `indices.map(|i| self.get_unchecked_mut(i))` would make miri unhappy, // or generate worse code otherwise. This is also why we need to go // through a raw pointer here. let slice: *mut [T] = self; - let mut arr: mem::MaybeUninit<[&mut T; N]> = mem::MaybeUninit::uninit(); + let mut arr: mem::MaybeUninit<[&mut I::Output; N]> = mem::MaybeUninit::uninit(); let arr_ptr = arr.as_mut_ptr(); // SAFETY: We expect `indices` to contain disjunct values that are // in bounds of `self`. unsafe { for i in 0..N { - let idx = *indices.get_unchecked(i); - *(*arr_ptr).get_unchecked_mut(i) = &mut *slice.get_unchecked_mut(idx); + let idx = indices.get_unchecked(i).clone(); + arr_ptr.cast::<&mut I::Output>().add(i).write(&mut *slice.get_unchecked_mut(idx)); } arr.assume_init() } @@ -4518,8 +4543,18 @@ impl [T] { /// Returns mutable references to many indices at once. /// - /// Returns an error if any index is out-of-bounds, or if the same index was - /// passed more than once. + /// An index can be either a `usize`, a [`Range`] or a [`RangeInclusive`]. Note + /// that this method takes an array, so all indices must be of the same type. + /// If passed an array of `usize`s this method gives back an array of mutable references + /// to single elements, while if passed an array of ranges it gives back an array of + /// mutable references to slices. + /// + /// Returns an error if any index is out-of-bounds, or if there are overlapping indices. + /// An empty range is not considered to overlap if it is located at the beginning or at + /// the end of another range, but is considered to overlap if it is located in the middle. + /// + /// This method does a O(n^2) check to check that there are no overlapping indices, so be careful + /// when passing many indices. /// /// # Examples /// @@ -4532,13 +4567,30 @@ impl [T] { /// *b = 612; /// } /// assert_eq!(v, &[413, 2, 612]); + /// + /// if let Ok([a, b]) = v.get_many_mut([0..1, 1..3]) { + /// a[0] = 8; + /// b[0] = 88; + /// b[1] = 888; + /// } + /// assert_eq!(v, &[8, 88, 888]); + /// + /// if let Ok([a, b]) = v.get_many_mut([1..=2, 0..=0]) { + /// a[0] = 11; + /// a[1] = 111; + /// b[0] = 1; + /// } + /// assert_eq!(v, &[1, 11, 111]); /// ``` #[unstable(feature = "get_many_mut", issue = "104642")] #[inline] - pub fn get_many_mut( + pub fn get_many_mut( &mut self, - indices: [usize; N], - ) -> Result<[&mut T; N], GetManyMutError> { + indices: [I; N], + ) -> Result<[&mut I::Output; N], GetManyMutError> + where + I: GetManyMutIndex + SliceIndex, + { if !get_many_check_valid(&indices, self.len()) { return Err(GetManyMutError { _private: () }); } @@ -4883,14 +4935,15 @@ impl SlicePattern for [T; N] { /// /// This will do `binomial(N + 1, 2) = N * (N + 1) / 2 = 0, 1, 3, 6, 10, ..` /// comparison operations. -fn get_many_check_valid(indices: &[usize; N], len: usize) -> bool { +#[inline] +fn get_many_check_valid(indices: &[I; N], len: usize) -> bool { // NB: The optimizer should inline the loops into a sequence // of instructions without additional branching. let mut valid = true; - for (i, &idx) in indices.iter().enumerate() { - valid &= idx < len; - for &idx2 in &indices[..i] { - valid &= idx != idx2; + for (i, idx) in indices.iter().enumerate() { + valid &= idx.is_in_bounds(len); + for idx2 in &indices[..i] { + valid &= !idx.is_overlapping(idx2); } } valid @@ -4914,6 +4967,7 @@ fn get_many_check_valid(indices: &[usize; N], len: usize) -> boo #[unstable(feature = "get_many_mut", issue = "104642")] // NB: The N here is there to be forward-compatible with adding more details // to the error type at a later point +#[derive(Clone, PartialEq, Eq)] pub struct GetManyMutError { _private: (), } @@ -4931,3 +4985,111 @@ impl fmt::Display for GetManyMutError { fmt::Display::fmt("an index is out of bounds or appeared multiple times in the array", f) } } + +mod private_get_many_mut_index { + use super::{Range, RangeInclusive, range}; + + #[unstable(feature = "get_many_mut_helpers", issue = "none")] + pub trait Sealed {} + + #[unstable(feature = "get_many_mut_helpers", issue = "none")] + impl Sealed for usize {} + #[unstable(feature = "get_many_mut_helpers", issue = "none")] + impl Sealed for Range {} + #[unstable(feature = "get_many_mut_helpers", issue = "none")] + impl Sealed for RangeInclusive {} + #[unstable(feature = "get_many_mut_helpers", issue = "none")] + impl Sealed for range::Range {} + #[unstable(feature = "get_many_mut_helpers", issue = "none")] + impl Sealed for range::RangeInclusive {} +} + +/// A helper trait for `<[T]>::get_many_mut()`. +/// +/// # Safety +/// +/// If `is_in_bounds()` returns `true` and `is_overlapping()` returns `false`, +/// it must be safe to index the slice with the indices. +#[unstable(feature = "get_many_mut_helpers", issue = "none")] +pub unsafe trait GetManyMutIndex: Clone + private_get_many_mut_index::Sealed { + /// Returns `true` if `self` is in bounds for `len` slice elements. + #[unstable(feature = "get_many_mut_helpers", issue = "none")] + fn is_in_bounds(&self, len: usize) -> bool; + + /// Returns `true` if `self` overlaps with `other`. + /// + /// Note that we don't consider zero-length ranges to overlap at the beginning or the end, + /// but do consider them to overlap in the middle. + #[unstable(feature = "get_many_mut_helpers", issue = "none")] + fn is_overlapping(&self, other: &Self) -> bool; +} + +#[unstable(feature = "get_many_mut_helpers", issue = "none")] +// SAFETY: We implement `is_in_bounds()` and `is_overlapping()` correctly. +unsafe impl GetManyMutIndex for usize { + #[inline] + fn is_in_bounds(&self, len: usize) -> bool { + *self < len + } + + #[inline] + fn is_overlapping(&self, other: &Self) -> bool { + *self == *other + } +} + +#[unstable(feature = "get_many_mut_helpers", issue = "none")] +// SAFETY: We implement `is_in_bounds()` and `is_overlapping()` correctly. +unsafe impl GetManyMutIndex for Range { + #[inline] + fn is_in_bounds(&self, len: usize) -> bool { + (self.start <= self.end) & (self.end <= len) + } + + #[inline] + fn is_overlapping(&self, other: &Self) -> bool { + (self.start < other.end) & (other.start < self.end) + } +} + +#[unstable(feature = "get_many_mut_helpers", issue = "none")] +// SAFETY: We implement `is_in_bounds()` and `is_overlapping()` correctly. +unsafe impl GetManyMutIndex for RangeInclusive { + #[inline] + fn is_in_bounds(&self, len: usize) -> bool { + (self.start <= self.end) & (self.end < len) + } + + #[inline] + fn is_overlapping(&self, other: &Self) -> bool { + (self.start <= other.end) & (other.start <= self.end) + } +} + +#[unstable(feature = "get_many_mut_helpers", issue = "none")] +// SAFETY: We implement `is_in_bounds()` and `is_overlapping()` correctly. +unsafe impl GetManyMutIndex for range::Range { + #[inline] + fn is_in_bounds(&self, len: usize) -> bool { + Range::from(*self).is_in_bounds(len) + } + + #[inline] + fn is_overlapping(&self, other: &Self) -> bool { + Range::from(*self).is_overlapping(&Range::from(*other)) + } +} + +#[unstable(feature = "get_many_mut_helpers", issue = "none")] +// SAFETY: We implement `is_in_bounds()` and `is_overlapping()` correctly. +unsafe impl GetManyMutIndex for range::RangeInclusive { + #[inline] + fn is_in_bounds(&self, len: usize) -> bool { + RangeInclusive::from(*self).is_in_bounds(len) + } + + #[inline] + fn is_overlapping(&self, other: &Self) -> bool { + RangeInclusive::from(*self).is_overlapping(&RangeInclusive::from(*other)) + } +} diff --git a/library/core/tests/slice.rs b/library/core/tests/slice.rs index 9ae2bcc852649..510dd4967c961 100644 --- a/library/core/tests/slice.rs +++ b/library/core/tests/slice.rs @@ -2,6 +2,7 @@ use core::cell::Cell; use core::cmp::Ordering; use core::mem::MaybeUninit; use core::num::NonZero; +use core::ops::{Range, RangeInclusive}; use core::slice; #[test] @@ -2553,6 +2554,14 @@ fn test_get_many_mut_normal_2() { *a += 10; *b += 100; assert_eq!(v, vec![101, 2, 3, 14, 5]); + + let [a, b] = v.get_many_mut([0..=1, 2..=2]).unwrap(); + assert_eq!(a, &mut [101, 2][..]); + assert_eq!(b, &mut [3][..]); + a[0] += 10; + a[1] += 20; + b[0] += 100; + assert_eq!(v, vec![111, 22, 103, 14, 5]); } #[test] @@ -2563,12 +2572,23 @@ fn test_get_many_mut_normal_3() { *b += 100; *c += 1000; assert_eq!(v, vec![11, 2, 1003, 4, 105]); + + let [a, b, c] = v.get_many_mut([0..1, 4..5, 1..4]).unwrap(); + assert_eq!(a, &mut [11][..]); + assert_eq!(b, &mut [105][..]); + assert_eq!(c, &mut [2, 1003, 4][..]); + a[0] += 10; + b[0] += 100; + c[0] += 1000; + assert_eq!(v, vec![21, 1002, 1003, 4, 205]); } #[test] fn test_get_many_mut_empty() { let mut v = vec![1, 2, 3, 4, 5]; - let [] = v.get_many_mut([]).unwrap(); + let [] = v.get_many_mut::([]).unwrap(); + let [] = v.get_many_mut::, 0>([]).unwrap(); + let [] = v.get_many_mut::, 0>([]).unwrap(); assert_eq!(v, vec![1, 2, 3, 4, 5]); } @@ -2606,6 +2626,54 @@ fn test_get_many_mut_duplicate() { assert!(v.get_many_mut([1, 3, 3, 4]).is_err()); } +#[test] +fn test_get_many_mut_range_oob() { + let mut v = vec![1, 2, 3, 4, 5]; + assert!(v.get_many_mut([0..6]).is_err()); + assert!(v.get_many_mut([5..6]).is_err()); + assert!(v.get_many_mut([6..6]).is_err()); + assert!(v.get_many_mut([0..=5]).is_err()); + assert!(v.get_many_mut([0..=6]).is_err()); + assert!(v.get_many_mut([5..=5]).is_err()); +} + +#[test] +fn test_get_many_mut_range_overlapping() { + let mut v = vec![1, 2, 3, 4, 5]; + assert!(v.get_many_mut([0..1, 0..2]).is_err()); + assert!(v.get_many_mut([0..1, 1..2, 0..1]).is_err()); + assert!(v.get_many_mut([0..3, 1..1]).is_err()); + assert!(v.get_many_mut([0..3, 1..2]).is_err()); + assert!(v.get_many_mut([0..=0, 2..=2, 0..=1]).is_err()); + assert!(v.get_many_mut([0..=4, 0..=0]).is_err()); + assert!(v.get_many_mut([4..=4, 0..=0, 3..=4]).is_err()); +} + +#[test] +fn test_get_many_mut_range_empty_at_edge() { + let mut v = vec![1, 2, 3, 4, 5]; + assert_eq!( + v.get_many_mut([0..0, 0..5, 5..5]), + Ok([&mut [][..], &mut [1, 2, 3, 4, 5], &mut []]), + ); + assert_eq!( + v.get_many_mut([0..0, 0..1, 1..1, 1..2, 2..2, 2..3, 3..3, 3..4, 4..4, 4..5, 5..5]), + Ok([ + &mut [][..], + &mut [1], + &mut [], + &mut [2], + &mut [], + &mut [3], + &mut [], + &mut [4], + &mut [], + &mut [5], + &mut [], + ]), + ); +} + #[test] fn test_slice_from_raw_parts_in_const() { static FANCY: i32 = 4;