diff --git a/examples/stackoverflow.cpp b/examples/stackoverflow.cpp index 2200387..6902f2d 100644 --- a/examples/stackoverflow.cpp +++ b/examples/stackoverflow.cpp @@ -11,25 +11,25 @@ struct task { using completion_signatures = ex::completion_signatures; struct base { - virtual void complete_value() noexcept = 0; + virtual void complete_value() noexcept = 0; + virtual void complete_stopped() noexcept = 0; }; struct promise_type { struct final_awaiter { base* data; bool await_ready() noexcept { return false; } - auto await_suspend(auto h) noexcept { - std::cout << "final_awaiter\n"; - this->data->complete_value(); - std::cout << "completed\n"; - }; - void await_resume() noexcept {} + auto await_suspend(auto h) noexcept { this->data->complete_value(); }; + void await_resume() noexcept {} }; std::suspend_always initial_suspend() const noexcept { return {}; } final_awaiter final_suspend() const noexcept { return {this->data}; } void unhandled_exception() const noexcept {} - std::coroutine_handle<> unhandled_stopped() { return std::coroutine_handle<>(); } - auto return_void() {} + std::coroutine_handle<> unhandled_stopped() { + this->data->complete_stopped(); + return std::noop_coroutine(); + } + auto return_void() {} auto get_return_object() { return task{std::coroutine_handle::from_promise(*this)}; } template <::beman::execution::sender Sender> auto await_transform(Sender&& sender) noexcept { @@ -51,7 +51,14 @@ struct task { this->handle.promise().data = this; this->handle.resume(); } - void complete_value() noexcept override { ex::set_value(std::move(this->r)); } + void complete_value() noexcept override { + this->handle.destroy(); + ex::set_value(std::move(this->r)); + } + void complete_stopped() noexcept override { + this->handle.destroy(); + ex::set_stopped(std::move(this->r)); + } }; std::coroutine_handle handle; @@ -63,11 +70,23 @@ struct task { }; int main(int ac, char*[]) { + std::cout << std::unitbuf; + using on_exit = std::unique_ptr; static_assert(ex::sender); ex::sync_wait([](int n) -> task { - for (int i{}; i < n; ++i) { - std::cout << "await=" << (co_await ex::just(i)) << "\n"; - } - co_return; + on_exit msg("coro run to the end"); + if constexpr (true) + for (int i{}; i < n; ++i) { + std::cout << "await just=" << (co_await ex::just(i)) << "\n"; + } + if constexpr (false) + for (int i{}; i < n; ++i) { + try { + co_await ex::just_error(i); + } catch (int x) { + std::cout << "await error=" << x << "\n"; + } + } + co_await ex::just_stopped(); }(ac < 2 ? 3 : 30000)); } diff --git a/include/beman/execution/detail/sender_awaitable.hpp b/include/beman/execution/detail/sender_awaitable.hpp index 2078f05..c49c56d 100644 --- a/include/beman/execution/detail/sender_awaitable.hpp +++ b/include/beman/execution/detail/sender_awaitable.hpp @@ -22,6 +22,8 @@ #include #include #include +#include +#include namespace beman::execution::detail { template @@ -31,58 +33,78 @@ class sender_awaitable { ::beman::execution::detail::single_sender_value_type>; using result_type = ::std::conditional_t<::std::is_void_v, unit, value_type>; using variant_type = ::std::variant<::std::monostate, result_type, ::std::exception_ptr>; + using data_type = ::std::tuple, ::std::coroutine_handle>; + struct awaitable_receiver { using receiver_concept = ::beman::execution::receiver_t; + void resume() { + if (::std::get<1>(*result_ptr_).exchange(true, std::memory_order_acq_rel)) { + ::std::get<2>(*result_ptr_).resume(); + } + } + template requires ::std::constructible_from void set_value(Args&&... args) && noexcept { try { - result_ptr_->template emplace<1>(::std::forward(args)...); + ::std::get<0>(*result_ptr_).template emplace<1>(::std::forward(args)...); } catch (...) { - result_ptr_->template emplace<2>(::std::current_exception()); + ::std::get<0>(*result_ptr_).template emplace<2>(::std::current_exception()); } - continuation_.resume(); + this->resume(); } - template void set_error(Error&& error) && noexcept { - result_ptr_->template emplace<2>(::beman::execution::detail::as_except_ptr(::std::forward(error))); - continuation_.resume(); + ::std::get<0>(*result_ptr_) + .template emplace<2>(::beman::execution::detail::as_except_ptr(::std::forward(error))); + this->resume(); } void set_stopped() && noexcept { - static_cast<::std::coroutine_handle<>>(continuation_.promise().unhandled_stopped()).resume(); + if (::std::get<1>(*result_ptr_).exchange(true, ::std::memory_order_acq_rel)) { + static_cast<::std::coroutine_handle<>>(::std::get<2>(*result_ptr_).promise().unhandled_stopped()) + .resume(); + } } auto get_env() const noexcept { - return ::beman::execution::detail::fwd_env{::beman::execution::get_env(continuation_.promise())}; + return ::beman::execution::detail::fwd_env{ + ::beman::execution::get_env(::std::get<2>(*result_ptr_).promise())}; } - variant_type* result_ptr_; - ::std::coroutine_handle continuation_; + data_type* result_ptr_; }; using op_state_type = ::beman::execution::connect_result_t; - variant_type result{}; + data_type result{}; op_state_type state; public: sender_awaitable(Sndr&& sndr, Promise& p) - : state{::beman::execution::connect( - ::std::forward(sndr), - awaitable_receiver{::std::addressof(result), ::std::coroutine_handle::from_promise(p)})} {} + : result{::std::monostate{}, false, ::std::coroutine_handle::from_promise(p)}, + state{::beman::execution::connect(::std::forward(sndr), + sender_awaitable::awaitable_receiver{::std::addressof(result)})} {} static constexpr bool await_ready() noexcept { return false; } - void await_suspend(::std::coroutine_handle) noexcept { ::beman::execution::start(state); } + bool await_suspend(::std::coroutine_handle) noexcept { + ::beman::execution::start(state); + if (::std::get<1>(this->result).exchange(true, std::memory_order_acq_rel)) { + if (::std::holds_alternative<::std::monostate>(::std::get<0>(this->result))) { + return bool(::std::get<2>(this->result).promise().unhandled_stopped()); + } + return false; + } + return true; + } value_type await_resume() { - if (::std::holds_alternative<::std::exception_ptr>(result)) { - ::std::rethrow_exception(::std::get<::std::exception_ptr>(result)); + if (::std::holds_alternative<::std::exception_ptr>(::std::get<0>(result))) { + ::std::rethrow_exception(::std::get<::std::exception_ptr>(::std::get<0>(result))); } if constexpr (::std::is_void_v) { return; } else { - return ::std::get(std::move(result)); + return ::std::get(std::move(::std::get<0>(result))); } } };