Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions include/asyncpp/io/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ namespace asyncpp::io {
class socket_accept_awaitable;
class socket_accept_error_code_awaitable;
class socket_send_awaitable;
class socket_send_exact_awaitable;
class socket_recv_awaitable;
class socket_recv_exact_awaitable;
class socket_recv_from_awaitable;
Expand All @@ -49,6 +50,7 @@ namespace asyncpp::io {
using socket_accept_error_code_cancellable_awaitable =
detail::cancellable_awaitable<socket_accept_error_code_awaitable>;
using socket_send_cancellable_awaitable = detail::cancellable_awaitable<socket_send_awaitable>;
using socket_send_exact_cancellable_awaitable = detail::cancellable_awaitable<socket_send_exact_awaitable>;
using socket_recv_cancellable_awaitable = detail::cancellable_awaitable<socket_recv_awaitable>;
using socket_recv_exact_cancellable_awaitable = detail::cancellable_awaitable<socket_recv_exact_awaitable>;
using socket_recv_from_cancellable_awaitable = detail::cancellable_awaitable<socket_recv_from_awaitable>;
Expand Down Expand Up @@ -123,6 +125,8 @@ namespace asyncpp::io {
[[nodiscard]] socket_accept_error_code_awaitable accept(std::error_code& ec) noexcept;
[[nodiscard]] socket_send_awaitable send(const void* buffer, std::size_t size) noexcept;
[[nodiscard]] socket_send_awaitable send(const void* buffer, std::size_t size, std::error_code& ec) noexcept;
[[nodiscard]] socket_send_exact_awaitable send_exact(const void* buffer, std::size_t size) noexcept;
[[nodiscard]] socket_send_exact_awaitable send_exact(const void* buffer, std::size_t size, std::error_code& ec) noexcept;
[[nodiscard]] socket_recv_awaitable recv(void* buffer, std::size_t size) noexcept;
[[nodiscard]] socket_recv_awaitable recv(void* buffer, std::size_t size, std::error_code& ec) noexcept;
[[nodiscard]] socket_recv_exact_awaitable recv_exact(void* buffer, std::size_t size) noexcept;
Expand All @@ -146,6 +150,10 @@ namespace asyncpp::io {
asyncpp::stop_token st) noexcept;
[[nodiscard]] socket_send_cancellable_awaitable send(const void* buffer, std::size_t size,
asyncpp::stop_token st, std::error_code& ec) noexcept;
[[nodiscard]] socket_send_exact_cancellable_awaitable send_exact(const void* buffer, std::size_t size,
asyncpp::stop_token st) noexcept;
[[nodiscard]] socket_send_exact_cancellable_awaitable send_exact(const void* buffer, std::size_t size,
asyncpp::stop_token st, std::error_code& ec) noexcept;
[[nodiscard]] socket_recv_cancellable_awaitable recv(void* buffer, std::size_t size,
asyncpp::stop_token st) noexcept;
[[nodiscard]] socket_recv_cancellable_awaitable recv(void* buffer, std::size_t size, asyncpp::stop_token st,
Expand Down Expand Up @@ -173,6 +181,9 @@ namespace asyncpp::io {
template<typename FN>
requires(std::is_invocable_v<FN, size_t, std::error_code>)
void send(const void* buffer, std::size_t size, FN&& cb, asyncpp::stop_token st = {});
template<typename FN>
requires(std::is_invocable_v<FN, size_t, std::error_code>)
void send_exact(const void* buffer, std::size_t size, FN&& cb, asyncpp::stop_token st = {});
template<typename FN>
requires(std::is_invocable_v<FN, size_t, std::error_code>)
void recv(void* buffer, std::size_t size, FN&& cb, asyncpp::stop_token st = {});
Expand Down Expand Up @@ -260,6 +271,21 @@ namespace asyncpp::io {
void await_resume();
};

class socket_send_exact_awaitable : public detail::socket_awaitable_base {
std::byte const* m_buffer;
std::size_t const m_size;
std::size_t m_remaining;
asyncpp::coroutine_handle<> m_handle;
std::error_code* const m_ec;

public:
socket_send_exact_awaitable(socket& sock, const void* buffer, size_t size,
std::error_code* ec = nullptr) noexcept
: socket_awaitable_base{sock}, m_buffer{static_cast<const std::byte*>(buffer)}, m_size{size}, m_remaining{size}, m_ec{ec} {}
bool await_suspend(coroutine_handle<> hdl);
size_t await_resume();
};

class socket_recv_awaitable : public detail::socket_awaitable_base {
void* const m_buffer;
std::size_t const m_size;
Expand Down Expand Up @@ -355,6 +381,15 @@ namespace asyncpp::io {
return socket_send_awaitable(*this, buffer, size, &ec);
}

[[nodiscard]] inline socket_send_exact_awaitable socket::send_exact(const void* buffer, std::size_t size) noexcept {
return socket_send_exact_awaitable(*this, buffer, size);
}

[[nodiscard]] inline socket_send_exact_awaitable socket::send_exact(const void* buffer, std::size_t size,
std::error_code& ec) noexcept {
return socket_send_exact_awaitable(*this, buffer, size, &ec);
}

[[nodiscard]] inline socket_recv_awaitable socket::recv(void* buffer, std::size_t size) noexcept {
return socket_recv_awaitable(*this, buffer, size);
}
Expand Down Expand Up @@ -421,6 +456,16 @@ namespace asyncpp::io {
return socket_send_cancellable_awaitable(std::move(st), *this, buffer, size, &ec);
}

[[nodiscard]] inline socket_send_exact_cancellable_awaitable socket::send_exact(const void* buffer, std::size_t size,
asyncpp::stop_token st) noexcept {
return socket_send_exact_cancellable_awaitable(std::move(st), *this, buffer, size);
}

[[nodiscard]] inline socket_send_exact_cancellable_awaitable
socket::send_exact(const void* buffer, std::size_t size, asyncpp::stop_token st, std::error_code& ec) noexcept {
return socket_send_exact_cancellable_awaitable(std::move(st), *this, buffer, size, &ec);
}

[[nodiscard]] inline socket_recv_cancellable_awaitable socket::recv(void* buffer, std::size_t size,
asyncpp::stop_token st) noexcept {
return socket_recv_cancellable_awaitable(std::move(st), *this, buffer, size);
Expand Down Expand Up @@ -487,6 +532,46 @@ namespace asyncpp::io {
*m_ec = m_completion.result;
}

inline bool socket_send_exact_awaitable::await_suspend(coroutine_handle<> hdl) {
m_completion.callback = [](void* ptr) {
auto that = static_cast<socket_send_exact_awaitable*>(ptr);
auto engine = that->m_socket.service().engine();
do {
if (that->m_completion.result_size == 0) {
that->m_completion.result = std::make_error_code(std::errc::not_connected);
}
if (that->m_completion.result) {
that->m_handle.resume();
break;
}
that->m_buffer += that->m_completion.result_size;
that->m_remaining -= that->m_completion.result_size;
if (that->m_remaining == 0) {
that->m_handle.resume();
break;
}
} while (engine->enqueue_send(that->m_socket.native_handle(), that->m_buffer, that->m_remaining,
&that->m_completion));
};
m_completion.userdata = this;
m_handle = hdl;
auto engine = m_socket.service().engine();
while (engine->enqueue_send(m_socket.native_handle(), m_buffer, m_remaining, &m_completion)) {
if (m_completion.result) return false;
m_buffer += m_completion.result_size;
m_remaining -= m_completion.result_size;
if (m_remaining == 0) return false;
}
return true;
}

inline size_t socket_send_exact_awaitable::await_resume() {
if (!m_completion.result) return m_size - m_remaining;
if (m_ec == nullptr) throw std::system_error(m_completion.result);
*m_ec = m_completion.result;
return m_size - m_remaining;
}

inline bool socket_recv_awaitable::await_suspend(coroutine_handle<> hdl) {
m_completion.callback = [](void* ptr) { coroutine_handle<>::from_address(ptr).resume(); };
m_completion.userdata = hdl.address();
Expand Down Expand Up @@ -663,6 +748,47 @@ namespace asyncpp::io {
if (service().engine()->enqueue_send(native_handle(), buffer, size, info)) { data::handle(info); }
}

template<typename FN>
requires(std::is_invocable_v<FN, size_t, std::error_code>)
inline void socket::send_exact(const void* buffer, std::size_t size, FN&& cb, asyncpp::stop_token st) {
struct data : detail::io_engine::completion_data {
FN m_real_cb;
const std::span<const std::byte> m_buffer;
size_t m_size_sent;
asyncpp::stop_callback<detail::cancel_io_stop_callback> m_stop_cb;
socket& m_socket;

data(FN&& cb, std::span<const std::byte> buf, asyncpp::stop_token st, socket& socket)
: completion_data{&handle, this}, m_real_cb(std::move(cb)), m_buffer(buf), m_size_sent(0),
m_stop_cb(std::move(st), detail::cancel_io_stop_callback{this, socket.service().engine()}),
m_socket(socket) {}

static void handle(void* ptr) {
auto that = static_cast<data*>(ptr);
if (that->result)
that->m_real_cb(0, that->result);
else {
that->m_size_sent += that->result_size;
if (that->m_size_sent < that->m_buffer.size())
return that->send_some(); // Early out without self-delete
else
that->m_real_cb(that->m_size_sent, {});
}
delete that;
};

void send_some() {
auto engine = m_socket.service().engine();
if (engine->enqueue_send(m_socket.native_handle(), m_buffer.data() + m_size_sent,
m_buffer.size() - m_size_sent, this)) {
data::handle(this);
}
}
};
auto info = new data(std::move(cb), std::span(static_cast<const std::byte*>(buffer), size), std::move(st), *this);
info->send_some();
}

template<typename FN>
requires(std::is_invocable_v<FN, size_t, std::error_code>)
inline void socket::recv(void* buffer, std::size_t size, FN&& cb, asyncpp::stop_token st) {
Expand Down
Loading