diff --git a/library/alloc/src/vec/extract_if.rs b/library/alloc/src/vec/extract_if.rs index 72d51e8904488..f722d89d4005a 100644 --- a/library/alloc/src/vec/extract_if.rs +++ b/library/alloc/src/vec/extract_if.rs @@ -1,4 +1,4 @@ -use core::{ptr, slice}; +use core::{ops::{Range, RangeBounds}, ptr, slice}; use super::Vec; use crate::alloc::{Allocator, Global}; @@ -14,7 +14,7 @@ use crate::alloc::{Allocator, Global}; /// #![feature(extract_if)] /// /// let mut v = vec![0, 1, 2]; -/// let iter: std::vec::ExtractIf<'_, _, _> = v.extract_if(|x| *x % 2 == 0); +/// let iter: std::vec::ExtractIf<'_, _, _> = v.extract_if(.., |x| *x % 2 == 0); /// ``` #[unstable(feature = "extract_if", reason = "recently added", issue = "43244")] #[derive(Debug)] @@ -30,6 +30,8 @@ pub struct ExtractIf< pub(super) vec: &'a mut Vec, /// The index of the item that will be inspected by the next call to `next`. pub(super) idx: usize, + /// Elements at and beyond this point will be retained. Must be equal or smaller than `old_len`. + pub(super) end: usize, /// The number of items that have been drained (removed) thus far. pub(super) del: usize, /// The original length of `vec` prior to draining. @@ -38,10 +40,21 @@ pub struct ExtractIf< pub(super) pred: F, } -impl ExtractIf<'_, T, F, A> +impl<'a, T, F, A: Allocator> ExtractIf<'a, T, F, A> where F: FnMut(&mut T) -> bool, { + pub(super) fn new>(vec: &'a mut Vec, pred: F, range: R) -> Self { + let old_len = vec.len(); + let Range { start, end } = slice::range(range, ..old_len); + + // Guard against the vec getting leaked (leak amplification) + unsafe { + vec.set_len(0); + } + ExtractIf { vec, idx: start, del: 0, end, old_len, pred } + } + /// Returns a reference to the underlying allocator. #[unstable(feature = "allocator_api", issue = "32838")] #[inline] @@ -59,7 +72,7 @@ where fn next(&mut self) -> Option { unsafe { - while self.idx < self.old_len { + while self.idx < self.end { let i = self.idx; let v = slice::from_raw_parts_mut(self.vec.as_mut_ptr(), self.old_len); let drained = (self.pred)(&mut v[i]); @@ -82,7 +95,7 @@ where } fn size_hint(&self) -> (usize, Option) { - (0, Some(self.old_len - self.idx)) + (0, Some(self.end - self.idx)) } } diff --git a/library/alloc/src/vec/mod.rs b/library/alloc/src/vec/mod.rs index 07a1bd4932138..fa23d540961e6 100644 --- a/library/alloc/src/vec/mod.rs +++ b/library/alloc/src/vec/mod.rs @@ -3610,12 +3610,15 @@ impl Vec { Splice { drain: self.drain(range), replace_with: replace_with.into_iter() } } - /// Creates an iterator which uses a closure to determine if an element should be removed. + /// Creates an iterator which uses a closure to determine if element in the range should be removed. /// /// If the closure returns true, then the element is removed and yielded. /// If the closure returns false, the element will remain in the vector and will not be yielded /// by the iterator. /// + /// Only elements that fall in the provided range are considered for extraction, but any elements + /// after the range will still have to be moved if any element has been extracted. + /// /// If the returned `ExtractIf` is not exhausted, e.g. because it is dropped without iterating /// or the iteration short-circuits, then the remaining elements will be retained. /// Use [`retain`] with a negated predicate if you do not need the returned iterator. @@ -3625,10 +3628,12 @@ impl Vec { /// Using this method is equivalent to the following code: /// /// ``` + /// # use std::cmp::min; /// # let some_predicate = |x: &mut i32| { *x == 2 || *x == 3 || *x == 6 }; /// # let mut vec = vec![1, 2, 3, 4, 5, 6]; - /// let mut i = 0; - /// while i < vec.len() { + /// # let range = 1..4; + /// let mut i = range.start; + /// while i < min(vec.len(), range.end) { /// if some_predicate(&mut vec[i]) { /// let val = vec.remove(i); /// // your code here @@ -3643,8 +3648,12 @@ impl Vec { /// But `extract_if` is easier to use. `extract_if` is also more efficient, /// because it can backshift the elements of the array in bulk. /// - /// Note that `extract_if` also lets you mutate every element in the filter closure, - /// regardless of whether you choose to keep or remove it. + /// Note that `extract_if` also lets you mutate the elements passed to the filter closure, + /// regardless of whether you choose to keep or remove them. + /// + /// # Panics + /// + /// If `range`` is out of bounds. /// /// # Examples /// @@ -3654,25 +3663,29 @@ impl Vec { /// #![feature(extract_if)] /// let mut numbers = vec![1, 2, 3, 4, 5, 6, 8, 9, 11, 13, 14, 15]; /// - /// let evens = numbers.extract_if(|x| *x % 2 == 0).collect::>(); + /// let evens = numbers.extract_if(.., |x| *x % 2 == 0).collect::>(); /// let odds = numbers; /// /// assert_eq!(evens, vec![2, 4, 6, 8, 14]); /// assert_eq!(odds, vec![1, 3, 5, 9, 11, 13, 15]); /// ``` + /// + /// Using the range argument to only process a part of the vector: + /// + /// ``` + /// #![feature(extract_if)] + /// let mut items = vec![0, 0, 0, 0, 0, 0, 0, 1, 2, 1, 2, 1, 2]; + /// let ones = items.extract_if(7.., |x| *x == 1).collect::>(); + /// assert_eq!(items, vec![0, 0, 0, 0, 0, 0, 0, 2, 2, 2]); + /// assert_eq!(ones.len(), 3); + /// ``` #[unstable(feature = "extract_if", reason = "recently added", issue = "43244")] - pub fn extract_if(&mut self, filter: F) -> ExtractIf<'_, T, F, A> + pub fn extract_if(&mut self, range: R, filter: F) -> ExtractIf<'_, T, F, A> where F: FnMut(&mut T) -> bool, + R: RangeBounds, { - let old_len = self.len(); - - // Guard against us getting leaked (leak amplification) - unsafe { - self.set_len(0); - } - - ExtractIf { vec: self, idx: 0, del: 0, old_len, pred: filter } + ExtractIf::new(self, filter, range) } } diff --git a/library/alloc/tests/vec.rs b/library/alloc/tests/vec.rs index 0f27fdff3e182..75d36194c534a 100644 --- a/library/alloc/tests/vec.rs +++ b/library/alloc/tests/vec.rs @@ -1414,7 +1414,7 @@ fn extract_if_empty() { let mut vec: Vec = vec![]; { - let mut iter = vec.extract_if(|_| true); + let mut iter = vec.extract_if(.., |_| true); assert_eq!(iter.size_hint(), (0, Some(0))); assert_eq!(iter.next(), None); assert_eq!(iter.size_hint(), (0, Some(0))); @@ -1431,7 +1431,7 @@ fn extract_if_zst() { let initial_len = vec.len(); let mut count = 0; { - let mut iter = vec.extract_if(|_| true); + let mut iter = vec.extract_if(.., |_| true); assert_eq!(iter.size_hint(), (0, Some(initial_len))); while let Some(_) = iter.next() { count += 1; @@ -1454,7 +1454,7 @@ fn extract_if_false() { let initial_len = vec.len(); let mut count = 0; { - let mut iter = vec.extract_if(|_| false); + let mut iter = vec.extract_if(.., |_| false); assert_eq!(iter.size_hint(), (0, Some(initial_len))); for _ in iter.by_ref() { count += 1; @@ -1476,7 +1476,7 @@ fn extract_if_true() { let initial_len = vec.len(); let mut count = 0; { - let mut iter = vec.extract_if(|_| true); + let mut iter = vec.extract_if(.., |_| true); assert_eq!(iter.size_hint(), (0, Some(initial_len))); while let Some(_) = iter.next() { count += 1; @@ -1492,6 +1492,31 @@ fn extract_if_true() { assert_eq!(vec, vec![]); } +#[test] +fn extract_if_ranges() { + let mut vec = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + + let mut count = 0; + let it = vec.extract_if(1..=3, |_| { + count += 1; + true + }); + assert_eq!(it.collect::>(), vec![1, 2, 3]); + assert_eq!(vec, vec![0, 4, 5, 6, 7, 8, 9, 10]); + assert_eq!(count, 3); + + let it = vec.extract_if(1..=3, |_| false); + assert_eq!(it.collect::>(), vec![]); + assert_eq!(vec, vec![0, 4, 5, 6, 7, 8, 9, 10]); +} + +#[test] +#[should_panic] +fn extraxt_if_out_of_bounds() { + let mut vec = vec![0, 1]; + let _ = vec.extract_if(5.., |_| true).for_each(drop); +} + #[test] fn extract_if_complex() { { @@ -1501,7 +1526,7 @@ fn extract_if_complex() { 39, ]; - let removed = vec.extract_if(|x| *x % 2 == 0).collect::>(); + let removed = vec.extract_if(.., |x| *x % 2 == 0).collect::>(); assert_eq!(removed.len(), 10); assert_eq!(removed, vec![2, 4, 6, 18, 20, 22, 24, 26, 34, 36]); @@ -1515,7 +1540,7 @@ fn extract_if_complex() { 2, 4, 6, 7, 9, 11, 13, 15, 17, 18, 20, 22, 24, 26, 27, 29, 31, 33, 34, 35, 36, 37, 39, ]; - let removed = vec.extract_if(|x| *x % 2 == 0).collect::>(); + let removed = vec.extract_if(.., |x| *x % 2 == 0).collect::>(); assert_eq!(removed.len(), 10); assert_eq!(removed, vec![2, 4, 6, 18, 20, 22, 24, 26, 34, 36]); @@ -1528,7 +1553,7 @@ fn extract_if_complex() { let mut vec = vec![2, 4, 6, 7, 9, 11, 13, 15, 17, 18, 20, 22, 24, 26, 27, 29, 31, 33, 34, 35, 36]; - let removed = vec.extract_if(|x| *x % 2 == 0).collect::>(); + let removed = vec.extract_if(.., |x| *x % 2 == 0).collect::>(); assert_eq!(removed.len(), 10); assert_eq!(removed, vec![2, 4, 6, 18, 20, 22, 24, 26, 34, 36]); @@ -1540,7 +1565,7 @@ fn extract_if_complex() { // [xxxxxxxxxx+++++++++++] let mut vec = vec![2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19]; - let removed = vec.extract_if(|x| *x % 2 == 0).collect::>(); + let removed = vec.extract_if(.., |x| *x % 2 == 0).collect::>(); assert_eq!(removed.len(), 10); assert_eq!(removed, vec![2, 4, 6, 8, 10, 12, 14, 16, 18, 20]); @@ -1552,7 +1577,7 @@ fn extract_if_complex() { // [+++++++++++xxxxxxxxxx] let mut vec = vec![1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20]; - let removed = vec.extract_if(|x| *x % 2 == 0).collect::>(); + let removed = vec.extract_if(.., |x| *x % 2 == 0).collect::>(); assert_eq!(removed.len(), 10); assert_eq!(removed, vec![2, 4, 6, 8, 10, 12, 14, 16, 18, 20]); @@ -1600,7 +1625,7 @@ fn extract_if_consumed_panic() { } c.index < 6 }; - let drain = data.extract_if(filter); + let drain = data.extract_if(.., filter); // NOTE: The ExtractIf is explicitly consumed drain.for_each(drop); @@ -1653,7 +1678,7 @@ fn extract_if_unconsumed_panic() { } c.index < 6 }; - let _drain = data.extract_if(filter); + let _drain = data.extract_if(.., filter); // NOTE: The ExtractIf is dropped without being consumed }); @@ -1669,7 +1694,7 @@ fn extract_if_unconsumed_panic() { #[test] fn extract_if_unconsumed() { let mut vec = vec![1, 2, 3, 4]; - let drain = vec.extract_if(|&mut x| x % 2 != 0); + let drain = vec.extract_if(.., |&mut x| x % 2 != 0); drop(drain); assert_eq!(vec, [1, 2, 3, 4]); }