Skip to content

Commit fd60a1c

Browse files
committed
test(http1): setup test fixture for dispatch loop
1 parent 1c70fab commit fd60a1c

File tree

5 files changed

+495
-0
lines changed

5 files changed

+495
-0
lines changed

Cargo.toml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ pin-project-lite = "0.2.4"
5252
spmc = "0.3"
5353
serde = { version = "1.0", features = ["derive"] }
5454
serde_json = "1.0"
55+
tracing = "0.1"
56+
tracing-subscriber = "0.3"
5557
tokio = { version = "1", features = [
5658
"fs",
5759
"macros",
@@ -239,6 +241,16 @@ name = "integration"
239241
path = "tests/integration.rs"
240242
required-features = ["full"]
241243

244+
[[test]]
245+
name = "ready_on_poll_stream"
246+
path = "tests/ready_on_poll_stream.rs"
247+
required-features = ["full"]
248+
249+
[[test]]
250+
name = "unbuffered_stream"
251+
path = "tests/unbuffered_stream.rs"
252+
required-features = ["full"]
253+
242254
[[test]]
243255
name = "server"
244256
path = "tests/server.rs"

tests/h1_server/fixture.rs

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
use http_body_util::StreamBody;
2+
use hyper::body::Bytes;
3+
use hyper::body::Frame;
4+
use hyper::server::conn::http1;
5+
use hyper::service::service_fn;
6+
use hyper::{Response, StatusCode};
7+
use std::convert::Infallible;
8+
use std::time::Duration;
9+
use tokio::sync::mpsc;
10+
use tokio::time::timeout;
11+
use tracing::{error, info};
12+
13+
pub struct TestConfig {
14+
pub total_chunks: usize,
15+
pub chunk_size: usize,
16+
pub chunk_timeout: Duration,
17+
}
18+
19+
impl TestConfig {
20+
pub fn with_timeout(chunk_timeout: Duration) -> Self {
21+
Self {
22+
total_chunks: 16,
23+
chunk_size: 64 * 1024,
24+
chunk_timeout,
25+
}
26+
}
27+
}
28+
29+
pub struct Client {
30+
pub rx: mpsc::UnboundedReceiver<Vec<u8>>,
31+
pub tx: mpsc::UnboundedSender<Vec<u8>>,
32+
}
33+
34+
pub async fn run<S>(server: S, mut client: Client, config: TestConfig)
35+
where
36+
S: hyper::rt::Read + hyper::rt::Write + Send + Unpin + 'static,
37+
{
38+
39+
let mut http_builder = http1::Builder::new();
40+
http_builder.max_buf_size(config.chunk_size);
41+
42+
let total_chunks = config.total_chunks;
43+
let chunk_size = config.chunk_size;
44+
45+
let service = service_fn(move |_| {
46+
let total_chunks = total_chunks;
47+
let chunk_size = chunk_size;
48+
async move {
49+
info!(
50+
"Creating payload of {} chunks of {} KiB each ({} MiB total)...",
51+
total_chunks,
52+
chunk_size / 1024,
53+
total_chunks * chunk_size / (1024 * 1024)
54+
);
55+
let bytes = Bytes::from(vec![0; chunk_size]);
56+
let data = vec![bytes.clone(); total_chunks];
57+
let stream = futures_util::stream::iter(
58+
data.into_iter()
59+
.map(|b| Ok::<_, Infallible>(Frame::data(b))),
60+
);
61+
let body = StreamBody::new(stream);
62+
info!("Server: Sending data response...");
63+
Ok::<_, hyper::Error>(
64+
Response::builder()
65+
.status(StatusCode::OK)
66+
.header("content-type", "application/octet-stream")
67+
.header("content-length", (total_chunks * chunk_size).to_string())
68+
.body(body)
69+
.unwrap(),
70+
)
71+
}
72+
});
73+
74+
let server_task = tokio::spawn(async move {
75+
let conn = http_builder.serve_connection(Box::pin(server), service);
76+
let conn_result = conn.await;
77+
if let Err(e) = &conn_result {
78+
error!("Server connection error: {}", e);
79+
}
80+
conn_result
81+
});
82+
83+
let get_request = "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
84+
client.tx.send(get_request.as_bytes().to_vec())
85+
.map_err(|e| Box::new(std::io::Error::new(std::io::ErrorKind::Other, format!("Failed to send request: {}", e))))
86+
.unwrap();
87+
88+
info!("Client is reading response...");
89+
let mut bytes_received = 0;
90+
let mut all_data = Vec::new();
91+
loop {
92+
match timeout(config.chunk_timeout, client.rx.recv()).await {
93+
Ok(Some(chunk)) => {
94+
bytes_received += chunk.len();
95+
all_data.extend_from_slice(&chunk);
96+
}
97+
Ok(None) => break,
98+
Err(_) => {
99+
panic!("Chunk timeout: chunk took longer than {:?}", config.chunk_timeout);
100+
}
101+
}
102+
}
103+
104+
// Clean up
105+
let result = server_task.await.unwrap();
106+
result.unwrap();
107+
108+
// Parse HTTP response to find body start
109+
// HTTP response format: "HTTP/1.1 200 OK\r\n...headers...\r\n\r\n<body>"
110+
let body_start = all_data.windows(4)
111+
.position(|w| w == b"\r\n\r\n")
112+
.map(|pos| pos + 4)
113+
.unwrap_or(0);
114+
115+
let body_bytes = bytes_received - body_start;
116+
assert_eq!(body_bytes, config.total_chunks * config.chunk_size,
117+
"Expected {} body bytes, got {} (total received: {}, headers: {})",
118+
config.total_chunks * config.chunk_size, body_bytes, bytes_received, body_start);
119+
info!(bytes_received, body_bytes, "Client done receiving bytes");
120+
}
121+

tests/h1_server/mod.rs

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
pub mod fixture;
2+
3+
use hyper::rt::{Read, ReadBufCursor};
4+
use pin_project_lite::pin_project;
5+
use std::io;
6+
use std::pin::Pin;
7+
use std::task::{ready, Context, Poll};
8+
use tokio::sync::mpsc;
9+
10+
// Common read half shared by both stream types
11+
pin_project! {
12+
#[derive(Debug)]
13+
pub struct StreamReadHalf {
14+
#[pin]
15+
read_rx: mpsc::UnboundedReceiver<Vec<u8>>,
16+
read_buffer: Vec<u8>,
17+
}
18+
}
19+
20+
impl StreamReadHalf {
21+
pub fn new(read_rx: mpsc::UnboundedReceiver<Vec<u8>>) -> Self {
22+
Self {
23+
read_rx,
24+
read_buffer: Vec::new(),
25+
}
26+
}
27+
}
28+
29+
impl Read for StreamReadHalf {
30+
fn poll_read(
31+
mut self: Pin<&mut Self>,
32+
cx: &mut Context<'_>,
33+
mut buf: ReadBufCursor<'_>,
34+
) -> Poll<io::Result<()>> {
35+
let mut this = self.as_mut().project();
36+
37+
// First, try to satisfy the read request from the internal buffer
38+
if !this.read_buffer.is_empty() {
39+
let to_read = std::cmp::min(this.read_buffer.len(), buf.remaining());
40+
// Copy data from internal buffer to the read buffer
41+
buf.put_slice(&this.read_buffer[..to_read]);
42+
// Remove the consumed data from the internal buffer
43+
this.read_buffer.drain(..to_read);
44+
return Poll::Ready(Ok(()));
45+
}
46+
47+
// If internal buffer is empty, try to get data from the channel
48+
match this.read_rx.as_mut().get_mut().try_recv() {
49+
Ok(data) => {
50+
// Copy as much data as we can fit in the buffer
51+
let to_read = std::cmp::min(data.len(), buf.remaining());
52+
buf.put_slice(&data[..to_read]);
53+
54+
// Store any remaining data in the internal buffer for next time
55+
if to_read < data.len() {
56+
let remaining = &data[to_read..];
57+
this.read_buffer.extend_from_slice(remaining);
58+
}
59+
Poll::Ready(Ok(()))
60+
}
61+
Err(mpsc::error::TryRecvError::Empty) => {
62+
match ready!(this.read_rx.poll_recv(cx)) {
63+
Some(data) => {
64+
// Copy as much data as we can fit in the buffer
65+
let to_read = std::cmp::min(data.len(), buf.remaining());
66+
buf.put_slice(&data[..to_read]);
67+
68+
// Store any remaining data in the internal buffer for next time
69+
if to_read < data.len() {
70+
let remaining = &data[to_read..];
71+
this.read_buffer.extend_from_slice(remaining);
72+
}
73+
Poll::Ready(Ok(()))
74+
}
75+
None => Poll::Ready(Ok(())),
76+
}
77+
}
78+
Err(mpsc::error::TryRecvError::Disconnected) => {
79+
// Channel closed, return EOF
80+
Poll::Ready(Ok(()))
81+
}
82+
}
83+
}
84+
}
85+
86+
pub fn init_tracing() {
87+
use std::sync::Once;
88+
static INIT: Once = Once::new();
89+
INIT.call_once(|| {
90+
tracing_subscriber::fmt()
91+
.with_max_level(tracing::Level::INFO)
92+
.with_target(true)
93+
.with_thread_ids(true)
94+
.with_thread_names(true)
95+
.init();
96+
});
97+
}

tests/ready_on_poll_stream.rs

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
#[path = "h1_server/mod.rs"]
2+
mod h1_server;
3+
4+
use std::future::Future;
5+
use h1_server::{init_tracing, fixture, StreamReadHalf};
6+
use hyper::rt::{Read, ReadBufCursor, Write};
7+
use pin_project_lite::pin_project;
8+
use std::io;
9+
use std::pin::Pin;
10+
use std::task::{ready, Context, Poll};
11+
use std::time::Duration;
12+
use tokio::sync::mpsc;
13+
use tokio::time::Sleep;
14+
use tracing::error;
15+
16+
pin_project! {
17+
#[derive(Debug)]
18+
pub struct ReadyOnPollStream {
19+
#[pin]
20+
read_half: StreamReadHalf,
21+
write_tx: mpsc::UnboundedSender<Vec<u8>>,
22+
#[pin]
23+
pending_write: Option<Pin<Box<Sleep>>>,
24+
poll_since_write: bool,
25+
flush_count: usize,
26+
}
27+
}
28+
29+
impl ReadyOnPollStream {
30+
fn new(
31+
read_rx: mpsc::UnboundedReceiver<Vec<u8>>,
32+
write_tx: mpsc::UnboundedSender<Vec<u8>>,
33+
) -> Self {
34+
Self {
35+
read_half: StreamReadHalf::new(read_rx),
36+
write_tx,
37+
poll_since_write: true,
38+
flush_count: 0,
39+
pending_write: None,
40+
}
41+
}
42+
43+
/// Create a new server stream and client pair.
44+
/// Returns a server stream (Read+Write) and a client (rx/tx channels).
45+
pub fn new_pair() -> (Self, fixture::Client) {
46+
let (client_tx, server_rx) = mpsc::unbounded_channel();
47+
let (server_tx, client_rx) = mpsc::unbounded_channel();
48+
let server = Self::new(server_rx, server_tx);
49+
let client = fixture::Client {
50+
rx: client_rx,
51+
tx: client_tx,
52+
};
53+
(server, client)
54+
}
55+
}
56+
57+
impl Read for ReadyOnPollStream {
58+
fn poll_read(
59+
mut self: Pin<&mut Self>,
60+
cx: &mut Context<'_>,
61+
buf: ReadBufCursor<'_>,
62+
) -> Poll<io::Result<()>> {
63+
self.as_mut().project().read_half.poll_read(cx, buf)
64+
}
65+
}
66+
67+
const WRITE_DELAY: Duration = Duration::from_millis(100);
68+
69+
impl Write for ReadyOnPollStream {
70+
fn poll_write(
71+
mut self: Pin<&mut Self>,
72+
cx: &mut Context<'_>,
73+
buf: &[u8],
74+
) -> Poll<io::Result<usize>> {
75+
if let Some(sleep) = self.pending_write.as_mut() {
76+
let sleep = sleep.as_mut();
77+
ready!(Future::poll(sleep, cx));
78+
}
79+
{
80+
let mut this = self.as_mut().project();
81+
this.pending_write.set(Some(Box::pin(tokio::time::sleep(WRITE_DELAY))));
82+
}
83+
let Some(sleep) = self.pending_write.as_mut() else {
84+
panic!("Sleep should have just been set");
85+
};
86+
// poll the future so that we can woken
87+
let sleep = sleep.as_mut();
88+
let Poll::Pending = Future::poll(sleep, cx) else {
89+
panic!("Sleep always be pending on first poll")
90+
};
91+
92+
let this = self.project();
93+
let buf = Vec::from(&buf[..buf.len()]);
94+
let len = buf.len();
95+
96+
// Send data through the channel - this should always be ready for unbounded channels
97+
match this.write_tx.send(buf) {
98+
Ok(_) => Poll::Ready(Ok(len)),
99+
Err(_) => {
100+
error!("ReadyStream::poll_write failed - channel closed");
101+
Poll::Ready(Err(io::Error::new(
102+
io::ErrorKind::BrokenPipe,
103+
"Write channel closed",
104+
)))
105+
}
106+
}
107+
}
108+
109+
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
110+
self.flush_count += 1;
111+
// We require two flushes to complete each chunk, simulating a success at the end of the old
112+
// poll loop. After all chunks are written, we always succeed on flush to allow for finish.
113+
const TOTAL_CHUNKS: usize = 16;
114+
if self.flush_count % 2 != 0 && self.flush_count < TOTAL_CHUNKS * 2 {
115+
if let Some(sleep) = self.pending_write.as_mut() {
116+
let sleep = sleep.as_mut();
117+
ready!(Future::poll(sleep, cx));
118+
} else {
119+
return Poll::Pending;
120+
}
121+
}
122+
let mut this = self.as_mut().project();
123+
this.pending_write.set(None);
124+
Poll::Ready(Ok(()))
125+
}
126+
127+
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
128+
Poll::Ready(Ok(()))
129+
}
130+
}
131+
132+
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
133+
async fn body_test() {
134+
init_tracing();
135+
let (server, client) = ReadyOnPollStream::new_pair();
136+
let config = fixture::TestConfig::with_timeout(WRITE_DELAY*2);
137+
fixture::run(server, client, config).await;
138+
}

0 commit comments

Comments
 (0)