From 2c0c63a967c4f80c3742a061b18d0678283e4120 Mon Sep 17 00:00:00 2001 From: Louis Thiery Date: Tue, 2 Dec 2025 13:02:52 -0800 Subject: [PATCH 1/2] test(http1): setup test fixture for dispatch loop --- Cargo.toml | 12 +++ tests/h1_server/fixture.rs | 136 +++++++++++++++++++++++++++++++++ tests/h1_server/mod.rs | 97 ++++++++++++++++++++++++ tests/ready_on_poll_stream.rs | 139 ++++++++++++++++++++++++++++++++++ tests/unbuffered_stream.rs | 126 ++++++++++++++++++++++++++++++ 5 files changed, 510 insertions(+) create mode 100644 tests/h1_server/fixture.rs create mode 100644 tests/h1_server/mod.rs create mode 100644 tests/ready_on_poll_stream.rs create mode 100644 tests/unbuffered_stream.rs diff --git a/Cargo.toml b/Cargo.toml index 4441bdcdea..d3808624f5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,6 +52,8 @@ pin-project-lite = "0.2.4" spmc = "0.3" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +tracing = "0.1" +tracing-subscriber = "0.3" tokio = { version = "1", features = [ "fs", "macros", @@ -239,6 +241,16 @@ name = "integration" path = "tests/integration.rs" required-features = ["full"] +[[test]] +name = "ready_on_poll_stream" +path = "tests/ready_on_poll_stream.rs" +required-features = ["full"] + +[[test]] +name = "unbuffered_stream" +path = "tests/unbuffered_stream.rs" +required-features = ["full"] + [[test]] name = "server" path = "tests/server.rs" diff --git a/tests/h1_server/fixture.rs b/tests/h1_server/fixture.rs new file mode 100644 index 0000000000..6b31efecbc --- /dev/null +++ b/tests/h1_server/fixture.rs @@ -0,0 +1,136 @@ +use http_body_util::StreamBody; +use hyper::body::Bytes; +use hyper::body::Frame; +use hyper::server::conn::http1; +use hyper::service::service_fn; +use hyper::{Response, StatusCode}; +use std::convert::Infallible; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio::time::timeout; +use tracing::{error, info}; + +pub struct TestConfig { + pub total_chunks: usize, + pub chunk_size: usize, + pub chunk_timeout: Duration, +} + +impl TestConfig { + pub fn with_timeout(chunk_timeout: Duration) -> Self { + Self { + total_chunks: 16, + chunk_size: 64 * 1024, + chunk_timeout, + } + } +} + +pub struct Client { + pub rx: mpsc::UnboundedReceiver>, + pub tx: mpsc::UnboundedSender>, +} + +pub async fn run(server: S, mut client: Client, config: TestConfig) +where + S: hyper::rt::Read + hyper::rt::Write + Send + Unpin + 'static, +{ + let mut http_builder = http1::Builder::new(); + http_builder.max_buf_size(config.chunk_size); + + let total_chunks = config.total_chunks; + let chunk_size = config.chunk_size; + + let service = service_fn(move |_| { + let total_chunks = total_chunks; + let chunk_size = chunk_size; + async move { + info!( + "Creating payload of {} chunks of {} KiB each ({} MiB total)...", + total_chunks, + chunk_size / 1024, + total_chunks * chunk_size / (1024 * 1024) + ); + let bytes = Bytes::from(vec![0; chunk_size]); + let data = vec![bytes.clone(); total_chunks]; + let stream = futures_util::stream::iter( + data.into_iter() + .map(|b| Ok::<_, Infallible>(Frame::data(b))), + ); + let body = StreamBody::new(stream); + info!("Server: Sending data response..."); + Ok::<_, hyper::Error>( + Response::builder() + .status(StatusCode::OK) + .header("content-type", "application/octet-stream") + .header("content-length", (total_chunks * chunk_size).to_string()) + .body(body) + .unwrap(), + ) + } + }); + + let server_task = tokio::spawn(async move { + let conn = http_builder.serve_connection(Box::pin(server), service); + let conn_result = conn.await; + if let Err(e) = &conn_result { + error!("Server connection error: {}", e); + } + conn_result + }); + + let get_request = "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"; + client + .tx + .send(get_request.as_bytes().to_vec()) + .map_err(|e| { + Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + format!("Failed to send request: {}", e), + )) + }) + .unwrap(); + + info!("Client is reading response..."); + let mut bytes_received = 0; + let mut all_data = Vec::new(); + loop { + match timeout(config.chunk_timeout, client.rx.recv()).await { + Ok(Some(chunk)) => { + bytes_received += chunk.len(); + all_data.extend_from_slice(&chunk); + } + Ok(None) => break, + Err(_) => { + panic!( + "Chunk timeout: chunk took longer than {:?}", + config.chunk_timeout + ); + } + } + } + + // Clean up + let result = server_task.await.unwrap(); + result.unwrap(); + + // Parse HTTP response to find body start + // HTTP response format: "HTTP/1.1 200 OK\r\n...headers...\r\n\r\n" + let body_start = all_data + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map(|pos| pos + 4) + .unwrap_or(0); + + let body_bytes = bytes_received - body_start; + assert_eq!( + body_bytes, + config.total_chunks * config.chunk_size, + "Expected {} body bytes, got {} (total received: {}, headers: {})", + config.total_chunks * config.chunk_size, + body_bytes, + bytes_received, + body_start + ); + info!(bytes_received, body_bytes, "Client done receiving bytes"); +} diff --git a/tests/h1_server/mod.rs b/tests/h1_server/mod.rs new file mode 100644 index 0000000000..7b2ee0a350 --- /dev/null +++ b/tests/h1_server/mod.rs @@ -0,0 +1,97 @@ +pub mod fixture; + +use hyper::rt::{Read, ReadBufCursor}; +use pin_project_lite::pin_project; +use std::io; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; +use tokio::sync::mpsc; + +// Common read half shared by both stream types +pin_project! { + #[derive(Debug)] + pub struct StreamReadHalf { + #[pin] + read_rx: mpsc::UnboundedReceiver>, + read_buffer: Vec, + } +} + +impl StreamReadHalf { + pub fn new(read_rx: mpsc::UnboundedReceiver>) -> Self { + Self { + read_rx, + read_buffer: Vec::new(), + } + } +} + +impl Read for StreamReadHalf { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: ReadBufCursor<'_>, + ) -> Poll> { + let mut this = self.as_mut().project(); + + // First, try to satisfy the read request from the internal buffer + if !this.read_buffer.is_empty() { + let to_read = std::cmp::min(this.read_buffer.len(), buf.remaining()); + // Copy data from internal buffer to the read buffer + buf.put_slice(&this.read_buffer[..to_read]); + // Remove the consumed data from the internal buffer + this.read_buffer.drain(..to_read); + return Poll::Ready(Ok(())); + } + + // If internal buffer is empty, try to get data from the channel + match this.read_rx.as_mut().get_mut().try_recv() { + Ok(data) => { + // Copy as much data as we can fit in the buffer + let to_read = std::cmp::min(data.len(), buf.remaining()); + buf.put_slice(&data[..to_read]); + + // Store any remaining data in the internal buffer for next time + if to_read < data.len() { + let remaining = &data[to_read..]; + this.read_buffer.extend_from_slice(remaining); + } + Poll::Ready(Ok(())) + } + Err(mpsc::error::TryRecvError::Empty) => { + match ready!(this.read_rx.poll_recv(cx)) { + Some(data) => { + // Copy as much data as we can fit in the buffer + let to_read = std::cmp::min(data.len(), buf.remaining()); + buf.put_slice(&data[..to_read]); + + // Store any remaining data in the internal buffer for next time + if to_read < data.len() { + let remaining = &data[to_read..]; + this.read_buffer.extend_from_slice(remaining); + } + Poll::Ready(Ok(())) + } + None => Poll::Ready(Ok(())), + } + } + Err(mpsc::error::TryRecvError::Disconnected) => { + // Channel closed, return EOF + Poll::Ready(Ok(())) + } + } + } +} + +pub fn init_tracing() { + use std::sync::Once; + static INIT: Once = Once::new(); + INIT.call_once(|| { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .with_target(true) + .with_thread_ids(true) + .with_thread_names(true) + .init(); + }); +} diff --git a/tests/ready_on_poll_stream.rs b/tests/ready_on_poll_stream.rs new file mode 100644 index 0000000000..f393d31c0e --- /dev/null +++ b/tests/ready_on_poll_stream.rs @@ -0,0 +1,139 @@ +#[path = "h1_server/mod.rs"] +mod h1_server; + +use h1_server::{fixture, init_tracing, StreamReadHalf}; +use hyper::rt::{Read, ReadBufCursor, Write}; +use pin_project_lite::pin_project; +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio::time::Sleep; +use tracing::error; + +pin_project! { + #[derive(Debug)] + pub struct ReadyOnPollStream { + #[pin] + read_half: StreamReadHalf, + write_tx: mpsc::UnboundedSender>, + #[pin] + pending_write: Option>>, + poll_since_write: bool, + flush_count: usize, + } +} + +impl ReadyOnPollStream { + fn new( + read_rx: mpsc::UnboundedReceiver>, + write_tx: mpsc::UnboundedSender>, + ) -> Self { + Self { + read_half: StreamReadHalf::new(read_rx), + write_tx, + poll_since_write: true, + flush_count: 0, + pending_write: None, + } + } + + /// Create a new server stream and client pair. + /// Returns a server stream (Read+Write) and a client (rx/tx channels). + pub fn new_pair() -> (Self, fixture::Client) { + let (client_tx, server_rx) = mpsc::unbounded_channel(); + let (server_tx, client_rx) = mpsc::unbounded_channel(); + let server = Self::new(server_rx, server_tx); + let client = fixture::Client { + rx: client_rx, + tx: client_tx, + }; + (server, client) + } +} + +impl Read for ReadyOnPollStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: ReadBufCursor<'_>, + ) -> Poll> { + self.as_mut().project().read_half.poll_read(cx, buf) + } +} + +const WRITE_DELAY: Duration = Duration::from_millis(100); + +impl Write for ReadyOnPollStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if let Some(sleep) = self.pending_write.as_mut() { + let sleep = sleep.as_mut(); + ready!(Future::poll(sleep, cx)); + } + { + let mut this = self.as_mut().project(); + this.pending_write + .set(Some(Box::pin(tokio::time::sleep(WRITE_DELAY)))); + } + let Some(sleep) = self.pending_write.as_mut() else { + panic!("Sleep should have just been set"); + }; + // poll the future so that we can woken + let sleep = sleep.as_mut(); + let Poll::Pending = Future::poll(sleep, cx) else { + panic!("Sleep always be pending on first poll") + }; + + let this = self.project(); + let buf = Vec::from(&buf[..buf.len()]); + let len = buf.len(); + + // Send data through the channel - this should always be ready for unbounded channels + match this.write_tx.send(buf) { + Ok(_) => Poll::Ready(Ok(len)), + Err(_) => { + error!("ReadyStream::poll_write failed - channel closed"); + Poll::Ready(Err(io::Error::new( + io::ErrorKind::BrokenPipe, + "Write channel closed", + ))) + } + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.flush_count += 1; + // We require two flushes to complete each chunk, simulating a success at the end of the old + // poll loop. After all chunks are written, we always succeed on flush to allow for finish. + const TOTAL_CHUNKS: usize = 16; + if self.flush_count % 2 != 0 && self.flush_count < TOTAL_CHUNKS * 2 { + if let Some(sleep) = self.pending_write.as_mut() { + let sleep = sleep.as_mut(); + ready!(Future::poll(sleep, cx)); + } else { + return Poll::Pending; + } + } + let mut this = self.as_mut().project(); + this.pending_write.set(None); + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn body_test() { + init_tracing(); + let (server, client) = ReadyOnPollStream::new_pair(); + let config = fixture::TestConfig::with_timeout(WRITE_DELAY * 2); + fixture::run(server, client, config).await; +} diff --git a/tests/unbuffered_stream.rs b/tests/unbuffered_stream.rs new file mode 100644 index 0000000000..9cea28d5f8 --- /dev/null +++ b/tests/unbuffered_stream.rs @@ -0,0 +1,126 @@ +#[path = "h1_server/mod.rs"] +mod h1_server; + +use h1_server::{fixture, init_tracing, StreamReadHalf}; +use hyper::rt::{Read, ReadBufCursor, Write}; +use pin_project_lite::pin_project; +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio::time::Sleep; +use tracing::error; + +pin_project! { + #[derive(Debug)] + pub struct UnbufferedStream { + #[pin] + read_half: StreamReadHalf, + #[pin] + pending_write: Option>>, + write_tx: mpsc::UnboundedSender>, + poll_cnt: usize, + } +} + +impl UnbufferedStream { + fn new( + read_rx: mpsc::UnboundedReceiver>, + write_tx: mpsc::UnboundedSender>, + ) -> Self { + Self { + read_half: StreamReadHalf::new(read_rx), + write_tx, + pending_write: None, + poll_cnt: 0, + } + } + + /// Create a new server stream and client pair. + /// Returns a server stream (Read+Write) and a client (rx/tx channels). + pub fn new_pair() -> (Self, fixture::Client) { + let (client_tx, server_rx) = mpsc::unbounded_channel(); + let (server_tx, client_rx) = mpsc::unbounded_channel(); + let server = Self::new(server_rx, server_tx); + let client = fixture::Client { + rx: client_rx, + tx: client_tx, + }; + (server, client) + } +} + +const WRITE_DELAY: Duration = Duration::from_millis(100); + +impl Read for UnbufferedStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: ReadBufCursor<'_>, + ) -> Poll> { + let response = self.as_mut().project().read_half.poll_read(cx, buf); + response + } +} + +impl Write for UnbufferedStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.poll_cnt += 1; + let poll_cnt = self.poll_cnt; + if let Some(sleep) = self.pending_write.as_mut() { + let sleep = sleep.as_mut(); + if poll_cnt > 4 { + return Poll::Ready(Err(io::Error::other("We are being hot polled!"))); + } + ready!(Future::poll(sleep, cx)); + let mut this = self.as_mut().project(); + this.pending_write.set(None); + *this.poll_cnt = 0; + } + let len = buf.len(); + { + let mut this = self.as_mut().project(); + let buf = Vec::from(&buf[..buf.len()]); + // Send data through the channel - this should always be ready for unbounded channels + let Ok(_) = this.write_tx.send(buf) else { + error!("UnbufferedStream::poll_write failed - channel closed"); + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::BrokenPipe, + "Write channel closed", + ))); + }; + this.pending_write + .set(Some(Box::pin(tokio::time::sleep(WRITE_DELAY)))) + } + let Some(sleep) = self.pending_write.as_mut() else { + panic!("Sleep should have just been set"); + }; + let sleep = sleep.as_mut(); + let Poll::Pending = Future::poll(sleep, cx) else { + panic!("Sleep always be pending on first poll") + }; + Poll::Ready(Ok(len)) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn body_test() { + init_tracing(); + let (server, client) = UnbufferedStream::new_pair(); + let config = fixture::TestConfig::with_timeout(WRITE_DELAY * 2); + fixture::run(server, client, config).await; +} From a416aa8be05e36767830df2180f9aa78f6b412e7 Mon Sep 17 00:00:00 2001 From: Louis Thiery Date: Tue, 2 Dec 2025 13:03:12 -0800 Subject: [PATCH 2/2] fix(http1): fix rare missed write wakeup on connections v2 --- src/proto/h1/dispatch.rs | 40 ++++++++++++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 6 deletions(-) diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs index 5daeb5ebf6..91defd3234 100644 --- a/src/proto/h1/dispatch.rs +++ b/src/proto/h1/dispatch.rs @@ -170,8 +170,14 @@ where // benchmarks often use. Perhaps it should be a config option instead. for _ in 0..16 { let _ = self.poll_read(cx)?; - let _ = self.poll_write(cx)?; - let _ = self.poll_flush(cx)?; + let write_ready = self.poll_write(cx)?.is_ready(); + let flush_ready = self.poll_flush(cx)?.is_ready(); + + // If we can write more body and the connection is ready, we should + // write again. If we return `Ready(Ok(())` here, we will yield + // without a guaranteed wake-up from the write side of the connection. + // This would lead to a deadlock if we also don't expect reads. + let wants_write_again = self.can_write_again() && (write_ready || flush_ready); // This could happen if reading paused before blocking on IO, // such as getting to the end of a framed message, but then @@ -181,14 +187,31 @@ where // // Using this instead of task::current() and notify() inside // the Conn is noticeably faster in pipelined benchmarks. - if !self.conn.wants_read_again() { - //break; + let wants_read_again = self.conn.wants_read_again(); + + // If we cannot write or read again, we yield and rely on the + // wake-up from the connection futures. + if !(wants_write_again || wants_read_again) { return Poll::Ready(Ok(())); } - } + // If we are continuing only because "wants_write_again", check if write is ready. + if !wants_read_again && wants_write_again { + // If write was ready, just proceed with the loop + if write_ready { + continue; + } + // Write was previously pending, but may have become ready since polling flush, so + // we need to check it again. If we simply proceeded, the case of an unbuffered + // writer where flush is always ready would cause us to hot loop. + if self.poll_write(cx)?.is_pending() { + // write is pending, so it is safe to yield and rely on wake-up from connection + // futures. + return Poll::Ready(Ok(())); + } + } + } trace!("poll_loop yielding (self = {:p})", self); - task::yield_now(cx).map(|never| match never {}) } @@ -433,6 +456,11 @@ where self.conn.close_write(); } + /// If there is pending data in body_rx, we can make progress writing if the connection is ready. + fn can_write_again(&mut self) -> bool { + self.body_rx.is_some() + } + fn is_done(&self) -> bool { if self.is_closing { return true;