Skip to content

Commit fb7e9fa

Browse files
committed
guard against infinite recursion from too many inline completions in task
1 parent 5076be2 commit fb7e9fa

File tree

1 file changed

+150
-77
lines changed

1 file changed

+150
-77
lines changed

include/stdexec/__detail/__as_awaitable.hpp

Lines changed: 150 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include "__execution_fwd.hpp"
1919

20+
#include "__atomic.hpp"
2021
#include "__awaitable.hpp"
2122
#include "__completion_signatures_of.hpp"
2223
#include "__concepts.hpp"
@@ -31,6 +32,9 @@
3132
#include <system_error>
3233
#include <variant>
3334

35+
STDEXEC_PRAGMA_PUSH()
36+
STDEXEC_PRAGMA_IGNORE_GNU("-Wmissing-braces")
37+
3438
namespace STDEXEC
3539
{
3640
#if !STDEXEC_NO_STDCPP_COROUTINES()
@@ -46,11 +50,11 @@ namespace STDEXEC
4650
inline constexpr __mconst<void> __as_single<0>;
4751

4852
template <class... _Values>
49-
using __single_value = __minvoke<decltype(__as_single<sizeof...(_Values)>), _Values...>;
53+
using __single_value_t = __minvoke<decltype(__as_single<sizeof...(_Values)>), _Values...>;
5054

5155
template <class _Sender, class _Promise>
5256
using __value_t = __decay_t<
53-
__value_types_of_t<_Sender, env_of_t<_Promise&>, __q<__single_value>, __msingle_or<void>>>;
57+
__value_types_of_t<_Sender, env_of_t<_Promise&>, __q<__single_value_t>, __msingle_or<void>>>;
5458

5559
inline constexpr auto __get_await_completion_adaptor =
5660
__with_default{get_await_completion_adaptor, std::identity{}};
@@ -75,7 +79,7 @@ namespace STDEXEC
7579
namespace __as_awaitable
7680
{
7781
struct __void
78-
{};
82+
{ };
7983

8084
template <class _Value>
8185
using __value_or_void_t = __if_c<__same_as<_Value, void>, __void, _Value>;
@@ -93,6 +97,45 @@ namespace STDEXEC
9397
&& __completes_inline_for<set_error_t, _Sender, _Env...>
9498
&& __completes_inline_for<set_stopped_t, _Sender, _Env...>;
9599

100+
template <class _Value, bool _Inline>
101+
struct __sender_awaitable_base;
102+
103+
template <class _Value>
104+
struct __sender_awaitable_base<_Value, true>
105+
{
106+
static constexpr auto await_ready() noexcept -> bool
107+
{
108+
return false;
109+
}
110+
111+
constexpr auto await_resume() -> _Value
112+
{
113+
// If the operation completed with set_stopped (as denoted by the monostate
114+
// alternative being active), we should not be resuming this coroutine at all.
115+
STDEXEC_ASSERT(__result_.index() != 0);
116+
if (__result_.index() == 2)
117+
{
118+
// The operation completed with set_error, so we need to rethrow the exception.
119+
std::rethrow_exception(std::move(std::get<2>(__result_)));
120+
}
121+
// The operation completed with set_value, so we can just return the value, which
122+
// may be void.
123+
using __reference_t = std::add_rvalue_reference_t<_Value>;
124+
return static_cast<__reference_t>(std::get<1>(__result_));
125+
}
126+
127+
__std::coroutine_handle<> __continuation_;
128+
__expected_t<_Value> __result_{};
129+
};
130+
131+
// When the sender is not statically known to complete inline, we need to use atomic
132+
// state to guard against too many inline completions causing a stack overflow.
133+
template <class _Value>
134+
struct __sender_awaitable_base<_Value, false> : __sender_awaitable_base<_Value, true>
135+
{
136+
__std::atomic<bool> __ready_{false};
137+
};
138+
96139
template <class _Value>
97140
struct __receiver_base
98141
{
@@ -103,36 +146,38 @@ namespace STDEXEC
103146
{
104147
STDEXEC_TRY
105148
{
106-
__result_.template emplace<1>(static_cast<_Us&&>(__us)...);
149+
__awaiter_.__result_.template emplace<1>(static_cast<_Us&&>(__us)...);
107150
}
108151
STDEXEC_CATCH_ALL
109152
{
110-
__result_.template emplace<2>(std::current_exception());
153+
__awaiter_.__result_.template emplace<2>(std::current_exception());
111154
}
112155
}
113156

114157
template <class _Error>
115158
void set_error(_Error&& __err) noexcept
116159
{
117160
if constexpr (__decays_to<_Error, std::exception_ptr>)
118-
__result_.template emplace<2>(static_cast<_Error&&>(__err));
161+
__awaiter_.__result_.template emplace<2>(static_cast<_Error&&>(__err));
119162
else if constexpr (__decays_to<_Error, std::error_code>)
120-
__result_.template emplace<2>(std::make_exception_ptr(std::system_error(__err)));
163+
__awaiter_.__result_.template emplace<2>(
164+
std::make_exception_ptr(std::system_error(__err)));
121165
else
122-
__result_.template emplace<2>(std::make_exception_ptr(static_cast<_Error&&>(__err)));
166+
__awaiter_.__result_.template emplace<2>(
167+
std::make_exception_ptr(static_cast<_Error&&>(__err)));
123168
}
124169

125-
__expected_t<_Value>& __result_;
170+
__sender_awaitable_base<_Value, true>& __awaiter_;
126171
};
127172

128173
template <class _Promise, class _Value>
129174
struct __sync_receiver : __receiver_base<_Value>
130175
{
131-
constexpr explicit __sync_receiver(__expected_t<_Value>& __result,
132-
__std::coroutine_handle<_Promise> __continuation) noexcept
133-
: __receiver_base<_Value>{__result}
134-
, __continuation_{__continuation}
135-
{}
176+
using __awaiter_t = __sender_awaitable_base<_Value, true>;
177+
178+
constexpr explicit __sync_receiver(__awaiter_t& __awaiter) noexcept
179+
: __receiver_base<_Value>{__awaiter}
180+
{ }
136181

137182
void set_stopped() noexcept
138183
{
@@ -141,35 +186,37 @@ namespace STDEXEC
141186
}
142187

143188
// Forward get_env query to the coroutine promise
189+
[[nodiscard]]
144190
constexpr auto get_env() const noexcept -> env_of_t<_Promise&>
145191
{
146-
return STDEXEC::get_env(__continuation_.promise());
192+
auto __pcoro = this->__awaiter_.__continuation_.address();
193+
auto __hcoro = __std::coroutine_handle<_Promise>::from_address(__pcoro);
194+
return STDEXEC::get_env(__hcoro.promise());
147195
}
148-
149-
__std::coroutine_handle<_Promise> __continuation_;
150196
};
151197

152198
// The receiver type used to connect to senders that could complete asynchronously.
153199
template <class _Promise, class _Value>
154200
struct __async_receiver : __sync_receiver<_Promise, _Value>
155201
{
156-
constexpr explicit __async_receiver(__expected_t<_Value>& __result,
157-
__std::coroutine_handle<_Promise> __continuation) noexcept
158-
: __sync_receiver<_Promise, _Value>{__result, __continuation}
159-
{}
202+
using __awaiter_t = __sender_awaitable_base<_Value, false>;
203+
204+
constexpr explicit __async_receiver(__awaiter_t& __awaiter) noexcept
205+
: __sync_receiver<_Promise, _Value>{__awaiter}
206+
{ }
160207

161208
template <class... _Us>
162209
void set_value(_Us&&... __us) noexcept
163210
{
164211
this->__sync_receiver<_Promise, _Value>::set_value(static_cast<_Us&&>(__us)...);
165-
this->__continuation_.resume();
212+
__done();
166213
}
167214

168215
template <class _Error>
169216
void set_error(_Error&& __err) noexcept
170217
{
171218
this->__sync_receiver<_Promise, _Value>::set_error(static_cast<_Error&&>(__err));
172-
this->__continuation_.resume();
219+
__done();
173220
}
174221

175222
constexpr void set_stopped() noexcept
@@ -179,14 +226,36 @@ namespace STDEXEC
179226
// Resuming the stopped continuation unwinds the coroutine stack until we reach
180227
// a promise that can handle the stopped signal. The coroutine referred to by
181228
// __continuation_ will never be resumed.
182-
__std::coroutine_handle<> __on_stopped =
183-
this->__continuation_.promise().unhandled_stopped();
229+
auto __pcoro = this->__awaiter_.__continuation_.address();
230+
auto __hcoro = __std::coroutine_handle<_Promise>::from_address(__pcoro);
231+
auto& __promise = __hcoro.promise();
232+
__std::coroutine_handle<> __on_stopped = __promise.unhandled_stopped();
184233
__on_stopped.resume();
185234
}
186235
STDEXEC_CATCH_ALL
187236
{
188-
this->__result_.template emplace<2>(std::current_exception());
189-
this->__continuation_.resume();
237+
this->__awaiter_.__result_.template emplace<2>(std::current_exception());
238+
this->__awaiter_.__continuation_.resume();
239+
}
240+
}
241+
242+
private:
243+
void __done() noexcept
244+
{
245+
// If __ready_ is still false, then we are completing inline. Update the
246+
// value of __ready_. Otherwise, we are completing asynchronously, so we invoke
247+
// the continuation without unwinding the stack.
248+
auto __expected = false;
249+
auto& __awaiter = static_cast<__awaiter_t&>(this->__awaiter_);
250+
if (!__awaiter.__ready_.compare_exchange_strong(__expected,
251+
true,
252+
__std::memory_order_release,
253+
__std::memory_order_acquire))
254+
{
255+
// We get here if __ready_ was true. It got set to true in await_suspend()
256+
// immediately after the operation was started, which implies that the operation
257+
// completed asynchronously, so we need to resume the continuation.
258+
__awaiter.__continuation_.resume();
190259
}
191260
}
192261
};
@@ -197,49 +266,49 @@ namespace STDEXEC
197266
template <class _Sender, class _Promise>
198267
using __async_receiver_t = __async_receiver<_Promise, __detail::__value_t<_Sender, _Promise>>;
199268

200-
template <class _Value>
201-
struct __sender_awaitable_base
202-
{
203-
static constexpr auto await_ready() noexcept -> bool
204-
{
205-
return false;
206-
}
207-
208-
constexpr auto await_resume() -> _Value
209-
{
210-
// If the operation completed with set_stopped (as denoted by the monostate
211-
// alternative being active), we should not be resuming this coroutine at all.
212-
STDEXEC_ASSERT(__result_.index() != 0);
213-
if (__result_.index() == 2)
214-
{
215-
// The operation completed with set_error, so we need to rethrow the exception.
216-
std::rethrow_exception(std::move(std::get<2>(__result_)));
217-
}
218-
// The operation completed with set_value, so we can just return the value, which
219-
// may be void.
220-
return static_cast<std::add_rvalue_reference_t<_Value>>(std::get<1>(__result_));
221-
}
222-
223-
protected:
224-
__expected_t<_Value> __result_{};
225-
};
226-
227269
//////////////////////////////////////////////////////////////////////////////////////
228270
// __sender_awaitable: awaitable type returned by as_awaitable when given a sender
229271
// that does not have an as_awaitable member function
230272
template <class _Promise, class _Sender>
231-
struct __sender_awaitable : __sender_awaitable_base<__detail::__value_t<_Sender, _Promise>>
273+
struct __sender_awaitable
274+
: __sender_awaitable_base<__detail::__value_t<_Sender, _Promise>, false>
232275
{
276+
using __value_t = __detail::__value_t<_Sender, _Promise>;
277+
233278
constexpr explicit __sender_awaitable(_Sender&& __sndr,
234279
__std::coroutine_handle<_Promise> __hcoro)
235280
noexcept(__nothrow_connectable<_Sender, __receiver_t>)
236-
: __opstate_(STDEXEC::connect(static_cast<_Sender&&>(__sndr),
237-
__receiver_t(this->__result_, __hcoro)))
238-
{}
281+
: __sender_awaitable_base<__value_t, false>{__hcoro}
282+
, __opstate_(STDEXEC::connect(static_cast<_Sender&&>(__sndr), __receiver_t(*this)))
283+
{ }
239284

240-
constexpr void await_suspend(__std::coroutine_handle<_Promise>) noexcept
285+
constexpr auto
286+
await_suspend([[maybe_unused]] __std::coroutine_handle<_Promise> __hcoro) noexcept //
287+
-> __std::coroutine_handle<>
241288
{
289+
STDEXEC_ASSERT(this->__continuation_ == __hcoro);
290+
291+
// Start the operation.
242292
STDEXEC::start(__opstate_);
293+
294+
auto __expected = false;
295+
if (this->__ready_.compare_exchange_strong(__expected,
296+
true,
297+
__std::memory_order_release,
298+
__std::memory_order_acquire))
299+
{
300+
// If __ready_ is still false, then the operation has not completed inline. The
301+
// continuation will be resumed when the operation completes, so we return a
302+
// noop_coroutine to suspend the current coroutine.
303+
return __std::noop_coroutine();
304+
}
305+
else
306+
{
307+
// The operation completed inline with set_value or set_error, so we can just
308+
// resume the current coroutine. await_resume will either return the value or
309+
// throw as appropriate.
310+
return __hcoro;
311+
}
243312
}
244313

245314
private:
@@ -252,18 +321,23 @@ namespace STDEXEC
252321
template <class _Promise, class _Sender>
253322
requires __completes_inline<_Sender, env_of_t<_Promise&>>
254323
struct __sender_awaitable<_Promise, _Sender>
255-
: __sender_awaitable_base<__detail::__value_t<_Sender, _Promise>>
324+
: __sender_awaitable_base<__detail::__value_t<_Sender, _Promise>, true>
256325
{
257-
constexpr explicit __sender_awaitable(_Sender&& sndr, __ignore)
326+
using __value_t = __detail::__value_t<_Sender, _Promise>;
327+
328+
constexpr explicit __sender_awaitable(_Sender&& sndr,
329+
__std::coroutine_handle<_Promise> __hcoro)
258330
noexcept(__nothrow_move_constructible<_Sender>)
259-
: __sndr_(static_cast<_Sender&&>(sndr))
260-
{}
331+
: __sender_awaitable_base<__value_t, true>{__hcoro}
332+
, __sndr_(static_cast<_Sender&&>(sndr))
333+
{ }
261334

262-
bool await_suspend(__std::coroutine_handle<_Promise> __hcoro)
335+
auto await_suspend([[maybe_unused]] __std::coroutine_handle<_Promise> __hcoro)
336+
-> __std::coroutine_handle<>
263337
{
338+
STDEXEC_ASSERT(this->__continuation_ == __hcoro);
264339
{
265-
auto __opstate = STDEXEC::connect(static_cast<_Sender&&>(__sndr_),
266-
__receiver_t(this->__result_, __hcoro));
340+
auto __opstate = STDEXEC::connect(static_cast<_Sender&&>(__sndr_), __receiver_t(*this));
267341
// The following call to start will complete synchronously, writing its result
268342
// into the __result_ variant.
269343
STDEXEC::start(__opstate);
@@ -275,18 +349,15 @@ namespace STDEXEC
275349
// unhandled_stopped() on the promise to propagate the stop signal. That will
276350
// result in the coroutine being torn down, so beware. We then resume the
277351
// returned coroutine handle (which may be a noop_coroutine).
278-
__std::coroutine_handle<> __on_stopped = __hcoro.promise().unhandled_stopped();
279-
__on_stopped.resume();
280-
281-
// By returning true, we indicate that the coroutine should not be resumed
282-
// (because it no longer exists).
283-
return true;
352+
return __hcoro.promise().unhandled_stopped();
353+
}
354+
else
355+
{
356+
// The operation completed with set_value or set_error, so we can just resume
357+
// the current coroutine. await_resume will either return the value or throw as
358+
// appropriate.
359+
return __hcoro;
284360
}
285-
286-
// The operation completed with set_value or set_error, so we can just resume the
287-
// current coroutine. await_resume with either return the value or throw as
288-
// appropriate.
289-
return false;
290361
}
291362

292363
private:
@@ -413,3 +484,5 @@ namespace STDEXEC
413484
inline constexpr as_awaitable_t as_awaitable{};
414485
#endif
415486
} // namespace STDEXEC
487+
488+
STDEXEC_PRAGMA_POP()

0 commit comments

Comments
 (0)