diff --git a/quic/QuicConstants.h b/quic/QuicConstants.h index cc43a3b81..6bb1bf201 100644 --- a/quic/QuicConstants.h +++ b/quic/QuicConstants.h @@ -805,6 +805,7 @@ enum class NoWriteReason { NO_FRAME, NO_BODY, SOCKET_FAILURE, + WRITER_BACKPRESSURE, // async writer SPSC queue full }; enum class NoReadReason { diff --git a/quic/api/CMakeLists.txt b/quic/api/CMakeLists.txt index 2e2737748..cfb656347 100644 --- a/quic/api/CMakeLists.txt +++ b/quic/api/CMakeLists.txt @@ -5,6 +5,14 @@ # Auto-generated by quic/facebook/generate_cmake.py - DO NOT EDIT MANUALLY +mvfst_add_library(mvfst_api_quic_packet_writer + EXPORTED_DEPS + mvfst_exception + mvfst_constants + Folly::folly_io_iobuf + Folly::folly_network_address +) + mvfst_add_library(mvfst_api_quic_batch_writer SRCS QuicBatchWriter.cpp @@ -112,6 +120,20 @@ mvfst_add_library(mvfst_api_transport_lite Folly::folly_maybe_managed_ptr ) +mvfst_add_library(mvfst_api_shared_threaded_packet_writer + SRCS + SharedThreadedPacketWriter.cpp + DEPS + mvfst_common_mvfst_logging + EXPORTED_DEPS + mvfst_api_quic_packet_writer + mvfst_codec_types + mvfst_common_event_fd_queue + mvfst_common_events_eventbase + Folly::folly + Folly::folly_io_async_async_udp_socket +) + mvfst_add_library(mvfst_api_ack_scheduler SRCS QuicAckScheduler.cpp @@ -139,6 +161,7 @@ mvfst_add_library(mvfst_api_transport_helpers EXPORTED_DEPS mvfst_api_ack_scheduler mvfst_api_quic_batch_writer + mvfst_api_quic_packet_writer mvfst_client_state_and_handshake mvfst_codec mvfst_codec_pktbuilder diff --git a/quic/api/IoBufQuicBatch.h b/quic/api/IoBufQuicBatch.h index 90dc28654..57ee27efc 100644 --- a/quic/api/IoBufQuicBatch.h +++ b/quic/api/IoBufQuicBatch.h @@ -8,16 +8,12 @@ #pragma once #include #include +#include #include #include namespace quic { -struct BufQuicBatchResult { - uint64_t packetsSent{0}; - uint64_t bytesSent{0}; -}; - class IOBufQuicBatch { public: IOBufQuicBatch( diff --git a/quic/api/QuicPacketWriter.h b/quic/api/QuicPacketWriter.h new file mode 100644 index 000000000..dccf7f378 --- /dev/null +++ b/quic/api/QuicPacketWriter.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include + +#include +#include + +namespace quic { + +// Defined here (not in IoBufQuicBatch.h) so StateData.h can include this +// header without pulling in the IoBufQuicBatch → QuicBatchWriter → StateData +// include cycle. +struct BufQuicBatchResult { + uint64_t packetsSent{0}; + uint64_t bytesSent{0}; +}; + +/** + * Abstract interface for sending fully-built, encrypted QUIC packets. + * + * The default path (conn.packetWriter == nullptr) calls IOBufQuicBatch + * directly. When conn.packetWriter is set (ChainedMemory data path only), + * writeConnectionDataToSocket dispatches through this interface instead. + */ +class QuicPacketWriter { + public: + virtual ~QuicPacketWriter() = default; + + // Called on the EventBase thread. Returns false → stop write loop + // (backpressure, not an error). Returns unexpected → close connection. + [[nodiscard]] virtual quic::Expected write( + BufPtr&& buf, + size_t encodedSize, + const folly::SocketAddress& peerAddr) = 0; + + [[nodiscard]] virtual quic::Expected flush() = 0; + + // packetsSent counts packets handed to this writer (enqueued or sent inline). + virtual BufQuicBatchResult getResult() const = 0; + + // Last retriable errno (EAGAIN/ENOBUFS) seen. Always 0 for async writers — + // errno tracking is internal to their drain thread. + virtual int getLastRetryableErrno() const { + return 0; + } +}; + +} // namespace quic diff --git a/quic/api/QuicTransportBaseLite.cpp b/quic/api/QuicTransportBaseLite.cpp index 20a095621..1ff8a3dec 100644 --- a/quic/api/QuicTransportBaseLite.cpp +++ b/quic/api/QuicTransportBaseLite.cpp @@ -1103,6 +1103,15 @@ QuicTransportBaseLite::getStreamFlowControl(StreamId id) const { stream->flowControlState.advertisedMaxOffset); } +void QuicTransportBaseLite::setPacketWriter( + std::unique_ptr writer) { + conn_->packetWriter = std::move(writer); +} + +void QuicTransportBaseLite::scheduleWrite() { + runOnEvbAsyncOp({.type = AsyncOpType::ConnectionWriteReady}); +} + void QuicTransportBaseLite::runOnEvbAsyncOp(AsyncOpData data) { auto evb = getEventBase(); evb->runInLoop( diff --git a/quic/api/QuicTransportBaseLite.h b/quic/api/QuicTransportBaseLite.h index 0eed0981c..f398cd5a7 100644 --- a/quic/api/QuicTransportBaseLite.h +++ b/quic/api/QuicTransportBaseLite.h @@ -307,6 +307,13 @@ class QuicTransportBaseLite : virtual public QuicSocketLite, virtual void cancelAllAppCallbacks(const QuicError& error) noexcept; + // Install a packet writer on the connection. Must be called before accept(). + void setPacketWriter(std::unique_ptr writer); + + // Schedule a write loop iteration on the connection's EventBase. Safe to + // call from any thread; the write will execute on the EventBase thread. + void scheduleWrite(); + void scheduleTimeout( QuicTimerCallback* callback, std::chrono::milliseconds timeout); diff --git a/quic/api/QuicTransportFunctions.cpp b/quic/api/QuicTransportFunctions.cpp index 4ee38bd41..943d51936 100644 --- a/quic/api/QuicTransportFunctions.cpp +++ b/quic/api/QuicTransportFunctions.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -235,10 +236,7 @@ quic::Expected writeQuicDataToSocketImpl( return result; } -void updateErrnoCount( - QuicConnectionStateBase& connection, - IOBufQuicBatch& ioBufBatch) { - int lastErrno = ioBufBatch.getLastRetryableErrno(); +void updateErrnoCount(QuicConnectionStateBase& connection, int lastErrno) { if (lastErrno == EAGAIN || lastErrno == EWOULDBLOCK) { connection.eagainOrEwouldblockCount++; } else if (lastErrno == ENOBUFS) { @@ -400,7 +398,7 @@ continuousMemoryBuildScheduleEncrypt( if (!flushResult.has_value()) { return quic::make_unexpected(flushResult.error()); } - updateErrnoCount(connection, ioBufBatch); + updateErrnoCount(connection, ioBufBatch.getLastRetryableErrno()); if (connection.loopDetectorCallback) { connection.writeDebugState.noWriteReason = NoWriteReason::NO_FRAME; } @@ -413,7 +411,7 @@ continuousMemoryBuildScheduleEncrypt( if (!flushResult.has_value()) { return quic::make_unexpected(flushResult.error()); } - updateErrnoCount(connection, ioBufBatch); + updateErrnoCount(connection, ioBufBatch.getLastRetryableErrno()); if (connection.loopDetectorCallback) { connection.writeDebugState.noWriteReason = NoWriteReason::NO_BODY; } @@ -494,7 +492,7 @@ continuousMemoryBuildScheduleEncrypt( if (!writeResult.has_value()) { return quic::make_unexpected(writeResult.error()); } - updateErrnoCount(connection, ioBufBatch); + updateErrnoCount(connection, ioBufBatch.getLastRetryableErrno()); return DataPathResult::makeWriteResult( writeResult.value(), std::move(result.value()), @@ -514,7 +512,9 @@ iobufChainBasedBuildScheduleEncrypt( IOBufQuicBatch& ioBufBatch, const Aead& aead, const PacketNumberCipher& headerCipher, - TimePoint sendTime) { + TimePoint sendTime, + const folly::SocketAddress& peerAddr, + QuicPacketWriter* packetWriter) { // SCONE: Pre-build SCONE packet and adjust max packet size to avoid overflow std::unique_ptr preBuildSconePacket; uint64_t adjustedMaxPacketSize = connection.udpSendPacketLen; @@ -556,11 +556,14 @@ iobufChainBasedBuildScheduleEncrypt( } auto& packet = result->packet; if (!packet || packet->packet.frames.empty()) { - auto flushResult = ioBufBatch.flush(); + auto flushResult = packetWriter ? packetWriter->flush() : ioBufBatch.flush(); if (!flushResult.has_value()) { return quic::make_unexpected(flushResult.error()); } - updateErrnoCount(connection, ioBufBatch); + updateErrnoCount( + connection, + packetWriter ? packetWriter->getLastRetryableErrno() + : ioBufBatch.getLastRetryableErrno()); if (connection.loopDetectorCallback) { connection.writeDebugState.noWriteReason = NoWriteReason::NO_FRAME; } @@ -568,11 +571,14 @@ iobufChainBasedBuildScheduleEncrypt( } if (packet->body.empty()) { // No more space remaining. - auto flushResult = ioBufBatch.flush(); + auto flushResult = packetWriter ? packetWriter->flush() : ioBufBatch.flush(); if (!flushResult.has_value()) { return quic::make_unexpected(flushResult.error()); } - updateErrnoCount(connection, ioBufBatch); + updateErrnoCount( + connection, + packetWriter ? packetWriter->getLastRetryableErrno() + : ioBufBatch.getLastRetryableErrno()); if (connection.loopDetectorCallback) { connection.writeDebugState.noWriteReason = NoWriteReason::NO_BODY; } @@ -660,11 +666,16 @@ iobufChainBasedBuildScheduleEncrypt( return DataPathResult::makeWriteResult( true, std::move(result.value()), encodedSize, encodedBodySize); } - auto writeResult = ioBufBatch.write(std::move(packetBuf), encodedSize); + auto writeResult = packetWriter + ? packetWriter->write(std::move(packetBuf), encodedSize, peerAddr) + : ioBufBatch.write(std::move(packetBuf), encodedSize); if (!writeResult.has_value()) { return quic::make_unexpected(writeResult.error()); } - updateErrnoCount(connection, ioBufBatch); + updateErrnoCount( + connection, + packetWriter ? packetWriter->getLastRetryableErrno() + : ioBufBatch.getLastRetryableErrno()); return DataPathResult::makeWriteResult( writeResult.value(), std::move(result.value()), @@ -1891,8 +1902,34 @@ quic::Expected writeConnectionDataToSocket( uint64_t bytesWritten = 0; uint64_t shortHeaderPadding = 0; [[maybe_unused]] uint64_t shortHeaderPaddingCount = 0; + // Capture baseline so pktSentFn() counts packets sent *this call*, not + // lifetime total. ConnectionPacketWriter::result_.packetsSent accumulates + // across calls; ioBufBatch is fresh each call so needs no adjustment. + const uint64_t pktSentBaseline = connection.packetWriter + ? connection.packetWriter->getResult().packetsSent + : 0; + auto pktSentFn = [&]() -> uint64_t { + return connection.packetWriter + ? connection.packetWriter->getResult().packetsSent - pktSentBaseline + : static_cast(ioBufBatch.getPktSent()); + }; + MVDCHECK( + !connection.packetWriter || + connection.transportSettings.dataPathType == + DataPathType::ChainedMemory, + "packetWriter requires ChainedMemory data path"); + auto flushFn = [&]() { + return connection.packetWriter ? connection.packetWriter->flush() + : ioBufBatch.flush(); + }; + auto errnoFn = [&]() { + return connection.packetWriter + ? connection.packetWriter->getLastRetryableErrno() + : ioBufBatch.getLastRetryableErrno(); + }; + SCOPE_EXIT { - auto nSent = ioBufBatch.getPktSent(); + auto nSent = pktSentFn(); if (nSent > 0) { QUIC_STATS(connection.statsCallback, onPacketsSent, nSent); QUIC_STATS(connection.statsCallback, onWrite, bytesWritten); @@ -1908,8 +1945,8 @@ quic::Expected writeConnectionDataToSocket( quic::TimePoint sentTime = Clock::now(); - while (scheduler.hasData() && ioBufBatch.getPktSent() < packetLimit && - ((ioBufBatch.getPktSent() < batchSize) || + while (scheduler.hasData() && pktSentFn() < packetLimit && + ((pktSentFn() < batchSize) || writeLoopTimeLimit(writeLoopBeginTime, connection))) { auto packetNum = getNextPacketNum(connection, pnSpace); auto header = builder(srcConnId, dstConnId, packetNum, version, token); @@ -1931,22 +1968,38 @@ quic::Expected writeConnectionDataToSocket( bool useChainedMemory = connection.transportSettings.dataPathType == DataPathType::ChainedMemory; - const auto& dataPlaneFunc = useChainedMemory - ? iobufChainBasedBuildScheduleEncrypt - : continuousMemoryBuildScheduleEncrypt; - auto ret = dataPlaneFunc( - connection, - std::move(header), - pnSpace, - packetNum, - cipherOverhead, - scheduler, - writableBytes, - ioBufBatch, - aead, - headerCipher, - sentTime); + auto ret = [&]() { + if (useChainedMemory) { + return iobufChainBasedBuildScheduleEncrypt( + connection, + std::move(header), + pnSpace, + packetNum, + cipherOverhead, + scheduler, + writableBytes, + ioBufBatch, + aead, + headerCipher, + sentTime, + peerAddress, + connection.packetWriter.get()); + } else { + return continuousMemoryBuildScheduleEncrypt( + connection, + std::move(header), + pnSpace, + packetNum, + cipherOverhead, + scheduler, + writableBytes, + ioBufBatch, + aead, + headerCipher, + sentTime); + } + }(); // This is a fatal error vs. a build error. if (!ret.has_value()) { @@ -1955,12 +2008,12 @@ quic::Expected writeConnectionDataToSocket( if (!ret->buildSuccess) { // If we're returning because we couldn't schedule more packets, // make sure we flush the buffer in this function. - auto flushResult = ioBufBatch.flush(); + auto flushResult = flushFn(); if (!flushResult.has_value()) { return quic::make_unexpected(flushResult.error()); } - updateErrnoCount(connection, ioBufBatch); - return WriteQuicDataResult{ioBufBatch.getPktSent(), 0, bytesWritten}; + updateErrnoCount(connection, errnoFn()); + return WriteQuicDataResult{pktSentFn(), 0, bytesWritten}; } // If we build a packet, we updateConnection(), even if write might have // been failed. Because if it builds, a lot of states need to be updated no @@ -1993,14 +2046,15 @@ quic::Expected writeConnectionDataToSocket( connection.streamManager->writeQueue().commitTransaction( std::move(writeQueueTransaction)); - // if ioBufBatch.write returns false - // it is because a flush() call failed + // writeSuccess == false means flush() failed (inline) or SPSC queue full + // (async writer backpressure). if (!ret->writeSuccess) { if (connection.loopDetectorCallback) { - connection.writeDebugState.noWriteReason = - NoWriteReason::SOCKET_FAILURE; + connection.writeDebugState.noWriteReason = connection.packetWriter + ? NoWriteReason::WRITER_BACKPRESSURE + : NoWriteReason::SOCKET_FAILURE; } - return WriteQuicDataResult{ioBufBatch.getPktSent(), 0, bytesWritten}; + return WriteQuicDataResult{pktSentFn(), 0, bytesWritten}; } if ((connection.transportSettings.batchingMode == @@ -2009,21 +2063,24 @@ quic::Expected writeConnectionDataToSocket( connection.transportSettings.maxBatchSize, connection.transportSettings.dataPathType)) { // With SinglePacketInplaceBatchWriter we always write one packet, and so - // ioBufBatch needs a flush. + // ioBufBatch needs a flush. This path requires ContinuousMemory, so + // connection.packetWriter is always null here. auto flushResult = ioBufBatch.flush(); if (!flushResult.has_value()) { return quic::make_unexpected(flushResult.error()); } - updateErrnoCount(connection, ioBufBatch); + updateErrnoCount(connection, ioBufBatch.getLastRetryableErrno()); } } - // Ensure that the buffer is flushed before returning - auto flushResult = ioBufBatch.flush(); + // Ensure that the buffer is flushed before returning. + // On the async path this flush() writes the eventfd to wake the drain thread + // — it must not be skipped. + auto flushResult = flushFn(); if (!flushResult.has_value()) { return quic::make_unexpected(flushResult.error()); } - updateErrnoCount(connection, ioBufBatch); + updateErrnoCount(connection, errnoFn()); if (connection.transportSettings.dataPathType == DataPathType::ContinuousMemory) { @@ -2032,7 +2089,7 @@ quic::Expected writeConnectionDataToSocket( connection.bufAccessor->length() == 0 && connection.bufAccessor->headroom() == 0); } - return WriteQuicDataResult{ioBufBatch.getPktSent(), 0, bytesWritten}; + return WriteQuicDataResult{pktSentFn(), 0, bytesWritten}; } quic::Expected writeProbingDataToSocket( diff --git a/quic/api/SharedThreadedPacketWriter.cpp b/quic/api/SharedThreadedPacketWriter.cpp new file mode 100644 index 000000000..85808c367 --- /dev/null +++ b/quic/api/SharedThreadedPacketWriter.cpp @@ -0,0 +1,365 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include + +#include +#include + +namespace quic { + +// ── ConnectionPacketWriter ────────────────────────────────────────────────── + +ConnectionPacketWriter::ConnectionPacketWriter( + SharedThreadedPacketWriter* shared, + ConnectionId connId) + : shared_(shared), connId_(std::move(connId)) {} + +quic::Expected ConnectionPacketWriter::write( + BufPtr&& buf, + size_t encodedSize, + const folly::SocketAddress& peerAddr) { + if (!shared_->write(std::move(buf), encodedSize, peerAddr, connId_)) { + shared_->registerBlocked(connId_); + return false; + } + result_.packetsSent++; + result_.bytesSent += encodedSize; + return true; +} + +quic::Expected ConnectionPacketWriter::flush() { + shared_->flush(); + return true; +} + +// ── SharedThreadedPacketWriter ────────────────────────────────────────────── + +SharedThreadedPacketWriter::SharedThreadedPacketWriter( + folly::AsyncUDPSocket& sock, + folly::EventBase* producerEvb, + folly::EventBase* drainEvb, + size_t queueCapacity, + size_t maxSegmentsPerMsg, + size_t maxMsgsPerCall, + size_t maxMsgsBeforeYield) + : queue_(drainEvb, queueCapacity), + maxSegmentsPerMsg_(maxSegmentsPerMsg), + maxMsgsPerCall_(maxMsgsPerCall), + maxMsgsBeforeYield_(maxMsgsBeforeYield), + sock_(sock), + producerEvb_(producerEvb), + drainEvb_(drainEvb) { + queue_.setOnReadable([this] { drainLoop(); }); + + // GSO grouping (UDP_SEGMENT) is Linux-only. On other platforms, disable it so + // writemGSO is always called with null options (gso==0 in every batch entry). +#ifndef FOLLY_HAVE_MSG_ERRQUEUE + maxSegmentsPerMsg_ = 1; +#else + // UDP_MAX_SEGMENTS=128 is the kernel's per-super-packet GSO segment limit. + // Exceeding it causes silent drops after sendmsg() returns success. + maxSegmentsPerMsg_ = std::min(maxSegmentsPerMsg_, size_t(128)); +#endif + + int sockFd = sock_.getNetworkSocket().toFd(); + writableHandler_ = std::make_unique( + this, drainEvb_, sockFd); + + bufs_.reserve(maxMsgsPerCall_); + segCounts_.resize(maxMsgsPerCall_, 0); + connIds_.resize(maxMsgsPerCall_, ConnectionId::createZeroLength()); + addrs_.resize(maxMsgsPerCall_); + opts_.resize(maxMsgsPerCall_); + + drainEvb_->runInEventBaseThread([this] { queue_.startConsuming(); }); +} + +SharedThreadedPacketWriter::~SharedThreadedPacketWriter() { + MVCHECK(!pendingFlush_) << "close() must be called before destroying SharedThreadedPacketWriter"; + bufs_.clear(); +} + +bool SharedThreadedPacketWriter::write( + BufPtr&& buf, + size_t encodedSize, + const folly::SocketAddress& peerAddr, + const ConnectionId& connId) { + if (closed_.load(std::memory_order_relaxed)) { + return false; + } + MVCHECK(peerAddr.isFamilyInet()) << "bad peerAddr family=" + << peerAddr.getFamily() << " connId=" << connId.hex(); + PacketEntry entry{std::move(buf), encodedSize, peerAddr, connId}; + if (!queue_.enqueue(std::move(entry))) { + wasEverFull_.store(true, std::memory_order_release); + FOLLY_SDT(quic, shared_packet_writer_queue_full); + return false; + } + FOLLY_SDT(quic, shared_packet_writer_enqueue, queue_.sizeGuess()); + return true; +} + +void SharedThreadedPacketWriter::flush() { + if (++pendingCount_ >= kEagerFlushThreshold) { + pendingCount_ = 0; + pendingFlush_ = false; + queue_.flush(); + } else if (!pendingFlush_) { + // First pending packet this loop: arm the callback to flush the tail. + pendingFlush_ = true; + producerEvb_->runBeforeLoop(&flushCallback_); + } +} + +void SharedThreadedPacketWriter::doFlush() { + pendingFlush_ = false; + pendingCount_ = 0; + queue_.flush(); +} + +void SharedThreadedPacketWriter::FlushLoopCallback::runLoopCallback() noexcept { + w_->doFlush(); +} + +void SharedThreadedPacketWriter::registerBlocked( + const ConnectionId& connId) { + blockedConnIds_.push_back(connId); +} + +void SharedThreadedPacketWriter::setOnFatalError( + std::function cb) { + onFatalError_ = std::move(cb); +} + +void SharedThreadedPacketWriter::setOnResumeProducer( + std::function&)> cb) { + onResumeProducer_ = std::move(cb); +} + +void SharedThreadedPacketWriter::close() { + MVCHECK(producerEvb_->isInEventBaseThread()); + flushCallback_.cancelLoopCallback(); + doFlush(); + closed_.store(true, std::memory_order_relaxed); +} + +void SharedThreadedPacketWriter::drainLoop() { + // If a retry is pending, don't drain — let the queue fill naturally so the + // producer pauses. drainQueue() will be called by retryAndDrain(). + if (!bufs_.empty()) { + return; + } + drainQueue(); +} + +void SharedThreadedPacketWriter::retryAndDrain() { + MVDCHECK(!bufs_.empty()); + // bufs_ holds the unsent slots (compacted to front for partial sends, or + // the full batch for EAGAIN). Call sendBatch() directly without rebuilding. + if (sendBatch() < 0) { + return; + } + drainQueue(); +} + +bool SharedThreadedPacketWriter::assembleNextBatch() { + MVDCHECK(bufs_.empty()); // cleared by sendBatch on success/fatal + // connIds_, addrs_, opts_, segCounts_ are sized to maxMsgsPerCall_ and may + // hold stale data from the previous batch; overwrite from index 0 — no resize needed. + needsGso_ = false; + totalSegsInBatch_ = 0; + size_t n = 0; // slots written this call + + size_t prevSize = 0; + size_t gso = 0; + size_t segsInChain = 0; // segments in the current mmsg slot + folly::SocketAddress curAddr; + ConnectionId curConnId = ConnectionId::createZeroLength(); + bool hasCurrentChain = false; + + PacketEntry entry{ + nullptr, 0, folly::SocketAddress{}, ConnectionId::createZeroLength()}; + while (n < maxMsgsPerCall_ && queue_.dequeue(entry)) { + size_t size = entry.encodedSize; + // GSO grouping: same connection, same peer address, non-increasing packet + // size, and chain not yet at the segment limit. A single smaller tail + // segment is valid: the kernel sets gso_size from the cmsg value (the + // opener's size) and the last segment is allowed to be shorter. + bool canAppend = hasCurrentChain && size <= prevSize && + (gso == 0 || gso == prevSize) && + entry.peerAddr == curAddr && + entry.connId == curConnId && segsInChain < maxSegmentsPerMsg_; + + if (canAppend) { + // Append to current chain. + bufs_.back()->appendToChain(std::move(entry.buf)); + gso = prevSize; // gso is the uniform segment size + prevSize = size; + segsInChain++; + } else { + // Finalize the previous chain's gso and segment count in opts_/segCounts_. + if (hasCurrentChain) { + // n >= 1: hasCurrentChain is only set after the first push_back. + opts_[n - 1] = folly::AsyncUDPSocket::WriteOptions( + static_cast(gso), /*zerocopy=*/false); + segCounts_[n - 1] = segsInChain; + totalSegsInBatch_ += segsInChain; + if (gso > 0) { + needsGso_ = true; + } + } + // addrs_/opts_/connIds_ at index n are pre-allocated; bufs_ grows via push_back. + bufs_.push_back(std::move(entry.buf)); + addrs_[n] = entry.peerAddr; + opts_[n] = folly::AsyncUDPSocket::WriteOptions(0, false); + connIds_[n] = entry.connId; + n++; + prevSize = size; + gso = 0; + segsInChain = 1; + curAddr = entry.peerAddr; + curConnId = entry.connId; + hasCurrentChain = true; + } + } + // Finalize the last chain. + if (hasCurrentChain) { + opts_[n - 1] = folly::AsyncUDPSocket::WriteOptions( + static_cast(gso), /*zerocopy=*/false); + segCounts_[n - 1] = segsInChain; + totalSegsInBatch_ += segsInChain; + if (gso > 0) { + needsGso_ = true; + } + } + return n < maxMsgsPerCall_; // true if queue ran dry +} + +ssize_t SharedThreadedPacketWriter::sendBatch() { + size_t n = bufs_.size(); + int ret = sock_.writemGSO( + folly::Range(addrs_.data(), n), + bufs_.data(), + n, + needsGso_ ? opts_.data() : nullptr); + MVVLOG(3) << "writemGSO batch=" << n << " ret=" << ret; + + if (ret < 0) { + int err = errno; + if (err == EAGAIN || err == EWOULDBLOCK || err == ENOBUFS) { + // TX buffer full. bufs_ stays intact — retryAndDrain will resend. + writableHandler_->registerHandler( + folly::EventHandler::WRITE | folly::EventHandler::PERSIST); + return -1; + } + // Fatal error. + auto quicErr = QuicError( + QuicErrorCode(LocalErrorCode::CONNECTION_ABANDONED), + std::string("SharedThreadedPacketWriter: fatal write error")); + dispatchErrors(n, quicErr); + bufs_.clear(); + if (!blockedConnIds_.empty()) { + producerEvb_->runInEventBaseThread([this] { resumeProducer(); }); + } + return -1; + } + + FOLLY_SDT(quic, shared_packet_writer_batch_result, n, static_cast(ret)); + size_t sent = static_cast(ret); + + if (sent < n) { + // Partial send: TX buffer is filling. Compact the unsent slots to the + // front and park on EPOLLOUT — retryAndDrain will resend them. + FOLLY_SDT(quic, shared_packet_writer_partial_send, n, ret); + bufs_.erase(bufs_.begin(), bufs_.begin() + sent); + std::move(addrs_.begin() + sent, addrs_.begin() + n, addrs_.begin()); + std::move(opts_.begin() + sent, opts_.begin() + n, opts_.begin()); + std::move(connIds_.begin() + sent, connIds_.begin() + n, connIds_.begin()); + std::move(segCounts_.begin() + sent, segCounts_.begin() + n, segCounts_.begin()); + writableHandler_->registerHandler( + folly::EventHandler::WRITE | folly::EventHandler::PERSIST); + return -1; + } + + // All sent: release IOBufs. addrs_/opts_/connIds_/segCounts_ stay at + // maxMsgsPerCall_ size so assembleNextBatch always has pre-initialized slots. + bufs_.clear(); + return static_cast(ret); +} + +void SharedThreadedPacketWriter::drainQueue() { + MVVLOG(3) << "drainQueue"; + size_t totalMsgsSent = 0; + + while (true) { + bool hitEnd = assembleNextBatch(); + + // Low-water mark: re-arm producers when the queue runs dry. We only load + // wasEverFull_ on hitEnd to avoid the atomic read on every iteration. + if (hitEnd && wasEverFull_.load(std::memory_order_acquire)) { + wasEverFull_.store(false, std::memory_order_relaxed); // cleared on drain thread only + producerEvb_->runInEventBaseThread([this] { resumeProducer(); }); + } + + if (bufs_.empty()) { + // hitEnd=true and queue was already empty; nothing to send. + FOLLY_SDT(quic, shared_packet_writer_drain_done, totalMsgsSent); + return; + } + + ssize_t sent = sendBatch(); + if (sent < 0) { + return; // parked on EPOLLOUT or fatal + } + totalMsgsSent += static_cast(sent); + + if (totalMsgsSent >= maxMsgsBeforeYield_) { + // Yield to other handlers. If the queue was empty when we last assembled, + // there's nothing left to drain — skip the wakeup. Any packets the + // producer adds after this point will fire the eventfd via flush(). + FOLLY_SDT( + quic, + shared_packet_writer_yield, + totalMsgsSent, + hitEnd ? 0 : 1 /*rescheduled*/); + if (!hitEnd) { + // Use drainLoop (not drainQueue) so the bufs_.empty() guard fires if + // a concurrent partial-send left bufs_ non-empty before this callback runs. + drainEvb_->runInEventBaseThread([this] { drainLoop(); }); + } + return; + } + } +} + +void SharedThreadedPacketWriter::dispatchErrors(size_t n, const QuicError& err) { + if (!onFatalError_) { + return; + } + auto cb = onFatalError_; + for (size_t i = 0; i < n; i++) { + producerEvb_->runInEventBaseThread([cb, connId = connIds_[i], err]() { + cb(connId, err); + }); + } +} + +void SharedThreadedPacketWriter::resumeProducer() { + // Called on producer EVB thread. Re-arm all connections blocked on a full queue. + if (onResumeProducer_ && !blockedConnIds_.empty()) { + onResumeProducer_(blockedConnIds_); + } + blockedConnIds_.clear(); +} + +} // namespace quic diff --git a/quic/api/SharedThreadedPacketWriter.h b/quic/api/SharedThreadedPacketWriter.h new file mode 100644 index 000000000..5b1712e7d --- /dev/null +++ b/quic/api/SharedThreadedPacketWriter.h @@ -0,0 +1,219 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include +#include + +#include + +#include +#include +#include + +namespace quic { + +struct PacketEntry { + BufPtr buf; + size_t encodedSize{0}; + folly::SocketAddress peerAddr; + ConnectionId connId; +}; + +class SharedThreadedPacketWriter; + +/** + * Per-connection adaptor. Lives in conn.packetWriter. Forwards write() calls + * to the shared writer with the connection's connId attached. + */ +class ConnectionPacketWriter : public QuicPacketWriter { + public: + ConnectionPacketWriter( + SharedThreadedPacketWriter* shared, + ConnectionId connId); + + [[nodiscard]] quic::Expected write( + BufPtr&& buf, + size_t encodedSize, + const folly::SocketAddress& peerAddr) override; + + [[nodiscard]] quic::Expected flush() override; + + BufQuicBatchResult getResult() const override { + return result_; + } + + private: + SharedThreadedPacketWriter* shared_; // non-owning; shared_ outlives this + ConnectionId connId_; + BufQuicBatchResult result_; +}; + +/** + * One instance per socket. All connections sharing a producer EventBase use + * the same SharedThreadedPacketWriter. Drain runs on a caller-supplied + * folly::EventBase (which may be shared across multiple sockets/servers). + * + * Requires DataPathType::ChainedMemory. See QuicServer::setDrainEventBase(). + * + * GSO grouping (UDP_SEGMENT) is restricted to same-connection packets, so + * connIds_ stores only one entry per mmsg slot (the chain-opener's connId) + * rather than one per segment. Cross-connection coalescing is intentionally + * excluded; see assembleNextBatch. + * + * maxMsgsBeforeYield controls drain-thread fairness: lower values yield more + * often to other EventBase handlers at the cost of peak throughput. The + * default of 10 is tuned for server fairness, not maximum batch efficiency. + * + * Thread safety: + * - write() / flush() / registerBlocked(): called on producer EVB thread only + * - drainLoop() / retryAndDrain() / drainQueue(): drain EVB thread only + * - closed_: atomic; set by producer, checked by drain + */ +class SharedThreadedPacketWriter { + public: + explicit SharedThreadedPacketWriter( + folly::AsyncUDPSocket& sock, + folly::EventBase* producerEvb, + folly::EventBase* drainEvb, + size_t queueCapacity = 4096, + size_t maxSegmentsPerMsg = 16, + size_t maxMsgsPerCall = 64, + size_t maxMsgsBeforeYield = 10); + + ~SharedThreadedPacketWriter(); + + // Producer EVB thread. Returns false if queue is full (backpressure). + [[nodiscard]] bool write( + BufPtr&& buf, + size_t encodedSize, + const folly::SocketAddress& peerAddr, + const ConnectionId& connId); + + // Producer EVB thread. Writes eventfd to wake drain thread. + void flush(); + + // Producer EVB thread. Register connId for write re-arm when queue drains. + void registerBlocked(const ConnectionId& connId); + + // Called by QuicServer::shutdown(). After this, enqueues return false. + void close(); + + // Set callbacks invoked on the producer EVB when errors occur or the queue + // drains after backpressure. Both are called on the producer EVB thread. + // onFatalError is called once per affected connection; onResumeProducer is + // called with the full list of previously-blocked connection IDs. + void setOnFatalError( + std::function cb); + void setOnResumeProducer( + std::function&)> cb); + + private: + void drainLoop(); // eventfd handler + void retryAndDrain(); // EPOLLOUT handler + void drainQueue(); // pull from queue in chunks + + // Fill bufs_, connIds_, addrs_, opts_, segCounts_, needsGso_ from the queue + // (up to maxMsgsPerCall_ entries). GSO chains are assembled in a single pass. + // Returns true if the queue ran dry (end reached before maxMsgsPerCall_). + bool assembleNextBatch(); + + // Call writemGSO with the current bufs_/addrs_/opts_/needsGso_. + // Returns messages sent (>=0) on success. + // Returns -1 if parked on EPOLLOUT (EAGAIN: arrays intact; partial: compacted) + // or on fatal error (bufs_/connIds_ cleared, errors dispatched). + ssize_t sendBatch(); + + // Dispatch fatal write errors to the first n connections in connIds_. + void dispatchErrors(size_t n, const QuicError& err); + + // Called on producer EVB thread to re-arm blocked connections. + void resumeProducer(); + + class SocketWritableHandler : public folly::EventHandler { + public: + SocketWritableHandler( + SharedThreadedPacketWriter* writer, + folly::EventBase* evb, + int fd) + : folly::EventHandler(evb, folly::NetworkSocket::fromFd(fd)), + writer_(writer) {} + + void handlerReady(uint16_t /*events*/) noexcept override { + unregisterHandler(); + writer_->retryAndDrain(); + } + + private: + SharedThreadedPacketWriter* writer_; + }; + + // Coalesced flush: producer calls flush() per-connection, but we only write + // to the eventfd once per EVB loop (or when pendingCount_ hits the threshold). + // All fields are producer-EVB-thread-only; no atomics needed. + static constexpr size_t kEagerFlushThreshold = 16; + + class FlushLoopCallback : public folly::EventBase::LoopCallback { + public: + explicit FlushLoopCallback(SharedThreadedPacketWriter* w) : w_(w) {} + void runLoopCallback() noexcept override; + + private: + SharedThreadedPacketWriter* w_; + }; + + // Write the eventfd immediately and reset deferred-flush state. + // Must be called from the producer EVB thread. + void doFlush(); + + std::atomic closed_{false}; + EventFdQueue queue_; + size_t maxSegmentsPerMsg_; + size_t maxMsgsPerCall_; + size_t maxMsgsBeforeYield_; + + folly::AsyncUDPSocket& sock_; + folly::EventBase* producerEvb_; + folly::EventBase* drainEvb_; + std::unique_ptr writableHandler_; + + // Drain-thread batch state; preallocated to maxMsgsPerCall_. Non-empty while + // a send is pending; preserved across EAGAIN/partial-send so retryAndDrain + // resends without rebuilding. + std::vector bufs_; + // One entry per mmsg slot: segment count (packets chained into this slot). + std::vector segCounts_; + size_t totalSegsInBatch_{0}; // sum of segCounts_; valid after assembleNextBatch() + // One entry per mmsg slot: the chain-opener's connId. GSO chains are + // restricted to a single connection so one connId per slot is sufficient. + std::vector connIds_; + std::vector addrs_; + std::vector opts_; + bool needsGso_{false}; + + // Producer EVB thread only: + std::vector blockedConnIds_; + bool pendingFlush_{false}; + size_t pendingCount_{0}; + FlushLoopCallback flushCallback_{this}; + + // Producer stores release, drain loads acquire: required for visibility on + // non-TSO architectures (ARM/POWER) where stores can reorder past the eventfd write. + std::atomic wasEverFull_{false}; + + // Callbacks — set once before the first write(); not thread-safe with the + // drain thread. Setting them after write() has been called is a data race. + std::function onFatalError_; + std::function&)> onResumeProducer_; +}; + +} // namespace quic diff --git a/quic/api/test/CMakeLists.txt b/quic/api/test/CMakeLists.txt index ff1fae641..339eef2d7 100644 --- a/quic/api/test/CMakeLists.txt +++ b/quic/api/test/CMakeLists.txt @@ -100,6 +100,16 @@ quic_add_test(TARGET QuicBatchWriterTest mvfst_test_utils ) +quic_add_test(TARGET SharedThreadedPacketWriterTest + SOURCES + SharedThreadedPacketWriterTest.cpp + DEPENDS + Folly::folly + mvfst_api_shared_threaded_packet_writer + mvfst_common_events_folly_eventbase + ${LIBGMOCK_LIBRARY} +) + quic_add_test(TARGET QuicStreamAsyncTransportTest SOURCES QuicStreamAsyncTransportTest.cpp diff --git a/quic/api/test/SharedThreadedPacketWriterTest.cpp b/quic/api/test/SharedThreadedPacketWriterTest.cpp new file mode 100644 index 000000000..9dc91d118 --- /dev/null +++ b/quic/api/test/SharedThreadedPacketWriterTest.cpp @@ -0,0 +1,1048 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include +#include +#include + +using namespace quic; +using ::testing::_; +using ::testing::Invoke; +using ::testing::NiceMock; +using ::testing::Return; + +// Minimal mock of folly::AsyncUDPSocket for STPW tests. +// Only writemGSO needs to be mocked — STPW calls no other socket methods +// (getNetworkSocket() is handled by setFD in SetUp). +class MockUDPSocket : public folly::AsyncUDPSocket { + public: + explicit MockUDPSocket(folly::EventBase* evb) : folly::AsyncUDPSocket(evb) {} + MOCK_METHOD( + int, + writemGSO, + (folly::Range, + const std::unique_ptr*, + size_t, + const folly::AsyncUDPSocket::WriteOptions*), + (override)); +}; + +namespace { + +ConnectionId makeConnId(uint8_t byte) { + return ConnectionId::createAndMaybeCrash({byte, 0, 0, 0}); +} + +BufPtr makeBuf(size_t size) { + auto buf = folly::IOBuf::create(size); + buf->append(size); + return buf; +} + +BufPtr makeBuf(size_t size, uint8_t fill) { + auto buf = folly::IOBuf::create(size); + buf->append(size); + ::memset(buf->writableData(), fill, size); + return buf; +} + +} // namespace + +class SharedThreadedPacketWriterTest : public ::testing::Test { + protected: + void SetUp() override { + ASSERT_EQ(0, ::socketpair(AF_UNIX, SOCK_DGRAM, 0, fds_)); + sock_ = std::make_unique>( + producerThread_.getEventBase()); + sock_->setFD( + folly::NetworkSocket::fromFd(fds_[0]), + folly::AsyncUDPSocket::FDOwnership::SHARED); + } + + void TearDown() override { + if (writer_) { + // close() cancels flushCallback_ and must run on the producer EVB. + onProducer([&] { writer_->close(); }); + // Destroy on drain thread so SocketWritableHandler::unregisterHandler() + // runs on the correct EventBase. + folly::Baton<> destroyed; + drainThread_.getEventBase()->runInEventBaseThread( + [w = std::move(writer_), &destroyed]() mutable { + w.reset(); + destroyed.post(); + }); + destroyed.wait(); + } + ::close(fds_[0]); + ::close(fds_[1]); + fds_[0] = fds_[1] = -1; + } + + void makeWriter( + size_t capacity = 64, + size_t maxSegmentsPerMsg = 64, + size_t maxMsgsPerCall = 64, + size_t maxMsgsBeforeYield = 256) { + writer_ = std::make_unique( + *sock_, + producerThread_.getEventBase(), + drainThread_.getEventBase(), + capacity, + maxSegmentsPerMsg, + maxMsgsPerCall, + maxMsgsBeforeYield); + } + + // Run fn on the producer EventBase thread and block until it completes. + template + void onProducer(Fn&& fn) { + folly::Baton<> done; + producerThread_.getEventBase()->runInEventBaseThread( + [f = std::forward(fn), &done]() mutable { + f(); + done.post(); + }); + done.wait(); + } + + // Post a no-op to the drain EventBase and wait for it, ensuring all + // previously-enqueued drain callbacks have completed. + void awaitDrain() { + folly::Baton<> idle; + drainThread_.getEventBase()->runInEventBaseThread( + [&idle] { idle.post(); }); + idle.wait(); + } + + // Member declaration order controls destructor order (C++ destroys in reverse + // declaration order). drainThread_ must outlive writer_ so EventHandlers can + // unregister from the drain EventBase in TearDown. sock_ must outlive writer_ + // since STPW holds a reference to it. + folly::SocketAddress peer_{"127.0.0.1", 1234}; + int fds_[2]{-1, -1}; + folly::ScopedEventBaseThread producerThread_; + folly::ScopedEventBaseThread drainThread_; + std::unique_ptr> sock_; + std::unique_ptr writer_; +}; + +// Packets enqueued on producer thread reach writemGSO on drain thread. +TEST_F(SharedThreadedPacketWriterTest, NormalPath) { + makeWriter(); + + folly::Baton<> done; + std::atomic totalSegments{0}; + + EXPECT_CALL(*sock_, writemGSO(_, _, _, _)) + .WillRepeatedly(Invoke( + [&](folly::Range, + const BufPtr* bufs, + size_t count, + const folly::AsyncUDPSocket::WriteOptions*) -> int { + for (size_t i = 0; i < count; i++) { + totalSegments.fetch_add( + static_cast(bufs[i]->countChainElements())); + } + done.post(); + return static_cast(count); + })); + + auto connId = makeConnId(1); + onProducer([&] { + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, connId)); + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, connId)); + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, connId)); + writer_->flush(); + }); + + done.wait(); + awaitDrain(); + EXPECT_EQ(totalSegments.load(), 3); +} + +// GSO grouping tests require Linux (UDP_SEGMENT / MSG_ERRQUEUE). +#ifdef FOLLY_HAVE_MSG_ERRQUEUE + +// Same peer + same encoded size → one GSO mmsg entry with multiple segments. +TEST_F(SharedThreadedPacketWriterTest, GSOGrouping_SameAddrSameSize) { + makeWriter(); + + folly::Baton<> done; + std::atomic msgCount{0}; + std::atomic segCount{0}; + + EXPECT_CALL(*sock_, writemGSO(_, _, _, _)) + .WillRepeatedly(Invoke( + [&](folly::Range, + const BufPtr* bufs, + size_t count, + const folly::AsyncUDPSocket::WriteOptions*) -> int { + msgCount.fetch_add(count); + for (size_t i = 0; i < count; i++) { + segCount.fetch_add(bufs[i]->countChainElements()); + } + done.post(); + return static_cast(count); + })); + + auto connId = makeConnId(1); + onProducer([&] { + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, connId)); + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, connId)); + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, connId)); + writer_->flush(); + }); + + done.wait(); + awaitDrain(); + // 3 same-size same-addr packets should assemble into 1 GSO chain. + EXPECT_EQ(msgCount.load(), 1u); + EXPECT_EQ(segCount.load(), 3u); +} + +#endif // FOLLY_HAVE_MSG_ERRQUEUE + +// Different peer addresses produce separate mmsg entries. +TEST_F(SharedThreadedPacketWriterTest, GSOGrouping_DifferentAddr) { + makeWriter(); + + folly::Baton<> done; + std::atomic msgCount{0}; + + EXPECT_CALL(*sock_, writemGSO(_, _, _, _)) + .WillRepeatedly(Invoke( + [&](folly::Range, + const BufPtr*, + size_t count, + const folly::AsyncUDPSocket::WriteOptions*) -> int { + msgCount.fetch_add(count); + done.post(); + return static_cast(count); + })); + + folly::SocketAddress peer2{"127.0.0.1", 5678}; + onProducer([&] { + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, makeConnId(1))); + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer2, makeConnId(2))); + writer_->flush(); + }); + + done.wait(); + awaitDrain(); + EXPECT_EQ(msgCount.load(), 2u); +} + +#ifdef FOLLY_HAVE_MSG_ERRQUEUE + +// A smaller last segment stays in the same GSO chain (valid GSO). The packet +// AFTER the smaller one starts a new chain because gso != newPrevSize. +TEST_F(SharedThreadedPacketWriterTest, GSOGrouping_TerminalSmallerSegment) { + makeWriter(); + + folly::Baton<> done; + std::atomic msgCount{0}; + std::atomic segCount{0}; + + EXPECT_CALL(*sock_, writemGSO(_, _, _, _)) + .WillRepeatedly(Invoke( + [&](folly::Range, + const BufPtr* bufs, + size_t count, + const folly::AsyncUDPSocket::WriteOptions*) -> int { + msgCount.fetch_add(count); + for (size_t i = 0; i < count; i++) { + segCount.fetch_add(bufs[i]->countChainElements()); + } + done.post(); + return static_cast(count); + })); + + auto connId = makeConnId(1); + onProducer([&] { + // {100,100,50} → one GSO chain (50 is valid smaller last segment). + // The following 50 breaks the gso==prevSize invariant → new chain. + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, connId)); + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, connId)); + EXPECT_TRUE(writer_->write(makeBuf(50), 50, peer_, connId)); + EXPECT_TRUE(writer_->write(makeBuf(50), 50, peer_, connId)); + writer_->flush(); + }); + + done.wait(); + awaitDrain(); + // Chain 1: {100,100,50} gso=100; Chain 2: {50} gso=0 + EXPECT_EQ(msgCount.load(), 2u); + EXPECT_EQ(segCount.load(), 4u); +} + +#endif // FOLLY_HAVE_MSG_ERRQUEUE + +// Queue full → write() returns false (backpressure). +TEST_F(SharedThreadedPacketWriterTest, Backpressure_QueueFull) { + makeWriter(/*capacity=*/2); + + auto connId = makeConnId(1); + bool third = true; + // Do not flush — avoid drain-thread interaction in this test. + onProducer([&] { + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, connId)); + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, connId)); + third = writer_->write(makeBuf(100), 100, peer_, connId); + }); + + EXPECT_FALSE(third); +} + +// EAGAIN on first writemGSO → EPOLLOUT fires → retryAndDrain → second call +// succeeds. The retry must deliver the same data to the same peer. +TEST_F(SharedThreadedPacketWriterTest, EAGAIN_RetryViaPollout) { + makeWriter(); + + folly::Baton<> retryDone; + std::atomic callCount{0}; + // Captured on drain thread; safe to read after retryDone + awaitDrain(). + size_t firstCount{0}; + size_t retryCount{0}; + std::string firstContent; + std::string retryContent; + folly::SocketAddress firstAddr; + folly::SocketAddress retryAddr; + + EXPECT_CALL(*sock_, writemGSO(_, _, _, _)) + .WillRepeatedly(Invoke( + [&](folly::Range addrs, + const BufPtr* bufs, + size_t count, + const folly::AsyncUDPSocket::WriteOptions*) -> int { + int n = callCount.fetch_add(1); + if (n == 0) { + firstCount = count; + firstContent = std::string( + reinterpret_cast(bufs[0]->data()), + bufs[0]->length()); + firstAddr = addrs[0]; + errno = EAGAIN; + return -1; + } + retryCount = count; + retryContent = std::string( + reinterpret_cast(bufs[0]->data()), + bufs[0]->length()); + retryAddr = addrs[0]; + retryDone.post(); + return static_cast(count); + })); + + auto connId = makeConnId(1); + onProducer([&] { + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, connId)); + writer_->flush(); + }); + + retryDone.wait(); + awaitDrain(); + EXPECT_EQ(retryCount, firstCount); + EXPECT_EQ(retryContent, firstContent); + EXPECT_EQ(retryAddr, firstAddr); +} + +// After close(), write() returns false immediately. +TEST_F(SharedThreadedPacketWriterTest, Closed_WritesRejected) { + makeWriter(); + onProducer([&] { writer_->close(); }); + + EXPECT_CALL(*sock_, writemGSO(_, _, _, _)).Times(0); + + auto connId = makeConnId(1); + bool result = true; + // No flush — writer is closed, no drain expected. + onProducer([&] { + result = writer_->write(makeBuf(100), 100, peer_, connId); + }); + EXPECT_FALSE(result); +} + +// Packets enqueued with a deferred flush (pendingFlush_=true, not yet written +// to the eventfd) must still reach the drain thread when close() is called. +// close() cancels flushCallback_ and calls queue_.flush() directly so the +// drain thread is not stranded waiting for an eventfd that never fires. +TEST_F(SharedThreadedPacketWriterTest, Close_FlushesDeferredPackets) { + makeWriter(); + + folly::Baton<> done; + EXPECT_CALL(*sock_, writemGSO(_, _, _, _)) + .WillOnce(Invoke( + [&](folly::Range, + const BufPtr*, + size_t count, + const folly::AsyncUDPSocket::WriteOptions*) -> int { + done.post(); + return static_cast(count); + })); + + // write() + flush() + close() in one EVB turn: flushCallback_ is armed but + // has not run yet when close() fires, exercising the cancel+flush path. + onProducer([&] { + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, makeConnId(1))); + writer_->flush(); + writer_->close(); + }); + + done.wait(); + awaitDrain(); +} + +#ifdef FOLLY_HAVE_MSG_ERRQUEUE + +// maxSegmentsPerMsg_ cap: once a GSO chain reaches the limit the next segment +// starts a new mmsg entry. +TEST_F(SharedThreadedPacketWriterTest, GSOGrouping_MaxSegmentsPerMsg) { + makeWriter(/*capacity=*/64, /*maxSegmentsPerMsg=*/2); + + folly::Baton<> done; + std::atomic msgCount{0}; + std::atomic segCount{0}; + + EXPECT_CALL(*sock_, writemGSO(_, _, _, _)) + .WillRepeatedly(Invoke( + [&](folly::Range, + const BufPtr* bufs, + size_t count, + const folly::AsyncUDPSocket::WriteOptions*) -> int { + msgCount.fetch_add(count); + for (size_t i = 0; i < count; i++) { + segCount.fetch_add(bufs[i]->countChainElements()); + } + done.post(); + return static_cast(count); + })); + + auto connId = makeConnId(1); + onProducer([&] { + // 3 same-size same-addr packets; cap is 2 segments → {100,100} + {100}. + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, connId)); + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, connId)); + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, connId)); + writer_->flush(); + }); + + done.wait(); + awaitDrain(); + EXPECT_EQ(msgCount.load(), 2u); + EXPECT_EQ(segCount.load(), 3u); +} + +// Packets from two different connections to the same peer with the same size +// are NOT grouped — cross-connection coalescing is excluded to keep connIds_ +// as a flat per-slot vector instead of a per-slot vector. +TEST_F(SharedThreadedPacketWriterTest, GSOGrouping_CrossConnection) { + makeWriter(); + + folly::Baton<> done; + std::atomic msgCount{0}; + std::atomic segCount{0}; + + EXPECT_CALL(*sock_, writemGSO(_, _, _, _)) + .WillRepeatedly(Invoke( + [&](folly::Range, + const BufPtr* bufs, + size_t count, + const folly::AsyncUDPSocket::WriteOptions*) -> int { + msgCount.fetch_add(count); + for (size_t i = 0; i < count; i++) { + segCount.fetch_add(bufs[i]->countChainElements()); + } + if (msgCount.load() >= 2) { + done.post(); + } + return static_cast(count); + })); + + onProducer([&] { + // Different connIds, same peer, same size → NOT grouped; connId mismatch + // breaks the GSO chain so each packet becomes its own mmsg entry. + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, makeConnId(1))); + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, makeConnId(2))); + writer_->flush(); + }); + + done.wait(); + awaitDrain(); + EXPECT_EQ(msgCount.load(), 2u); + EXPECT_EQ(segCount.load(), 2u); +} + +// A single connection sends 1200, 1200, 700, 700. +// Expected GSO grouping: +// Slot 0: {1200, 1200, 700} = 3100 bytes gso=1200 (700 is valid smaller terminal) +// Slot 1: {700} = 700 bytes gso=0 (breaks chain: gso=1200 != prevSize=700) +// Verify GSO chain byte ordering. Packets are filled with distinct bytes +// (0x01, 0x02, 0x03, 0x04) so that the coalesce order is observable. +// Expected batching: slot0=[pkt1(1200,0x01), pkt2(1200,0x02), pkt3(700,0x03)] +// gso=1200; slot1=[pkt4(700,0x04)] gso=0. +TEST_F(SharedThreadedPacketWriterTest, GSOGrouping_1200_1200_700_700) { + makeWriter(); + + folly::Baton<> done; + struct SlotInfo { + std::vector firstBytePerSeg; // first byte of each chain element + std::vector segLengths; + int gso{0}; + }; + std::vector slots; + + EXPECT_CALL(*sock_, writemGSO(_, _, _, _)) + .WillRepeatedly(Invoke( + [&](folly::Range, + const BufPtr* bufs, + size_t count, + const folly::AsyncUDPSocket::WriteOptions* opts) -> int { + for (size_t i = 0; i < count; i++) { + SlotInfo s; + s.gso = opts ? opts[i].gso : 0; + for (const auto& seg : *bufs[i]) { + s.firstBytePerSeg.push_back(seg.data()[0]); + s.segLengths.push_back(seg.size()); + } + slots.push_back(std::move(s)); + } + done.post(); + return static_cast(count); + })); + + auto connId = makeConnId(1); + onProducer([&] { + EXPECT_TRUE(writer_->write(makeBuf(1200, 0x01), 1200, peer_, connId)); + EXPECT_TRUE(writer_->write(makeBuf(1200, 0x02), 1200, peer_, connId)); + EXPECT_TRUE(writer_->write(makeBuf(700, 0x03), 700, peer_, connId)); + EXPECT_TRUE(writer_->write(makeBuf(700, 0x04), 700, peer_, connId)); + writer_->flush(); + }); + + done.wait(); + awaitDrain(); + + ASSERT_EQ(slots.size(), 2u); + // Slot 0: [pkt1, pkt2, pkt3] in order, gso stride 1200. + ASSERT_EQ(slots[0].firstBytePerSeg.size(), 3u); + EXPECT_EQ(slots[0].firstBytePerSeg[0], 0x01); + EXPECT_EQ(slots[0].firstBytePerSeg[1], 0x02); + EXPECT_EQ(slots[0].firstBytePerSeg[2], 0x03); + EXPECT_EQ(slots[0].segLengths[0], 1200u); + EXPECT_EQ(slots[0].segLengths[1], 1200u); + EXPECT_EQ(slots[0].segLengths[2], 700u); + EXPECT_EQ(slots[0].gso, 1200); + // Slot 1: pkt4 alone, no GSO needed. + ASSERT_EQ(slots[1].firstBytePerSeg.size(), 1u); + EXPECT_EQ(slots[1].firstBytePerSeg[0], 0x04); + EXPECT_EQ(slots[1].segLengths[0], 700u); + EXPECT_EQ(slots[1].gso, 0); +} + +#endif // FOLLY_HAVE_MSG_ERRQUEUE + +// assembleNextBatch fills at most maxMsgsPerCall_ slots. A second batch that +// uses fewer slots must pass bufs_.size() (not the preallocated array size) to +// writemGSO so stale entries from the first batch are excluded. +TEST_F(SharedThreadedPacketWriterTest, VariableBatchSizes) { + // maxMsgsPerCall=2 forces the 3-packet queue to split across two batches. + // maxMsgsBeforeYield=256 keeps both batches in the same drainQueue() call. + makeWriter( + /*capacity=*/64, + /*maxSegmentsPerMsg=*/64, + /*maxMsgsPerCall=*/2, + /*maxMsgsBeforeYield=*/256); + + folly::Baton<> done; + // Written only on the drain thread; safe to read after awaitDrain(). + std::vector counts; + + EXPECT_CALL(*sock_, writemGSO(_, _, _, _)) + .WillRepeatedly(Invoke( + [&](folly::Range, + const BufPtr*, + size_t count, + const folly::AsyncUDPSocket::WriteOptions*) -> int { + counts.push_back(count); + if (counts.size() == 2) { + done.post(); + } + return static_cast(count); + })); + + // Three different peers → 3 separate msgs (no GSO grouping across peers). + // maxMsgsPerCall=2 splits them into first batch of 2 and second batch of 1. + onProducer([&] { + EXPECT_TRUE(writer_->write( + makeBuf(100), 100, folly::SocketAddress{"127.0.0.1", 1001}, makeConnId(1))); + EXPECT_TRUE(writer_->write( + makeBuf(100), 100, folly::SocketAddress{"127.0.0.1", 1002}, makeConnId(2))); + EXPECT_TRUE(writer_->write( + makeBuf(100), 100, folly::SocketAddress{"127.0.0.1", 1003}, makeConnId(3))); + writer_->flush(); + }); + + done.wait(); + awaitDrain(); + ASSERT_EQ(counts.size(), 2u); + EXPECT_EQ(counts[0], 2u); // first batch: maxMsgsPerCall_ entries + EXPECT_EQ(counts[1], 1u); // second batch: one remaining entry, not maxMsgsPerCall_ +} + +// ConnectionPacketWriter accumulates packetsSent / bytesSent in getResult(). +TEST_F(SharedThreadedPacketWriterTest, ConnectionPacketWriter_AccumulatesResult) { + makeWriter(); + + folly::Baton<> done; + EXPECT_CALL(*sock_, writemGSO(_, _, _, _)) + .WillRepeatedly(Invoke( + [&](folly::Range, + const BufPtr*, + size_t count, + const folly::AsyncUDPSocket::WriteOptions*) -> int { + done.post(); + return static_cast(count); + })); + + auto connId = makeConnId(1); + ConnectionPacketWriter cpw(writer_.get(), connId); + + onProducer([&] { + EXPECT_TRUE(*cpw.write(makeBuf(100), 100, peer_)); + EXPECT_TRUE(*cpw.write(makeBuf(200), 200, peer_)); + EXPECT_TRUE(*cpw.flush()); + }); + + done.wait(); + awaitDrain(); + auto result = cpw.getResult(); + EXPECT_EQ(result.packetsSent, 2u); + EXPECT_EQ(result.bytesSent, 300u); +} + +// getResult() accumulates lifetime totals; callers compute per-cycle deltas by +// capturing a baseline before each write cycle. This verifies the delta is +// correct across two cycles (cycle 2 delta must be 2, not the lifetime total 3). +TEST_F(SharedThreadedPacketWriterTest, ConnectionPacketWriter_PerCycleDelta) { + makeWriter(); + + folly::Baton<> done1, done2; + int call = 0; + EXPECT_CALL(*sock_, writemGSO(_, _, _, _)) + .WillRepeatedly(Invoke( + [&](folly::Range, + const BufPtr*, + size_t count, + const folly::AsyncUDPSocket::WriteOptions*) -> int { + if (++call == 1) { + done1.post(); + } else { + done2.post(); + } + return static_cast(count); + })); + + auto connId = makeConnId(1); + ConnectionPacketWriter cpw(writer_.get(), connId); + + // Cycle 1: write 1 packet; delta should be 1. + uint64_t base = cpw.getResult().packetsSent; + onProducer([&] { + EXPECT_TRUE(*cpw.write(makeBuf(100), 100, peer_)); + EXPECT_TRUE(*cpw.flush()); + }); + done1.wait(); + awaitDrain(); + EXPECT_EQ(cpw.getResult().packetsSent - base, 1u); + + // Cycle 2: write 2 packets. Delta must be 2, not 3 (lifetime total). + base = cpw.getResult().packetsSent; + onProducer([&] { + EXPECT_TRUE(*cpw.write(makeBuf(100), 100, peer_)); + EXPECT_TRUE(*cpw.write(makeBuf(100), 100, peer_)); + EXPECT_TRUE(*cpw.flush()); + }); + done2.wait(); + awaitDrain(); + EXPECT_EQ(cpw.getResult().packetsSent - base, 2u); + EXPECT_EQ(cpw.getResult().packetsSent, 3u); // lifetime total +} + +// Non-EAGAIN writemGSO failure (fatal) → no retry, drainQueue returns. +TEST_F(SharedThreadedPacketWriterTest, FatalWriteError_NoRetry) { + makeWriter(); + + EXPECT_CALL(*sock_, writemGSO(_, _, _, _)) + .Times(1) + .WillOnce(Invoke( + [](folly::Range, + const BufPtr*, + size_t, + const folly::AsyncUDPSocket::WriteOptions*) -> int { + errno = EIO; + return -1; + })); + + auto connId = makeConnId(1); + onProducer([&] { + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, connId)); + writer_->flush(); + }); + + // Drain completes without crash; Times(1) enforces no retry attempt. + awaitDrain(); +} + +// drainQueue() yields to the EventBase after maxMsgsBeforeYield msgs, then +// re-enters to drain the remainder. +TEST_F(SharedThreadedPacketWriterTest, YieldAndResumeDrain) { + // maxMsgsPerCall=2 so each drainQueue iteration sends at most 2 msgs. + // maxMsgsBeforeYield=2 so the yield triggers after every call. + makeWriter( + /*capacity=*/64, + /*maxSegmentsPerMsg=*/64, + /*maxMsgsPerCall=*/2, + /*maxMsgsBeforeYield=*/2); + + folly::Baton<> done; + std::atomic totalSent{0}; + std::atomic callCount{0}; + + EXPECT_CALL(*sock_, writemGSO(_, _, _, _)) + .WillRepeatedly(Invoke( + [&](folly::Range, + const BufPtr*, + size_t count, + const folly::AsyncUDPSocket::WriteOptions*) -> int { + callCount.fetch_add(1); + if (totalSent.fetch_add(count) + count == 4) { + done.post(); + } + return static_cast(count); + })); + + // 4 different peers → 4 separate msgs (no GSO grouping across peers). + onProducer([&] { + EXPECT_TRUE(writer_->write(makeBuf(100), 100, {"127.0.0.1", 1001}, makeConnId(1))); + EXPECT_TRUE(writer_->write(makeBuf(100), 100, {"127.0.0.1", 1002}, makeConnId(2))); + EXPECT_TRUE(writer_->write(makeBuf(100), 100, {"127.0.0.1", 1003}, makeConnId(3))); + EXPECT_TRUE(writer_->write(makeBuf(100), 100, {"127.0.0.1", 1004}, makeConnId(4))); + writer_->flush(); + }); + + done.wait(); + awaitDrain(); + EXPECT_EQ(totalSent.load(), 4u); + // maxMsgsPerCall=2 forces >=2 writemGSO calls, confirming the yield path. + EXPECT_GE(callCount.load(), 2); +} + +// When the queue runs dry on the same assembleNextBatch() call that hits the +// yield threshold (hitEnd == true), no spurious drainQueue callback is +// scheduled. writemGSO is called exactly once. +TEST_F(SharedThreadedPacketWriterTest, YieldSuppressed_WhenQueueEmpty) { + makeWriter( + /*capacity=*/64, + /*maxSegmentsPerMsg=*/64, + /*maxMsgsPerCall=*/64, + /*maxMsgsBeforeYield=*/1); + + folly::Baton<> done; + + EXPECT_CALL(*sock_, writemGSO(_, _, _, _)) + .Times(1) + .WillOnce(Invoke( + [&](folly::Range, + const BufPtr*, + size_t count, + const folly::AsyncUDPSocket::WriteOptions*) -> int { + done.post(); + return static_cast(count); + })); + + onProducer([&] { + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, makeConnId(1))); + writer_->flush(); + }); + + done.wait(); + // Extra awaitDrain() turn to catch any spurious drainQueue callback that + // would violate Times(1) above. + awaitDrain(); +} + +// Queue full → write() returns false; after drain drops below capacity/2, +// onResumeProducer fires on the producer EVB with the blocked connId. +TEST_F(SharedThreadedPacketWriterTest, Backpressure_ResumeProducer) { + makeWriter(/*capacity=*/4); + + folly::Baton<> resumed; + std::vector resumedIds; + + EXPECT_CALL(*sock_, writemGSO(_, _, _, _)) + .WillRepeatedly(Invoke( + [](folly::Range, + const BufPtr*, + size_t count, + const folly::AsyncUDPSocket::WriteOptions*) -> int { + return static_cast(count); + })); + + auto connId = makeConnId(1); + onProducer([&] { + writer_->setOnResumeProducer([&](const std::vector& ids) { + resumedIds = ids; + resumed.post(); + }); + // Fill queue completely (capacity=4). + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, connId)); + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, connId)); + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, connId)); + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, connId)); + // 5th write fails (queue full); register this connection as blocked. + EXPECT_FALSE(writer_->write(makeBuf(100), 100, peer_, connId)); + writer_->registerBlocked(connId); + writer_->flush(); + }); + + resumed.wait(); + ASSERT_EQ(resumedIds.size(), 1u); + EXPECT_EQ(resumedIds[0], connId); +} + +// writemGSO returns 0 < ret < batch.size() (partial send). The unsent entries +// are retried; all packets must eventually be delivered with correct content. +TEST_F(SharedThreadedPacketWriterTest, PartialSend_RetryViaPollout) { + makeWriter(); + + folly::Baton<> retryDone; + std::atomic callCount{0}; + size_t totalDelivered{0}; + std::string skippedContent; + std::string retryContent; + + EXPECT_CALL(*sock_, writemGSO(_, _, _, _)) + .WillRepeatedly(Invoke( + [&](folly::Range, + const BufPtr* bufs, + size_t count, + const folly::AsyncUDPSocket::WriteOptions*) -> int { + int n = callCount.fetch_add(1); + if (n == 0) { + skippedContent = std::string( + reinterpret_cast(bufs[1]->data()), + bufs[1]->length()); + totalDelivered += 1; + return 1; + } + retryContent = std::string( + reinterpret_cast(bufs[0]->data()), + bufs[0]->length()); + totalDelivered += static_cast(count); + retryDone.post(); + return static_cast(count); + })); + + // Two different peers produce two separate batch entries. + onProducer([&] { + EXPECT_TRUE(writer_->write( + makeBuf(100), 100, folly::SocketAddress{"127.0.0.1", 1001}, makeConnId(1))); + EXPECT_TRUE(writer_->write( + makeBuf(100), 100, folly::SocketAddress{"127.0.0.1", 1002}, makeConnId(2))); + writer_->flush(); + }); + + retryDone.wait(); + awaitDrain(); + EXPECT_EQ(totalDelivered, 2u); + EXPECT_EQ(retryContent, skippedContent); +} + +// Fatal write error fires onFatalError on the producer EVB. If the connection +// is already gone by then, the callback handles "not found" gracefully. +TEST_F(SharedThreadedPacketWriterTest, ConnectionClosedWhileInQueue) { + makeWriter(); + + folly::Baton<> callbackRan; + auto connId = makeConnId(1); + bool seenExpectedId = false; + + EXPECT_CALL(*sock_, writemGSO(_, _, _, _)) + .WillOnce(Invoke( + [](folly::Range, + const BufPtr*, + size_t, + const folly::AsyncUDPSocket::WriteOptions*) -> int { + errno = EIO; + return -1; + })); + + onProducer([&] { + writer_->setOnFatalError([&](const ConnectionId& id, const QuicError&) { + // Simulates a server worker that finds the connection already gone: + // just record that the right connId was dispatched and drop it. + seenExpectedId = (id == connId); + callbackRan.post(); + }); + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, connId)); + writer_->flush(); + }); + + callbackRan.wait(); + EXPECT_TRUE(seenExpectedId); +} + +// close() called while the drain thread is inside writemGSO. The in-flight +// batch completes; subsequent writes on the producer are rejected. +TEST_F(SharedThreadedPacketWriterTest, ShutdownRace) { + makeWriter(); + + folly::Baton<> writemGSOEntered; + folly::Baton<> resumeDrain; + + EXPECT_CALL(*sock_, writemGSO(_, _, _, _)) + .WillRepeatedly(Invoke( + [&](folly::Range, + const BufPtr*, + size_t count, + const folly::AsyncUDPSocket::WriteOptions*) -> int { + writemGSOEntered.post(); + resumeDrain.wait(); // park until test thread signals + return static_cast(count); + })); + + auto connId = makeConnId(1); + onProducer([&] { + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, connId)); + writer_->flush(); + }); + + writemGSOEntered.wait(); // drain is now inside writemGSO + + // close() on the producer EVB while drain is mid-send. + onProducer([&] { writer_->close(); }); + resumeDrain.post(); // let drain finish the in-flight send + + awaitDrain(); + + // Writes after close() are rejected. + bool afterClose = true; + onProducer([&] { + afterClose = writer_->write(makeBuf(100), 100, peer_, connId); + }); + EXPECT_FALSE(afterClose); +} + +// Two independent SharedThreadedPacketWriters sharing one drain EventBase both +// drain their queues correctly. +TEST_F(SharedThreadedPacketWriterTest, MultipleSocketsSharedDrainEvb) { + int fds2[2]{-1, -1}; + ASSERT_EQ(0, ::socketpair(AF_UNIX, SOCK_DGRAM, 0, fds2)); + auto sock2 = std::make_unique>( + producerThread_.getEventBase()); + sock2->setFD( + folly::NetworkSocket::fromFd(fds2[0]), + folly::AsyncUDPSocket::FDOwnership::SHARED); + + makeWriter(); + auto writer2 = std::make_unique( + *sock2, + producerThread_.getEventBase(), + drainThread_.getEventBase()); // shared drain EVB + + folly::Baton<> done1, done2; + + EXPECT_CALL(*sock_, writemGSO(_, _, _, _)) + .WillRepeatedly(Invoke( + [&](folly::Range, + const BufPtr*, + size_t count, + const folly::AsyncUDPSocket::WriteOptions*) -> int { + done1.post(); + return static_cast(count); + })); + + EXPECT_CALL(*sock2, writemGSO(_, _, _, _)) + .WillRepeatedly(Invoke( + [&](folly::Range, + const BufPtr*, + size_t count, + const folly::AsyncUDPSocket::WriteOptions*) -> int { + done2.post(); + return static_cast(count); + })); + + onProducer([&] { + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, makeConnId(1))); + writer_->flush(); + EXPECT_TRUE(writer2->write(makeBuf(100), 100, peer_, makeConnId(2))); + writer2->flush(); + }); + + done1.wait(); + done2.wait(); + + // Tear down writer2: close on producer EVB, destroy on drain EVB. + onProducer([&] { writer2->close(); }); + folly::Baton<> destroyed; + drainThread_.getEventBase()->runInEventBaseThread( + [w = std::move(writer2), &destroyed]() mutable { + w.reset(); + destroyed.post(); + }); + destroyed.wait(); + ::close(fds2[0]); + ::close(fds2[1]); +} + +// A connection that filled the queue (registerBlocked) must not hang forever +// when a fatal write error occurs. The bug: when the queue has more items than +// maxMsgsPerCall, assembleNextBatch sets hitEnd=false, suppressing the +// wasEverFull_ low-watermark check. A fatal sendBatch() then returns early +// without scheduling resumeProducer, leaving blockedId hung indefinitely. +// Without the fix this test hangs indefinitely on resumed.wait(). +TEST_F(SharedThreadedPacketWriterTest, BlockedConnection_WokenAfterFatalError) { + // maxMsgsPerCall=2 with capacity=4: first batch assembles 2 items, hitEnd=false, + // so the wasEverFull_ low-watermark path does NOT fire before the fatal error. + makeWriter(/*capacity=*/4, /*maxSegmentsPerMsg=*/64, /*maxMsgsPerCall=*/2); + + folly::Baton<> resumed; + auto blockedId = makeConnId(99); + + EXPECT_CALL(*sock_, writemGSO(_, _, _, _)) + .WillRepeatedly(Invoke( + [](folly::Range, + const BufPtr*, + size_t, + const folly::AsyncUDPSocket::WriteOptions*) -> int { + errno = EIO; + return -1; + })); + + onProducer([&] { + writer_->setOnFatalError([](const ConnectionId&, const QuicError&) {}); + writer_->setOnResumeProducer([&](const std::vector&) { + resumed.post(); + }); + // Fill the queue (4 items) then register blockedId as waiting for resume. + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, makeConnId(1))); + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, makeConnId(2))); + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, makeConnId(3))); + EXPECT_TRUE(writer_->write(makeBuf(100), 100, peer_, makeConnId(4))); + EXPECT_FALSE(writer_->write(makeBuf(100), 100, peer_, blockedId)); + writer_->registerBlocked(blockedId); + writer_->flush(); + }); + + resumed.wait(); // hangs forever without the fix +} diff --git a/quic/common/CMakeLists.txt b/quic/common/CMakeLists.txt index 44dfea018..4973c9dfa 100644 --- a/quic/common/CMakeLists.txt +++ b/quic/common/CMakeLists.txt @@ -151,6 +151,12 @@ mvfst_add_library(mvfst_common_quic_iobuf_queue mvfst_common_quic_buffer ) +mvfst_add_library(mvfst_common_event_fd_queue + EXPORTED_DEPS + Folly::folly + mvfst_common_events_eventbase +) + add_subdirectory(events) add_subdirectory(testutil) add_subdirectory(third-party) diff --git a/quic/common/EventFdQueue.h b/quic/common/EventFdQueue.h new file mode 100644 index 000000000..d221a0a5f --- /dev/null +++ b/quic/common/EventFdQueue.h @@ -0,0 +1,173 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#if __has_include() +#include +#define QUIC_HAS_EVENTFD 1 +#else +#define QUIC_HAS_EVENTFD 0 +#endif + +#include +#include +#include +#include + +#include + +namespace quic { + +/** + * SPSC queue with asynchronous wakeup. + * + * The producer calls enqueue() from one thread and flush() once per batch. + * The consumer is driven by a folly::EventBase; setOnReadable() registers + * the callback fired when items are ready. + * + * Wakeup uses eventfd on Linux (EFD_NONBLOCK; counter semantics coalesce + * multiple flush() calls into one wakeup) and a non-blocking self-pipe on + * platforms without eventfd (e.g. macOS). + */ +template +class EventFdQueue { + public: + EventFdQueue(folly::EventBase* consumerEvb, size_t capacity) + // ProducerConsumerQueue with size N holds N-1 items; add 1 so the + // external-facing capacity is exact. + : queue_(static_cast(capacity + 1)) { + initNotifyFds(); + handler_ = + std::make_unique(this, consumerEvb, readFd()); + } + + // Begin consuming. Must be called from the consumer EventBase thread. + // Calling registerHandler() (which calls event_add/kevent) from an off-thread + // is not safe on macOS kqueue — the filter may not be visible to the waiting + // kevent() call. Modelled on folly::NotificationQueue::Consumer::startConsuming(). + void startConsuming() { + handler_->registerHandler(folly::EventHandler::READ | folly::EventHandler::PERSIST); + } + + ~EventFdQueue() { + handler_->unregisterHandler(); +#if QUIC_HAS_EVENTFD + ::close(eventfd_); +#else + ::close(pipeFds_[0]); + ::close(pipeFds_[1]); +#endif + } + + // Producer thread. Returns false if queue is full. + bool enqueue(T item) { + if (!queue_.write(std::move(item))) { + return false; + } + pendingFlush_ = true; + return true; + } + + // Producer thread. Signal the consumer once if anything was enqueued since + // last flush. Multiple enqueues coalesce into one wakeup. + void flush() { + if (!pendingFlush_) { + return; + } + pendingFlush_ = false; +#if QUIC_HAS_EVENTFD + uint64_t one = 1; + auto ret = ::write(eventfd_, &one, sizeof(one)); + PCHECK(ret == (ssize_t)sizeof(one) || errno == EAGAIN || errno == EWOULDBLOCK); +#else + char one = 1; + auto ret = ::write(pipeFds_[1], &one, 1); + PCHECK(ret == 1 || errno == EAGAIN || errno == EWOULDBLOCK); +#endif + } + + // Consumer setup. Must be called before events start firing. + void setOnReadable(folly::Function cb) { + onReadable_ = std::move(cb); + } + + // Consumer thread. Returns false if queue is empty. + bool dequeue(T& out) { + return queue_.read(out); + } + + // Approximate number of items currently in the queue. Safe to call from any + // thread; uses the same relaxed loads as folly::ProducerConsumerQueue. + size_t sizeGuess() const { + return queue_.sizeGuess(); + } + + private: + void initNotifyFds() { +#if QUIC_HAS_EVENTFD + eventfd_ = ::eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK); + PCHECK(eventfd_ >= 0) << "eventfd() failed"; +#else + PCHECK(::pipe(pipeFds_) == 0) << "pipe() failed"; + PCHECK(::fcntl(pipeFds_[0], F_SETFL, O_NONBLOCK) != -1); + PCHECK(::fcntl(pipeFds_[1], F_SETFL, O_NONBLOCK) != -1); +#endif + } + + int readFd() const { +#if QUIC_HAS_EVENTFD + return eventfd_; +#else + return pipeFds_[0]; +#endif + } + + // Drain the wakeup fd so it rearms for the next flush(). + void drainWakeupFd() { +#if QUIC_HAS_EVENTFD + uint64_t val; + while (::read(eventfd_, &val, sizeof(val)) > 0) { + } +#else + char buf[64]; + while (::read(pipeFds_[0], buf, sizeof(buf)) > 0) { + } +#endif + } + + class DrainHandler : public folly::EventHandler { + public: + DrainHandler(EventFdQueue* q, folly::EventBase* evb, int fd) + : folly::EventHandler(evb, folly::NetworkSocket::fromFd(fd)), q_(q) {} + + void handlerReady(uint16_t /*events*/) noexcept override { + q_->drainWakeupFd(); + if (q_->onReadable_) { + q_->onReadable_(); + } + } + + private: + EventFdQueue* q_; + }; + + folly::ProducerConsumerQueue queue_; +#if QUIC_HAS_EVENTFD + int eventfd_{-1}; +#else + int pipeFds_[2]{-1, -1}; +#endif + bool pendingFlush_{false}; + std::unique_ptr handler_; + folly::Function onReadable_; +}; + +} // namespace quic diff --git a/quic/common/test/CMakeLists.txt b/quic/common/test/CMakeLists.txt index 60e1c65c4..1a067d52f 100644 --- a/quic/common/test/CMakeLists.txt +++ b/quic/common/test/CMakeLists.txt @@ -84,3 +84,11 @@ quic_add_test(TARGET QuicCommonUtilTest SOURCES mvfst_test_utils ${BOOST_LIBRARIES} ) + +quic_add_test(TARGET EventFdQueueTest SOURCES + EventFdQueueTest.cpp + DEPENDS + Folly::folly + mvfst_common_event_fd_queue + mvfst_common_events_folly_eventbase +) diff --git a/quic/common/test/EventFdQueueTest.cpp b/quic/common/test/EventFdQueueTest.cpp new file mode 100644 index 000000000..8c9182f14 --- /dev/null +++ b/quic/common/test/EventFdQueueTest.cpp @@ -0,0 +1,178 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include +#include + +using namespace quic; + +// Items flushed before startConsuming() fires are still delivered because +// eventfd/pipe is level-triggered: once the handler is registered it sees the +// fd is already readable and fires immediately. +TEST(EventFdQueueTest, HandlerFiresAfterStartConsuming) { + folly::ScopedEventBaseThread consumer; + EventFdQueue queue(consumer.getEventBase(), 8); + + folly::Baton<> done; + queue.setOnReadable([&] { + int v; + while (queue.dequeue(v)) { + } + done.post(); + }); + consumer.getEventBase()->runInEventBaseThread([&] { queue.startConsuming(); }); + + queue.enqueue(42); + queue.flush(); + + done.wait(); +} + +TEST(EventFdQueueTest, BasicFifo) { + folly::ScopedEventBaseThread consumer; + EventFdQueue queue(consumer.getEventBase(), 8); + + std::vector received; + folly::Baton<> done; + + consumer.getEventBase()->runInEventBaseThread([&] { + queue.setOnReadable([&] { + int v; + while (queue.dequeue(v)) { + received.push_back(v); + } + if (received.size() == 4) { + done.post(); + } + }); + queue.startConsuming(); + }); + + // Give consumer thread time to register handler + /* sleep override */ std::this_thread::sleep_for(std::chrono::milliseconds(10)); + + EXPECT_TRUE(queue.enqueue(1)); + EXPECT_TRUE(queue.enqueue(2)); + EXPECT_TRUE(queue.enqueue(3)); + EXPECT_TRUE(queue.enqueue(4)); + queue.flush(); + + done.wait(); + + ASSERT_EQ(received.size(), 4u); + EXPECT_EQ(received[0], 1); + EXPECT_EQ(received[1], 2); + EXPECT_EQ(received[2], 3); + EXPECT_EQ(received[3], 4); +} + +// Backpressure is a property of the underlying SPSC queue; no notification +// path (startConsuming / flush) is needed to test it. +TEST(EventFdQueueTest, Backpressure) { + folly::ScopedEventBaseThread consumer; + EventFdQueue queue(consumer.getEventBase(), 4); + + // Fill the queue + EXPECT_TRUE(queue.enqueue(1)); + EXPECT_TRUE(queue.enqueue(2)); + EXPECT_TRUE(queue.enqueue(3)); + EXPECT_TRUE(queue.enqueue(4)); + // Queue is now full + EXPECT_FALSE(queue.enqueue(5)); + + // Drain one slot and verify we can enqueue again + int v; + EXPECT_TRUE(queue.dequeue(v)); + EXPECT_EQ(v, 1); + EXPECT_TRUE(queue.enqueue(5)); +} + +TEST(EventFdQueueTest, Coalescing) { + // Multiple flush() calls while consumer is busy should coalesce into one + // wakeup that drains all items. + folly::ScopedEventBaseThread consumer; + EventFdQueue queue(consumer.getEventBase(), 64); + + std::atomic wakeups{0}; + std::atomic totalDrained{0}; + folly::Baton<> done; + + consumer.getEventBase()->runInEventBaseThread([&] { + queue.setOnReadable([&] { + wakeups.fetch_add(1); + int v; + while (queue.dequeue(v)) { + totalDrained.fetch_add(1); + } + if (totalDrained.load() == 30) { + done.post(); + } + }); + queue.startConsuming(); + }); + + /* sleep override */ std::this_thread::sleep_for(std::chrono::milliseconds(10)); + + // Enqueue 3 batches of 10, flushing after each + for (int batch = 0; batch < 3; batch++) { + for (int i = 0; i < 10; i++) { + EXPECT_TRUE(queue.enqueue(batch * 10 + i)); + } + queue.flush(); + } + + done.wait(); + EXPECT_EQ(totalDrained.load(), 30); + // May be 1, 2, or 3 wakeups depending on scheduling; just verify all items + // arrived and count is at least 1. + EXPECT_GE(wakeups.load(), 1); +} + +TEST(EventFdQueueTest, NoItemLoss) { + folly::ScopedEventBaseThread consumer; + const int total = 200; + EventFdQueue queue(consumer.getEventBase(), 64); + + std::atomic received{0}; + folly::Baton<> done; + + consumer.getEventBase()->runInEventBaseThread([&] { + queue.setOnReadable([&] { + int v; + while (queue.dequeue(v)) { + if (received.fetch_add(1) + 1 == total) { + done.post(); + } + } + }); + queue.startConsuming(); + }); + + /* sleep override */ std::this_thread::sleep_for(std::chrono::milliseconds(10)); + + // Producer sends items in small batches; queue capacity may fill and require + // the consumer to drain before we can continue. + for (int i = 0; i < total; i++) { + while (!queue.enqueue(i)) { + // Queue full — yield and retry + /* sleep override */ std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + if ((i + 1) % 10 == 0) { + queue.flush(); + } + } + queue.flush(); + + done.wait(); + EXPECT_EQ(received.load(), total); +} diff --git a/quic/server/CMakeLists.txt b/quic/server/CMakeLists.txt index c7bff8965..28020621d 100644 --- a/quic/server/CMakeLists.txt +++ b/quic/server/CMakeLists.txt @@ -53,6 +53,7 @@ mvfst_add_library(mvfst_server_server Folly::folly_token_bucket fmt::fmt EXPORTED_DEPS + mvfst_api_shared_threaded_packet_writer mvfst_api_transport mvfst_api_transport_helpers mvfst_codec_types diff --git a/quic/server/QuicServer.cpp b/quic/server/QuicServer.cpp index 39052454d..fdf4a6446 100644 --- a/quic/server/QuicServer.cpp +++ b/quic/server/QuicServer.cpp @@ -134,6 +134,11 @@ bool QuicServer::isInitialized() const noexcept { return initialized_; } +void QuicServer::setDrainEventBase(folly::EventBase* drainEvb) { + checkRunningInThread(mainThreadId_); + drainEvb_ = drainEvb; +} + void QuicServer::start(const folly::SocketAddress& address, size_t maxWorkers) { checkRunningInThread(mainThreadId_); MVCHECK(ctx_, "Must set a TLS context for the Quic server"); @@ -198,6 +203,11 @@ void QuicServer::initializeImpl( MVCHECK(shutdown_); shutdown_ = false; + MVCHECK( + !drainEvb_ || + transportSettings_.dataPathType == DataPathType::ChainedMemory, + "setDrainEventBase requires DataPathType::ChainedMemory"); + // it the connid algo factory is not set, use default impl if (!connIdAlgoFactory_) { connIdAlgoFactory_ = std::make_unique(); @@ -263,6 +273,9 @@ std::unique_ptr QuicServer::newWorkerWithoutSocket() { worker->setShouldRegisterKnobParamHandlerFn( shouldRegisterKnobParamHandlerFn_); worker->setQuicExperimentHandlerFn(quicExperimentHandlerFn_); + if (drainEvb_) { + worker->setDrainEventBase(drainEvb_); + } return worker; } diff --git a/quic/server/QuicServer.h b/quic/server/QuicServer.h index ee71683a0..e9084d7d2 100644 --- a/quic/server/QuicServer.h +++ b/quic/server/QuicServer.h @@ -324,6 +324,15 @@ class QuicServer : public QuicServerWorker::WorkerCallback, void setConnectionIdAlgoFactory( std::unique_ptr connIdAlgoFactory); + /** + * Enable offloading sendmsg calls to the given EventBase thread. + * The caller owns the EventBase and must ensure it outlives this server. + * The same EventBase may be shared across multiple QuicServer instances. + * Requires DataPathType::ChainedMemory — fails at start() otherwise. + * Must be called before start() or initialize(). + */ + void setDrainEventBase(folly::EventBase* drainEvb); + /** * Returns vector of running eventbases. * This is useful if QuicServer is initialized with a 'default' mode by just @@ -492,6 +501,9 @@ class QuicServer : public QuicServerWorker::WorkerCallback, std::function unfinishedHandshakeLimitFn_{[]() { return 1048576; }}; + // Drain EventBase for SharedThreadedPacketWriter; null = inline path. + folly::EventBase* drainEvb_{nullptr}; + // Options to AsyncUDPSocket::bind, only controls IPV6_ONLY currently. FollyAsyncUDPSocketAlias::BindOptions bindOptions_; diff --git a/quic/server/QuicServerWorker.cpp b/quic/server/QuicServerWorker.cpp index 7b54304b4..dd6ff655d 100644 --- a/quic/server/QuicServerWorker.cpp +++ b/quic/server/QuicServerWorker.cpp @@ -26,6 +26,9 @@ #include #include #include +#include +#include +#include #include #include #include @@ -84,14 +87,14 @@ void QuicServerWorker::setSocket( std::unique_ptr socket) { socket_ = std::move(socket); evb_ = folly::ExecutorKeepAlive(socket_->getEventBase()); + workerQuicEvb_ = std::make_shared(evb_.get()); } void QuicServerWorker::bind( const folly::SocketAddress& address, FollyAsyncUDPSocketAlias::BindOptions bindOptions) { // TODO get rid of the temporary wrapper - FollyQuicAsyncUDPSocket tmpSock( - std::make_shared(evb_.get()), *socket_); + FollyQuicAsyncUDPSocket tmpSock(workerQuicEvb_, *socket_); MVDCHECK(!supportedVersions_.empty()); MVCHECK(socket_); // TODO this totally doesn't work, we can't apply socket options before @@ -137,8 +140,7 @@ void QuicServerWorker::bind( void QuicServerWorker::applyAllSocketOptions() { MVCHECK(socket_); // TODO get rid of the temporary wrapper - FollyQuicAsyncUDPSocket tmpSock( - std::make_shared(evb_.get()), *socket_); + FollyQuicAsyncUDPSocket tmpSock(workerQuicEvb_, *socket_); if (socketOptions_) { (void)applySocketOptions( tmpSock, @@ -200,12 +202,49 @@ void QuicServerWorker::setUnfinishedHandshakeLimit( unfinishedHandshakeLimitFn_ = std::move(limitFn); } +void QuicServerWorker::setDrainEventBase(folly::EventBase* drainEvb) { + drainEvb_ = drainEvb; +} + void QuicServerWorker::start() { MVCHECK(socket_); if (!pacingTimer_) { pacingTimer_ = std::make_unique( evb_.get(), transportSettings_.pacingTimerResolution); } + if (drainEvb_) { + sharedWriter_ = std::make_unique( + *socket_, + evb_.get(), + drainEvb_, + /*queueCapacity=*/4096, + /*maxSegmentsPerMsg=*/( + transportSettings_.batchingMode == + QuicBatchingMode::BATCHING_MODE_GSO || + transportSettings_.batchingMode == + QuicBatchingMode::BATCHING_MODE_SENDMMSG_GSO + ? transportSettings_.maxBatchSize + : 1), + /*maxMsgsPerCall=*/16); + sharedWriter_->setOnFatalError( + [this](const ConnectionId& connId, const QuicError& err) { + // Runs on producer EVB thread. + auto it = connectionIdMap_.find(connId); + if (it != connectionIdMap_.end()) { + it->second->close(err); + } + }); + sharedWriter_->setOnResumeProducer( + [this](const std::vector& ids) { + // Runs on producer EVB thread. + for (const auto& connId : ids) { + auto it = connectionIdMap_.find(connId); + if (it != connectionIdMap_.end()) { + it->second->scheduleWrite(); + } + } + }); + } socket_->resumeRead(this); MVVLOG(10) << fmt::format( "Registered read on worker={}, thread={}, processId={}", @@ -691,6 +730,10 @@ QuicServerTransport::Ptr QuicServerWorker::makeTransport( // parameters to create server chosen connection id trans->setServerConnectionIdParams(ServerConnectionIdParams( cidVersion_, hostId_, static_cast(processId_), workerId_)); + if (sharedWriter_) { + trans->setPacketWriter( + std::make_unique(sharedWriter_.get(), dstConnId)); + } trans->accept(quicVersion); auto result = sourceAddressMap_.emplace( std::make_pair(std::make_pair(client, dstConnId), trans)); @@ -1415,6 +1458,9 @@ void QuicServerWorker::shutdownAllConnections(LocalErrorCode error) { return; } shutdown_ = true; + if (sharedWriter_) { + sharedWriter_->close(); + } if (socket_) { socket_->pauseRead(); } diff --git a/quic/server/QuicServerWorker.h b/quic/server/QuicServerWorker.h index f76ed7ae6..da6923d7b 100644 --- a/quic/server/QuicServerWorker.h +++ b/quic/server/QuicServerWorker.h @@ -32,6 +32,8 @@ namespace quic { class AcceptObserver; +class FollyQuicEventBase; +class SharedThreadedPacketWriter; class QuicServerWorker : public FollyAsyncUDPSocketAlias::ReadCallback, public QuicServerTransport::RoutingCallback, @@ -268,6 +270,11 @@ class QuicServerWorker : public FollyAsyncUDPSocketAlias::ReadCallback, void setCongestionControllerFactory( std::shared_ptr factory); + // Enable the SharedThreadedPacketWriter for this worker. Must be called + // before start(). The caller owns drainEvb and must ensure it outlives this + // worker. + void setDrainEventBase(folly::EventBase* drainEvb); + /** * Set the rate limiter which will be used to rate limit new connections. */ @@ -502,6 +509,13 @@ class QuicServerWorker : public FollyAsyncUDPSocketAlias::ReadCallback, std::unique_ptr socket_; folly::SocketOptionMap* socketOptions_{nullptr}; + // Wrapper around evb_; initialized in setSocket(). Shared with + // listenerQuicSock_ and used wherever a FollyQuicEventBase* is needed. + std::shared_ptr workerQuicEvb_; + + // Threaded packet writer — set when drainEvb_ is provided. + folly::EventBase* drainEvb_{nullptr}; + std::unique_ptr sharedWriter_; std::shared_ptr callback_; folly::Executor::KeepAlive evb_; diff --git a/quic/server/test/QuicClientServerIntegrationTest.cpp b/quic/server/test/QuicClientServerIntegrationTest.cpp index 4b9d699f7..9bcb9248c 100644 --- a/quic/server/test/QuicClientServerIntegrationTest.cpp +++ b/quic/server/test/QuicClientServerIntegrationTest.cpp @@ -13,6 +13,7 @@ #include #include +#include #include #include #include @@ -228,4 +229,172 @@ TEST_F(ServerTransportParameters, disableMigrationParam) { EXPECT_NE(it, serverTransportParams->parameters.end()); } +// ── ThreadedPacketWriter integration ──────────────────────────────────────── + +namespace { + +constexpr size_t kServerPayloadSize = 1000; + +// Reads all bytes from a single unidirectional stream, then terminates the +// given EventBase loop when the FIN arrives. +class StreamFinReadCallback : public StreamReadCallback { + public: + StreamFinReadCallback(QuicSocketLite* sock, StreamId id, folly::EventBase* evb) + : sock_(sock), id_(id), evb_(evb) {} + + void readAvailable(StreamId) noexcept override { + auto res = sock_->read(id_, 0); + if (res.hasError()) { + evb_->terminateLoopSoon(); + return; + } + auto& [buf, fin] = res.value(); + if (buf) { + received_ += buf->computeChainDataLength(); + } + if (fin) { + evb_->terminateLoopSoon(); + } + } + + void readError(StreamId, QuicError) noexcept override { + evb_->terminateLoopSoon(); + } + + size_t received_{0}; + + private: + QuicSocketLite* sock_; + StreamId id_; + folly::EventBase* evb_; +}; + +// Server-side callback: on full handshake done, opens one unidirectional stream +// and writes kServerPayloadSize bytes with FIN. +class DataSendingServerCallback : public MockConnectionSetupCallback, + public MockConnectionCallback { + public: + explicit DataSendingServerCallback(QuicSocketLite* sock) : sock_(sock) {} + + void onFullHandshakeDone() noexcept override { + auto streamId = sock_->createUnidirectionalStream(); + if (streamId.hasError()) { + return; + } + auto buf = folly::IOBuf::create(kServerPayloadSize); + buf->append(kServerPayloadSize); + memset(buf->writableData(), 'x', kServerPayloadSize); + sock_->writeChain(*streamId, std::move(buf), /*eof=*/true); + } + + private: + QuicSocketLite* sock_; +}; + +class DataSendingTransportFactory : public QuicServerTransportFactory { + QuicServerTransport::Ptr make( + folly::EventBase* evb, + std::unique_ptr socket, + const folly::SocketAddress&, + QuicVersion quicVersion, + std::shared_ptr ctx) noexcept + override { + auto trans = + QuicServerTransport::make(evb, std::move(socket), nullptr, nullptr, ctx); + auto cb = std::make_shared(trans.get()); + cbs_.push_back(cb); + EXPECT_CALL(*cb, onConnectionEnd()).Times(AtMost(1)); + EXPECT_CALL(*cb, onConnectionError(_)).Times(AtMost(1)); + trans->setConnectionSetupCallback(cb.get()); + trans->setConnectionCallback(cb.get()); + return trans; + } + + std::vector> cbs_; +}; + +} // namespace + +class ThreadedPacketWriterIntegrationTest : public testing::Test { + public: + void SetUp() override { + qEvb_ = std::make_shared(&evb_); + } + + void TearDown() override { + if (client_) { + client_->close(std::nullopt); + } + if (server_) { + server_->shutdown(); + } + // Per shutdown contract: stop drain EVB before destroying writers/workers. + drainThread_.reset(); + // Drain any pending producer-EVB callbacks (e.g. onFatalError). + evb_.loop(); + } + + void startServer() { + serverTs_.statelessResetTokenSecret = getRandSecret(); + serverTs_.dataPathType = DataPathType::ChainedMemory; + server_ = QuicServer::createQuicServer(serverTs_); + server_->setFizzContext(quic::test::createServerCtx()); + server_->setQuicServerTransportFactory( + std::make_unique()); + drainThread_ = std::make_unique(); + server_->setDrainEventBase(drainThread_->getEventBase()); + server_->start(folly::SocketAddress("::1", 0), 1); + server_->waitUntilInitialized(); + } + + std::shared_ptr createQuicClient() { + CHECK(server_); + auto fizzCtx = FizzClientQuicHandshakeContext::Builder() + .setFizzClientContext(quic::test::createClientCtx()) + .setCertificateVerifier(createTestCertificateVerifier()) + .build(); + auto client = std::make_shared( + qEvb_, + std::make_unique(qEvb_), + std::move(fizzCtx)); + client->addNewPeerAddress(server_->getAddress()); + client->setHostname("::1"); + client->setSupportedVersions({QuicVersion::MVFST}); + return client; + } + + std::shared_ptr client_; + std::shared_ptr server_; + TransportSettings serverTs_{}; + folly::EventBase evb_; + std::shared_ptr qEvb_; + std::unique_ptr drainThread_; + MockConnectionSetupCallback setupCb_; + MockConnectionCallback connCb_; + std::unique_ptr readCb_; +}; + +// Start a real QuicServer with SharedThreadedPacketWriter enabled, connect a +// real client, and verify that the server can send application data through +// the full path: ConnectionPacketWriter → SPSC queue → drain thread → +// writemGSO → kernel → client. +TEST_F(ThreadedPacketWriterIntegrationTest, ServerSendsDataThroughThreadedWriter) { + startServer(); + client_ = createQuicClient(); + + EXPECT_CALL(setupCb_, onReplaySafe()); + EXPECT_CALL(connCb_, onNewUnidirectionalStream(_)) + .WillOnce(Invoke([&](StreamId id) { + readCb_ = + std::make_unique(client_.get(), id, &evb_); + client_->setReadCallback(id, readCb_.get()); + })); + + client_->start(&setupCb_, &connCb_); + evb_.loopForever(); // terminates when StreamFinReadCallback receives FIN + + ASSERT_NE(readCb_, nullptr); + EXPECT_EQ(readCb_->received_, kServerPayloadSize); +} + } // namespace quic::test diff --git a/quic/state/CMakeLists.txt b/quic/state/CMakeLists.txt index 7f38c8524..f78e86d93 100644 --- a/quic/state/CMakeLists.txt +++ b/quic/state/CMakeLists.txt @@ -64,6 +64,7 @@ mvfst_add_library(mvfst_state_quic_state_machine QuicStreamManager.cpp StateData.cpp DEPS + mvfst_api_quic_packet_writer mvfst_logging_qlogger_macros mvfst_priority_http_priority_queue EXPORTED_DEPS diff --git a/quic/state/StateData.h b/quic/state/StateData.h index 18bc1145e..34a2880bd 100644 --- a/quic/state/StateData.h +++ b/quic/state/StateData.h @@ -8,6 +8,7 @@ #pragma once #include +#include #include #include #include @@ -786,6 +787,10 @@ struct QuicConnectionStateBase : public folly::DelayedDestruction { }; Optional scone; + + // If set, encrypted packets are handed to this writer instead of the inline + // IOBufQuicBatch. Only valid when DataPathType::ChainedMemory is in use. + std::unique_ptr packetWriter; }; std::ostream& operator<<(std::ostream& os, const QuicConnectionStateBase& st);