Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ pin-project-lite = "0.2.4"
spmc = "0.3"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tracing = "0.1"
tracing-subscriber = "0.3"
tokio = { version = "1", features = [
"fs",
"macros",
Expand Down Expand Up @@ -239,6 +241,16 @@ name = "integration"
path = "tests/integration.rs"
required-features = ["full"]

[[test]]
name = "ready_on_poll_stream"
path = "tests/ready_on_poll_stream.rs"
required-features = ["full"]

[[test]]
name = "unbuffered_stream"
path = "tests/unbuffered_stream.rs"
required-features = ["full"]

[[test]]
name = "server"
path = "tests/server.rs"
Expand Down
40 changes: 34 additions & 6 deletions src/proto/h1/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,14 @@ where
// benchmarks often use. Perhaps it should be a config option instead.
for _ in 0..16 {
let _ = self.poll_read(cx)?;
let _ = self.poll_write(cx)?;
let _ = self.poll_flush(cx)?;
let write_ready = self.poll_write(cx)?.is_ready();
let flush_ready = self.poll_flush(cx)?.is_ready();

// If we can write more body and the connection is ready, we should
// write again. If we return `Ready(Ok(())` here, we will yield
// without a guaranteed wake-up from the write side of the connection.
// This would lead to a deadlock if we also don't expect reads.
let wants_write_again = self.can_write_again() && (write_ready || flush_ready);

// This could happen if reading paused before blocking on IO,
// such as getting to the end of a framed message, but then
Expand All @@ -181,14 +187,31 @@ where
//
// Using this instead of task::current() and notify() inside
// the Conn is noticeably faster in pipelined benchmarks.
if !self.conn.wants_read_again() {
//break;
let wants_read_again = self.conn.wants_read_again();

// If we cannot write or read again, we yield and rely on the
// wake-up from the connection futures.
if !(wants_write_again || wants_read_again) {
return Poll::Ready(Ok(()));
}
}

// If we are continuing only because "wants_write_again", check if write is ready.
if !wants_read_again && wants_write_again {
// If write was ready, just proceed with the loop
if write_ready {
continue;
}
// Write was previously pending, but may have become ready since polling flush, so
// we need to check it again. If we simply proceeded, the case of an unbuffered
// writer where flush is always ready would cause us to hot loop.
if self.poll_write(cx)?.is_pending() {
// write is pending, so it is safe to yield and rely on wake-up from connection
// futures.
return Poll::Ready(Ok(()));
}
}
}
trace!("poll_loop yielding (self = {:p})", self);

task::yield_now(cx).map(|never| match never {})
}

Expand Down Expand Up @@ -433,6 +456,11 @@ where
self.conn.close_write();
}

/// If there is pending data in body_rx, we can make progress writing if the connection is ready.
fn can_write_again(&mut self) -> bool {
self.body_rx.is_some()
}

fn is_done(&self) -> bool {
if self.is_closing {
return true;
Expand Down
136 changes: 136 additions & 0 deletions tests/h1_server/fixture.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
use http_body_util::StreamBody;
use hyper::body::Bytes;
use hyper::body::Frame;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Response, StatusCode};
use std::convert::Infallible;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::time::timeout;
use tracing::{error, info};

pub struct TestConfig {
pub total_chunks: usize,
pub chunk_size: usize,
pub chunk_timeout: Duration,
}

impl TestConfig {
pub fn with_timeout(chunk_timeout: Duration) -> Self {
Self {
total_chunks: 16,
chunk_size: 64 * 1024,
chunk_timeout,
}
}
}

pub struct Client {
pub rx: mpsc::UnboundedReceiver<Vec<u8>>,
pub tx: mpsc::UnboundedSender<Vec<u8>>,
}

pub async fn run<S>(server: S, mut client: Client, config: TestConfig)
where
S: hyper::rt::Read + hyper::rt::Write + Send + Unpin + 'static,
{
let mut http_builder = http1::Builder::new();
http_builder.max_buf_size(config.chunk_size);

let total_chunks = config.total_chunks;
let chunk_size = config.chunk_size;

let service = service_fn(move |_| {
let total_chunks = total_chunks;
let chunk_size = chunk_size;
async move {
info!(
"Creating payload of {} chunks of {} KiB each ({} MiB total)...",
total_chunks,
chunk_size / 1024,
total_chunks * chunk_size / (1024 * 1024)
);
let bytes = Bytes::from(vec![0; chunk_size]);
let data = vec![bytes.clone(); total_chunks];
let stream = futures_util::stream::iter(
data.into_iter()
.map(|b| Ok::<_, Infallible>(Frame::data(b))),
);
let body = StreamBody::new(stream);
info!("Server: Sending data response...");
Ok::<_, hyper::Error>(
Response::builder()
.status(StatusCode::OK)
.header("content-type", "application/octet-stream")
.header("content-length", (total_chunks * chunk_size).to_string())
.body(body)
.unwrap(),
)
}
});

let server_task = tokio::spawn(async move {
let conn = http_builder.serve_connection(Box::pin(server), service);
let conn_result = conn.await;
if let Err(e) = &conn_result {
error!("Server connection error: {}", e);
}
conn_result
});

let get_request = "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
client
.tx
.send(get_request.as_bytes().to_vec())
.map_err(|e| {
Box::new(std::io::Error::new(
std::io::ErrorKind::Other,
format!("Failed to send request: {}", e),
))
})
.unwrap();

info!("Client is reading response...");
let mut bytes_received = 0;
let mut all_data = Vec::new();
loop {
match timeout(config.chunk_timeout, client.rx.recv()).await {
Ok(Some(chunk)) => {
bytes_received += chunk.len();
all_data.extend_from_slice(&chunk);
}
Ok(None) => break,
Err(_) => {
panic!(
"Chunk timeout: chunk took longer than {:?}",
config.chunk_timeout
);
}
}
}

// Clean up
let result = server_task.await.unwrap();
result.unwrap();

// Parse HTTP response to find body start
// HTTP response format: "HTTP/1.1 200 OK\r\n...headers...\r\n\r\n<body>"
let body_start = all_data
.windows(4)
.position(|w| w == b"\r\n\r\n")
.map(|pos| pos + 4)
.unwrap_or(0);

let body_bytes = bytes_received - body_start;
assert_eq!(
body_bytes,
config.total_chunks * config.chunk_size,
"Expected {} body bytes, got {} (total received: {}, headers: {})",
config.total_chunks * config.chunk_size,
body_bytes,
bytes_received,
body_start
);
info!(bytes_received, body_bytes, "Client done receiving bytes");
}
97 changes: 97 additions & 0 deletions tests/h1_server/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
pub mod fixture;

use hyper::rt::{Read, ReadBufCursor};
use pin_project_lite::pin_project;
use std::io;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use tokio::sync::mpsc;

// Common read half shared by both stream types
pin_project! {
#[derive(Debug)]
pub struct StreamReadHalf {
#[pin]
read_rx: mpsc::UnboundedReceiver<Vec<u8>>,
read_buffer: Vec<u8>,
}
}

impl StreamReadHalf {
pub fn new(read_rx: mpsc::UnboundedReceiver<Vec<u8>>) -> Self {
Self {
read_rx,
read_buffer: Vec::new(),
}
}
}

impl Read for StreamReadHalf {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
mut buf: ReadBufCursor<'_>,
) -> Poll<io::Result<()>> {
let mut this = self.as_mut().project();

// First, try to satisfy the read request from the internal buffer
if !this.read_buffer.is_empty() {
let to_read = std::cmp::min(this.read_buffer.len(), buf.remaining());
// Copy data from internal buffer to the read buffer
buf.put_slice(&this.read_buffer[..to_read]);
// Remove the consumed data from the internal buffer
this.read_buffer.drain(..to_read);
return Poll::Ready(Ok(()));
}

// If internal buffer is empty, try to get data from the channel
match this.read_rx.as_mut().get_mut().try_recv() {
Ok(data) => {
// Copy as much data as we can fit in the buffer
let to_read = std::cmp::min(data.len(), buf.remaining());
buf.put_slice(&data[..to_read]);

// Store any remaining data in the internal buffer for next time
if to_read < data.len() {
let remaining = &data[to_read..];
this.read_buffer.extend_from_slice(remaining);
}
Poll::Ready(Ok(()))
}
Err(mpsc::error::TryRecvError::Empty) => {
match ready!(this.read_rx.poll_recv(cx)) {
Some(data) => {
// Copy as much data as we can fit in the buffer
let to_read = std::cmp::min(data.len(), buf.remaining());
buf.put_slice(&data[..to_read]);

// Store any remaining data in the internal buffer for next time
if to_read < data.len() {
let remaining = &data[to_read..];
this.read_buffer.extend_from_slice(remaining);
}
Poll::Ready(Ok(()))
}
None => Poll::Ready(Ok(())),
}
}
Err(mpsc::error::TryRecvError::Disconnected) => {
// Channel closed, return EOF
Poll::Ready(Ok(()))
}
}
}
}

pub fn init_tracing() {
use std::sync::Once;
static INIT: Once = Once::new();
INIT.call_once(|| {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
.with_target(true)
.with_thread_ids(true)
.with_thread_names(true)
.init();
});
}
Loading