From bb88bc84a9f089201ed996d0b758528389f28ff5 Mon Sep 17 00:00:00 2001 From: Rain Date: Thu, 5 Jan 2023 12:27:34 -0800 Subject: [PATCH 1/8] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20ch?= =?UTF-8?q?anges=20to=20main=20this=20commit=20is=20based=20on?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Created using spr 1.3.4 [skip ci] --- Cargo.lock | 10 ++++++++++ dropshot/Cargo.toml | 1 + dropshot/src/logging.rs | 9 ++++++--- dropshot/src/test_util.rs | 21 +++++++++++---------- dropshot/tests/common/mod.rs | 2 +- dropshot/tests/fail/bad_endpoint4.stderr | 2 +- dropshot/tests/fail/bad_endpoint5.stderr | 2 +- dropshot/tests/test_pagination.rs | 2 +- 8 files changed, 32 insertions(+), 17 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c9f6c4699..489ab0aad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -132,6 +132,15 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dfb24e866b15a1af2a1b663f10c6b6b8f397a84aadb828f12e5b289ec23a3a3c" +[[package]] +name = "camino" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88ad0e1e3e88dd237a156ab9f571021b8a158caa0ae44b1968a241efb5144c1e" +dependencies = [ + "serde", +] + [[package]] name = "cc" version = "1.0.72" @@ -286,6 +295,7 @@ dependencies = [ "async-trait", "base64 0.20.0", "bytes", + "camino", "chrono", "dropshot_endpoint", "expectorate", diff --git a/dropshot/Cargo.toml b/dropshot/Cargo.toml index 0028e4d66..18f2433bd 100644 --- a/dropshot/Cargo.toml +++ b/dropshot/Cargo.toml @@ -15,6 +15,7 @@ async-stream = "0.3.3" async-trait = "0.1.60" base64 = "0.20.0" bytes = "1" +camino = { version = "1.1.1", features = ["serde1"] } futures = "0.3.25" hostname = "0.3.0" http = "0.2.8" diff --git a/dropshot/src/logging.rs b/dropshot/src/logging.rs index 2ce9b4d47..8a5d50326 100644 --- a/dropshot/src/logging.rs +++ b/dropshot/src/logging.rs @@ -5,12 +5,14 @@ * they're provided because they're commonly wanted by consumers of this crate. */ +use camino::Utf8PathBuf; use serde::Deserialize; use serde::Serialize; use slog::Drain; use slog::Level; use slog::Logger; use std::fs::OpenOptions; +use std::io::LineWriter; use std::{io, path::Path}; /** @@ -25,7 +27,7 @@ pub enum ConfigLogging { /** Bunyan-formatted output to a specified file. */ File { level: ConfigLoggingLevel, - path: String, + path: Utf8PathBuf, if_exists: ConfigLoggingIfExists, }, } @@ -136,12 +138,13 @@ fn log_drain_for_file( open_options: &OpenOptions, path: &Path, log_name: String, -) -> Result>, io::Error> { +) -> Result>>, io::Error> { if let Some(parent) = path.parent() { std::fs::create_dir_all(parent)?; } - let file = open_options.open(path)?; + // Buffer writes to the file around newlines to minimize syscalls. + let file = LineWriter::new(open_options.open(path)?); /* * Record a message to the stderr so that a reader who doesn't already know diff --git a/dropshot/src/test_util.rs b/dropshot/src/test_util.rs index 7271bae77..36b652e82 100644 --- a/dropshot/src/test_util.rs +++ b/dropshot/src/test_util.rs @@ -4,6 +4,7 @@ * and dependents of this crate. */ +use camino::Utf8PathBuf; use chrono::DateTime; use chrono::Utc; use http::method::Method; @@ -19,12 +20,12 @@ use serde::de::DeserializeOwned; use serde::Deserialize; use serde::Serialize; use slog::Logger; +use std::convert::TryFrom; use std::fmt::Debug; use std::fs; use std::iter::Iterator; use std::net::SocketAddr; use std::path::Path; -use std::path::PathBuf; use std::sync::atomic::AtomicU32; use std::sync::atomic::Ordering; @@ -426,7 +427,7 @@ impl ClientTestContext { pub struct LogContext { /** general-purpose logger */ pub log: Logger, - log_path: Option, + log_path: Option, } impl LogContext { @@ -459,13 +460,12 @@ impl LogContext { each test." ); let new_path = log_file_for_test(test_name); - let new_path_str = new_path.as_path().display().to_string(); - eprintln!("log file: {:?}", new_path_str); + eprintln!("log file: {}", new_path); ( - Some(new_path), + Some(new_path.clone()), ConfigLogging::File { level: level.clone(), - path: new_path_str, + path: new_path, if_exists: if_exists.clone(), }, ) @@ -764,13 +764,14 @@ static TEST_SUITE_LOGGER_ID: AtomicU32 = AtomicU32::new(0); * Returns a unique path name in a temporary directory that includes the given * `test_name`. */ -pub fn log_file_for_test(test_name: &str) -> PathBuf { +pub fn log_file_for_test(test_name: &str) -> Utf8PathBuf { let arg0 = { - let arg0path = std::env::args().next().unwrap(); - Path::new(&arg0path).file_name().unwrap().to_str().unwrap().to_string() + let arg0path = Utf8PathBuf::from(std::env::args().next().unwrap()); + arg0path.file_name().unwrap().to_owned() }; - let mut pathbuf = std::env::temp_dir(); + let mut pathbuf = Utf8PathBuf::try_from(std::env::temp_dir()) + .expect("temp dir is valid UTF-8"); let id = TEST_SUITE_LOGGER_ID.fetch_add(1, Ordering::SeqCst); let pid = std::process::id(); pathbuf.push(format!("{}-{}.{}.{}.log", arg0, test_name, pid, id)); diff --git a/dropshot/tests/common/mod.rs b/dropshot/tests/common/mod.rs index 93e8951d8..b74332b1e 100644 --- a/dropshot/tests/common/mod.rs +++ b/dropshot/tests/common/mod.rs @@ -37,7 +37,7 @@ pub fn test_setup( pub fn create_log_context(test_name: &str) -> LogContext { let log_config = ConfigLogging::File { level: ConfigLoggingLevel::Debug, - path: "UNUSED".to_string(), + path: "UNUSED".into(), if_exists: ConfigLoggingIfExists::Fail, }; LogContext::new(test_name, &log_config) diff --git a/dropshot/tests/fail/bad_endpoint4.stderr b/dropshot/tests/fail/bad_endpoint4.stderr index d497f4030..d99a3626b 100644 --- a/dropshot/tests/fail/bad_endpoint4.stderr +++ b/dropshot/tests/fail/bad_endpoint4.stderr @@ -28,13 +28,13 @@ error[E0277]: the trait bound `for<'de> QueryParams: serde::de::Deserialize<'de> | = help: the following other types implement trait `serde::de::Deserialize<'de>`: &'a [u8] + &'a camino::Utf8Path &'a std::path::Path &'a str () (T0, T1) (T0, T1, T2) (T0, T1, T2, T3) - (T0, T1, T2, T3, T4) and $N others = note: required for `QueryParams` to implement `serde::de::DeserializeOwned` note: required by a bound in `dropshot::Query` diff --git a/dropshot/tests/fail/bad_endpoint5.stderr b/dropshot/tests/fail/bad_endpoint5.stderr index 044b61537..799f4e82d 100644 --- a/dropshot/tests/fail/bad_endpoint5.stderr +++ b/dropshot/tests/fail/bad_endpoint5.stderr @@ -6,13 +6,13 @@ error[E0277]: the trait bound `for<'de> QueryParams: serde::de::Deserialize<'de> | = help: the following other types implement trait `serde::de::Deserialize<'de>`: &'a [u8] + &'a camino::Utf8Path &'a std::path::Path &'a str () (T0, T1) (T0, T1, T2) (T0, T1, T2, T3) - (T0, T1, T2, T3, T4) and $N others = note: required for `QueryParams` to implement `serde::de::DeserializeOwned` note: required by a bound in `dropshot::Query` diff --git a/dropshot/tests/test_pagination.rs b/dropshot/tests/test_pagination.rs index 841708a71..83c9379fe 100644 --- a/dropshot/tests/test_pagination.rs +++ b/dropshot/tests/test_pagination.rs @@ -873,7 +873,7 @@ async fn start_example(path: &str, port: u16) -> ExampleContext { path, &ConfigLogging::File { level: ConfigLoggingLevel::Info, - path: "UNUSED".to_string(), + path: "UNUSED".into(), if_exists: ConfigLoggingIfExists::Fail, }, ); From cc65ece949721bd06857b2007b194218c48beb54 Mon Sep 17 00:00:00 2001 From: Rain Date: Thu, 5 Jan 2023 15:19:39 -0800 Subject: [PATCH 2/8] Use async-stream rather than tokio-stream Created using spr 1.3.4 --- Cargo.lock | 12 ------------ dropshot/Cargo.toml | 1 - dropshot/src/handler.rs | 14 ++++---------- 3 files changed, 4 insertions(+), 23 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 451f7aff0..a32a89095 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -341,7 +341,6 @@ dependencies = [ "tempfile", "tokio", "tokio-rustls", - "tokio-stream", "tokio-tungstenite", "toml", "trybuild", @@ -1668,17 +1667,6 @@ dependencies = [ "webpki", ] -[[package]] -name = "tokio-stream" -version = "0.1.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d660770404473ccd7bc9f8b28494a811bc18542b915c0855c51e8f419d5223ce" -dependencies = [ - "futures-core", - "pin-project-lite", - "tokio", -] - [[package]] name = "tokio-tungstenite" version = "0.18.0" diff --git a/dropshot/Cargo.toml b/dropshot/Cargo.toml index 15a79b35e..6504b6148 100644 --- a/dropshot/Cargo.toml +++ b/dropshot/Cargo.toml @@ -35,7 +35,6 @@ slog-bunyan = "2.4.0" slog-json = "2.6.1" slog-term = "2.9.0" tokio-rustls = "0.23.4" -tokio-stream = "0.1.11" toml = "0.5.10" [dependencies.chrono] diff --git a/dropshot/src/handler.rs b/dropshot/src/handler.rs index d60ecba03..273059e09 100644 --- a/dropshot/src/handler.rs +++ b/dropshot/src/handler.rs @@ -53,6 +53,7 @@ use crate::router::VariableSet; use crate::to_map::to_map; use crate::websocket::WEBSOCKET_PARAM_SENTINEL; +use async_stream::try_stream; use async_trait::async_trait; use buf_list::BufList; use bytes::Bytes; @@ -79,7 +80,6 @@ use std::future::Future; use std::marker::PhantomData; use std::num::NonZeroU32; use std::sync::Arc; -use tokio_stream::wrappers::ReceiverStream; /** * Type alias for the result returned by HTTP handler functions. @@ -1148,19 +1148,13 @@ impl UntypedBody { self, ) -> impl Stream> + Send + Sync + 'static { - let (sender, receiver) = tokio::sync::mpsc::channel(8); - tokio::spawn(async move { + try_stream! { let mut request = self.request.lock().await; let body = request.body_mut(); while let Some(data) = body.data().await { - if let Err(_) = sender.send(data.map_err(Into::into)).await { - // The receiver was dropped -- drop the stream. - break; - } + yield data?; } - }); - - ReceiverStream::new(receiver) + } } } From a654e72f30652ccf7ec4cf03d293f4b311162535 Mon Sep 17 00:00:00 2001 From: Rain Date: Thu, 5 Jan 2023 22:16:31 -0800 Subject: [PATCH 3/8] Read trailers Created using spr 1.3.4 --- dropshot/src/handler.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/dropshot/src/handler.rs b/dropshot/src/handler.rs index 273059e09..03904f0f6 100644 --- a/dropshot/src/handler.rs +++ b/dropshot/src/handler.rs @@ -1154,6 +1154,9 @@ impl UntypedBody { while let Some(data) = body.data().await { yield data?; } + // Read the trailers even though we aren't going to do anything with + // them. + body.trailers().await?; } } } @@ -1194,6 +1197,12 @@ impl Extractor for UntypedBody { } } +struct CappedStream { + inner: St, + max_bytes: usize, + current_bytes: usize, +} + /* * Response Type Conversion * From e953952f8c1ea2af254399bbe3607a7cc0aede60 Mon Sep 17 00:00:00 2001 From: Rain Date: Fri, 6 Jan 2023 09:56:11 -0800 Subject: [PATCH 4/8] Simplify implementation a little Created using spr 1.3.4 --- Cargo.lock | 5 +- dropshot/Cargo.toml | 1 + dropshot/src/handler.rs | 204 ++++++++++++++++++++++++++++++------ dropshot/src/http_util.rs | 166 ----------------------------- dropshot/src/lib.rs | 2 + dropshot/tests/test_demo.rs | 2 +- 6 files changed, 178 insertions(+), 202 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a32a89095..253c7a55d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -323,6 +323,7 @@ dependencies = [ "paste", "pem", "percent-encoding", + "pin-project-lite", "proc-macro2", "rcgen", "rustls", @@ -1001,9 +1002,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.7" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d31d11c69a6b52a174b42bdc0c30e5e11670f90788b2c471c31c1d17d449443" +checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" [[package]] name = "pin-utils" diff --git a/dropshot/Cargo.toml b/dropshot/Cargo.toml index 6504b6148..ef56076a0 100644 --- a/dropshot/Cargo.toml +++ b/dropshot/Cargo.toml @@ -23,6 +23,7 @@ http = "0.2.8" indexmap = "1.9.2" paste = "1.0.11" percent-encoding = "2.2.0" +pin-project-lite = "0.2.9" proc-macro2 = "1.0.49" rustls = "0.20.7" rustls-pemfile = "1.0.1" diff --git a/dropshot/src/handler.rs b/dropshot/src/handler.rs index 03904f0f6..4d3810884 100644 --- a/dropshot/src/handler.rs +++ b/dropshot/src/handler.rs @@ -35,7 +35,6 @@ use super::error::HttpError; use super::http_util::http_extract_path_params; -use super::http_util::http_read_body_bytes; use super::http_util::CONTENT_TYPE_JSON; use super::http_util::CONTENT_TYPE_OCTET_STREAM; use super::server::DropshotState; @@ -46,25 +45,30 @@ use crate::api_description::ApiEndpointParameterLocation; use crate::api_description::ApiEndpointResponse; use crate::api_description::ApiSchemaGenerator; use crate::api_description::{ApiEndpointBodyContentType, ExtensionMode}; -use crate::http_util::http_read_body_buf_list; use crate::pagination::PaginationParams; use crate::pagination::PAGINATION_PARAM_SENTINEL; use crate::router::VariableSet; use crate::to_map::to_map; use crate::websocket::WEBSOCKET_PARAM_SENTINEL; -use async_stream::try_stream; +use async_stream::stream; use async_trait::async_trait; use buf_list::BufList; +use bytes::BufMut; use bytes::Bytes; +use bytes::BytesMut; use futures::lock::Mutex; +use futures::ready; +use futures::stream::BoxStream; use futures::Stream; +use futures::StreamExt; use http::HeaderMap; use http::StatusCode; use hyper::body::HttpBody; use hyper::Body; use hyper::Request; use hyper::Response; +use pin_project_lite::pin_project; use schemars::schema::InstanceType; use schemars::schema::SchemaObject; use schemars::JsonSchema; @@ -79,7 +83,9 @@ use std::fmt::Result as FmtResult; use std::future::Future; use std::marker::PhantomData; use std::num::NonZeroU32; +use std::pin::Pin; use std::sync::Arc; +use std::task::Poll; /** * Type alias for the result returned by HTTP handler functions. @@ -976,13 +982,15 @@ where BodyType: JsonSchema + DeserializeOwned + Send + Sync, { let server = &rqctx.server; - let mut request = rqctx.request.lock().await; - let body = http_read_body_bytes( - request.body_mut(), - server.config.request_body_max_bytes, - ) + let body = UntypedBody { + request: rqctx.request.clone(), + max_bytes: server.config.request_body_max_bytes, + } + .into_bytes() .await?; + let request = rqctx.request.lock().await; + // RFC 7231 ยง3.1.1.1: media types are case insensitive and may // be followed by whitespace and/or a parameter (e.g., charset), // which we currently ignore. @@ -1093,8 +1101,14 @@ impl UntypedBody { /// /// Errors if the request body is too large, or if the pub async fn into_bytes(self) -> Result { - let mut request = self.request.lock().await; - http_read_body_bytes(request.body_mut(), self.max_bytes).await + let mut stream = self.into_stream(); + let mut bytes = BytesMut::new(); + + while let Some(data) = stream.next().await { + bytes.put(data?); + } + + Ok(bytes.freeze()) } /// Reads the body into a `String`. @@ -1122,7 +1136,7 @@ impl UntypedBody { /// Recommended for larger request bodies. pub async fn into_buf_list(self) -> Result { let max_bytes = self.max_bytes; - self.into_buf_list_with_limit(max_bytes).await + self.into_buf_list_with_cap(max_bytes).await } /// Reads the body into a [`BufList`] with a custom limit for the maximum @@ -1132,32 +1146,50 @@ impl UntypedBody { /// a custom limit is specified. If this method is called, the default /// [`request_body_max_bytes`](ServerConfig::request_body_max_bytes) limit /// is ignored. - pub async fn into_buf_list_with_limit( + pub async fn into_buf_list_with_cap( self, max_bytes: usize, ) -> Result { - let mut request = self.request.lock().await; - http_read_body_buf_list(request.body_mut(), max_bytes).await + let mut stream = + self.into_stream().into_uncapped().into_capped(max_bytes); + let mut buf_list = BufList::new(); + + while let Some(data) = stream.next().await { + buf_list.push_chunk(data?); + } + + Ok(buf_list) } - /// Converts `self` into a [`Stream`] of [`Bytes`] chunks. - /// - /// This method ignores the - /// [`request_body_max_bytes`](ServerConfig::request_body_max_bytes) limit. - pub fn into_stream( - self, - ) -> impl Stream> + Send + Sync + 'static - { - try_stream! { + /// Converts `self` into a [`Stream`] of `Result` chunks. + pub fn into_stream(self) -> CappedBodyStream { + let max_bytes = self.max_bytes; + + let stream = stream! { let mut request = self.request.lock().await; let body = request.body_mut(); - while let Some(data) = body.data().await { - yield data?; + + 'outer: { + while let Some(data) = body.data().await { + match data { + Ok(data) => yield Ok(data), + Err(e) => { + yield Err(HttpError::from(e)); + break 'outer; + } + } + } + + // Read the trailers even though we aren't going to do anything + // with them. + if let Err(e) = body.trailers().await { + yield Err(HttpError::from(e)); + } } - // Read the trailers even though we aren't going to do anything with - // them. - body.trailers().await?; - } + }; + + let uncapped = UncappedBodyStream::new(stream); + CappedBodyStream::new(uncapped, max_bytes) } } @@ -1197,10 +1229,116 @@ impl Extractor for UntypedBody { } } -struct CappedStream { - inner: St, - max_bytes: usize, - current_bytes: usize, +pin_project! { + /// A stream over an HTTP body that sets a limit on the number of bytes that + /// can be read from it. + /// + /// To change the cap read from it, use + /// [`into_uncapped`](Self::into_uncapped), then apply a new cap with + /// `UncappedStream::into_capped`. + pub struct CappedBodyStream { + #[pin] + stream: UncappedBodyStream, + max_bytes: usize, + current_bytes: usize, + } +} + +impl CappedBodyStream { + pub(crate) fn new(stream: UncappedBodyStream, max_bytes: usize) -> Self { + Self { stream, max_bytes, current_bytes: 0 } + } + + /// Returns the maximum number of bytes that can be read from this stream. + pub fn max_bytes(&self) -> usize { + self.max_bytes + } + + /// Returns the current number of bytes read from the stream. + pub fn current_bytes(&self) -> usize { + self.current_bytes + } + + /// Turns this stream into an uncapped one. + pub fn into_uncapped(self) -> UncappedBodyStream { + self.stream + } +} + +impl Stream for CappedBodyStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let this = self.as_mut().project(); + let result = ready!(this.stream.poll_next(cx)); + let bytes = match result { + Some(Ok(bytes)) => bytes, + x @ None | x @ Some(Err(_)) => return Poll::Ready(x), + }; + + let max_bytes = *this.max_bytes; + + let is_too_large = + Arc::new(this.current_bytes.checked_add(bytes.len())) + .map_or(true, |x| x > max_bytes); + if is_too_large { + // The request was too large. Drain the rest of the stream. + while let Some(data) = + futures::ready!(self.as_mut().project().stream.poll_next(cx)) + { + if let Err(e) = data { + return Poll::Ready(Some(Err(e))); + } + } + return Poll::Ready(Some(Err(HttpError::for_bad_request( + None, + format!( + "request body exceeded maximum size of {} bytes", + max_bytes + ), + )))); + } + + *this.current_bytes += bytes.len(); + + Poll::Ready(Some(Ok(bytes))) + } +} + +pin_project! { + /// A stream over an HTTP request body that does not cap the bytes. + pub struct UncappedBodyStream { + // TODO: replace with a concrete type once TAIT is stabilized. + #[pin] + stream: BoxStream<'static, Result>, + } +} + +impl UncappedBodyStream { + pub(crate) fn new( + stream: impl Stream> + Send + 'static, + ) -> Self { + Self { stream: stream.boxed() } + } + + /// Adds a cap to this stream. + pub fn into_capped(self, cap: usize) -> CappedBodyStream { + CappedBodyStream::new(self, cap) + } +} + +impl Stream for UncappedBodyStream { + type Item = Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.project().stream.poll_next(cx) + } } /* diff --git a/dropshot/src/http_util.rs b/dropshot/src/http_util.rs index dcb75ab92..d50885b81 100644 --- a/dropshot/src/http_util.rs +++ b/dropshot/src/http_util.rs @@ -3,11 +3,6 @@ * General-purpose HTTP-related facilities */ -use buf_list::BufList; -use bytes::BufMut; -use bytes::Bytes; -use bytes::BytesMut; -use hyper::body::HttpBody; use serde::de::DeserializeOwned; use super::error::HttpError; @@ -25,167 +20,6 @@ pub const CONTENT_TYPE_NDJSON: &str = "application/x-ndjson"; /** MIME type for form/urlencoded data */ pub const CONTENT_TYPE_URL_ENCODED: &str = "application/x-www-form-urlencoded"; -/// Reads the rest of the body from the request up to the given number of bytes, as a `Bytes`. -/// -/// This is intended for smaller bodies (e.g. a few kilobytes). -/// -/// # Errors -/// -/// Errors if the body length exceeds the given cap. -pub async fn http_read_body_bytes( - body: &mut T, - cap: usize, -) -> Result -where - T: HttpBody + std::marker::Unpin, -{ - http_read_body::(body, cap).await -} - -/// Reads the rest of the body from the request up to the given number of bytes, as a `BufList`. -/// -/// This is intended for larger bodies (e.g. a megabyte or larger). -/// -/// # Errors -/// -/// Errors if the body length exceeds the given cap. -pub async fn http_read_body_buf_list( - body: &mut T, - cap: usize, -) -> Result -where - T: HttpBody + std::marker::Unpin, -{ - http_read_body::(body, cap).await -} - -trait BufListLike { - type Output; - fn new() -> Self; - fn push_chunk(&mut self, chunk: Bytes); - fn finish(self) -> Self::Output; -} - -impl BufListLike for BytesMut { - type Output = Bytes; - - fn new() -> Self { - BytesMut::new() - } - - fn push_chunk(&mut self, chunk: Bytes) { - self.put(chunk); - } - - fn finish(self) -> Self::Output { - self.freeze() - } -} - -impl BufListLike for BufList { - type Output = BufList; - - fn new() -> Self { - BufList::new() - } - - fn push_chunk(&mut self, chunk: Bytes) { - self.push_chunk(chunk); - } - - fn finish(self) -> Self::Output { - self - } -} - -/** - * Reads the rest of the body from the request up to the given number of bytes. - * If the body fits within the specified cap, a buffer is returned with all the - * bytes read. If not, an error is returned. - */ -async fn http_read_body( - body: &mut T, - cap: usize, -) -> Result -where - T: HttpBody + std::marker::Unpin, - B: BufListLike, -{ - /* - * This looks a lot like the implementation of hyper::body::to_bytes(), but - * applies the requested cap. We've skipped the optimization for the - * 1-buffer case for now, as it seems likely this implementation will change - * anyway. - * TODO should this use some Stream interface instead? - * TODO why does this look so different in type signature (Data=Bytes, - * std::marker::Unpin, &mut T) - * TODO Error type shouldn't have to be hyper Error -- Into should - * work too? - * TODO do we need to use saturating_add() here? - */ - let mut parts = B::new(); - let mut nbytesread: usize = 0; - while let Some(maybebuf) = body.data().await { - let buf = maybebuf?; - let bufsize = buf.len(); - - if nbytesread + bufsize > cap { - http_dump_body(body).await?; - // TODO-correctness check status code - return Err(HttpError::for_bad_request( - None, - format!("request body exceeded maximum size of {} bytes", cap), - )); - } - - nbytesread += bufsize; - parts.push_chunk(buf); - } - - /* - * Read the trailers as well, even though we're not going to do anything - * with them. - */ - body.trailers().await?; - /* - * TODO-correctness why does the is_end_stream() assertion fail and the next - * one panic? - */ - // assert!(body.is_end_stream()); - // assert!(body.data().await.is_none()); - // assert!(body.trailers().await?.is_none()); - Ok(parts.finish()) -} - -/** - * Reads the rest of the body from the request, dropping all the bytes. This is - * useful after encountering error conditions. - */ -pub async fn http_dump_body(body: &mut T) -> Result -where - T: HttpBody + std::marker::Unpin, -{ - /* - * TODO should this use some Stream interface instead? - * TODO-hardening: does this actually cap the amount of data that will be - * read? What if the underlying implementation chooses to wait for a much - * larger number of bytes? - * TODO better understand pin_mut!() - * TODO do we need to use saturating_add() here? - */ - let mut nbytesread: usize = 0; - while let Some(maybebuf) = body.data().await { - let buf = maybebuf?; - nbytesread += buf.len(); - } - - /* - * TODO-correctness why does the is_end_stream() assertion fail? - */ - // assert!(body.is_end_stream()); - Ok(nbytesread) -} - /** * Given a set of variables (most immediately from a RequestContext, likely * generated by the HttpRouter when routing an incoming request), extract them diff --git a/dropshot/src/lib.rs b/dropshot/src/lib.rs index 8c0f2e76e..0837efe74 100644 --- a/dropshot/src/lib.rs +++ b/dropshot/src/lib.rs @@ -637,6 +637,7 @@ pub use error::HttpErrorResponseBody; pub use handler::http_response_found; pub use handler::http_response_see_other; pub use handler::http_response_temporary_redirect; +pub use handler::CappedBodyStream; pub use handler::Extractor; pub use handler::ExtractorMetadata; pub use handler::FreeformBody; @@ -656,6 +657,7 @@ pub use handler::Path; pub use handler::Query; pub use handler::RequestContext; pub use handler::TypedBody; +pub use handler::UncappedBodyStream; pub use handler::UntypedBody; pub use http_util::CONTENT_TYPE_JSON; pub use http_util::CONTENT_TYPE_NDJSON; diff --git a/dropshot/tests/test_demo.rs b/dropshot/tests/test_demo.rs index b3bbc75f8..2309d4251 100644 --- a/dropshot/tests/test_demo.rs +++ b/dropshot/tests/test_demo.rs @@ -1107,7 +1107,7 @@ async fn demo_handler_untyped_body( UntypedQueryInto::BufList => { let buf_list = match query.limit { Some(max_bytes) => { - body.into_buf_list_with_limit(max_bytes).await? + body.into_buf_list_with_cap(max_bytes).await? } None => body.into_buf_list().await?, }; From a63ffee3c113f05b1b7095ba14a7658fe088caeb Mon Sep 17 00:00:00 2001 From: Rain Date: Fri, 6 Jan 2023 10:03:00 -0800 Subject: [PATCH 5/8] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20ch?= =?UTF-8?q?anges=20introduced=20through=20rebase?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Created using spr 1.3.4 [skip ci] --- Cargo.lock | 10 ---------- dropshot/Cargo.toml | 1 - dropshot/src/logging.rs | 9 +++------ dropshot/src/test_util.rs | 21 ++++++++++----------- dropshot/tests/common/mod.rs | 2 +- dropshot/tests/fail/bad_endpoint4.stderr | 2 +- dropshot/tests/fail/bad_endpoint5.stderr | 2 +- dropshot/tests/test_pagination.rs | 2 +- 8 files changed, 17 insertions(+), 32 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 489ab0aad..c9f6c4699 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -132,15 +132,6 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dfb24e866b15a1af2a1b663f10c6b6b8f397a84aadb828f12e5b289ec23a3a3c" -[[package]] -name = "camino" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ad0e1e3e88dd237a156ab9f571021b8a158caa0ae44b1968a241efb5144c1e" -dependencies = [ - "serde", -] - [[package]] name = "cc" version = "1.0.72" @@ -295,7 +286,6 @@ dependencies = [ "async-trait", "base64 0.20.0", "bytes", - "camino", "chrono", "dropshot_endpoint", "expectorate", diff --git a/dropshot/Cargo.toml b/dropshot/Cargo.toml index 18f2433bd..0028e4d66 100644 --- a/dropshot/Cargo.toml +++ b/dropshot/Cargo.toml @@ -15,7 +15,6 @@ async-stream = "0.3.3" async-trait = "0.1.60" base64 = "0.20.0" bytes = "1" -camino = { version = "1.1.1", features = ["serde1"] } futures = "0.3.25" hostname = "0.3.0" http = "0.2.8" diff --git a/dropshot/src/logging.rs b/dropshot/src/logging.rs index 8a5d50326..2ce9b4d47 100644 --- a/dropshot/src/logging.rs +++ b/dropshot/src/logging.rs @@ -5,14 +5,12 @@ * they're provided because they're commonly wanted by consumers of this crate. */ -use camino::Utf8PathBuf; use serde::Deserialize; use serde::Serialize; use slog::Drain; use slog::Level; use slog::Logger; use std::fs::OpenOptions; -use std::io::LineWriter; use std::{io, path::Path}; /** @@ -27,7 +25,7 @@ pub enum ConfigLogging { /** Bunyan-formatted output to a specified file. */ File { level: ConfigLoggingLevel, - path: Utf8PathBuf, + path: String, if_exists: ConfigLoggingIfExists, }, } @@ -138,13 +136,12 @@ fn log_drain_for_file( open_options: &OpenOptions, path: &Path, log_name: String, -) -> Result>>, io::Error> { +) -> Result>, io::Error> { if let Some(parent) = path.parent() { std::fs::create_dir_all(parent)?; } - // Buffer writes to the file around newlines to minimize syscalls. - let file = LineWriter::new(open_options.open(path)?); + let file = open_options.open(path)?; /* * Record a message to the stderr so that a reader who doesn't already know diff --git a/dropshot/src/test_util.rs b/dropshot/src/test_util.rs index 36b652e82..7271bae77 100644 --- a/dropshot/src/test_util.rs +++ b/dropshot/src/test_util.rs @@ -4,7 +4,6 @@ * and dependents of this crate. */ -use camino::Utf8PathBuf; use chrono::DateTime; use chrono::Utc; use http::method::Method; @@ -20,12 +19,12 @@ use serde::de::DeserializeOwned; use serde::Deserialize; use serde::Serialize; use slog::Logger; -use std::convert::TryFrom; use std::fmt::Debug; use std::fs; use std::iter::Iterator; use std::net::SocketAddr; use std::path::Path; +use std::path::PathBuf; use std::sync::atomic::AtomicU32; use std::sync::atomic::Ordering; @@ -427,7 +426,7 @@ impl ClientTestContext { pub struct LogContext { /** general-purpose logger */ pub log: Logger, - log_path: Option, + log_path: Option, } impl LogContext { @@ -460,12 +459,13 @@ impl LogContext { each test." ); let new_path = log_file_for_test(test_name); - eprintln!("log file: {}", new_path); + let new_path_str = new_path.as_path().display().to_string(); + eprintln!("log file: {:?}", new_path_str); ( - Some(new_path.clone()), + Some(new_path), ConfigLogging::File { level: level.clone(), - path: new_path, + path: new_path_str, if_exists: if_exists.clone(), }, ) @@ -764,14 +764,13 @@ static TEST_SUITE_LOGGER_ID: AtomicU32 = AtomicU32::new(0); * Returns a unique path name in a temporary directory that includes the given * `test_name`. */ -pub fn log_file_for_test(test_name: &str) -> Utf8PathBuf { +pub fn log_file_for_test(test_name: &str) -> PathBuf { let arg0 = { - let arg0path = Utf8PathBuf::from(std::env::args().next().unwrap()); - arg0path.file_name().unwrap().to_owned() + let arg0path = std::env::args().next().unwrap(); + Path::new(&arg0path).file_name().unwrap().to_str().unwrap().to_string() }; - let mut pathbuf = Utf8PathBuf::try_from(std::env::temp_dir()) - .expect("temp dir is valid UTF-8"); + let mut pathbuf = std::env::temp_dir(); let id = TEST_SUITE_LOGGER_ID.fetch_add(1, Ordering::SeqCst); let pid = std::process::id(); pathbuf.push(format!("{}-{}.{}.{}.log", arg0, test_name, pid, id)); diff --git a/dropshot/tests/common/mod.rs b/dropshot/tests/common/mod.rs index b74332b1e..93e8951d8 100644 --- a/dropshot/tests/common/mod.rs +++ b/dropshot/tests/common/mod.rs @@ -37,7 +37,7 @@ pub fn test_setup( pub fn create_log_context(test_name: &str) -> LogContext { let log_config = ConfigLogging::File { level: ConfigLoggingLevel::Debug, - path: "UNUSED".into(), + path: "UNUSED".to_string(), if_exists: ConfigLoggingIfExists::Fail, }; LogContext::new(test_name, &log_config) diff --git a/dropshot/tests/fail/bad_endpoint4.stderr b/dropshot/tests/fail/bad_endpoint4.stderr index d99a3626b..d497f4030 100644 --- a/dropshot/tests/fail/bad_endpoint4.stderr +++ b/dropshot/tests/fail/bad_endpoint4.stderr @@ -28,13 +28,13 @@ error[E0277]: the trait bound `for<'de> QueryParams: serde::de::Deserialize<'de> | = help: the following other types implement trait `serde::de::Deserialize<'de>`: &'a [u8] - &'a camino::Utf8Path &'a std::path::Path &'a str () (T0, T1) (T0, T1, T2) (T0, T1, T2, T3) + (T0, T1, T2, T3, T4) and $N others = note: required for `QueryParams` to implement `serde::de::DeserializeOwned` note: required by a bound in `dropshot::Query` diff --git a/dropshot/tests/fail/bad_endpoint5.stderr b/dropshot/tests/fail/bad_endpoint5.stderr index 799f4e82d..044b61537 100644 --- a/dropshot/tests/fail/bad_endpoint5.stderr +++ b/dropshot/tests/fail/bad_endpoint5.stderr @@ -6,13 +6,13 @@ error[E0277]: the trait bound `for<'de> QueryParams: serde::de::Deserialize<'de> | = help: the following other types implement trait `serde::de::Deserialize<'de>`: &'a [u8] - &'a camino::Utf8Path &'a std::path::Path &'a str () (T0, T1) (T0, T1, T2) (T0, T1, T2, T3) + (T0, T1, T2, T3, T4) and $N others = note: required for `QueryParams` to implement `serde::de::DeserializeOwned` note: required by a bound in `dropshot::Query` diff --git a/dropshot/tests/test_pagination.rs b/dropshot/tests/test_pagination.rs index 83c9379fe..841708a71 100644 --- a/dropshot/tests/test_pagination.rs +++ b/dropshot/tests/test_pagination.rs @@ -873,7 +873,7 @@ async fn start_example(path: &str, port: u16) -> ExampleContext { path, &ConfigLogging::File { level: ConfigLoggingLevel::Info, - path: "UNUSED".into(), + path: "UNUSED".to_string(), if_exists: ConfigLoggingIfExists::Fail, }, ); From d3d93a4b2b3b67c26d60ae383d81fa0dfb2a6f74 Mon Sep 17 00:00:00 2001 From: Rain Date: Fri, 6 Jan 2023 13:13:57 -0800 Subject: [PATCH 6/8] Putting this up for discussion Created using spr 1.3.4 --- Cargo.lock | 4 +- dropshot/Cargo.toml | 2 +- dropshot/src/handler.rs | 81 ++++++++++++++++++++++++++--------------- 3 files changed, 55 insertions(+), 32 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f4a76452b..17c35ae44 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -110,9 +110,9 @@ dependencies = [ [[package]] name = "buf-list" -version = "0.1.3" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "deb213ab6aa87733e74428d8a33f0bb93181810b0ae09ae96b8cc1521d265283" +checksum = "4851afb0c681f0bf27d675eff71da4d659ea32949363b048470e69b11fdcd02f" dependencies = [ "bytes", ] diff --git a/dropshot/Cargo.toml b/dropshot/Cargo.toml index 1c2b3f3c5..e1f6ed59f 100644 --- a/dropshot/Cargo.toml +++ b/dropshot/Cargo.toml @@ -14,7 +14,7 @@ categories = ["network-programming", "web-programming::http-server"] async-stream = "0.3.3" async-trait = "0.1.60" base64 = "0.20.0" -buf-list = "0.1.3" +buf-list = "1.0.0" bytes = "1" futures = "0.3.25" hostname = "0.3.0" diff --git a/dropshot/src/handler.rs b/dropshot/src/handler.rs index 4d3810884..77b2ddd46 100644 --- a/dropshot/src/handler.rs +++ b/dropshot/src/handler.rs @@ -62,6 +62,7 @@ use futures::ready; use futures::stream::BoxStream; use futures::Stream; use futures::StreamExt; +use futures::TryStreamExt; use http::HeaderMap; use http::StatusCode; use hyper::body::HttpBody; @@ -1080,10 +1081,11 @@ where /// An extractor for raw bytes. /// -/// `UntypedBody` is meant to read in the contents of an HTTP request -/// body as a series of raw bytes. An `UntypedBody` represents a read -/// that hasn't happened yet; a method like `into_bytes()` or -/// `into_stream()` must be called to read the full body. +/// `UntypedBody` is meant to read in the contents of an HTTP request body as a +/// series of raw bytes. Unlike [`TypedBody`], an `UntypedBody` represents a +/// read that hasn't happened yet; a method like +/// [`into_bytes`](Self::into_bytes) or [`into_stream`](Self::into_stream) must +/// be called to read the full body. #[derive(Debug)] pub struct UntypedBody { request: Arc>>, @@ -1091,7 +1093,8 @@ pub struct UntypedBody { } impl UntypedBody { - /// Reads the body into a single, contiguous [`Bytes`] chunk. + /// Reads the body into a single, contiguous [`Bytes`] chunk, up to the + /// server's configured maximum body size. /// /// Recommended for smaller request bodies. Constructing a single `Bytes` /// chunk from larger request bodies might cause excessive copying and @@ -1099,7 +1102,8 @@ impl UntypedBody { /// /// # Errors /// - /// Errors if the request body is too large, or if the + /// Errors if there's an underlying HTTP error. Returns a "400 Bad Request" + /// error if the request body is too large. pub async fn into_bytes(self) -> Result { let mut stream = self.into_stream(); let mut bytes = BytesMut::new(); @@ -1111,12 +1115,13 @@ impl UntypedBody { Ok(bytes.freeze()) } - /// Reads the body into a `String`. + /// Reads the body into a `String`, up to the server's configured maximum + /// body size. /// /// # Errors /// - /// In addition to the usual errors, if the stream is not valid UTF-8, this - /// returns a "Bad Request" error. + /// In addition to the errors returned by [`into_bytes`](Self::into_bytes), + /// returns a "400 Bad Request" error if the if the body is not valid UTF-8. pub async fn into_string(self) -> Result { let v = Vec::from(self.into_bytes().await?); String::from_utf8(v).map_err(|e| { @@ -1127,20 +1132,24 @@ impl UntypedBody { }) } - /// Reads the body into a [`BufList`]. + /// Reads the body into a [`BufList`], up to the server's configured maximum + /// body size. /// /// A `BufList` is a list of [`Bytes`] chunks that implements the /// [`Buf`](bytes::Buf) trait. A `BufList` chunks can be operated on /// as a unit. /// - /// Recommended for larger request bodies. + /// Recommended over [`into_bytes`](Self::into_bytes) or + /// [`into_string`](Self::into_string) for larger request bodies. Like those + /// functions, this function fails if the body exceeds the server's + /// configured maximum body size. pub async fn into_buf_list(self) -> Result { let max_bytes = self.max_bytes; self.into_buf_list_with_cap(max_bytes).await } - /// Reads the body into a [`BufList`] with a custom limit for the maximum - /// size of bytes. + /// Reads the body into a [`BufList`], with a custom limit for the maximum + /// body size. /// /// This method is similar to [`into_buf_list`](Self::into_buf_list), except /// a custom limit is specified. If this method is called, the default @@ -1150,18 +1159,24 @@ impl UntypedBody { self, max_bytes: usize, ) -> Result { - let mut stream = - self.into_stream().into_uncapped().into_capped(max_bytes); - let mut buf_list = BufList::new(); - - while let Some(data) = stream.next().await { - buf_list.push_chunk(data?); - } - - Ok(buf_list) + self.into_stream() + .into_uncapped() + .into_capped(max_bytes) + .try_collect() + .await } /// Converts `self` into a [`Stream`] of `Result` chunks. + /// + /// This stream is limited to the server's configured maximum body size. To + /// set a different maximum body size, call + /// [`into_uncapped`](CappedBodyStream::into_uncapped), followed by + /// [`UncappedBodyStream::into_capped`]. + /// + /// # Errors + /// + /// The stream errors if there's an underlying HTTP error, or if the request + /// body exceeds the cap. pub fn into_stream(self) -> CappedBodyStream { let max_bytes = self.max_bytes; @@ -1233,9 +1248,11 @@ pin_project! { /// A stream over an HTTP body that sets a limit on the number of bytes that /// can be read from it. /// - /// To change the cap read from it, use + /// Returned by [`UntypedBody::into_stream`]. + /// + /// To change the maximum body size, use /// [`into_uncapped`](Self::into_uncapped), then apply a new cap with - /// `UncappedStream::into_capped`. + /// [`UncappedStream::into_capped`]. pub struct CappedBodyStream { #[pin] stream: UncappedBodyStream, @@ -1259,7 +1276,7 @@ impl CappedBodyStream { self.current_bytes } - /// Turns this stream into an uncapped one. + /// Turns this stream into one that does not have a maximum size limit. pub fn into_uncapped(self) -> UncappedBodyStream { self.stream } @@ -1309,7 +1326,10 @@ impl Stream for CappedBodyStream { } pin_project! { - /// A stream over an HTTP request body that does not cap the bytes. + /// A stream over an HTTP request body that does not set a limit on the + /// maximum body size. + /// + /// Returned by [`CappedBodyStream::into_uncapped`]. pub struct UncappedBodyStream { // TODO: replace with a concrete type once TAIT is stabilized. #[pin] @@ -1324,9 +1344,12 @@ impl UncappedBodyStream { Self { stream: stream.boxed() } } - /// Adds a cap to this stream. - pub fn into_capped(self, cap: usize) -> CappedBodyStream { - CappedBodyStream::new(self, cap) + /// Turns this stream into one that sets a limit on the request body size. + /// + /// Note that the `CappedBodyStream`'s count starts from zero: it will only + /// consider bytes read from here on out. + pub fn into_capped(self, max_bytes: usize) -> CappedBodyStream { + CappedBodyStream::new(self, max_bytes) } } From 41667a8e822d5bd50624ea109079ff895a5cceae Mon Sep 17 00:00:00 2001 From: Rain Date: Fri, 6 Jan 2023 13:49:32 -0800 Subject: [PATCH 7/8] Simplify into_buf_list Created using spr 1.3.4 --- dropshot/src/handler.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dropshot/src/handler.rs b/dropshot/src/handler.rs index 77b2ddd46..bc4ef1877 100644 --- a/dropshot/src/handler.rs +++ b/dropshot/src/handler.rs @@ -1144,8 +1144,7 @@ impl UntypedBody { /// functions, this function fails if the body exceeds the server's /// configured maximum body size. pub async fn into_buf_list(self) -> Result { - let max_bytes = self.max_bytes; - self.into_buf_list_with_cap(max_bytes).await + self.into_stream().try_collect().await } /// Reads the body into a [`BufList`], with a custom limit for the maximum From 3f7740221ef4d561f5827c7671a79b95b1c65dfb Mon Sep 17 00:00:00 2001 From: Rain Date: Fri, 6 Jan 2023 14:21:09 -0800 Subject: [PATCH 8/8] Remove UncappedBodyStream, rename Capped to UntypedBodyStream Created using spr 1.3.4 --- dropshot/src/handler.rs | 104 ++++++++++++------------------------ dropshot/src/lib.rs | 3 +- dropshot/tests/test_demo.rs | 2 +- 3 files changed, 35 insertions(+), 74 deletions(-) diff --git a/dropshot/src/handler.rs b/dropshot/src/handler.rs index bc4ef1877..3c23b7b2a 100644 --- a/dropshot/src/handler.rs +++ b/dropshot/src/handler.rs @@ -1151,32 +1151,27 @@ impl UntypedBody { /// body size. /// /// This method is similar to [`into_buf_list`](Self::into_buf_list), except - /// a custom limit is specified. If this method is called, the default + /// it specifies a custom limit. If this method is called, the default /// [`request_body_max_bytes`](ServerConfig::request_body_max_bytes) limit /// is ignored. - pub async fn into_buf_list_with_cap( + pub async fn into_buf_list_with_limit( self, max_bytes: usize, ) -> Result { - self.into_stream() - .into_uncapped() - .into_capped(max_bytes) - .try_collect() - .await + self.into_stream().with_limit(max_bytes).try_collect().await } /// Converts `self` into a [`Stream`] of `Result` chunks. /// - /// This stream is limited to the server's configured maximum body size. To - /// set a different maximum body size, call - /// [`into_uncapped`](CappedBodyStream::into_uncapped), followed by - /// [`UncappedBodyStream::into_capped`]. + /// By default, the stream is limited to the server's configured maximum + /// body size. To set a different maximum body size, call + /// [`with_limit`](UntypedBodyStream::with_limit) on the returned stream. /// /// # Errors /// /// The stream errors if there's an underlying HTTP error, or if the request - /// body exceeds the cap. - pub fn into_stream(self) -> CappedBodyStream { + /// body exceeds the limit. + pub fn into_stream(self) -> UntypedBodyStream { let max_bytes = self.max_bytes; let stream = stream! { @@ -1202,8 +1197,7 @@ impl UntypedBody { } }; - let uncapped = UncappedBodyStream::new(stream); - CappedBodyStream::new(uncapped, max_bytes) + UntypedBodyStream::new(stream, max_bytes) } } @@ -1244,25 +1238,23 @@ impl Extractor for UntypedBody { } pin_project! { - /// A stream over an HTTP body that sets a limit on the number of bytes that - /// can be read from it. - /// - /// Returned by [`UntypedBody::into_stream`]. - /// - /// To change the maximum body size, use - /// [`into_uncapped`](Self::into_uncapped), then apply a new cap with - /// [`UncappedStream::into_capped`]. - pub struct CappedBodyStream { + /// A stream over an HTTP body. This stream that a limit on the number of + /// bytes that can be read from it. + pub struct UntypedBodyStream { + // TODO: replace with concrete type once TAIT is stabilized. #[pin] - stream: UncappedBodyStream, + stream: BoxStream<'static, Result>, max_bytes: usize, current_bytes: usize, } } -impl CappedBodyStream { - pub(crate) fn new(stream: UncappedBodyStream, max_bytes: usize) -> Self { - Self { stream, max_bytes, current_bytes: 0 } +impl UntypedBodyStream { + pub(crate) fn new( + stream: impl Stream> + Send + 'static, + max_bytes: usize, + ) -> Self { + Self { stream: stream.boxed(), max_bytes, current_bytes: 0 } } /// Returns the maximum number of bytes that can be read from this stream. @@ -1275,13 +1267,22 @@ impl CappedBodyStream { self.current_bytes } - /// Turns this stream into one that does not have a maximum size limit. - pub fn into_uncapped(self) -> UncappedBodyStream { - self.stream + /// Sets a new limit on the stream. + /// + /// [`Self::current_bytes`] does not change. + pub fn set_limit(&mut self, new_limit: usize) -> &mut Self { + self.max_bytes = new_limit; + self + } + + /// An owned version of [`set_limit`](Self::set_limit). + pub fn with_limit(mut self, new_limit: usize) -> Self { + self.set_limit(new_limit); + self } } -impl Stream for CappedBodyStream { +impl Stream for UntypedBodyStream { type Item = Result; fn poll_next( @@ -1324,45 +1325,6 @@ impl Stream for CappedBodyStream { } } -pin_project! { - /// A stream over an HTTP request body that does not set a limit on the - /// maximum body size. - /// - /// Returned by [`CappedBodyStream::into_uncapped`]. - pub struct UncappedBodyStream { - // TODO: replace with a concrete type once TAIT is stabilized. - #[pin] - stream: BoxStream<'static, Result>, - } -} - -impl UncappedBodyStream { - pub(crate) fn new( - stream: impl Stream> + Send + 'static, - ) -> Self { - Self { stream: stream.boxed() } - } - - /// Turns this stream into one that sets a limit on the request body size. - /// - /// Note that the `CappedBodyStream`'s count starts from zero: it will only - /// consider bytes read from here on out. - pub fn into_capped(self, max_bytes: usize) -> CappedBodyStream { - CappedBodyStream::new(self, max_bytes) - } -} - -impl Stream for UncappedBodyStream { - type Item = Result; - - fn poll_next( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - self.project().stream.poll_next(cx) - } -} - /* * Response Type Conversion * diff --git a/dropshot/src/lib.rs b/dropshot/src/lib.rs index 0837efe74..66d014024 100644 --- a/dropshot/src/lib.rs +++ b/dropshot/src/lib.rs @@ -637,7 +637,6 @@ pub use error::HttpErrorResponseBody; pub use handler::http_response_found; pub use handler::http_response_see_other; pub use handler::http_response_temporary_redirect; -pub use handler::CappedBodyStream; pub use handler::Extractor; pub use handler::ExtractorMetadata; pub use handler::FreeformBody; @@ -657,8 +656,8 @@ pub use handler::Path; pub use handler::Query; pub use handler::RequestContext; pub use handler::TypedBody; -pub use handler::UncappedBodyStream; pub use handler::UntypedBody; +pub use handler::UntypedBodyStream; pub use http_util::CONTENT_TYPE_JSON; pub use http_util::CONTENT_TYPE_NDJSON; pub use http_util::CONTENT_TYPE_OCTET_STREAM; diff --git a/dropshot/tests/test_demo.rs b/dropshot/tests/test_demo.rs index 2309d4251..b3bbc75f8 100644 --- a/dropshot/tests/test_demo.rs +++ b/dropshot/tests/test_demo.rs @@ -1107,7 +1107,7 @@ async fn demo_handler_untyped_body( UntypedQueryInto::BufList => { let buf_list = match query.limit { Some(max_bytes) => { - body.into_buf_list_with_cap(max_bytes).await? + body.into_buf_list_with_limit(max_bytes).await? } None => body.into_buf_list().await?, };