diff --git a/include/tmc/detail/awaitable_customizer.hpp b/include/tmc/detail/awaitable_customizer.hpp index 64d69da0..eba50a57 100644 --- a/include/tmc/detail/awaitable_customizer.hpp +++ b/include/tmc/detail/awaitable_customizer.hpp @@ -68,54 +68,17 @@ struct awaitable_customizer_base { awaitable_customizer_base() noexcept : continuation{nullptr}, continuation_executor{tmc::detail::this_thread::executor()}, - done_count{nullptr}, flags{tmc::detail::this_thread::this_task().prio} {} - - // Either returns the awaiting coroutine (continuation) to be resumed - // directly, or submits that awaiting coroutine to the continuation executor - // to be resumed. This should be called exactly once, after the awaitable is - // complete and any results are ready. - // - // The overload taking a Handle destroys the coroutine via `self.destroy()` - // BEFORE performing any atomic operation that could allow a parent task to - // resume. This is required when this task is HALO'd into a parent task's - // allocation: after the atomic (e.g., done_count.fetch_sub), another child - // could resume the parent, which could then complete and destroy its - // allocation (including this task's HALO'd frame). By copying all needed data - // to stack locals and destroying ourselves first, we avoid use-after-free. - TMC_FORCE_INLINE inline std::coroutine_handle<> - resume_continuation(std::coroutine_handle<> self) noexcept { - // Copy all needed fields to stack locals FIRST, before destroying self. - void* lContinuationExecutor = continuation_executor; - void* lContinuation = continuation; - void* lDoneCount = done_count; - size_t lFlags = flags; - // Destroy the coroutine BEFORE the atomic that could allow parent - // destruction. After this point, `this` is INVALID - use only locals. - self.destroy(); - return resume_continuation_impl( - lContinuationExecutor, lContinuation, lDoneCount, lFlags - ); - } - - // The no-argument overload is for non-coroutine awaitables that don't need - // to destroy themselves. - TMC_FORCE_INLINE inline std::coroutine_handle<> - resume_continuation() noexcept { - // There's no risk of use-after-free here, so just pass the fields directly. - return resume_continuation_impl( - continuation_executor, continuation, done_count, flags - ); + done_count{nullptr}, flags{tmc::detail::this_thread::this_task().prio} { } -private: - // Implementation that works only with stack locals - no access to `this`. - TMC_FORCE_INLINE inline static std::coroutine_handle<> - resume_continuation_impl( - void* ContinuationExecutor, void* Continuation, void* DoneCount, + // Return the raw continuation, and if the ContinuationExecutor should be + // indirected (it's dispatched through a group), updates the parameter with + // the final value, ready to be cast to a tmc::ex_any*. + // May return nullptr. + TMC_FORCE_INLINE inline static std::coroutine_handle<> get_continuation( + void*& ContinuationExecutor, void* Continuation, void* DoneCount, size_t Flags ) noexcept { - tmc::ex_any* continuationExecutor = - static_cast(ContinuationExecutor); std::coroutine_handle<> finalContinuation; if (DoneCount == nullptr) { // being awaited alone, or detached @@ -150,7 +113,7 @@ struct awaitable_customizer_base { ) == 0; } if (shouldResume) { - continuationExecutor = + ContinuationExecutor = *static_cast(ContinuationExecutor); finalContinuation = *(static_cast*>(Continuation)); @@ -159,25 +122,139 @@ struct awaitable_customizer_base { } } - // Common submission and continuation logic + // Single return to satisfy NRVO + return finalContinuation; + } + + // Gets the awaiting coroutine (continuation) and then posts it to the + // continuation executor - no symmetric transfer. This should be called + // exactly once, after the awaitable is complete and any results are ready. + // + // The overload taking a Handle destroys the coroutine via `self.destroy()` + // BEFORE performing any atomic operation that could allow a parent task to + // resume. This is required when this task is HALO'd into a parent task's + // allocation: after the atomic (e.g., done_count.fetch_sub), another child + // could resume the parent, which could then complete and destroy its + // allocation (including this task's HALO'd frame). By copying all needed data + // to stack locals and destroying ourselves first, we avoid use-after-free. + TMC_FORCE_INLINE inline void + post_continuation(std::coroutine_handle<> self) noexcept { + // Copy all needed fields to stack locals FIRST, before destroying self. + void* lContinuationExecutor = continuation_executor; + void* lContinuation = continuation; + void* lDoneCount = done_count; + size_t lFlags = flags; + + // Destroy the coroutine BEFORE the atomic that could allow parent + // destruction. After this point, `this` is INVALID - use only locals. + self.destroy(); + + post_continuation_impl( + lContinuationExecutor, lContinuation, lDoneCount, lFlags + ); + } + + // Gets the awaiting coroutine (continuation) and then posts it to the + // continuation executor - no symmetric transfer. This should be called + // exactly once, after the awaitable is complete and any results are ready. + // + // The no-argument overload is for non-coroutine awaitables that don't need + // to destroy themselves. + TMC_FORCE_INLINE inline void post_continuation() noexcept { + // There's no risk of use-after-free here, so just pass the fields directly. + post_continuation_impl( + continuation_executor, continuation, done_count, flags + ); + } + + // Either returns the awaiting coroutine (continuation) to be resumed + // directly, or submits that awaiting coroutine to the continuation executor + // to be resumed. This should be called exactly once, after the awaitable is + // complete and any results are ready. + // + // The overload taking a Handle destroys the coroutine via `self.destroy()` + // BEFORE performing any atomic operation that could allow a parent task to + // resume. This is required when this task is HALO'd into a parent task's + // allocation: after the atomic (e.g., done_count.fetch_sub), another child + // could resume the parent, which could then complete and destroy its + // allocation (including this task's HALO'd frame). By copying all needed data + // to stack locals and destroying ourselves first, we avoid use-after-free. + TMC_FORCE_INLINE inline std::coroutine_handle<> + resume_continuation(std::coroutine_handle<> self) noexcept { + // Copy all needed fields to stack locals FIRST, before destroying self. + void* lContinuationExecutor = continuation_executor; + void* lContinuation = continuation; + void* lDoneCount = done_count; + size_t lFlags = flags; + + // Destroy the coroutine BEFORE the atomic that could allow parent + // destruction. After this point, `this` is INVALID - use only locals. + self.destroy(); + + return resume_continuation_impl( + lContinuationExecutor, lContinuation, lDoneCount, lFlags + ); + } + + // Either returns the awaiting coroutine (continuation) to be resumed + // directly, or submits that awaiting coroutine to the continuation executor + // to be resumed. This should be called exactly once, after the awaitable is + // complete and any results are ready. + // + // The no-argument overload is for non-coroutine awaitables that don't need + // to destroy themselves. + TMC_FORCE_INLINE inline std::coroutine_handle<> + resume_continuation() noexcept { + // There's no risk of use-after-free here, so just pass the fields directly. + return resume_continuation_impl( + continuation_executor, continuation, done_count, flags + ); + } + +private: + // Implementation that works only with stack locals - no access to `this`. + TMC_FORCE_INLINE inline static void post_continuation_impl( + void* ContinuationExecutor, void* Continuation, void* DoneCount, + size_t Flags + ) noexcept { + auto finalContinuation = + get_continuation(ContinuationExecutor, Continuation, DoneCount, Flags); + + if (finalContinuation == nullptr) { + return; + } + + size_t continuationPriority = Flags & task_flags::PRIORITY_MASK; + auto exec = static_cast(ContinuationExecutor); + tmc::detail::post_checked( + exec, std::move(finalContinuation), continuationPriority + ); + } + + // Implementation that works only with stack locals - no access to `this`. + TMC_FORCE_INLINE inline static std::coroutine_handle<> + resume_continuation_impl( + void* ContinuationExecutor, void* Continuation, void* DoneCount, + size_t Flags + ) noexcept { + auto finalContinuation = + get_continuation(ContinuationExecutor, Continuation, DoneCount, Flags); + + // Determine if we are allowed to symmetric transfer to the continuation if (finalContinuation == nullptr) { finalContinuation = std::noop_coroutine(); } else { size_t continuationPriority = Flags & task_flags::PRIORITY_MASK; - if (continuationExecutor != nullptr && - !tmc::detail::this_thread::exec_prio_is( - continuationExecutor, continuationPriority - )) { + auto exec = static_cast(ContinuationExecutor); + if (exec != nullptr && + !tmc::detail::this_thread::exec_prio_is(exec, continuationPriority)) { // post_checked is redundant with the prior check at the moment tmc::detail::post_checked( - continuationExecutor, std::move(finalContinuation), - continuationPriority + exec, std::move(finalContinuation), continuationPriority ); finalContinuation = std::noop_coroutine(); } } - - // Single return to satisfy NRVO return finalContinuation; } }; diff --git a/include/tmc/detail/waiter_list.hpp b/include/tmc/detail/waiter_list.hpp index dfa4fa5b..41e9f3fe 100644 --- a/include/tmc/detail/waiter_list.hpp +++ b/include/tmc/detail/waiter_list.hpp @@ -36,6 +36,19 @@ struct waiter_list_waiter { try_symmetric_transfer(std::coroutine_handle<> Outer) noexcept; }; +/// 1. Checks ToWake's executor and priority for symmetric transfer eligibility. +/// If eligible, returns ToWake and posts Continuation to its executor. +/// 2. Checks Continuation's executor and priority for symmetric transfer +/// eligibility. If eligible, returns Continuation and posts ToWake to its +/// executor. +/// 3. If neither is eligible, posts them both to their executors and returns +/// std::noop_coroutine(). +/// Also checks both for null (counts as ineligible). +[[nodiscard]] TMC_DECL std::coroutine_handle<> try_symmetric_transfer2_waiter( + waiter_list_waiter* ToWake, std::coroutine_handle<> Continuation, + tmc::ex_any* Executor, size_t Priority +) noexcept; + struct waiter_data_base; struct waiter_list_node { diff --git a/include/tmc/detail/waiter_list.ipp b/include/tmc/detail/waiter_list.ipp index d731199f..acf196e9 100644 --- a/include/tmc/detail/waiter_list.ipp +++ b/include/tmc/detail/waiter_list.ipp @@ -7,6 +7,7 @@ #include "tmc/detail/impl.hpp" // IWYU pragma: keep +#include "tmc/current.hpp" #include "tmc/detail/thread_locals.hpp" #include "tmc/detail/waiter_list.hpp" @@ -83,6 +84,40 @@ reverse_chain(tmc::detail::waiter_list_node* curr) noexcept { } } // namespace +std::coroutine_handle<> try_symmetric_transfer2_waiter( + waiter_list_waiter* ToWake, std::coroutine_handle<> Continuation, + tmc::ex_any* Executor, size_t Priority +) noexcept { + if (ToWake != nullptr) { + std::coroutine_handle<> toContinuation = ToWake->continuation; + tmc::ex_any* toExecutor = ToWake->continuation_executor; + size_t toPriority = ToWake->continuation_priority; + // If we can transfer to primary, then do so, and post backup. + if (tmc::detail::this_thread::exec_prio_is(toExecutor, toPriority)) { + if (Continuation != nullptr) { + tmc::detail::post_checked(Executor, std::move(Continuation), Priority); + } + return toContinuation; + } + + // Transfer to primary disallowed + tmc::detail::post_checked( + toExecutor, std::move(toContinuation), toPriority + ); + } + + if (Continuation != nullptr) { + // Try to transfer to backup + if (tmc::detail::this_thread::exec_prio_is(Executor, Priority)) { + return Continuation; + } + + // Transfer to backup disallowed + tmc::detail::post_checked(Executor, std::move(Continuation), Priority); + } + return std::noop_coroutine(); +} + void waiter_list::add_waiter(tmc::detail::waiter_list_node& w) noexcept { auto h = input.load(std::memory_order_acquire); do { diff --git a/include/tmc/mutex.hpp b/include/tmc/mutex.hpp index c94dfba1..e9792e44 100644 --- a/include/tmc/mutex.hpp +++ b/include/tmc/mutex.hpp @@ -6,11 +6,14 @@ #pragma once #include "tmc/detail/impl.hpp" // IWYU pragma: keep +#include "tmc/detail/awaitable_customizer.hpp" #include "tmc/detail/concepts_awaitable.hpp" #include "tmc/detail/waiter_list.hpp" #include +#include #include +#include namespace tmc::tests { class waiter_count_accessor; @@ -94,11 +97,54 @@ class [[nodiscard( aw_mutex_co_unlock& operator=(aw_mutex_co_unlock&&) = delete; }; +template +class [[nodiscard( + "You must co_await aw_mutex_co_unlock_return for it to have any effect." +)]] aw_mutex_co_unlock_return : tmc::detail::AwaitTagNoGroupAsIs { + mutex& parent; + + // Store lvalues by reference. Move rvalues into this. + using ReturnValueStorage = std::conditional_t< + std::is_lvalue_reference_v, Result, std::remove_cvref_t>; + + // Handle value return and void return. + struct empty {}; + using ResultStorage = + std::conditional_t, empty, ReturnValueStorage>; + TMC_NO_UNIQUE_ADDRESS ResultStorage result; + + friend class mutex; + + // For result return + template + inline aw_mutex_co_unlock_return(mutex& Parent, ResultArg&& ResultIn) noexcept + : parent(Parent), result(static_cast(ResultIn)) {} + + // For void return + inline aw_mutex_co_unlock_return(mutex& Parent) noexcept : parent(Parent) {} + +public: + inline bool await_ready() noexcept { return false; } + + template + std::coroutine_handle<> + await_suspend(std::coroutine_handle Outer) noexcept; + + [[maybe_unused]] inline void await_resume() noexcept {} + + aw_mutex_co_unlock_return(aw_mutex_co_unlock_return const&) = delete; + aw_mutex_co_unlock_return& + operator=(aw_mutex_co_unlock_return const&) = delete; + aw_mutex_co_unlock_return(aw_mutex_co_unlock_return&&) = delete; + aw_mutex_co_unlock_return& operator=(aw_mutex_co_unlock_return&&) = delete; +}; + /// An async version of std::mutex. class mutex : protected tmc::detail::waiter_data_base { friend class aw_acquire; friend class aw_mutex_lock_scope; friend class aw_mutex_co_unlock; + template friend class aw_mutex_co_unlock_return; friend class ::tmc::tests::waiter_count_accessor; static inline constexpr tmc::detail::half_word LOCKED = 0; @@ -137,6 +183,63 @@ class mutex : protected tmc::detail::waiter_data_base { return aw_mutex_co_unlock(*this); } + /// Unlocks the mutex. If there are any awaiters, an awaiter will be resumed + /// and the lock will be re-locked and transferred to that awaiter. Also + /// completes this coroutine immediately, returns the result back to its + /// parent coroutine, and resumes the parent coroutine. Both the resuming + /// awaiter and the parent coroutine will be checked for symmetric transfer + /// eligibility; otherwise they will be posted back to their respective + /// executors. + /// + /// This effectively contains a `co_return` statement, ending the current + /// coroutine; nothing will be executed after it in the current scope. + /// + /// The purpose of this is to skip a round-trip through the executor when + /// you want to unlock this mutex immediately before returning. + /// + /// ``` + /// // You can replace this: + /// co_await mut.co_unlock(); + /// co_return result; + /// + /// // With this: + /// co_await mut.co_unlock_return(result); + /// std::unreachable(); + /// ``` + template + inline aw_mutex_co_unlock_return + co_unlock_return(Result&& result) noexcept { + return aw_mutex_co_unlock_return( + *this, static_cast(result) + ); + } + + /// Unlocks the mutex. If there are any awaiters, an awaiter will be resumed + /// and the lock will be re-locked and transferred to that awaiter. Also + /// completes this coroutine immediately and resumes the parent coroutine. + /// Both the resuming awaiter and the parent coroutine will be checked for + /// symmetric transfer eligibility; otherwise they will be posted back to + /// their respective executors. + /// + /// This effectively contains a `co_return` statement, ending the current + /// coroutine; nothing will be executed after it in the current scope. + /// + /// The purpose of this is to skip a round-trip through the executor when + /// you want to unlock this mutex immediately before returning. + /// + /// ``` + /// // You can replace this: + /// co_await mut.co_unlock(); + /// co_return; + /// + /// // With this: + /// co_await mut.co_unlock_return(); + /// std::unreachable(); + /// ``` + inline aw_mutex_co_unlock_return co_unlock_return() noexcept { + return aw_mutex_co_unlock_return(*this); + } + /// Tries to acquire the mutex. If it is locked by another task, will /// suspend until it can be locked by this task, then transfer the /// ownership to this task. Not re-entrant. @@ -155,6 +258,50 @@ class mutex : protected tmc::detail::waiter_data_base { TMC_DECL ~mutex(); }; +template +template +std::coroutine_handle<> aw_mutex_co_unlock_return::await_suspend( + std::coroutine_handle Outer +) noexcept { + assert(parent.is_locked()); + if constexpr (std::is_void_v) { + Outer.promise().return_void(); + } else { + Outer.promise().return_value(static_cast(result)); + } + + // Unlock the mutex normally and capture the continuation + size_t old = + parent.value.fetch_or(mutex::UNLOCKED, std::memory_order_acq_rel); + size_t v = mutex::UNLOCKED | old; + auto toWake = parent.waiters.maybe_wake(parent.value, v, old, true); + + // Capture these values locally before destroying the coroutine frame + auto& customizer = Outer.promise().customizer; + void* continuationExecutor = customizer.continuation_executor; + void* continuationPtr = customizer.continuation; + void* doneCount = customizer.done_count; + size_t flags = customizer.flags; + + // Destroy the coroutine *before* calling get_continuation, which could allow + // the continuation to be stolen by another parent, which could then complete + // and destroy the frame of this coroutine if it is HALO'd into that parent. + // By destroying ourselves first, we avoid use-after-free. This is the same + // protocol that tmc::task's final_suspend follows. + Outer.destroy(); + + std::coroutine_handle<> continuation = + tmc::detail::awaitable_customizer_base::get_continuation( + continuationExecutor, continuationPtr, doneCount, flags + ); + + size_t continuationPriority = flags & tmc::detail::task_flags::PRIORITY_MASK; + return tmc::detail::try_symmetric_transfer2_waiter( + toWake, continuation, static_cast(continuationExecutor), + continuationPriority + ); +} + namespace detail { template <> struct awaitable_traits { static constexpr configure_mode mode = WRAPPER;