diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b4bcfc..481adcb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file. ## Unreleased +### Bug fixes + +* Clean up cancelled broadcast receive futures registered through the internal `WaitSet`. + ## [0.6.3] - 2026-01-21 ### Improvements diff --git a/mea/src/barrier/mod.rs b/mea/src/barrier/mod.rs index 76d12d3..99dc303 100644 --- a/mea/src/barrier/mod.rs +++ b/mea/src/barrier/mod.rs @@ -75,6 +75,7 @@ use std::task::Poll; use crate::internal::Mutex; use crate::internal::WaitSet; +use crate::internal::WaiterId; #[cfg(test)] mod tests; @@ -228,7 +229,11 @@ impl Barrier { if state.arrived == self.n { state.arrived = 0; state.generation += 1; - state.waiters.wake_all(); + let wakers = state.waiters.take_wakers(); + drop(state); + for waker in wakers { + waker.wake(); + } return BarrierWaitResult(true); } @@ -250,7 +255,7 @@ impl Barrier { /// This future will complete when all tasks have reached the barrier point. #[must_use = "futures do nothing unless you `.await` or poll them"] struct BarrierWait<'a> { - idx: Option, + idx: Option, generation: usize, barrier: &'a Barrier, } @@ -277,7 +282,9 @@ impl Future for BarrierWait<'_> { if *generation < state.generation { Poll::Ready(()) } else { - state.waiters.register_waker(idx, cx); + let waker = state.waiters.register_waker(idx, cx); + drop(state); + drop(waker); Poll::Pending } } diff --git a/mea/src/broadcast/overflow/mod.rs b/mea/src/broadcast/overflow/mod.rs index 34b7036..68281ff 100644 --- a/mea/src/broadcast/overflow/mod.rs +++ b/mea/src/broadcast/overflow/mod.rs @@ -73,6 +73,7 @@ use std::task::Poll; use crate::internal::Mutex; use crate::internal::RwLock; use crate::internal::WaitSet; +use crate::internal::WaiterId; #[cfg(test)] mod tests; @@ -218,7 +219,13 @@ impl Drop for Sender { 1 => { // If this is the last sender, we need to wake up the receiver so it can // observe the disconnected state. - self.shared.waiters.lock().wake_all(); + let wakers = { + let mut waiters = self.shared.waiters.lock(); + waiters.take_wakers() + }; + for waker in wakers { + waker.wake(); + } } _ => { // there are still other senders left, do nothing @@ -254,7 +261,13 @@ impl Sender { } // Notify all waiting receivers. - self.shared.waiters.lock().wake_all(); + let wakers = { + let mut waiters = self.shared.waiters.lock(); + waiters.take_wakers() + }; + for waker in wakers { + waker.wake(); + } } /// Creates a new receiver that starts receiving messages from the current tail of the channel. @@ -435,7 +448,22 @@ impl Receiver { struct Recv<'a, T> { receiver: &'a mut Receiver, - index: Option, + index: Option, +} + +impl Drop for Recv<'_, T> { + fn drop(&mut self) { + // Ready paths clear the waiter ID, so only a cancelled pending receive takes this lock. + if self.index.is_none() { + return; + } + + let waker = { + let mut waiters = self.receiver.shared.waiters.lock(); + waiters.remove_waker(&mut self.index) + }; + drop(waker); + } } impl Future for Recv<'_, T> { @@ -446,9 +474,16 @@ impl Future for Recv<'_, T> { loop { match receiver.try_recv() { - Ok(val) => return Poll::Ready(Ok(val)), - Err(TryRecvError::Lagged(n)) => return Poll::Ready(Err(RecvError::Lagged(n))), + Ok(val) => { + *index = None; + return Poll::Ready(Ok(val)); + } + Err(TryRecvError::Lagged(n)) => { + *index = None; + return Poll::Ready(Err(RecvError::Lagged(n))); + } Err(TryRecvError::Disconnected) => { + *index = None; return Poll::Ready(Err(RecvError::Disconnected)); } Err(TryRecvError::Empty) => {} @@ -468,11 +503,14 @@ impl Future for Recv<'_, T> { // Check for Closed // Use Acquire to ensure we see all writes before the sender dropped. if shared.senders.load(Ordering::Acquire) == 0 { + *index = None; return Poll::Ready(Err(RecvError::Disconnected)); } // Register Waker - waiters.register_waker(index, cx); + let waker = waiters.register_waker(index, cx); + drop(waiters); + drop(waker); return Poll::Pending; } } diff --git a/mea/src/broadcast/overflow/tests.rs b/mea/src/broadcast/overflow/tests.rs index 0a1bb9b..78b279e 100644 --- a/mea/src/broadcast/overflow/tests.rs +++ b/mea/src/broadcast/overflow/tests.rs @@ -12,7 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::future::Future; +use std::sync::atomic::Ordering; +use std::task::Context; + use super::*; +use crate::count_waker; #[tokio::test] async fn test_broadcast_basic() { @@ -77,6 +82,22 @@ async fn test_wait_mechanism() { assert_eq!(handle.await.unwrap(), Ok(42)); } +#[tokio::test] +async fn test_recv_cancellation_removes_waiter() { + let (tx, mut rx) = channel::(10); + let (waker, wake_count) = count_waker(); + let mut cx = Context::from_waker(&waker); + let mut recv = Box::pin(rx.recv()); + + assert!(recv.as_mut().poll(&mut cx).is_pending()); + + drop(recv); + tx.send(1); + + assert_eq!(wake_count.load(Ordering::Relaxed), 0); + assert_eq!(rx.try_recv(), Ok(1)); +} + #[tokio::test] async fn test_subscribe() { let (tx, _rx) = channel(10); diff --git a/mea/src/internal/countdown.rs b/mea/src/internal/countdown.rs index 66ecd81..6c4c855 100644 --- a/mea/src/internal/countdown.rs +++ b/mea/src/internal/countdown.rs @@ -18,6 +18,7 @@ use std::task::Context; use crate::internal::Mutex; use crate::internal::WaitSet; +use crate::internal::WaiterId; #[derive(Debug)] pub(crate) struct CountdownState { @@ -56,17 +57,26 @@ impl CountdownState { /// Drain and wake up all waiters. pub(crate) fn wake_all(&self) { - let mut waiters = self.waiters.lock(); - waiters.wake_all(); + let wakers = { + let mut waiters = self.waiters.lock(); + waiters.take_wakers() + }; + + for waker in wakers { + waker.wake(); + } } /// Registers a waker to be woken up when the countdown reaches zero. /// - /// `idx` must be `None` when the waker is not registered, or `Some(key)` where `key` is - /// a value previously returned by this method. - pub(crate) fn register_waker(&self, idx: &mut Option, cx: &mut Context<'_>) { - let mut waiters = self.waiters.lock(); - waiters.register_waker(idx, cx); + /// `id` must be `None` when the waker is not registered, or `Some` with a waiter ID previously + /// stored by this method. + pub(crate) fn register_waker(&self, id: &mut Option, cx: &mut Context<'_>) { + let waker = { + let mut waiters = self.waiters.lock(); + waiters.register_waker(id, cx) + }; + drop(waker); } /// Returns `Ok(())` if the counter is zero, otherwise returns `Err(s)` where `s` is the current diff --git a/mea/src/internal/waitset.rs b/mea/src/internal/waitset.rs index d57096f..99fa330 100644 --- a/mea/src/internal/waitset.rs +++ b/mea/src/internal/waitset.rs @@ -17,9 +17,26 @@ use std::task::Waker; use slab::Slab; +/// Identifies a registered waker in a [`WaitSet`]. +/// +/// The generation distinguishes reused slab slots so stale waiter IDs cannot remove or update a +/// newer waiter that happens to occupy the same index. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) struct WaiterId { + index: usize, + generation: u64, +} + +#[derive(Debug)] +struct Waiter { + generation: u64, + waker: Waker, +} + #[derive(Debug)] pub(crate) struct WaitSet { - waiters: Slab, + waiters: Slab, + next_generation: u64, } impl WaitSet { @@ -27,6 +44,7 @@ impl WaitSet { pub const fn new() -> Self { Self { waiters: Slab::new(), + next_generation: 0, } } @@ -34,47 +52,176 @@ impl WaitSet { pub fn with_capacity(capacity: usize) -> Self { Self { waiters: Slab::with_capacity(capacity), + next_generation: 0, } } - /// Drain and wake up all waiters. - pub(crate) fn wake_all(&mut self) { - for w in self.waiters.drain() { - w.wake(); + /// Drains all wakers from the wait set. + /// + /// Callers should release any lock that protects the wait set before calling [`Waker::wake`] on + /// the returned wakers, to avoid running user-provided waker code with that lock held. + pub(crate) fn take_wakers(&mut self) -> Vec { + self.waiters.drain().map(|waiter| waiter.waker).collect() + } + + /// Removes a previously registered waker from the wait set, returning it if it was still + /// registered. + /// + /// The returned waker should be dropped after releasing any lock that protects the wait set, + /// so that user-provided `Waker::Drop` does not run with that lock held. + pub(crate) fn remove_waker(&mut self, id: &mut Option) -> Option { + let key = id.take()?; + let waiter = self.waiters.get(key.index)?; + + if waiter.generation == key.generation { + Some(self.waiters.remove(key.index).waker) + } else { + None } } /// Registers a waker to the wait set. /// - /// `idx` must be `None` when the waker is not registered, or `Some(key)` where `key` is - /// a value previously returned by this method. - pub(crate) fn register_waker(&mut self, idx: &mut Option, cx: &mut Context<'_>) { - match *idx { - None => { - let key = self.waiters.insert(cx.waker().clone()); - *idx = Some(key); - } - Some(key) => { - if self.waiters.contains(key) { - if !self.waiters[key].will_wake(cx.waker()) { - self.waiters[key] = cx.waker().clone(); + /// `id` must be `None` when the waker is not registered, or `Some` with a waiter ID previously + /// stored by this method. + /// + /// If a stored waker is replaced by a different one, the old waker is returned. The caller + /// should drop the returned waker after releasing any lock that protects the wait set, so + /// that user-provided `Waker::Drop` does not run with that lock held. + pub(crate) fn register_waker( + &mut self, + id: &mut Option, + cx: &mut Context<'_>, + ) -> Option { + if let Some(key) = *id { + if let Some(waiter) = self.waiters.get_mut(key.index) { + if waiter.generation == key.generation { + if !waiter.waker.will_wake(cx.waker()) { + return Some(std::mem::replace(&mut waiter.waker, cx.waker().clone())); } - } else { - // DEFENSIVE NOTE: - // - // This is possible if latch/waitgroup is fired between the first and second - // state check. - // - // In this case, it does not harm to re-register the waker. Because - // the second state check will finish the future and the WaitSet gets - // dropped. - // - // Barrier holds the lock during check and register, so the race condition - // above won't happen. - let key = self.waiters.insert(cx.waker().clone()); - *idx = Some(key); + return None; } } } + + // The stored WaiterId may be stale if the waiter was removed, drained by `take_wakers`, + // or if the slab slot has since been reused by another waiter. Register the current waker + // again and replace the stale ID. + *id = Some(self.insert_waker(cx.waker())); + None + } + + /// Allocates a fresh waiter ID and stores the waker in the wait set. + fn insert_waker(&mut self, waker: &Waker) -> WaiterId { + let generation = self.next_generation; + + // Do not wrap the generation counter: wrapping could make a stale WaiterId valid again + // after enough insertions. + self.next_generation = self + .next_generation + .checked_add(1) + .expect("wait set generation counter overflowed"); + + let index = self.waiters.insert(Waiter { + generation, + waker: waker.clone(), + }); + WaiterId { index, generation } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use std::sync::atomic::AtomicBool; + use std::sync::atomic::AtomicUsize; + use std::sync::atomic::Ordering; + use std::task::Context; + use std::task::Wake; + use std::task::Waker; + + use super::WaitSet; + + struct DropWake { + dropped: Arc, + wake_count: AtomicUsize, + } + + impl Wake for DropWake { + fn wake(self: Arc) { + self.wake_count.fetch_add(1, Ordering::Relaxed); + } + } + + impl Drop for DropWake { + fn drop(&mut self) { + self.dropped.store(true, Ordering::Relaxed); + } + } + + fn drop_waker() -> (Waker, Arc) { + let dropped = Arc::new(AtomicBool::new(false)); + let waker = Waker::from(Arc::new(DropWake { + dropped: dropped.clone(), + wake_count: AtomicUsize::new(0), + })); + (waker, dropped) + } + + #[test] + fn test_remove_waker_delays_drop() { + let mut waiters = WaitSet::new(); + let mut id = None; + let (waker, dropped) = drop_waker(); + let mut cx = Context::from_waker(&waker); + + assert!(waiters.register_waker(&mut id, &mut cx).is_none()); + drop(waker); + + let removed = waiters.remove_waker(&mut id).unwrap(); + assert!(!dropped.load(Ordering::Relaxed)); + + drop(removed); + assert!(dropped.load(Ordering::Relaxed)); + } + + #[test] + fn test_replace_waker_delays_drop() { + let mut waiters = WaitSet::new(); + let mut id = None; + let (waker, dropped) = drop_waker(); + let mut cx = Context::from_waker(&waker); + + assert!(waiters.register_waker(&mut id, &mut cx).is_none()); + drop(waker); + + let (replacement, _) = drop_waker(); + let mut cx = Context::from_waker(&replacement); + let replaced = waiters.register_waker(&mut id, &mut cx).unwrap(); + assert!(!dropped.load(Ordering::Relaxed)); + + drop(replaced); + assert!(dropped.load(Ordering::Relaxed)); + } + + #[test] + fn test_stale_waiter_id_does_not_remove_reused_slot() { + let mut waiters = WaitSet::new(); + let mut stale_id = None; + let (waker, _) = drop_waker(); + let mut cx = Context::from_waker(&waker); + + assert!(waiters.register_waker(&mut stale_id, &mut cx).is_none()); + let stale_id_value = stale_id.unwrap(); + drop(waiters.take_wakers()); + + let mut current_id = None; + assert!(waiters.register_waker(&mut current_id, &mut cx).is_none()); + let current_id_value = current_id.unwrap(); + assert_eq!(stale_id_value.index, current_id_value.index); + assert_ne!(stale_id_value.generation, current_id_value.generation); + + assert!(waiters.remove_waker(&mut stale_id).is_none()); + assert!(waiters.remove_waker(&mut current_id).is_some()); } } diff --git a/mea/src/latch/mod.rs b/mea/src/latch/mod.rs index d82b5ce..e644ed7 100644 --- a/mea/src/latch/mod.rs +++ b/mea/src/latch/mod.rs @@ -60,6 +60,7 @@ use std::task::Context; use std::task::Poll; use crate::internal::CountdownState; +use crate::internal::WaiterId; #[cfg(test)] mod tests; @@ -252,7 +253,7 @@ impl Latch { } impl Latch { - fn intern_poll(&self, idx: &mut Option, cx: &mut Context<'_>) -> Poll<()> { + fn intern_poll(&self, idx: &mut Option, cx: &mut Context<'_>) -> Poll<()> { // register waker if the counter is not zero if self.state.spin_wait(16).is_err() { self.state.register_waker(idx, cx); @@ -271,7 +272,7 @@ impl Latch { /// This future will complete when the latch count reaches zero. #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct LatchWait<'a> { - idx: Option, + idx: Option, latch: &'a Latch, } @@ -295,7 +296,7 @@ impl Future for LatchWait<'_> { /// This future will complete when the latch count reaches zero. #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct OwnedLatchWait { - idx: Option, + idx: Option, latch: Arc, } diff --git a/mea/src/lib.rs b/mea/src/lib.rs index 969c5d4..7dab7a7 100644 --- a/mea/src/lib.rs +++ b/mea/src/lib.rs @@ -95,6 +95,38 @@ fn test_runtime() -> &'static tokio::runtime::Runtime { RT.get_or_init(|| Runtime::new().unwrap()) } +#[cfg(test)] +use std::sync::Arc; +#[cfg(test)] +use std::sync::atomic::AtomicUsize; +#[cfg(test)] +use std::sync::atomic::Ordering; +#[cfg(test)] +use std::task::Wake; +#[cfg(test)] +use std::task::Waker; + +#[cfg(test)] +struct CountWake(Arc); + +#[cfg(test)] +impl Wake for CountWake { + fn wake(self: Arc) { + self.0.fetch_add(1, Ordering::Relaxed); + } + + fn wake_by_ref(self: &Arc) { + self.0.fetch_add(1, Ordering::Relaxed); + } +} + +#[cfg(test)] +fn count_waker() -> (Waker, Arc) { + let wake_count = Arc::new(AtomicUsize::new(0)); + let waker = Waker::from(Arc::new(CountWake(wake_count.clone()))); + (waker, wake_count) +} + #[cfg(test)] mod tests { use crate::barrier::Barrier; diff --git a/mea/src/once/once/mod.rs b/mea/src/once/once/mod.rs index cb5efc3..76712a2 100644 --- a/mea/src/once/once/mod.rs +++ b/mea/src/once/once/mod.rs @@ -18,6 +18,7 @@ use std::task::Context; use std::task::Poll; use crate::internal::CountdownState; +use crate::internal::WaiterId; use crate::semaphore::Semaphore; #[cfg(test)] @@ -215,7 +216,7 @@ impl Once { } struct OnceWait<'a> { - idx: Option, + idx: Option, once: &'a Once, } diff --git a/mea/src/waitgroup/mod.rs b/mea/src/waitgroup/mod.rs index 66fc42a..9b273ef 100644 --- a/mea/src/waitgroup/mod.rs +++ b/mea/src/waitgroup/mod.rs @@ -59,6 +59,7 @@ use std::task::Context; use std::task::Poll; use crate::internal::CountdownState; +use crate::internal::WaiterId; #[cfg(test)] mod tests; @@ -145,7 +146,7 @@ impl IntoFuture for WaitGroup { /// will complete when the WaitGroup counter reaches zero. #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct Wait { - idx: Option, + idx: Option, state: Arc, }