diff --git a/include/async/cancellation.hpp b/include/async/cancellation.hpp index bbf350b..2052a5c 100644 --- a/include/async/cancellation.hpp +++ b/include/async/cancellation.hpp @@ -1,7 +1,9 @@ #pragma once #include +#include #include + #include #include "basic.hpp" @@ -349,4 +351,114 @@ inline suspend_indefinitely_sender suspend_indefinitely(C&&... cts return {std::array{cts...}}; } +//--------------------------------------------------------------------------------------- +// with_cancel_cb() +//--------------------------------------------------------------------------------------- + +template +requires Receives +struct with_cancel_cb_operation { + using value_type = typename S::value_type; + + with_cancel_cb_operation(S s, Cb cb, cancellation_token ct, R dr) + : op_{execution::connect(std::move(s), intermediate_receiver{this})}, + cb_{std::move(cb)}, + dr_{std::move(dr)}, + ct_{ct} { } + + with_cancel_cb_operation(const with_cancel_cb_operation &) = delete; + + with_cancel_cb_operation &operator=(const with_cancel_cb_operation &) = delete; + + void start() { + cobs_.force_set(ct_); + execution::start(op_); + } + +private: + struct intermediate_receiver { + template + void set_value(Args &&...args) { + self->value_.emplace(std::forward(args)...); + + // If try_reset() succeeds, the operation was not cancelled and cancel_state_ is irrelevant. + if (self->cobs_.try_reset() + || self->cancel_state_.fetch_sub(1, std::memory_order_acq_rel) == 1) + self->complete_(); + } + + with_cancel_cb_operation *self; + }; + static_assert(Receives); + + struct cancel_handler { + void operator()() { + self->cb_(); + + if (self->cancel_state_.fetch_sub(1, std::memory_order_acq_rel) == 1) + self->complete_(); + } + + with_cancel_cb_operation *self; + }; + + void complete_() { + assert(value_); + if constexpr (std::is_same_v) { + execution::set_value(std::move(dr_)); + } else { + execution::set_value(std::move(dr_), std::move(*value_)); + } + } + + execution::operation_t op_; + Cb cb_; + R dr_; + cancellation_token ct_; + cancellation_observer cobs_{cancel_handler{this}}; + + // Valid if cancellation is triggered before try_reset(): + // 2: Both cb_() and the operation are still running. + // 1: Either cb_() or the operation (both not both) are still running. + // 0: Both cb_() and the operation are done. + std::atomic cancel_state_{2}; + + struct empty { }; + + std::optional< + std::conditional_t< + std::is_same_v, + empty, + value_type + > + > value_; +}; + +template +struct [[nodiscard]] with_cancel_cb_sender { + using value_type = typename S::value_type; + + template + requires Receives + friend with_cancel_cb_operation + connect(with_cancel_cb_sender s, R r) { + return {std::move(s.s), std::move(s.cb), s.ct, std::move(r)}; + } + + S s; + Cb cb; + cancellation_token ct; +}; + +template +with_cancel_cb_sender with_cancel_cb(S s, Cb cb, cancellation_token ct) { + return {std::move(s), std::move(cb), ct}; +} + +template +sender_awaiter, typename S::value_type> +operator co_await(with_cancel_cb_sender s) { + return {std::move(s)}; +} + } // namespace async diff --git a/tests/meson.build b/tests/meson.build index 42fa799..f5711b6 100644 --- a/tests/meson.build +++ b/tests/meson.build @@ -21,7 +21,8 @@ sources = files( 'oneshot.cpp', 'promise.cpp', 'sequenced.cpp', - 'post-ack.cpp' + 'post-ack.cpp', + 'with_cancel_cb.cpp', ) exe = executable('gtests', diff --git a/tests/with_cancel_cb.cpp b/tests/with_cancel_cb.cpp new file mode 100644 index 0000000..ab8c4f6 --- /dev/null +++ b/tests/with_cancel_cb.cpp @@ -0,0 +1,78 @@ +#include +#include +#include +#include + +TEST(Algorithm, WithCancelCbHappy) { + bool cb_called = false; + async::cancellation_event ce; + + int v = async::run(async::with_cancel_cb( + []() -> async::result { + co_return 42; + }(), + [&] { + cb_called = true; + }, + async::cancellation_token{ce} + )); + + ASSERT_EQ(v, 42); + ASSERT_FALSE(cb_called); +} + +TEST(Algorithm, WithCancelCbHappyVoid) { + bool cb_called = false; + async::cancellation_event ce; + + async::run(async::with_cancel_cb( + []() -> async::result { + co_return; + }(), + [&] { + cb_called = true; + }, + async::cancellation_token{ce} + )); + + ASSERT_FALSE(cb_called); +} + +TEST(Algorithm, WithCancelCbCancelledBefore) { + bool cb_called = false; + async::cancellation_event ce; + ce.cancel(); + + int v = async::run(async::with_cancel_cb( + []() -> async::result { + co_return 42; + }(), + [&] { + cb_called = true; + }, + async::cancellation_token{ce} + )); + + ASSERT_EQ(v, 42); + ASSERT_TRUE(cb_called); +} + +TEST(Algorithm, WithCancelCbCancelledDuring) { + bool cb_called = false; + async::cancellation_event ce; + ce.cancel(); + + int v = async::run(async::with_cancel_cb( + [](async::cancellation_event *evp) -> async::result { + evp->cancel(); + co_return 42; + }(&ce), + [&] { + cb_called = true; + }, + async::cancellation_token{ce} + )); + + ASSERT_EQ(v, 42); + ASSERT_TRUE(cb_called); +}