From 41222b2980eb827e7e4d4e69ee4f8850849efe6f Mon Sep 17 00:00:00 2001 From: SOFe Date: Sat, 1 Nov 2025 12:07:47 +0800 Subject: [PATCH 1/2] feat: add retain method --- src/lib.rs | 19 +++++++++++ src/retain.rs | 94 +++++++++++++++++++++++++++++++++++++++++++++++++++ src/tests.rs | 93 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 206 insertions(+) create mode 100644 src/retain.rs diff --git a/src/lib.rs b/src/lib.rs index c155fec..9704236 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -58,6 +58,8 @@ pub use into_iter::IntoIter; mod drain; pub use drain::Drain; +mod retain; + #[cfg(feature = "serde")] mod serde_impl; #[cfg(all(test, feature = "serde"))] @@ -744,6 +746,23 @@ impl WordVec { } } + /// Retains only the elements specified by the predicate. + /// + /// In other words, remove all elements `e` for which `predicate(&mut e)` returns false. + /// This method operates in place, visiting each element exactly once in the original order, + /// and preserves the order of the retained elements. + pub fn retain(&mut self, mut predicate: F) + where + F: FnMut(&mut T) -> bool, + { + let mut retain = retain::Retain::new(self); + loop { + if let retain::NextResult::Exhausted = retain.next(&mut predicate) { + break; + } + } + } + /// Resizes the vector so that its length is equal to `len`. /// /// If `len` is greater than the current length, diff --git a/src/retain.rs b/src/retain.rs new file mode 100644 index 0000000..24ad111 --- /dev/null +++ b/src/retain.rs @@ -0,0 +1,94 @@ +use core::mem::{self, MaybeUninit}; + +pub(super) struct Retain<'a, T> { + set_len: super::LengthSetter<'a>, + init_slice: &'a mut [MaybeUninit], + read_len: usize, + written_len: usize, +} + +impl<'a, T> Retain<'a, T> { + pub(super) fn new(vec: &'a mut super::WordVec) -> Self { + let (capacity_slice, old_len, mut set_len) = vec.as_uninit_slice_with_length_setter(); + + // SAFETY: length 0 is always safe + unsafe { set_len.set_len(0) }; + + Self { set_len, init_slice: &mut capacity_slice[..old_len], read_len: 0, written_len: 0 } + } +} + +impl Drop for Retain<'_, T> { + fn drop(&mut self) { + // Shift all unvisited elements forward. + let data_len = self.init_slice.len(); + let data_ptr = self.init_slice.as_mut_ptr(); + let moved_len = data_len - self.read_len; + unsafe { + core::ptr::copy(data_ptr.add(self.read_len), data_ptr.add(self.written_len), moved_len); + } + + // SAFETY: ensured by target_len setters + unsafe { + self.set_len.set_len(self.written_len + moved_len); + } + } +} + +impl Retain<'_, T> { + pub(super) fn next(&mut self, should_retain: impl FnOnce(&mut T) -> bool) -> NextResult { + let Some(item_uninit) = self.init_slice.get_mut(self.read_len) else { + return NextResult::Exhausted; + }; + + // SAFETY: init_slice[read_len..] are always initialized + let item_mut = unsafe { item_uninit.assume_init_mut() }; + + // If `should_retain` panics, `item` is no longer referenced, + // so the state of this struct is just as if the current `next` call never happened. + // Thus the destructor will work as expected. + let retain = should_retain(item_mut); + + if retain { + let src_index = self.read_len; + let dest_index = self.written_len; + + // init_slice[read_len] is moved to init_slice[written_len] after this step. + // If read_len == written_len, this just retains the item in place + // and has no safety implications. + // If read_len != written_len, by contract read_len > written_len, + // so init_slice[written_len..read_len] is uninitialized, + // and after this operation, init_slice[written_len] becomes initialized + // while init_slice[read_len] becomes uninitialized. + self.read_len += 1; + self.written_len += 1; + + if src_index != dest_index { + unsafe { + // SAFETY: read_len != written_len checked in condition + let [src, dest] = + self.init_slice.get_disjoint_unchecked_mut([src_index, dest_index]); + dest.write(mem::replace(src, MaybeUninit::uninit()).assume_init()); + } + } + // If src_index == dest_index, this move would be a no-op + + NextResult::Retained + } else { + // this never overflows because read_len < init_slice.len() <= usize::MAX + self.read_len += 1; + + // SAFETY: item can be safely moved out as an initialized value. + let item = mem::replace(item_uninit, MaybeUninit::uninit()); + let item = unsafe { item.assume_init() }; + + NextResult::Removed(item) + } + } +} + +pub(super) enum NextResult { + Exhausted, + Retained, + Removed(T), +} diff --git a/src/tests.rs b/src/tests.rs index 61f08cd..9e3a9b1 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,6 +1,7 @@ use alloc::string::{String, ToString}; use alloc::vec::Vec; use core::cell::Cell; +use core::panic::AssertUnwindSafe; use crate::WordVec; @@ -539,6 +540,98 @@ fn test_drain_long_long_short_early_drop_back() { assert_eq!(wv.as_slice(), &[0, 1, 2, 6, 7]); } +#[test] +fn test_retain_empty() { + let mut wv = WordVec::::new(); + wv.retain(|_| unreachable!()); + assert!(wv.as_slice().is_empty()); +} + +#[test] +fn test_retain_everything() { + let counter = &Cell::new(0); + let mut wv = + (0..3).map(|i| AssertDrop { string: i.to_string(), counter }).collect::>(); + wv.retain(|_| true); + assert_eq!(counter.get(), 0); + assert_eq!(wv.iter().map(|d| d.string.as_str()).collect::>(), &["0", "1", "2"]); + drop(wv); + assert_eq!(counter.get(), 3); +} + +#[test] +fn test_retain_nothing() { + let counter = &Cell::new(0); + let mut wv = + (0..3).map(|i| AssertDrop { string: i.to_string(), counter }).collect::>(); + wv.retain(|_| false); + assert_eq!(counter.get(), 3); + assert!(wv.as_slice().is_empty()); +} + +#[test] +fn test_retain_tft() { + let counter = &Cell::new(0); + let mut wv = + (0..3).map(|i| AssertDrop { string: i.to_string(), counter }).collect::>(); + let mut retain_seq = [true, false, true].into_iter(); + wv.retain(|_| retain_seq.next().unwrap()); + assert_eq!(counter.get(), 1); + assert_eq!(wv.iter().map(|d| d.string.as_str()).collect::>(), &["0", "2"]); + drop(wv); + assert_eq!(counter.get(), 3); +} + +#[test] +fn test_retain_ftf() { + let counter = &Cell::new(0); + let mut wv = + (0..3).map(|i| AssertDrop { string: i.to_string(), counter }).collect::>(); + let mut retain_seq = [false, true, false].into_iter(); + wv.retain(|_| retain_seq.next().unwrap()); + assert_eq!(counter.get(), 2); + assert_eq!(wv.iter().map(|d| d.string.as_str()).collect::>(), &["1"]); + drop(wv); + assert_eq!(counter.get(), 3); +} + +fn test_retain_panic(retain_prev: bool, expect_retain_drops: usize, expect_after_retain: &[&str]) { + extern crate std; + + let counter = &Cell::new(0); + let mut wv = + (0..3).map(|i| AssertDrop { string: i.to_string(), counter }).collect::>(); + + _ = std::panic::catch_unwind({ + let mut wv = AssertUnwindSafe(&mut wv); + move || { + let mut next_index = 0; + wv.retain(|_| { + let index = next_index; + next_index += 1; + + #[expect(clippy::manual_assert, reason = "clarity")] + if index == 1 { + panic!("intentional panic"); + } + + retain_prev + }); + } + }); + + assert_eq!(counter.get(), expect_retain_drops); + assert_eq!(wv.iter().map(|d| d.string.as_str()).collect::>(), expect_after_retain); + drop(wv); + assert_eq!(counter.get(), 3); +} + +#[test] +fn test_retain_shifted_panic() { test_retain_panic(false, 1, &["1", "2"]); } + +#[test] +fn test_retain_unshifted_panic() { test_retain_panic(true, 0, &["0", "1", "2"]); } + fn assert_resize( initial_len: usize, resize_len: usize, From b1ce787c3d19354a84a7abd58904637e836a6a1c Mon Sep 17 00:00:00 2001 From: SOFe Date: Sat, 1 Nov 2025 12:34:16 +0800 Subject: [PATCH 2/2] feat: add extract_if --- src/lib.rs | 23 +++++++- src/retain.rs | 22 ++++++++ src/tests.rs | 142 ++++++++++++++++++++++++++++++++++++-------------- 3 files changed, 146 insertions(+), 41 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 9704236..9bf4f04 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -751,18 +751,37 @@ impl WordVec { /// In other words, remove all elements `e` for which `predicate(&mut e)` returns false. /// This method operates in place, visiting each element exactly once in the original order, /// and preserves the order of the retained elements. - pub fn retain(&mut self, mut predicate: F) + pub fn retain(&mut self, mut should_retain: F) where F: FnMut(&mut T) -> bool, { let mut retain = retain::Retain::new(self); loop { - if let retain::NextResult::Exhausted = retain.next(&mut predicate) { + if let retain::NextResult::Exhausted = retain.next(&mut should_retain) { break; } } } + /// Creates an iterator which uses a closure to determine if an element should be removed. + /// + /// If the closure returns `true`, the element is removed from the vector + /// and yielded. If the closure returns `false`, or panics, the element + /// remains in the vector and will not be yielded. + /// + /// If the returned iterator is not exhausted, e.g. because it is dropped without iterating + /// or the iteration short-circuits, then the remaining elements will be retained. + /// Use [`retain`] with a negated predicate if you do not need the returned iterator. + /// + /// [`retain`]: Self::retain + #[doc(alias = "drain_filter")] + pub fn extract_if( + &mut self, + should_remove: impl FnMut(&mut T) -> bool, + ) -> impl Iterator { + retain::ExtractIf { retain: retain::Retain::new(self), should_remove } + } + /// Resizes the vector so that its length is equal to `len`. /// /// If `len` is greater than the current length, diff --git a/src/retain.rs b/src/retain.rs index 24ad111..d54cf19 100644 --- a/src/retain.rs +++ b/src/retain.rs @@ -92,3 +92,25 @@ pub(super) enum NextResult { Retained, Removed(T), } + +pub(super) struct ExtractIf<'a, T, F> { + pub(super) retain: Retain<'a, T>, + pub(super) should_remove: F, +} + +impl Iterator for ExtractIf<'_, T, F> +where + F: FnMut(&mut T) -> bool, +{ + type Item = T; + + fn next(&mut self) -> Option { + loop { + return match self.retain.next(|elem| !(self.should_remove)(elem)) { + NextResult::Exhausted => None, + NextResult::Retained => continue, + NextResult::Removed(item) => Some(item), + }; + } + } +} diff --git a/src/tests.rs b/src/tests.rs index 9e3a9b1..6b86f16 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,6 +1,7 @@ use alloc::string::{String, ToString}; use alloc::vec::Vec; use core::cell::Cell; +use core::mem; use core::panic::AssertUnwindSafe; use crate::WordVec; @@ -540,59 +541,73 @@ fn test_drain_long_long_short_early_drop_back() { assert_eq!(wv.as_slice(), &[0, 1, 2, 6, 7]); } -#[test] -fn test_retain_empty() { - let mut wv = WordVec::::new(); - wv.retain(|_| unreachable!()); - assert!(wv.as_slice().is_empty()); -} - -#[test] -fn test_retain_everything() { +fn test_retain_with( + initial_len: usize, + mut predicate: impl FnMut(&str) -> bool, + expect_retain_drops: usize, + expect_after_retain: &[&str], + retain_fn: impl FnOnce(&mut WordVec, N>, &mut dyn FnMut(&mut AssertDrop<'_>) -> bool), +) { let counter = &Cell::new(0); - let mut wv = - (0..3).map(|i| AssertDrop { string: i.to_string(), counter }).collect::>(); - wv.retain(|_| true); - assert_eq!(counter.get(), 0); - assert_eq!(wv.iter().map(|d| d.string.as_str()).collect::>(), &["0", "1", "2"]); + let mut wv = (0..initial_len) + .map(|i| AssertDrop { string: i.to_string(), counter }) + .collect::>(); + retain_fn(&mut wv, &mut |d| predicate(d.string.as_str())); + assert_eq!(counter.get(), expect_retain_drops); + assert_eq!(wv.iter().map(|d| d.string.as_str()).collect::>(), expect_after_retain); drop(wv); - assert_eq!(counter.get(), 3); + assert_eq!(counter.get(), initial_len); } -#[test] -fn test_retain_nothing() { - let counter = &Cell::new(0); - let mut wv = - (0..3).map(|i| AssertDrop { string: i.to_string(), counter }).collect::>(); - wv.retain(|_| false); - assert_eq!(counter.get(), 3); - assert!(wv.as_slice().is_empty()); +fn test_retain( + initial_len: usize, + mut predicate: impl FnMut(&str) -> bool, + expect_retain_drops: usize, + expect_after_retain: &[&str], +) { + test_retain_with::( + initial_len, + &mut predicate, + expect_retain_drops, + expect_after_retain, + |wv, predicate| wv.retain(|d| predicate(d)), + ); } +fn test_extract_if( + initial_len: usize, + mut predicate: impl FnMut(&str) -> bool, + expect_retain_drops: usize, + expect_after_retain: &[&str], +) { + test_retain_with::( + initial_len, + &mut predicate, + expect_retain_drops, + expect_after_retain, + |wv, predicate| wv.extract_if(|d| !predicate(d)).for_each(drop), + ); +} + +#[test] +fn test_retain_empty() { test_retain::<4>(0, |_| unreachable!(), 0, &[]); } + +#[test] +fn test_retain_everything() { test_retain::<4>(3, |_| true, 0, &["0", "1", "2"]); } + +#[test] +fn test_retain_nothing() { test_retain::<4>(3, |_| false, 3, &[]); } + #[test] fn test_retain_tft() { - let counter = &Cell::new(0); - let mut wv = - (0..3).map(|i| AssertDrop { string: i.to_string(), counter }).collect::>(); let mut retain_seq = [true, false, true].into_iter(); - wv.retain(|_| retain_seq.next().unwrap()); - assert_eq!(counter.get(), 1); - assert_eq!(wv.iter().map(|d| d.string.as_str()).collect::>(), &["0", "2"]); - drop(wv); - assert_eq!(counter.get(), 3); + test_retain::<4>(3, |_| retain_seq.next().unwrap(), 1, &["0", "2"]); } #[test] fn test_retain_ftf() { - let counter = &Cell::new(0); - let mut wv = - (0..3).map(|i| AssertDrop { string: i.to_string(), counter }).collect::>(); let mut retain_seq = [false, true, false].into_iter(); - wv.retain(|_| retain_seq.next().unwrap()); - assert_eq!(counter.get(), 2); - assert_eq!(wv.iter().map(|d| d.string.as_str()).collect::>(), &["1"]); - drop(wv); - assert_eq!(counter.get(), 3); + test_retain::<4>(3, |_| retain_seq.next().unwrap(), 2, &["1"]); } fn test_retain_panic(retain_prev: bool, expect_retain_drops: usize, expect_after_retain: &[&str]) { @@ -632,6 +647,55 @@ fn test_retain_shifted_panic() { test_retain_panic(false, 1, &["1", "2"]); } #[test] fn test_retain_unshifted_panic() { test_retain_panic(true, 0, &["0", "1", "2"]); } +#[test] +fn test_extract_if_empty() { test_extract_if::<4>(0, |_| unreachable!(), 0, &[]); } + +#[test] +fn test_extract_if_everything() { test_extract_if::<4>(3, |_| true, 0, &["0", "1", "2"]); } + +#[test] +fn test_extract_if_nothing() { test_extract_if::<4>(3, |_| false, 3, &[]); } + +#[test] +fn test_extract_if_tft() { + let mut extract_if_seq = [true, false, true].into_iter(); + test_extract_if::<4>(3, |_| extract_if_seq.next().unwrap(), 1, &["0", "2"]); +} + +#[test] +fn test_extract_if_ftf() { + let mut extract_if_seq = [false, true, false].into_iter(); + test_extract_if::<4>(3, |_| extract_if_seq.next().unwrap(), 2, &["1"]); +} + +fn test_extract_if_drop( + retain_first: bool, + expect_extract_result: &str, + expect_retain_drops: usize, + expect_after_retain: &[&str], +) { + let counter = &Cell::new(0); + let mut wv = + (0..3).map(|i| AssertDrop { string: i.to_string(), counter }).collect::>(); + + { + let mut retain = retain_first; + let mut iter = wv.extract_if(|_| !mem::replace(&mut retain, false)); + assert_eq!(iter.next().unwrap().string, expect_extract_result); + } + + assert_eq!(counter.get(), expect_retain_drops); + assert_eq!(wv.iter().map(|d| d.string.as_str()).collect::>(), expect_after_retain); + drop(wv); + assert_eq!(counter.get(), 3); +} + +#[test] +fn test_extract_if_shifted_drop() { test_extract_if_drop(true, "1", 1, &["0", "2"]); } + +#[test] +fn test_extract_if_unshifted_drop() { test_extract_if_drop(false, "0", 1, &["1", "2"]); } + fn assert_resize( initial_len: usize, resize_len: usize,