diff --git a/include/asyncpp/io/socket.h b/include/asyncpp/io/socket.h index 7123fea..1a45f78 100644 --- a/include/asyncpp/io/socket.h +++ b/include/asyncpp/io/socket.h @@ -38,6 +38,7 @@ namespace asyncpp::io { class socket_accept_awaitable; class socket_accept_error_code_awaitable; class socket_send_awaitable; + class socket_send_exact_awaitable; class socket_recv_awaitable; class socket_recv_exact_awaitable; class socket_recv_from_awaitable; @@ -49,6 +50,7 @@ namespace asyncpp::io { using socket_accept_error_code_cancellable_awaitable = detail::cancellable_awaitable; using socket_send_cancellable_awaitable = detail::cancellable_awaitable; + using socket_send_exact_cancellable_awaitable = detail::cancellable_awaitable; using socket_recv_cancellable_awaitable = detail::cancellable_awaitable; using socket_recv_exact_cancellable_awaitable = detail::cancellable_awaitable; using socket_recv_from_cancellable_awaitable = detail::cancellable_awaitable; @@ -123,6 +125,8 @@ namespace asyncpp::io { [[nodiscard]] socket_accept_error_code_awaitable accept(std::error_code& ec) noexcept; [[nodiscard]] socket_send_awaitable send(const void* buffer, std::size_t size) noexcept; [[nodiscard]] socket_send_awaitable send(const void* buffer, std::size_t size, std::error_code& ec) noexcept; + [[nodiscard]] socket_send_exact_awaitable send_exact(const void* buffer, std::size_t size) noexcept; + [[nodiscard]] socket_send_exact_awaitable send_exact(const void* buffer, std::size_t size, std::error_code& ec) noexcept; [[nodiscard]] socket_recv_awaitable recv(void* buffer, std::size_t size) noexcept; [[nodiscard]] socket_recv_awaitable recv(void* buffer, std::size_t size, std::error_code& ec) noexcept; [[nodiscard]] socket_recv_exact_awaitable recv_exact(void* buffer, std::size_t size) noexcept; @@ -146,6 +150,10 @@ namespace asyncpp::io { asyncpp::stop_token st) noexcept; [[nodiscard]] socket_send_cancellable_awaitable send(const void* buffer, std::size_t size, asyncpp::stop_token st, std::error_code& ec) noexcept; + [[nodiscard]] socket_send_exact_cancellable_awaitable send_exact(const void* buffer, std::size_t size, + asyncpp::stop_token st) noexcept; + [[nodiscard]] socket_send_exact_cancellable_awaitable send_exact(const void* buffer, std::size_t size, + asyncpp::stop_token st, std::error_code& ec) noexcept; [[nodiscard]] socket_recv_cancellable_awaitable recv(void* buffer, std::size_t size, asyncpp::stop_token st) noexcept; [[nodiscard]] socket_recv_cancellable_awaitable recv(void* buffer, std::size_t size, asyncpp::stop_token st, @@ -173,6 +181,9 @@ namespace asyncpp::io { template requires(std::is_invocable_v) void send(const void* buffer, std::size_t size, FN&& cb, asyncpp::stop_token st = {}); + template + requires(std::is_invocable_v) + void send_exact(const void* buffer, std::size_t size, FN&& cb, asyncpp::stop_token st = {}); template requires(std::is_invocable_v) void recv(void* buffer, std::size_t size, FN&& cb, asyncpp::stop_token st = {}); @@ -260,6 +271,21 @@ namespace asyncpp::io { void await_resume(); }; + class socket_send_exact_awaitable : public detail::socket_awaitable_base { + std::byte const* m_buffer; + std::size_t const m_size; + std::size_t m_remaining; + asyncpp::coroutine_handle<> m_handle; + std::error_code* const m_ec; + + public: + socket_send_exact_awaitable(socket& sock, const void* buffer, size_t size, + std::error_code* ec = nullptr) noexcept + : socket_awaitable_base{sock}, m_buffer{static_cast(buffer)}, m_size{size}, m_remaining{size}, m_ec{ec} {} + bool await_suspend(coroutine_handle<> hdl); + size_t await_resume(); + }; + class socket_recv_awaitable : public detail::socket_awaitable_base { void* const m_buffer; std::size_t const m_size; @@ -355,6 +381,15 @@ namespace asyncpp::io { return socket_send_awaitable(*this, buffer, size, &ec); } + [[nodiscard]] inline socket_send_exact_awaitable socket::send_exact(const void* buffer, std::size_t size) noexcept { + return socket_send_exact_awaitable(*this, buffer, size); + } + + [[nodiscard]] inline socket_send_exact_awaitable socket::send_exact(const void* buffer, std::size_t size, + std::error_code& ec) noexcept { + return socket_send_exact_awaitable(*this, buffer, size, &ec); + } + [[nodiscard]] inline socket_recv_awaitable socket::recv(void* buffer, std::size_t size) noexcept { return socket_recv_awaitable(*this, buffer, size); } @@ -421,6 +456,16 @@ namespace asyncpp::io { return socket_send_cancellable_awaitable(std::move(st), *this, buffer, size, &ec); } + [[nodiscard]] inline socket_send_exact_cancellable_awaitable socket::send_exact(const void* buffer, std::size_t size, + asyncpp::stop_token st) noexcept { + return socket_send_exact_cancellable_awaitable(std::move(st), *this, buffer, size); + } + + [[nodiscard]] inline socket_send_exact_cancellable_awaitable + socket::send_exact(const void* buffer, std::size_t size, asyncpp::stop_token st, std::error_code& ec) noexcept { + return socket_send_exact_cancellable_awaitable(std::move(st), *this, buffer, size, &ec); + } + [[nodiscard]] inline socket_recv_cancellable_awaitable socket::recv(void* buffer, std::size_t size, asyncpp::stop_token st) noexcept { return socket_recv_cancellable_awaitable(std::move(st), *this, buffer, size); @@ -487,6 +532,46 @@ namespace asyncpp::io { *m_ec = m_completion.result; } + inline bool socket_send_exact_awaitable::await_suspend(coroutine_handle<> hdl) { + m_completion.callback = [](void* ptr) { + auto that = static_cast(ptr); + auto engine = that->m_socket.service().engine(); + do { + if (that->m_completion.result_size == 0) { + that->m_completion.result = std::make_error_code(std::errc::not_connected); + } + if (that->m_completion.result) { + that->m_handle.resume(); + break; + } + that->m_buffer += that->m_completion.result_size; + that->m_remaining -= that->m_completion.result_size; + if (that->m_remaining == 0) { + that->m_handle.resume(); + break; + } + } while (engine->enqueue_send(that->m_socket.native_handle(), that->m_buffer, that->m_remaining, + &that->m_completion)); + }; + m_completion.userdata = this; + m_handle = hdl; + auto engine = m_socket.service().engine(); + while (engine->enqueue_send(m_socket.native_handle(), m_buffer, m_remaining, &m_completion)) { + if (m_completion.result) return false; + m_buffer += m_completion.result_size; + m_remaining -= m_completion.result_size; + if (m_remaining == 0) return false; + } + return true; + } + + inline size_t socket_send_exact_awaitable::await_resume() { + if (!m_completion.result) return m_size - m_remaining; + if (m_ec == nullptr) throw std::system_error(m_completion.result); + *m_ec = m_completion.result; + return m_size - m_remaining; + } + inline bool socket_recv_awaitable::await_suspend(coroutine_handle<> hdl) { m_completion.callback = [](void* ptr) { coroutine_handle<>::from_address(ptr).resume(); }; m_completion.userdata = hdl.address(); @@ -663,6 +748,47 @@ namespace asyncpp::io { if (service().engine()->enqueue_send(native_handle(), buffer, size, info)) { data::handle(info); } } + template + requires(std::is_invocable_v) + inline void socket::send_exact(const void* buffer, std::size_t size, FN&& cb, asyncpp::stop_token st) { + struct data : detail::io_engine::completion_data { + FN m_real_cb; + const std::span m_buffer; + size_t m_size_sent; + asyncpp::stop_callback m_stop_cb; + socket& m_socket; + + data(FN&& cb, std::span buf, asyncpp::stop_token st, socket& socket) + : completion_data{&handle, this}, m_real_cb(std::move(cb)), m_buffer(buf), m_size_sent(0), + m_stop_cb(std::move(st), detail::cancel_io_stop_callback{this, socket.service().engine()}), + m_socket(socket) {} + + static void handle(void* ptr) { + auto that = static_cast(ptr); + if (that->result) + that->m_real_cb(0, that->result); + else { + that->m_size_sent += that->result_size; + if (that->m_size_sent < that->m_buffer.size()) + return that->send_some(); // Early out without self-delete + else + that->m_real_cb(that->m_size_sent, {}); + } + delete that; + }; + + void send_some() { + auto engine = m_socket.service().engine(); + if (engine->enqueue_send(m_socket.native_handle(), m_buffer.data() + m_size_sent, + m_buffer.size() - m_size_sent, this)) { + data::handle(this); + } + } + }; + auto info = new data(std::move(cb), std::span(static_cast(buffer), size), std::move(st), *this); + info->send_some(); + } + template requires(std::is_invocable_v) inline void socket::recv(void* buffer, std::size_t size, FN&& cb, asyncpp::stop_token st) {