From aa4a24bb9080a3773b7aeb299f0841d80b9b214c Mon Sep 17 00:00:00 2001 From: "Xin Wang (from Dev Box)" Date: Tue, 14 Apr 2026 02:04:49 +0800 Subject: [PATCH 1/2] Localhost Relay: non-blocking IO on Linux side, IOCP with fully async reads and writes on Windows side --- src/linux/init/localhost.cpp | 175 +++++++++++++++++-- src/windows/common/relay.cpp | 323 ++++++++++++++++++++++++++--------- 2 files changed, 399 insertions(+), 99 deletions(-) diff --git a/src/linux/init/localhost.cpp b/src/linux/init/localhost.cpp index f1af32182..e0e7c6182 100644 --- a/src/linux/init/localhost.cpp +++ b/src/linux/init/localhost.cpp @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -30,6 +31,39 @@ namespace { +// Per-direction relay buffer for the non-blocking socket relay. The read limit +// for each direction is reduced by the amount of data pending from an incomplete +// write, establishing back-pressure through SOCK_STREAM flow control to throttle +// an abusive peer and bound memory usage. +struct RelayDirection +{ + int srcFd; + int dstFd; + std::vector buf; + size_t head; + size_t tail; + bool srcEof; + bool done; + + size_t Pending() const { return tail - head; } + size_t Available() const { return buf.size() - tail; } + + void Compact() + { + if (head > 0) + { + auto pending = Pending(); + if (pending > 0) + { + memmove(buf.data(), buf.data() + head, pending); + } + + head = 0; + tail = pending; + } + } +}; + void ListenThread(sockaddr_vm hvSocketAddress, int listenSocket) { pollfd pollDescriptors[] = {{listenSocket, POLLIN}}; @@ -103,40 +137,145 @@ void ListenThread(sockaddr_vm hvSocketAddress, int listenSocket) return; } - // Resize the buffer to be the requested size. - buffer.resize(message->BufferSize); + // Switch both sockets to non-blocking for the relay loop. + for (int fd : {tcpSocket.get(), relaySocket.get()}) + { + int flags = fcntl(fd, F_GETFL, 0); + THROW_LAST_ERROR_IF(flags < 0); + THROW_LAST_ERROR_IF(fcntl(fd, F_SETFL, flags | O_NONBLOCK) < 0); + } - // Begin relaying data. - int outFd[2] = {tcpSocket.get(), relaySocket.get()}; - pollfd pollDescriptors[] = {{relaySocket.get(), POLLIN}, {tcpSocket.get(), POLLIN}}; + const auto bufferSize = message->BufferSize; + buffer.resize(bufferSize); + RelayDirection dirs[2] = { + {relaySocket.get(), tcpSocket.get(), std::move(buffer), 0, 0, false, false}, + {tcpSocket.get(), relaySocket.get(), std::vector(bufferSize), 0, 0, false, false}, + }; + + pollfd pfds[4] = {}; + int pollDirIndex[4] = {}; for (;;) { - if ((pollDescriptors[0].fd == -1) || (pollDescriptors[1].fd == -1)) + // Complete directions where the source hit EOF and all + // pending data has been flushed to the destination. + for (auto& d : dirs) + { + if (!d.done && d.srcEof && d.Pending() == 0) + { + shutdown(d.dstFd, SHUT_WR); + d.done = true; + } + } + + if (dirs[0].done && dirs[1].done) { return; } - THROW_LAST_ERROR_IF(poll(pollDescriptors, COUNT_OF(pollDescriptors), -1) < 0); + // Build the poll set based on current state. + int nfds = 0; - bytesRead = 0; - for (int Index = 0; Index < COUNT_OF(pollDescriptors); Index += 1) + for (int i = 0; i < 2; i++) { - if (pollDescriptors[Index].revents & POLLIN) + auto& d = dirs[i]; + if (d.done) { - bytesRead = UtilReadBuffer(pollDescriptors[Index].fd, buffer); - if (bytesRead == 0) + continue; + } + + // Poll for read when the source is open and the buffer has space. + if (!d.srcEof && d.Available() > 0) + { + pfds[nfds] = {d.srcFd, POLLIN, 0}; + pollDirIndex[nfds] = i; + nfds++; + } + + // Poll for write when there is data waiting to go out. + if (d.Pending() > 0) + { + pfds[nfds] = {d.dstFd, POLLOUT, 0}; + pollDirIndex[nfds] = i; + nfds++; + } + } + + if (nfds == 0) + { + return; + } + + THROW_LAST_ERROR_IF(poll(pfds, nfds, -1) < 0); + + for (int j = 0; j < nfds; j++) + { + auto& d = dirs[pollDirIndex[j]]; + + if (pfds[j].events & POLLOUT) + { + // can't write to dstFd any more + if (pfds[j].revents & (POLLERR | POLLHUP)) + { + d.done = true; + continue; + } + + if (!(pfds[j].revents & POLLOUT)) + { + continue; + } + + auto written = TEMP_FAILURE_RETRY(write(d.dstFd, d.buf.data() + d.head, d.Pending())); + if (written < 0) { - pollDescriptors[Index].fd = -1; - shutdown(outFd[Index], SHUT_WR); + if (errno == EAGAIN || errno == EWOULDBLOCK) + { + continue; + } + + d.done = true; + continue; } - else if (bytesRead < 0) + + d.head += written; + if (d.Pending() == 0) { - return; + d.head = d.tail = 0; + } + } + else + { + if (!(pfds[j].revents & POLLIN)) + { + // No data to read; if the source is gone, mark EOF. + if (pfds[j].revents & (POLLERR | POLLHUP)) + { + d.srcEof = true; + } + + continue; + } + + d.Compact(); + auto nread = TEMP_FAILURE_RETRY(read(d.srcFd, d.buf.data() + d.tail, d.Available())); + if (nread == 0) + { + d.srcEof = true; + } + else if (nread < 0) + { + if (errno == EAGAIN || errno == EWOULDBLOCK) + { + continue; + } + + d.srcEof = true; + continue; } - else if (UtilWriteBuffer(outFd[Index], buffer.data(), bytesRead) < 0) + else { - return; + d.tail += nread; } } } diff --git a/src/windows/common/relay.cpp b/src/windows/common/relay.cpp index 6d2dce21f..8cf3430ba 100644 --- a/src/windows/common/relay.cpp +++ b/src/windows/common/relay.cpp @@ -38,6 +38,128 @@ LARGE_INTEGER InitializeFileOffset(HANDLE File) return Offset; } +// Types and helpers for the IOCP-based BidirectionalRelay. + +enum class IoOp +{ + Read, + Write +}; + +// Extended OVERLAPPED for IOCP dispatch. OVERLAPPED must be +// the first member so that reinterpret_cast(overlapped) +// is valid when recovering context from GetQueuedCompletionStatus. +struct IoContext +{ + OVERLAPPED Overlapped; + int DirIndex; + IoOp Op; +}; + +// Per-direction state for the bidirectional relay. Each direction has +// its own buffer, pending I/O tracking, and file offsets. The read +// limit is reduced by the amount of data pending from an incomplete +// write, establishing back-pressure through TCP flow control. +struct RelayDirection +{ + HANDLE SrcHandle; + HANDLE DstHandle; + std::vector Buffer; + size_t Head = 0; + size_t Tail = 0; + IoContext ReadCtx{}; + IoContext WriteCtx{}; + LARGE_INTEGER ReadOffset{}; + LARGE_INTEGER WriteOffset{}; + bool ReadPending = false; + bool WritePending = false; + bool SrcEof = false; + bool Done = false; + bool DstIsSocket = false; + + size_t Pending() const { return Tail - Head; } + size_t Available() const { return Buffer.size() - Tail; } +}; + +void TryIssueRead(RelayDirection& d) +{ + if (d.ReadPending || d.SrcEof || d.Done || d.Available() == 0) + { + return; + } + + d.ReadCtx.Overlapped = {}; + d.ReadCtx.Overlapped.Offset = d.ReadOffset.LowPart; + d.ReadCtx.Overlapped.OffsetHigh = d.ReadOffset.HighPart; + + DWORD bytesRead = 0; + if (!ReadFile(d.SrcHandle, d.Buffer.data() + d.Tail, gsl::narrow_cast(d.Available()), &bytesRead, &d.ReadCtx.Overlapped)) + { + const auto error = GetLastError(); + if (error == ERROR_IO_PENDING) + { + d.ReadPending = true; + return; + } + + if (error == ERROR_HANDLE_EOF || error == ERROR_BROKEN_PIPE) + { + d.SrcEof = true; + return; + } + + THROW_WIN32(error); + } + + d.ReadPending = true; +} + +void TryIssueWrite(RelayDirection& d) +{ + if (d.WritePending || d.Done || d.Pending() == 0) + { + return; + } + + d.WriteCtx.Overlapped = {}; + d.WriteCtx.Overlapped.Offset = d.WriteOffset.LowPart; + d.WriteCtx.Overlapped.OffsetHigh = d.WriteOffset.HighPart; + + DWORD bytesWritten = 0; + if (!WriteFile(d.DstHandle, d.Buffer.data() + d.Head, gsl::narrow_cast(d.Pending()), &bytesWritten, &d.WriteCtx.Overlapped)) + { + const auto error = GetLastError(); + if (error == ERROR_IO_PENDING) + { + d.WritePending = true; + return; + } + + if (error == ERROR_NO_DATA || error == ERROR_BROKEN_PIPE) + { + d.Done = true; + return; + } + + THROW_WIN32(error); + } + + d.WritePending = true; +} + +void CheckDirectionDone(RelayDirection& d) +{ + if (!d.Done && d.SrcEof && d.Pending() == 0 && !d.WritePending && !d.ReadPending) + { + if (d.DstIsSocket) + { + LOG_LAST_ERROR_IF(shutdown(reinterpret_cast(d.DstHandle), SD_SEND) == SOCKET_ERROR); + } + + d.Done = true; + } +} + } // namespace std::thread wsl::windows::common::relay::CreateThread(_In_ HANDLE InputHandle, _In_ HANDLE OutputHandle, _In_opt_ HANDLE ExitHandle, _In_ size_t BufferSize) @@ -257,128 +379,167 @@ wsl::windows::common::relay::InterruptableWrite( void wsl::windows::common::relay::BidirectionalRelay(_In_ HANDLE LeftHandle, _In_ HANDLE RightHandle, _In_ size_t BufferSize, _In_ RelayFlags Flags) { - std::vector leftBuffer(BufferSize); - const auto leftReadSpan = gsl::make_span(leftBuffer); - OVERLAPPED leftOverlapped = {0}; - const wil::unique_event leftOverlappedEvent(wil::EventOptions::None); - leftOverlapped.hEvent = leftOverlappedEvent.get(); - LARGE_INTEGER leftOffset{}; - - std::vector rightBuffer(BufferSize); - const auto rightReadSpan = gsl::make_span(rightBuffer); - OVERLAPPED rightOverlapped = {0}; - const wil::unique_event rightOverlappedEvent(wil::EventOptions::None); - rightOverlapped.hEvent = rightOverlappedEvent.get(); - LARGE_INTEGER rightOffset{}; + // Create a completion port and associate both handles. + wil::unique_handle iocp(CreateIoCompletionPort(INVALID_HANDLE_VALUE, nullptr, 0, 1)); + THROW_LAST_ERROR_IF(!iocp); + THROW_LAST_ERROR_IF_NULL(CreateIoCompletionPort(LeftHandle, iocp.get(), 0, 0)); + THROW_LAST_ERROR_IF_NULL(CreateIoCompletionPort(RightHandle, iocp.get(), 0, 0)); + + // Initialize per-direction state. + RelayDirection dirs[2] = {}; + + // Direction 0: Left → Right. + dirs[0].SrcHandle = LeftHandle; + dirs[0].DstHandle = RightHandle; + dirs[0].Buffer.resize(BufferSize); + dirs[0].ReadCtx.DirIndex = 0; + dirs[0].ReadCtx.Op = IoOp::Read; + dirs[0].WriteCtx.DirIndex = 0; + dirs[0].WriteCtx.Op = IoOp::Write; + dirs[0].ReadOffset = InitializeFileOffset(LeftHandle); + dirs[0].WriteOffset = InitializeFileOffset(RightHandle); + dirs[0].DstIsSocket = WI_IsFlagSet(Flags, RelayFlags::RightIsSocket); + + // Direction 1: Right → Left. + dirs[1].SrcHandle = RightHandle; + dirs[1].DstHandle = LeftHandle; + dirs[1].Buffer.resize(BufferSize); + dirs[1].ReadCtx.DirIndex = 1; + dirs[1].ReadCtx.Op = IoOp::Read; + dirs[1].WriteCtx.DirIndex = 1; + dirs[1].WriteCtx.Op = IoOp::Write; + dirs[1].ReadOffset = InitializeFileOffset(RightHandle); + dirs[1].WriteOffset = InitializeFileOffset(LeftHandle); + dirs[1].DstIsSocket = WI_IsFlagSet(Flags, RelayFlags::LeftIsSocket); + + // Cancel all pending I/O and drain completions on exit. + auto cancelPending = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&] { + int pendingCount = 0; + for (auto& d : dirs) + { + if (d.ReadPending) + { + CancelIoEx(d.SrcHandle, &d.ReadCtx.Overlapped); + pendingCount++; + } - bool leftReadPending = false; - bool rightReadPending = false; - auto cancelReads = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&] { - DWORD bytes; - if (leftReadPending) - { - CancelIoEx(LeftHandle, &leftOverlapped); - GetOverlappedResult(LeftHandle, &leftOverlapped, &bytes, TRUE); + if (d.WritePending) + { + CancelIoEx(d.DstHandle, &d.WriteCtx.Overlapped); + pendingCount++; + } } - if (rightReadPending) + for (int i = 0; i < pendingCount; i++) { - CancelIoEx(RightHandle, &rightOverlapped); - GetOverlappedResult(RightHandle, &rightOverlapped, &bytes, TRUE); + DWORD bytes = 0; + ULONG_PTR key = 0; + LPOVERLAPPED ov = nullptr; + GetQueuedCompletionStatus(iocp.get(), &bytes, &key, &ov, INFINITE); } }); - DWORD bytesWritten; - const HANDLE waitObjects[] = {leftOverlapped.hEvent, rightOverlapped.hEvent}; + // Issue initial reads. + for (auto& d : dirs) + { + TryIssueRead(d); + CheckDirectionDone(d); + } + for (;;) { - if ((LeftHandle == nullptr) || (RightHandle == nullptr)) + if (dirs[0].Done && dirs[1].Done) { break; } - DWORD leftBytesRead = 0; - if (!leftReadPending && LeftHandle) + // If no operations are pending, nothing to wait for. + bool anyPending = false; + for (const auto& d : dirs) { - if (!ReadFile(LeftHandle, leftReadSpan.data(), gsl::narrow_cast(leftReadSpan.size()), &leftBytesRead, &leftOverlapped)) + if (d.ReadPending || d.WritePending) { - THROW_LAST_ERROR_IF(GetLastError() != ERROR_IO_PENDING); + anyPending = true; + break; } - - leftReadPending = true; } - DWORD rightBytesRead = 0; - if (!rightReadPending && RightHandle) + if (!anyPending) { - if (!ReadFile(RightHandle, rightReadSpan.data(), gsl::narrow_cast(rightReadSpan.size()), &rightBytesRead, &rightOverlapped)) - { - THROW_LAST_ERROR_IF(GetLastError() != ERROR_IO_PENDING); - } + break; + } + + DWORD bytesTransferred = 0; + ULONG_PTR completionKey = 0; + LPOVERLAPPED overlapped = nullptr; + const BOOL success = GetQueuedCompletionStatus(iocp.get(), &bytesTransferred, &completionKey, &overlapped, INFINITE); - rightReadPending = true; + if (!overlapped) + { + THROW_LAST_ERROR(); } - const DWORD waitResult = WaitForMultipleObjects(RTL_NUMBER_OF(waitObjects), waitObjects, FALSE, INFINITE); - if (waitResult == WAIT_OBJECT_0) + auto* ctx = reinterpret_cast(overlapped); + auto& d = dirs[ctx->DirIndex]; + const DWORD error = success ? ERROR_SUCCESS : GetLastError(); + + if (ctx->Op == IoOp::Read) { - LOG_LAST_ERROR_IF_MSG( - !GetOverlappedResult(LeftHandle, &leftOverlapped, &leftBytesRead, FALSE), "WSAGetLastError %d", WSAGetLastError()); + d.ReadPending = false; - leftReadPending = false; - if (leftBytesRead == 0) + if (!success) { - LeftHandle = nullptr; - if (WI_IsFlagSet(Flags, RelayFlags::RightIsSocket)) + if (error == ERROR_HANDLE_EOF || error == ERROR_BROKEN_PIPE || error == ERROR_OPERATION_ABORTED) { - LOG_LAST_ERROR_IF(shutdown(reinterpret_cast(RightHandle), SD_SEND) == SOCKET_ERROR); + d.SrcEof = true; } - } - else if (RightHandle != nullptr) - { - auto writeSpan = leftReadSpan.first(leftBytesRead); - bytesWritten = InterruptableWrite(RightHandle, writeSpan, {}, &leftOverlapped); - if (bytesWritten == 0) + else { - break; + LOG_WIN32_MSG(error, "Read completion failed"); + d.SrcEof = true; } - - leftOffset.QuadPart += leftBytesRead; - leftOverlapped.Offset = leftOffset.LowPart; - leftOverlapped.OffsetHigh = leftOffset.HighPart; + } + else if (bytesTransferred == 0) + { + d.SrcEof = true; + } + else + { + d.Tail += bytesTransferred; + d.ReadOffset.QuadPart += bytesTransferred; } } - else if (waitResult == (WAIT_OBJECT_0 + 1)) + else { - LOG_LAST_ERROR_IF_MSG( - !GetOverlappedResult(RightHandle, &rightOverlapped, &rightBytesRead, FALSE), "WSAGetLastError %d", WSAGetLastError()); + d.WritePending = false; - rightReadPending = false; - if (rightBytesRead == 0) + if (!success) { - RightHandle = nullptr; - if (WI_IsFlagSet(Flags, RelayFlags::LeftIsSocket)) - { - LOG_LAST_ERROR_IF(shutdown(reinterpret_cast(LeftHandle), SD_SEND) == SOCKET_ERROR); - } + LOG_WIN32_MSG(error, "Write completion failed"); + d.Done = true; } - else if (LeftHandle != nullptr) + else { - auto writeSpan = rightReadSpan.first(rightBytesRead); - bytesWritten = InterruptableWrite(LeftHandle, writeSpan, {}, &rightOverlapped); - if (bytesWritten == 0) + d.Head += bytesTransferred; + d.WriteOffset.QuadPart += bytesTransferred; + + if (d.Pending() == 0 && !d.ReadPending) { - break; + d.Head = d.Tail = 0; } - - rightOffset.QuadPart += rightBytesRead; - rightOverlapped.Offset = rightOffset.LowPart; - rightOverlapped.OffsetHigh = rightOffset.HighPart; } } - else + + // Advance all directions: issue writes first to free buffer + // space, then reads, then check for completion. + for (auto& dir : dirs) { - THROW_HR_MSG(E_FAIL, "WaitForMultipleObjects %d", waitResult); + if (!dir.Done) + { + TryIssueWrite(dir); + TryIssueRead(dir); + CheckDirectionDone(dir); + } } } } From 0a6b8d5e3d9032ee6d2b74d3cc4ba03e75359f82 Mon Sep 17 00:00:00 2001 From: "Xin Wang (from Dev Box)" Date: Tue, 14 Apr 2026 02:09:58 +0800 Subject: [PATCH 2/2] format code --- src/linux/init/localhost.cpp | 10 ++++++++-- src/windows/common/relay.cpp | 10 ++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/linux/init/localhost.cpp b/src/linux/init/localhost.cpp index e0e7c6182..906810f0b 100644 --- a/src/linux/init/localhost.cpp +++ b/src/linux/init/localhost.cpp @@ -45,8 +45,14 @@ struct RelayDirection bool srcEof; bool done; - size_t Pending() const { return tail - head; } - size_t Available() const { return buf.size() - tail; } + size_t Pending() const + { + return tail - head; + } + size_t Available() const + { + return buf.size() - tail; + } void Compact() { diff --git a/src/windows/common/relay.cpp b/src/windows/common/relay.cpp index 8cf3430ba..c045284b1 100644 --- a/src/windows/common/relay.cpp +++ b/src/windows/common/relay.cpp @@ -77,8 +77,14 @@ struct RelayDirection bool Done = false; bool DstIsSocket = false; - size_t Pending() const { return Tail - Head; } - size_t Available() const { return Buffer.size() - Tail; } + size_t Pending() const + { + return Tail - Head; + } + size_t Available() const + { + return Buffer.size() - Tail; + } }; void TryIssueRead(RelayDirection& d)