Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@

https://github.com/oxidecomputer/dropshot/compare/v0.13.0\...HEAD[Full list of commits]

=== Breaking changes

* The `request_body_max_bytes` config has been renamed to `default_request_body_max_bytes`. This is to make its semantics clear with respect to per-endpoint request limits.
+
Defining the old config option will produce an error, guiding you to perform the rename.

== 0.13.0 (released 2024-11-13)

https://github.com/oxidecomputer/dropshot/compare/v0.12.0\...v0.13.0[Full list of commits]
Expand Down
2 changes: 1 addition & 1 deletion README.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ include:
|No
|Specifies that the server should bind to the given IP address and TCP port. In general, servers can bind to more than one IP address and port, but this is not (yet?) supported. Defaults to "127.0.0.1:0".

|`request_body_max_bytes`
|`default_request_body_max_bytes`
|`4096`
|No
|Specifies the maximum number of bytes allowed in a request body. Larger requests will receive a 400 error. Defaults to 1024.
Expand Down
115 changes: 111 additions & 4 deletions dropshot/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub type RawTlsConfig = rustls::ServerConfig;
/// r##"
/// [http_api_server]
/// bind_address = "127.0.0.1:12345"
/// request_body_max_bytes = 1024
/// default_request_body_max_bytes = 1024
/// ## ... (other app-specific config)
/// "##
/// ).map_err(|error| format!("parsing config: {}", error))?;
Expand All @@ -43,12 +43,15 @@ pub type RawTlsConfig = rustls::ServerConfig;
/// }
/// ```
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
#[serde(default)]
#[serde(
from = "DeserializedConfigDropshot",
into = "DeserializedConfigDropshot"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the point of the into?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is possible to derive Serialize on ConfigDropshot directly, and it would work today, but in case there's more divergence between the in-memory implementation and the serializable implementation over time, that may no longer hold true. I've seen this be a recurring source of bugs in other places, and making the types that directly interact with serde be the same for serialization and deserialization has always been helpful.

)]
pub struct ConfigDropshot {
/// IP address and TCP port to which to bind for accepting connections
pub bind_address: SocketAddr,
/// maximum allowed size of a request body, defaults to 1024
pub request_body_max_bytes: usize,
pub default_request_body_max_bytes: usize,
/// Default behavior for HTTP handler functions with respect to clients
/// disconnecting early.
pub default_handler_task_mode: HandlerTaskMode,
Expand Down Expand Up @@ -113,9 +116,113 @@ impl Default for ConfigDropshot {
fn default() -> Self {
ConfigDropshot {
bind_address: "127.0.0.1:0".parse().unwrap(),
request_body_max_bytes: 1024,
default_request_body_max_bytes: 1024,
default_handler_task_mode: HandlerTaskMode::Detached,
log_headers: Default::default(),
}
}
}

#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to be Serialize?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is coupled with the into comment.

#[serde(default)]
struct DeserializedConfigDropshot {
bind_address: SocketAddr,
default_request_body_max_bytes: usize,
// Previous name for default_request_body_max_bytes, in Dropshot < 0.14.
// Present only to guide users to the new name.
#[serde(
deserialize_with = "deserialize_invalid_request_body_max_bytes",
skip_serializing
)]
request_body_max_bytes: Option<InvalidConfig>,
default_handler_task_mode: HandlerTaskMode,
log_headers: Vec<String>,
}

impl From<DeserializedConfigDropshot> for ConfigDropshot {
fn from(v: DeserializedConfigDropshot) -> Self {
ConfigDropshot {
bind_address: v.bind_address,
default_request_body_max_bytes: v.default_request_body_max_bytes,
default_handler_task_mode: v.default_handler_task_mode,
log_headers: v.log_headers,
}
}
}

impl From<ConfigDropshot> for DeserializedConfigDropshot {
fn from(v: ConfigDropshot) -> Self {
DeserializedConfigDropshot {
bind_address: v.bind_address,
default_request_body_max_bytes: v.default_request_body_max_bytes,
request_body_max_bytes: None,
default_handler_task_mode: v.default_handler_task_mode,
log_headers: v.log_headers,
}
}
}

impl Default for DeserializedConfigDropshot {
fn default() -> Self {
ConfigDropshot::default().into()
}
}

/// A marker type to indicate that the configuration is invalid.
///
/// This type can never be constructed, which means that for any valid config,
/// `Option<InvalidConfig>` is always none.
#[derive(Clone, Debug, PartialEq)]
pub enum InvalidConfig {}

// We prefer having a deserialize function over `impl Deserialize for
// InvalidConfig` for two reasons:
//
// 1. This returns an `Option<InvalidConfig>`, not an `InvalidConfig`.
// 2. This way, the deserializer has a custom message associated with it.
fn deserialize_invalid_request_body_max_bytes<'de, D>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we prefer this to impl Serialize for InvalidConfig?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern allows a custom error message to be passed in -- InvalidConfig is a general marker value that shouldn't necessarily be tied with the specific error message.

I'll add a comment to this note.

Copy link
Contributor Author

@sunshowers sunshowers Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh also this returns an Option<InvalidConfig> which wouldn't be possible to do via a type implementation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh also this returns an Option<InvalidConfig> which wouldn't be possible to do via a type implementation.

that's true... but I think impl Serialize for InvalidConfig would probably suffice since the implementation of Serialize for Option will call it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, InvalidConfig doesn't implement Serialize at all (the skip_serializing makes that moot). This code only does deserialization.

deserializer: D,
) -> Result<Option<InvalidConfig>, D::Error>
where
D: serde::Deserializer<'de>,
{
deserialize_invalid(
deserializer,
"request_body_max_bytes has been renamed to \
default_request_body_max_bytes",
)
}

fn deserialize_invalid<'de, D>(
deserializer: D,
msg: &'static str,
) -> Result<Option<InvalidConfig>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;

struct V {
msg: &'static str,
}

impl<'de> serde::de::Visitor<'de> for V {
type Value = Option<InvalidConfig>;

fn expecting(
&self,
formatter: &mut std::fmt::Formatter,
) -> std::fmt::Result {
write!(formatter, "the field to be absent ({})", self.msg)
}

fn visit_some<D>(self, _: D) -> Result<Self::Value, D::Error>
where
D: serde::Deserializer<'de>,
{
Err(D::Error::custom(self.msg))
}
}

deserializer.deserialize_any(V { msg })
}
10 changes: 3 additions & 7 deletions dropshot/src/extractor/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,8 @@ async fn http_request_load_body<Context: ServerContext, BodyType>(
where
BodyType: JsonSchema + DeserializeOwned + Send + Sync,
{
let server = &rqctx.server;
let (parts, body) = request.into_parts();
let body = StreamingBody::new(body, server.config.request_body_max_bytes)
let body = StreamingBody::new(body, rqctx.request_body_max_bytes())
.into_bytes_mut()
.await?;

Expand Down Expand Up @@ -262,10 +261,9 @@ impl ExclusiveExtractor for UntypedBody {
rqctx: &RequestContext<Context>,
request: hyper::Request<crate::Body>,
) -> Result<UntypedBody, HttpError> {
let server = &rqctx.server;
let body = request.into_body();
let body_bytes =
StreamingBody::new(body, server.config.request_body_max_bytes)
StreamingBody::new(body, rqctx.request_body_max_bytes())
.into_bytes_mut()
.await?;
Ok(UntypedBody { content: body_bytes.freeze() })
Expand Down Expand Up @@ -425,11 +423,9 @@ impl ExclusiveExtractor for StreamingBody {
rqctx: &RequestContext<Context>,
request: hyper::Request<crate::Body>,
) -> Result<Self, HttpError> {
let server = &rqctx.server;

Ok(Self {
body: request.into_body(),
cap: server.config.request_body_max_bytes,
cap: rqctx.request_body_max_bytes(),
})
}

Expand Down
5 changes: 5 additions & 0 deletions dropshot/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ impl<Context: ServerContext> RequestContext<Context> {
&self.server.private
}

/// Returns the maximum request body size.
pub fn request_body_max_bytes(&self) -> usize {
self.server.config.default_request_body_max_bytes
}

/// Returns the appropriate count of items to return for a paginated request
///
/// This first looks at any client-requested limit and clamps it based on the
Expand Down
5 changes: 3 additions & 2 deletions dropshot/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ impl<C: ServerContext> DropshotState<C> {
#[derive(Debug)]
pub struct ServerConfig {
/// maximum allowed size of a request body
pub request_body_max_bytes: usize,
pub default_request_body_max_bytes: usize,
/// maximum size of any page of results
pub page_max_nitems: NonZeroU32,
/// default size for a page of results
Expand Down Expand Up @@ -182,7 +182,8 @@ impl<C: ServerContext> HttpServerStarter<C> {

let server_config = ServerConfig {
// We start aggressively to ensure test coverage.
request_body_max_bytes: config.request_body_max_bytes,
default_request_body_max_bytes: config
.default_request_body_max_bytes,
page_max_nitems: NonZeroU32::new(10000).unwrap(),
page_default_nitems: NonZeroU32::new(100).unwrap(),
default_handler_task_mode: config.default_handler_task_mode,
Expand Down
2 changes: 1 addition & 1 deletion dropshot/src/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ mod tests {
server: Arc::new(DropshotState {
private: (),
config: ServerConfig {
request_body_max_bytes: 0,
default_request_body_max_bytes: 0,
page_max_nitems: NonZeroU32::new(1).unwrap(),
page_default_nitems: NonZeroU32::new(1).unwrap(),
default_handler_task_mode:
Expand Down
26 changes: 21 additions & 5 deletions dropshot/tests/integration-tests/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fn test_valid_config_all_settings() {
"valid_config_basic",
r#"
bind_address = "127.0.0.1:12345"
request_body_max_bytes = 1048576
default_request_body_max_bytes = 1048576
default_handler_task_mode = "cancel-on-disconnect"
log_headers = ["X-Forwarded-For"]
"#,
Expand All @@ -63,7 +63,7 @@ fn test_valid_config_all_settings() {
parsed,
ConfigDropshot {
bind_address: "127.0.0.1:12345".parse().unwrap(),
request_body_max_bytes: 1048576,
default_request_body_max_bytes: 1048576,
default_handler_task_mode: HandlerTaskMode::CancelOnDisconnect,
log_headers: vec!["X-Forwarded-For".to_string()],
},
Expand Down Expand Up @@ -114,7 +114,7 @@ fn test_config_bad_bind_address_garbage() {
fn test_config_bad_request_body_max_bytes_negative() {
let error = read_config::<ConfigDropshot>(
"bad_request_body_max_bytes_negative",
"request_body_max_bytes = -1024",
"default_request_body_max_bytes = -1024",
)
.unwrap_err()
.to_string();
Expand All @@ -126,14 +126,30 @@ fn test_config_bad_request_body_max_bytes_negative() {
fn test_config_bad_request_body_max_bytes_too_large() {
let error = read_config::<ConfigDropshot>(
"bad_request_body_max_bytes_too_large",
"request_body_max_bytes = 999999999999999999999999999999",
"default_request_body_max_bytes = 999999999999999999999999999999",
)
.unwrap_err()
.to_string();
println!("found error: {}", error);
assert!(error.starts_with(""));
}

#[test]
fn test_config_deprecated_request_body_max_bytes() {
let error = read_config::<ConfigDropshot>(
"deprecated_request_body_max_bytes",
"request_body_max_bytes = 1024",
)
.unwrap_err();
assert_eq!(
error.message(),
"invalid type: integer `1024`, \
expected the field to be absent \
(request_body_max_bytes has been renamed to \
default_request_body_max_bytes)",
);
}

fn make_server<T: Send + Sync + 'static>(
context: T,
config: &ConfigDropshot,
Expand Down Expand Up @@ -162,7 +178,7 @@ fn make_config(
std::net::IpAddr::from_str(bind_ip_str).unwrap(),
bind_port,
),
request_body_max_bytes: 1024,
default_request_body_max_bytes: 1024,
default_handler_task_mode,
log_headers: Default::default(),
}
Expand Down
4 changes: 2 additions & 2 deletions dropshot/tests/integration-tests/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ fn make_server(
) -> HttpServer<i32> {
let config = ConfigDropshot {
bind_address: "127.0.0.1:0".parse().unwrap(),
request_body_max_bytes: 1024,
default_request_body_max_bytes: 1024,
default_handler_task_mode: HandlerTaskMode::CancelOnDisconnect,
log_headers: Default::default(),
};
Expand Down Expand Up @@ -429,7 +429,7 @@ async fn test_server_is_https() {

let config = ConfigDropshot {
bind_address: "127.0.0.1:0".parse().unwrap(),
request_body_max_bytes: 1024,
default_request_body_max_bytes: 1024,
default_handler_task_mode: HandlerTaskMode::CancelOnDisconnect,
log_headers: Default::default(),
};
Expand Down