Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ pub use into_iter::IntoIter;
mod drain;
pub use drain::Drain;

mod retain;

#[cfg(feature = "serde")]
mod serde_impl;
#[cfg(all(test, feature = "serde"))]
Expand Down Expand Up @@ -744,6 +746,42 @@ impl<T, const N: usize> WordVec<T, N> {
}
}

/// Retains only the elements specified by the predicate.
///
/// In other words, remove all elements `e` for which `predicate(&mut e)` returns false.
/// This method operates in place, visiting each element exactly once in the original order,
/// and preserves the order of the retained elements.
pub fn retain<F>(&mut self, mut should_retain: F)
where
F: FnMut(&mut T) -> bool,
{
let mut retain = retain::Retain::new(self);
loop {
if let retain::NextResult::Exhausted = retain.next(&mut should_retain) {
break;
}
}
}

/// Creates an iterator which uses a closure to determine if an element should be removed.
///
/// If the closure returns `true`, the element is removed from the vector
/// and yielded. If the closure returns `false`, or panics, the element
/// remains in the vector and will not be yielded.
///
/// If the returned iterator is not exhausted, e.g. because it is dropped without iterating
/// or the iteration short-circuits, then the remaining elements will be retained.
/// Use [`retain`] with a negated predicate if you do not need the returned iterator.
///
/// [`retain`]: Self::retain
#[doc(alias = "drain_filter")]
pub fn extract_if(
&mut self,
should_remove: impl FnMut(&mut T) -> bool,
) -> impl Iterator<Item = T> {
retain::ExtractIf { retain: retain::Retain::new(self), should_remove }
}

/// Resizes the vector so that its length is equal to `len`.
///
/// If `len` is greater than the current length,
Expand Down
116 changes: 116 additions & 0 deletions src/retain.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use core::mem::{self, MaybeUninit};

pub(super) struct Retain<'a, T> {
set_len: super::LengthSetter<'a>,
init_slice: &'a mut [MaybeUninit<T>],
read_len: usize,
written_len: usize,
}

impl<'a, T> Retain<'a, T> {
pub(super) fn new<const N: usize>(vec: &'a mut super::WordVec<T, N>) -> Self {
let (capacity_slice, old_len, mut set_len) = vec.as_uninit_slice_with_length_setter();

// SAFETY: length 0 is always safe
unsafe { set_len.set_len(0) };

Self { set_len, init_slice: &mut capacity_slice[..old_len], read_len: 0, written_len: 0 }
}
}

impl<T> Drop for Retain<'_, T> {
fn drop(&mut self) {
// Shift all unvisited elements forward.
let data_len = self.init_slice.len();
let data_ptr = self.init_slice.as_mut_ptr();
let moved_len = data_len - self.read_len;
unsafe {
core::ptr::copy(data_ptr.add(self.read_len), data_ptr.add(self.written_len), moved_len);
}

// SAFETY: ensured by target_len setters
unsafe {
self.set_len.set_len(self.written_len + moved_len);
}
}
}

impl<T> Retain<'_, T> {
pub(super) fn next(&mut self, should_retain: impl FnOnce(&mut T) -> bool) -> NextResult<T> {
let Some(item_uninit) = self.init_slice.get_mut(self.read_len) else {
return NextResult::Exhausted;
};

// SAFETY: init_slice[read_len..] are always initialized
let item_mut = unsafe { item_uninit.assume_init_mut() };

// If `should_retain` panics, `item` is no longer referenced,
// so the state of this struct is just as if the current `next` call never happened.
// Thus the destructor will work as expected.
let retain = should_retain(item_mut);

if retain {
let src_index = self.read_len;
let dest_index = self.written_len;

// init_slice[read_len] is moved to init_slice[written_len] after this step.
// If read_len == written_len, this just retains the item in place
// and has no safety implications.
// If read_len != written_len, by contract read_len > written_len,
// so init_slice[written_len..read_len] is uninitialized,
// and after this operation, init_slice[written_len] becomes initialized
// while init_slice[read_len] becomes uninitialized.
self.read_len += 1;
self.written_len += 1;

if src_index != dest_index {
unsafe {
// SAFETY: read_len != written_len checked in condition
let [src, dest] =
self.init_slice.get_disjoint_unchecked_mut([src_index, dest_index]);
dest.write(mem::replace(src, MaybeUninit::uninit()).assume_init());
}
}
// If src_index == dest_index, this move would be a no-op

NextResult::Retained
} else {
// this never overflows because read_len < init_slice.len() <= usize::MAX
self.read_len += 1;

// SAFETY: item can be safely moved out as an initialized value.
let item = mem::replace(item_uninit, MaybeUninit::uninit());
let item = unsafe { item.assume_init() };

NextResult::Removed(item)
}
}
}

pub(super) enum NextResult<T> {
Exhausted,
Retained,
Removed(T),
}

pub(super) struct ExtractIf<'a, T, F> {
pub(super) retain: Retain<'a, T>,
pub(super) should_remove: F,
}

impl<T, F> Iterator for ExtractIf<'_, T, F>
where
F: FnMut(&mut T) -> bool,
{
type Item = T;

fn next(&mut self) -> Option<Self::Item> {
loop {
return match self.retain.next(|elem| !(self.should_remove)(elem)) {
NextResult::Exhausted => None,
NextResult::Retained => continue,
NextResult::Removed(item) => Some(item),
};
}
}
}
157 changes: 157 additions & 0 deletions src/tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use core::cell::Cell;
use core::mem;
use core::panic::AssertUnwindSafe;

use crate::WordVec;

Expand Down Expand Up @@ -539,6 +541,161 @@ fn test_drain_long_long_short_early_drop_back() {
assert_eq!(wv.as_slice(), &[0, 1, 2, 6, 7]);
}

fn test_retain_with<const N: usize>(
initial_len: usize,
mut predicate: impl FnMut(&str) -> bool,
expect_retain_drops: usize,
expect_after_retain: &[&str],
retain_fn: impl FnOnce(&mut WordVec<AssertDrop<'_>, N>, &mut dyn FnMut(&mut AssertDrop<'_>) -> bool),
) {
let counter = &Cell::new(0);
let mut wv = (0..initial_len)
.map(|i| AssertDrop { string: i.to_string(), counter })
.collect::<WordVec<_, N>>();
retain_fn(&mut wv, &mut |d| predicate(d.string.as_str()));
assert_eq!(counter.get(), expect_retain_drops);
assert_eq!(wv.iter().map(|d| d.string.as_str()).collect::<Vec<_>>(), expect_after_retain);
drop(wv);
assert_eq!(counter.get(), initial_len);
}

fn test_retain<const N: usize>(
initial_len: usize,
mut predicate: impl FnMut(&str) -> bool,
expect_retain_drops: usize,
expect_after_retain: &[&str],
) {
test_retain_with::<N>(
initial_len,
&mut predicate,
expect_retain_drops,
expect_after_retain,
|wv, predicate| wv.retain(|d| predicate(d)),
);
}

fn test_extract_if<const N: usize>(
initial_len: usize,
mut predicate: impl FnMut(&str) -> bool,
expect_retain_drops: usize,
expect_after_retain: &[&str],
) {
test_retain_with::<N>(
initial_len,
&mut predicate,
expect_retain_drops,
expect_after_retain,
|wv, predicate| wv.extract_if(|d| !predicate(d)).for_each(drop),
);
}

#[test]
fn test_retain_empty() { test_retain::<4>(0, |_| unreachable!(), 0, &[]); }

#[test]
fn test_retain_everything() { test_retain::<4>(3, |_| true, 0, &["0", "1", "2"]); }

#[test]
fn test_retain_nothing() { test_retain::<4>(3, |_| false, 3, &[]); }

#[test]
fn test_retain_tft() {
let mut retain_seq = [true, false, true].into_iter();
test_retain::<4>(3, |_| retain_seq.next().unwrap(), 1, &["0", "2"]);
}

#[test]
fn test_retain_ftf() {
let mut retain_seq = [false, true, false].into_iter();
test_retain::<4>(3, |_| retain_seq.next().unwrap(), 2, &["1"]);
}

fn test_retain_panic(retain_prev: bool, expect_retain_drops: usize, expect_after_retain: &[&str]) {
extern crate std;

let counter = &Cell::new(0);
let mut wv =
(0..3).map(|i| AssertDrop { string: i.to_string(), counter }).collect::<WordVec<_, 4>>();

_ = std::panic::catch_unwind({
let mut wv = AssertUnwindSafe(&mut wv);
move || {
let mut next_index = 0;
wv.retain(|_| {
let index = next_index;
next_index += 1;

#[expect(clippy::manual_assert, reason = "clarity")]
if index == 1 {
panic!("intentional panic");
}

retain_prev
});
}
});

assert_eq!(counter.get(), expect_retain_drops);
assert_eq!(wv.iter().map(|d| d.string.as_str()).collect::<Vec<_>>(), expect_after_retain);
drop(wv);
assert_eq!(counter.get(), 3);
}

#[test]
fn test_retain_shifted_panic() { test_retain_panic(false, 1, &["1", "2"]); }

#[test]
fn test_retain_unshifted_panic() { test_retain_panic(true, 0, &["0", "1", "2"]); }

#[test]
fn test_extract_if_empty() { test_extract_if::<4>(0, |_| unreachable!(), 0, &[]); }

#[test]
fn test_extract_if_everything() { test_extract_if::<4>(3, |_| true, 0, &["0", "1", "2"]); }

#[test]
fn test_extract_if_nothing() { test_extract_if::<4>(3, |_| false, 3, &[]); }

#[test]
fn test_extract_if_tft() {
let mut extract_if_seq = [true, false, true].into_iter();
test_extract_if::<4>(3, |_| extract_if_seq.next().unwrap(), 1, &["0", "2"]);
}

#[test]
fn test_extract_if_ftf() {
let mut extract_if_seq = [false, true, false].into_iter();
test_extract_if::<4>(3, |_| extract_if_seq.next().unwrap(), 2, &["1"]);
}

fn test_extract_if_drop(
retain_first: bool,
expect_extract_result: &str,
expect_retain_drops: usize,
expect_after_retain: &[&str],
) {
let counter = &Cell::new(0);
let mut wv =
(0..3).map(|i| AssertDrop { string: i.to_string(), counter }).collect::<WordVec<_, 4>>();

{
let mut retain = retain_first;
let mut iter = wv.extract_if(|_| !mem::replace(&mut retain, false));
assert_eq!(iter.next().unwrap().string, expect_extract_result);
}

assert_eq!(counter.get(), expect_retain_drops);
assert_eq!(wv.iter().map(|d| d.string.as_str()).collect::<Vec<_>>(), expect_after_retain);
drop(wv);
assert_eq!(counter.get(), 3);
}

#[test]
fn test_extract_if_shifted_drop() { test_extract_if_drop(true, "1", 1, &["0", "2"]); }

#[test]
fn test_extract_if_unshifted_drop() { test_extract_if_drop(false, "0", 1, &["1", "2"]); }

fn assert_resize<const N: usize>(
initial_len: usize,
resize_len: usize,
Expand Down
Loading