Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.

## Unreleased

### Bug fixes

* Clean up cancelled broadcast receive futures registered through the internal `WaitSet`.

## [0.6.3] - 2026-01-21

### Improvements
Expand Down
13 changes: 10 additions & 3 deletions mea/src/barrier/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ use std::task::Poll;

use crate::internal::Mutex;
use crate::internal::WaitSet;
use crate::internal::WaiterId;

#[cfg(test)]
mod tests;
Expand Down Expand Up @@ -228,7 +229,11 @@ impl Barrier {
if state.arrived == self.n {
state.arrived = 0;
state.generation += 1;
state.waiters.wake_all();
let wakers = state.waiters.take_wakers();
drop(state);
for waker in wakers {
waker.wake();
}
return BarrierWaitResult(true);
}

Expand All @@ -250,7 +255,7 @@ impl Barrier {
/// This future will complete when all tasks have reached the barrier point.
#[must_use = "futures do nothing unless you `.await` or poll them"]
struct BarrierWait<'a> {
idx: Option<usize>,
idx: Option<WaiterId>,
generation: usize,
barrier: &'a Barrier,
}
Expand All @@ -277,7 +282,9 @@ impl Future for BarrierWait<'_> {
if *generation < state.generation {
Poll::Ready(())
} else {
state.waiters.register_waker(idx, cx);
let waker = state.waiters.register_waker(idx, cx);
drop(state);
drop(waker);
Poll::Pending
}
}
Expand Down
50 changes: 44 additions & 6 deletions mea/src/broadcast/overflow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ use std::task::Poll;
use crate::internal::Mutex;
use crate::internal::RwLock;
use crate::internal::WaitSet;
use crate::internal::WaiterId;

#[cfg(test)]
mod tests;
Expand Down Expand Up @@ -218,7 +219,13 @@ impl<T> Drop for Sender<T> {
1 => {
// If this is the last sender, we need to wake up the receiver so it can
// observe the disconnected state.
self.shared.waiters.lock().wake_all();
let wakers = {
let mut waiters = self.shared.waiters.lock();
waiters.take_wakers()
};
for waker in wakers {
waker.wake();
}
}
_ => {
// there are still other senders left, do nothing
Expand Down Expand Up @@ -254,7 +261,13 @@ impl<T> Sender<T> {
}

// Notify all waiting receivers.
self.shared.waiters.lock().wake_all();
let wakers = {
let mut waiters = self.shared.waiters.lock();
waiters.take_wakers()
};
for waker in wakers {
waker.wake();
}
}

/// Creates a new receiver that starts receiving messages from the current tail of the channel.
Expand Down Expand Up @@ -435,7 +448,22 @@ impl<T> Receiver<T> {

struct Recv<'a, T> {
receiver: &'a mut Receiver<T>,
index: Option<usize>,
index: Option<WaiterId>,
}

impl<T> Drop for Recv<'_, T> {
fn drop(&mut self) {
// Ready paths clear the waiter ID, so only a cancelled pending receive takes this lock.
if self.index.is_none() {
return;
}

let waker = {
let mut waiters = self.receiver.shared.waiters.lock();
waiters.remove_waker(&mut self.index)
};
drop(waker);
}
}

impl<T: Clone> Future for Recv<'_, T> {
Expand All @@ -446,9 +474,16 @@ impl<T: Clone> Future for Recv<'_, T> {

loop {
match receiver.try_recv() {
Ok(val) => return Poll::Ready(Ok(val)),
Err(TryRecvError::Lagged(n)) => return Poll::Ready(Err(RecvError::Lagged(n))),
Ok(val) => {
*index = None;
return Poll::Ready(Ok(val));
}
Err(TryRecvError::Lagged(n)) => {
*index = None;
return Poll::Ready(Err(RecvError::Lagged(n)));
}
Err(TryRecvError::Disconnected) => {
*index = None;
return Poll::Ready(Err(RecvError::Disconnected));
}
Err(TryRecvError::Empty) => {}
Expand All @@ -468,11 +503,14 @@ impl<T: Clone> Future for Recv<'_, T> {
// Check for Closed
// Use Acquire to ensure we see all writes before the sender dropped.
if shared.senders.load(Ordering::Acquire) == 0 {
*index = None;
return Poll::Ready(Err(RecvError::Disconnected));
}

// Register Waker
waiters.register_waker(index, cx);
let waker = waiters.register_waker(index, cx);
drop(waiters);
drop(waker);
return Poll::Pending;
}
}
Expand Down
21 changes: 21 additions & 0 deletions mea/src/broadcast/overflow/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::future::Future;
use std::sync::atomic::Ordering;
use std::task::Context;

use super::*;
use crate::count_waker;

#[tokio::test]
async fn test_broadcast_basic() {
Expand Down Expand Up @@ -77,6 +82,22 @@ async fn test_wait_mechanism() {
assert_eq!(handle.await.unwrap(), Ok(42));
}

#[tokio::test]
async fn test_recv_cancellation_removes_waiter() {
let (tx, mut rx) = channel::<i32>(10);
let (waker, wake_count) = count_waker();
let mut cx = Context::from_waker(&waker);
let mut recv = Box::pin(rx.recv());

assert!(recv.as_mut().poll(&mut cx).is_pending());

drop(recv);
tx.send(1);

assert_eq!(wake_count.load(Ordering::Relaxed), 0);
assert_eq!(rx.try_recv(), Ok(1));
}

#[tokio::test]
async fn test_subscribe() {
let (tx, _rx) = channel(10);
Expand Down
24 changes: 17 additions & 7 deletions mea/src/internal/countdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use std::task::Context;

use crate::internal::Mutex;
use crate::internal::WaitSet;
use crate::internal::WaiterId;

#[derive(Debug)]
pub(crate) struct CountdownState {
Expand Down Expand Up @@ -56,17 +57,26 @@ impl CountdownState {

/// Drain and wake up all waiters.
pub(crate) fn wake_all(&self) {
let mut waiters = self.waiters.lock();
waiters.wake_all();
let wakers = {
let mut waiters = self.waiters.lock();
waiters.take_wakers()
};

for waker in wakers {
waker.wake();
}
}

/// Registers a waker to be woken up when the countdown reaches zero.
///
/// `idx` must be `None` when the waker is not registered, or `Some(key)` where `key` is
/// a value previously returned by this method.
pub(crate) fn register_waker(&self, idx: &mut Option<usize>, cx: &mut Context<'_>) {
let mut waiters = self.waiters.lock();
waiters.register_waker(idx, cx);
/// `id` must be `None` when the waker is not registered, or `Some` with a waiter ID previously
/// stored by this method.
pub(crate) fn register_waker(&self, id: &mut Option<WaiterId>, cx: &mut Context<'_>) {
let waker = {
let mut waiters = self.waiters.lock();
waiters.register_waker(id, cx)
};
drop(waker);
}

/// Returns `Ok(())` if the counter is zero, otherwise returns `Err(s)` where `s` is the current
Expand Down
Loading