From 404d6a2071d73af8f6c07c9e9c740d1771599845 Mon Sep 17 00:00:00 2001 From: Snow Pettersen Date: Tue, 17 Mar 2026 15:36:19 -0700 Subject: [PATCH 1/4] implement response mutation hook in bd-grpc --- bd-grpc/src/grpc_test.rs | 46 +++++++++++++++++++++++++-- bd-grpc/src/lib.rs | 69 +++++++++++++++++++++++++++++++++++++--- 2 files changed, 107 insertions(+), 8 deletions(-) diff --git a/bd-grpc/src/grpc_test.rs b/bd-grpc/src/grpc_test.rs index a6cd1ff6..a77f98d9 100644 --- a/bd-grpc/src/grpc_test.rs +++ b/bd-grpc/src/grpc_test.rs @@ -25,7 +25,7 @@ use crate::{ StreamingApi, StreamingApiSender, make_server_streaming_router, - make_unary_router, + make_unary_router_with_response_mutator, new_grpc_response, }; use assert_matches::assert_matches; @@ -68,12 +68,22 @@ async fn make_unary_server( error_handler: impl Fn(&crate::Error) + Clone + Send + Sync + 'static, endpoint_stats: Option<&EndpointStats>, ) -> SocketAddr { - let router = make_unary_router( + make_unary_server_with_response_mutator(handler, error_handler, endpoint_stats, None).await +} + +async fn make_unary_server_with_response_mutator( + handler: Arc>, + error_handler: impl Fn(&crate::Error) + Clone + Send + Sync + 'static, + endpoint_stats: Option<&EndpointStats>, + response_mutator: Option>, +) -> SocketAddr { + let router = make_unary_router_with_response_mutator( &service_method(), handler, - error_handler, endpoint_stats, true, + response_mutator, + error_handler, ) .layer(ConnectSafeCompressionLayer::new()); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); @@ -1048,6 +1058,36 @@ async fn unary_json_transcoding() { assert_eq!(body, serde_json::json!({ "echo": "json_echo" })); } +#[tokio::test] +async fn unary_json_transcoding_applies_response_mutator() { + let local_address = make_unary_server_with_response_mutator( + Arc::new(EchoHandler::default()), + |_| {}, + None, + Some(Arc::new(|response: &mut EchoResponse| { + response.echo = "mutated".to_string(); + Ok(()) + })), + ) + .await; + let client = reqwest::Client::builder().deflate(false).build().unwrap(); + let address = AddressHelper::new(format!("http://{local_address}")).unwrap(); + let response = client + .post(address.build(&service_method()).to_string()) + .header(CONTENT_TYPE, CONTENT_TYPE_JSON) + .body("{\"echo\":\"json_echo\"}") + .send() + .await + .unwrap(); + assert_eq!(response.status(), 200); + assert_eq!( + response.headers().get(CONTENT_TYPE).unwrap(), + CONTENT_TYPE_JSON + ); + let body: serde_json::Value = serde_json::from_slice(&response.bytes().await.unwrap()).unwrap(); + assert_eq!(body, serde_json::json!({ "echo": "mutated" })); +} + #[tokio::test] async fn server_streaming_json_request() { let local_address = make_server_streaming_server(Arc::new(EchoHandler::default()), |_| {}) diff --git a/bd-grpc/src/lib.rs b/bd-grpc/src/lib.rs index fddecdc7..59c4fdb1 100644 --- a/bd-grpc/src/lib.rs +++ b/bd-grpc/src/lib.rs @@ -72,6 +72,8 @@ const TRANSFER_ENCODING_TRAILERS: &str = "trailers"; const CONNECT_PROTOCOL_VERSION: &str = "connect-protocol-version"; pub type BodySender = mpsc::Sender, BoxError>>; +pub type UnaryResponseMutator = + Arc Result<()> + Send + Sync>; // // StreamingApi @@ -534,8 +536,12 @@ async fn unary_connect_handler>, + response_mutator: Option>, ) -> Result { - let response = handler.handle(headers, extensions, message).await?; + let mut response = handler.handle(headers, extensions, message).await?; + if let Some(response_mutator) = response_mutator { + response_mutator(&mut response)?; + } Ok(new_grpc_response( response.write_to_bytes().unwrap().into(), None, @@ -549,13 +555,14 @@ pub async fn unary_handler handler: Arc>, validate_request: bool, connect_protocol_type: Option, + response_mutator: Option>, ) -> Result { let (headers, extensions, message) = decode_request::(request, validate_request, connect_protocol_type).await?; let json_transcoding = connect_protocol_type.is_none() && is_json_request_content_type(&headers); if matches!(connect_protocol_type, Some(ConnectProtocolType::Unary)) { - return unary_connect_handler(headers, extensions, message, handler).await; + return unary_connect_handler(headers, extensions, message, handler, response_mutator).await; } let compression = finalize_response_compression( @@ -563,7 +570,10 @@ pub async fn unary_handler &headers, ); - let response = handler.handle(headers, extensions, message).await?; + let mut response = handler.handle(headers, extensions, message).await?; + if let Some(response_mutator) = response_mutator { + response_mutator(&mut response)?; + } if json_transcoding { let json = @@ -750,13 +760,36 @@ pub fn make_unary_router( endpoint_stats: Option<&EndpointStats>, validate_request: bool, ) -> Router { - make_unary_router_at_path( + make_unary_router_with_response_mutator( service_method, - &service_method.full_path(), handler, + endpoint_stats, + validate_request, + None, error_handler, + ) +} + +// Create an axum router for a unary request and a handler with an optional response mutator. +pub fn make_unary_router_with_response_mutator< + OutgoingType: MessageFull, + IncomingType: MessageFull, +>( + service_method: &ServiceMethod, + handler: Arc>, + endpoint_stats: Option<&EndpointStats>, + validate_request: bool, + response_mutator: Option>, + error_handler: impl Fn(&crate::Error) + Clone + Send + Sync + 'static, +) -> Router { + make_unary_router_at_path_with_response_mutator( + service_method, + &service_method.full_path(), + handler, endpoint_stats, validate_request, + response_mutator, + error_handler, ) } @@ -768,6 +801,31 @@ pub fn make_unary_router_at_path, validate_request: bool, +) -> Router { + make_unary_router_at_path_with_response_mutator( + service_method, + full_path, + handler, + endpoint_stats, + validate_request, + None, + error_handler, + ) +} + +// Create an axum router for a unary request and a handler at a provided path with an optional +// response mutator. +pub fn make_unary_router_at_path_with_response_mutator< + OutgoingType: MessageFull, + IncomingType: MessageFull, +>( + service_method: &ServiceMethod, + full_path: &str, + handler: Arc>, + endpoint_stats: Option<&EndpointStats>, + validate_request: bool, + response_mutator: Option>, + error_handler: impl Fn(&crate::Error) + Clone + Send + Sync + 'static, ) -> Router { let warn_tracker = Arc::new(WarnTracker::default()); let full_path = Arc::new(full_path.to_string()); @@ -783,6 +841,7 @@ pub fn make_unary_router_at_path Date: Tue, 17 Mar 2026 15:37:08 -0700 Subject: [PATCH 2/4] fmt grpc --- bd-grpc/Cargo.toml | 70 +++++++++++++++++++++++----------------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/bd-grpc/Cargo.toml b/bd-grpc/Cargo.toml index 9073a2c1..3a14be14 100644 --- a/bd-grpc/Cargo.toml +++ b/bd-grpc/Cargo.toml @@ -9,42 +9,42 @@ version = "1.0.0" doctest = false [dependencies] -anyhow.workspace = true -assert_matches.workspace = true -async-trait.workspace = true -axum.workspace = true -base64ct.workspace = true -bd-grpc-codec.path = "../bd-grpc-codec" -bd-log.path = "../bd-log" -bd-pgv.path = "../bd-pgv" -bd-server-stats.path = "../bd-server-stats" -bd-shutdown.path = "../bd-shutdown" -bd-stats-common.path = "../bd-stats-common" -bd-time.path = "../bd-time" -bd-workspace-hack.workspace = true -bytes.workspace = true -futures.workspace = true -http.workspace = true -http-body.workspace = true -http-body-util.workspace = true -hyper.workspace = true -hyper-util.workspace = true -log.workspace = true -mockall = { workspace = true, optional = true } -prometheus.workspace = true -protobuf.workspace = true +anyhow.workspace = true +assert_matches.workspace = true +async-trait.workspace = true +axum.workspace = true +base64ct.workspace = true +bd-grpc-codec.path = "../bd-grpc-codec" +bd-log.path = "../bd-log" +bd-pgv.path = "../bd-pgv" +bd-server-stats.path = "../bd-server-stats" +bd-shutdown.path = "../bd-shutdown" +bd-stats-common.path = "../bd-stats-common" +bd-time.path = "../bd-time" +bd-workspace-hack.workspace = true +bytes.workspace = true +futures.workspace = true +http.workspace = true +http-body.workspace = true +http-body-util.workspace = true +hyper.workspace = true +hyper-util.workspace = true +log.workspace = true +mockall = { workspace = true, optional = true } +prometheus.workspace = true +protobuf.workspace = true protobuf-json-mapping.workspace = true -serde.workspace = true -serde_json.workspace = true -snap.workspace = true -thiserror.workspace = true -time.workspace = true -tokio.workspace = true -tokio-stream.workspace = true -tower.workspace = true -tower-http.workspace = true -unwrap-infallible.workspace = true -urlencoding.workspace = true +serde.workspace = true +serde_json.workspace = true +snap.workspace = true +thiserror.workspace = true +time.workspace = true +tokio.workspace = true +tokio-stream.workspace = true +tower.workspace = true +tower-http.workspace = true +unwrap-infallible.workspace = true +urlencoding.workspace = true [dev-dependencies] assert_matches.workspace = true From b1c9e308533e7071ada10220b44ebaa0e37fb043 Mon Sep 17 00:00:00 2001 From: Snow Pettersen Date: Tue, 17 Mar 2026 16:07:13 -0700 Subject: [PATCH 3/4] better handling of default fields --- bd-grpc/src/generated/proto/test.rs | 131 +++++++++++++++++++++++++++- bd-grpc/src/grpc_test.rs | 56 +++++++++++- bd-grpc/src/lib.rs | 1 + bd-grpc/src/proto/test.proto | 5 ++ 4 files changed, 189 insertions(+), 4 deletions(-) diff --git a/bd-grpc/src/generated/proto/test.rs b/bd-grpc/src/generated/proto/test.rs index 0ce97a97..a6019333 100644 --- a/bd-grpc/src/generated/proto/test.rs +++ b/bd-grpc/src/generated/proto/test.rs @@ -268,12 +268,136 @@ impl ::protobuf::reflect::ProtobufValue for EchoResponse { type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; } +// @@protoc_insertion_point(message:test.EchoRepeatedResponse) +#[derive(PartialEq,Clone,Default,Debug)] +pub struct EchoRepeatedResponse { + // message fields + // @@protoc_insertion_point(field:test.EchoRepeatedResponse.values) + pub values: ::std::vec::Vec<::std::string::String>, + // special fields + // @@protoc_insertion_point(special_field:test.EchoRepeatedResponse.special_fields) + pub special_fields: ::protobuf::SpecialFields, +} + +impl<'a> ::std::default::Default for &'a EchoRepeatedResponse { + fn default() -> &'a EchoRepeatedResponse { + ::default_instance() + } +} + +impl EchoRepeatedResponse { + pub fn new() -> EchoRepeatedResponse { + ::std::default::Default::default() + } + + fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { + let mut fields = ::std::vec::Vec::with_capacity(1); + let mut oneofs = ::std::vec::Vec::with_capacity(0); + fields.push(::protobuf::reflect::rt::v2::make_vec_simpler_accessor::<_, _>( + "values", + |m: &EchoRepeatedResponse| { &m.values }, + |m: &mut EchoRepeatedResponse| { &mut m.values }, + )); + ::protobuf::reflect::GeneratedMessageDescriptorData::new_2::( + "EchoRepeatedResponse", + fields, + oneofs, + ) + } +} + +impl ::protobuf::Message for EchoRepeatedResponse { + const NAME: &'static str = "EchoRepeatedResponse"; + + fn is_initialized(&self) -> bool { + true + } + + fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> { + while let Some(tag) = is.read_raw_tag_or_eof()? { + match tag { + 10 => { + self.values.push(is.read_string()?); + }, + tag => { + ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; + }, + }; + } + ::std::result::Result::Ok(()) + } + + // Compute sizes of nested messages + #[allow(unused_variables)] + fn compute_size(&self) -> u64 { + let mut my_size = 0; + for value in &self.values { + my_size += ::protobuf::rt::string_size(1, &value); + }; + my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); + self.special_fields.cached_size().set(my_size as u32); + my_size + } + + fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { + for v in &self.values { + os.write_string(1, &v)?; + }; + os.write_unknown_fields(self.special_fields.unknown_fields())?; + ::std::result::Result::Ok(()) + } + + fn special_fields(&self) -> &::protobuf::SpecialFields { + &self.special_fields + } + + fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields { + &mut self.special_fields + } + + fn new() -> EchoRepeatedResponse { + EchoRepeatedResponse::new() + } + + fn clear(&mut self) { + self.values.clear(); + self.special_fields.clear(); + } + + fn default_instance() -> &'static EchoRepeatedResponse { + static instance: EchoRepeatedResponse = EchoRepeatedResponse { + values: ::std::vec::Vec::new(), + special_fields: ::protobuf::SpecialFields::new(), + }; + &instance + } +} + +impl ::protobuf::MessageFull for EchoRepeatedResponse { + fn descriptor() -> ::protobuf::reflect::MessageDescriptor { + static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new(); + descriptor.get(|| file_descriptor().message_by_package_relative_name("EchoRepeatedResponse").unwrap()).clone() + } +} + +impl ::std::fmt::Display for EchoRepeatedResponse { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::protobuf::text_format::fmt(self, f) + } +} + +impl ::protobuf::reflect::ProtobufValue for EchoRepeatedResponse { + type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage; +} + static file_descriptor_proto_data: &'static [u8] = b"\ \n\ntest.proto\x12\x04test\x1a\x17validate/validate.proto\"*\n\x0bEchoRe\ quest\x12\x1b\n\x04echo\x18\x01\x20\x01(\tR\x04echoB\x07\xfaB\x04r\x02\ \x10\x01\"\"\n\x0cEchoResponse\x12\x12\n\x04echo\x18\x01\x20\x01(\tR\x04\ - echo25\n\x04Test\x12-\n\x04Echo\x12\x11.test.EchoRequest\x1a\x12.test.Ec\ - hoResponseb\x06proto3\ + echo\".\n\x14EchoRepeatedResponse\x12\x16\n\x06values\x18\x01\x20\x03(\t\ + R\x06values2t\n\x04Test\x12-\n\x04Echo\x12\x11.test.EchoRequest\x1a\x12.\ + test.EchoResponse\x12=\n\x0cEchoRepeated\x12\x11.test.EchoRequest\x1a\ + \x1a.test.EchoRepeatedResponseb\x06proto3\ "; /// `FileDescriptorProto` object which was a source for this generated file @@ -292,9 +416,10 @@ pub fn file_descriptor() -> &'static ::protobuf::reflect::FileDescriptor { let generated_file_descriptor = generated_file_descriptor_lazy.get(|| { let mut deps = ::std::vec::Vec::with_capacity(1); deps.push(super::validate::file_descriptor().clone()); - let mut messages = ::std::vec::Vec::with_capacity(2); + let mut messages = ::std::vec::Vec::with_capacity(3); messages.push(EchoRequest::generated_message_descriptor_data()); messages.push(EchoResponse::generated_message_descriptor_data()); + messages.push(EchoRepeatedResponse::generated_message_descriptor_data()); let mut enums = ::std::vec::Vec::with_capacity(0); ::protobuf::reflect::GeneratedFileDescriptor::new_generated( file_descriptor_proto(), diff --git a/bd-grpc/src/grpc_test.rs b/bd-grpc/src/grpc_test.rs index a77f98d9..42c043fd 100644 --- a/bd-grpc/src/grpc_test.rs +++ b/bd-grpc/src/grpc_test.rs @@ -8,7 +8,7 @@ use crate::client::{AddressHelper, Client}; use crate::compression::{Compression, ConnectSafeCompressionLayer}; use crate::connect_protocol::ConnectProtocolType; -use crate::generated::proto::test::{EchoRequest, EchoResponse}; +use crate::generated::proto::test::{EchoRepeatedResponse, EchoRequest, EchoResponse}; use crate::stats::EndpointStats; use crate::{ CONNECT_PROTOCOL_VERSION, @@ -63,6 +63,10 @@ fn service_method() -> ServiceMethod { ServiceMethod::::new("Test", "Echo") } +fn repeated_service_method() -> ServiceMethod { + ServiceMethod::::new("Test", "EchoRepeated") +} + async fn make_unary_server( handler: Arc>, error_handler: impl Fn(&crate::Error) + Clone + Send + Sync + 'static, @@ -161,6 +165,10 @@ struct EchoHandler { do_sleep: bool, streaming_event_sender: Mutex>>, } + +#[derive(Default)] +struct EchoRepeatedHandler; + enum StreamingTestEvent { Message(EchoResponse), EndStreamOk, @@ -186,6 +194,18 @@ impl Handler for EchoHandler { } } +#[async_trait] +impl Handler for EchoRepeatedHandler { + async fn handle( + &self, + _headers: HeaderMap, + _extensions: Extensions, + _request: EchoRequest, + ) -> Result { + Ok(EchoRepeatedResponse::default()) + } +} + #[async_trait] impl ServerStreamingHandler for EchoHandler { async fn stream( @@ -1088,6 +1108,40 @@ async fn unary_json_transcoding_applies_response_mutator() { assert_eq!(body, serde_json::json!({ "echo": "mutated" })); } +#[tokio::test] +async fn unary_json_transcoding_omits_empty_repeated_fields() { + let router = make_unary_router_with_response_mutator( + &repeated_service_method(), + Arc::new(EchoRepeatedHandler), + None, + true, + None, + |_| {}, + ) + .layer(ConnectSafeCompressionLayer::new()); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let local_address = listener.local_addr().unwrap(); + let server = axum::serve(listener, router.into_make_service()); + tokio::spawn(async { server.await.unwrap() }); + + let client = reqwest::Client::builder().deflate(false).build().unwrap(); + let address = AddressHelper::new(format!("http://{local_address}")).unwrap(); + let response = client + .post(address.build(&repeated_service_method()).to_string()) + .header(CONTENT_TYPE, CONTENT_TYPE_JSON) + .body("{\"echo\":\"json_echo\"}") + .send() + .await + .unwrap(); + assert_eq!(response.status(), 200); + assert_eq!( + response.headers().get(CONTENT_TYPE).unwrap(), + CONTENT_TYPE_JSON + ); + let body: serde_json::Value = serde_json::from_slice(&response.bytes().await.unwrap()).unwrap(); + assert_eq!(body, serde_json::json!({})); +} + #[tokio::test] async fn server_streaming_json_request() { let local_address = make_server_streaming_server(Arc::new(EchoHandler::default()), |_| {}) diff --git a/bd-grpc/src/lib.rs b/bd-grpc/src/lib.rs index 59c4fdb1..dd2b764c 100644 --- a/bd-grpc/src/lib.rs +++ b/bd-grpc/src/lib.rs @@ -527,6 +527,7 @@ fn json_parse_options() -> ParseOptions { fn json_print_options() -> PrintOptions { PrintOptions { proto_field_name: true, + always_output_default_values: false, ..Default::default() } } diff --git a/bd-grpc/src/proto/test.proto b/bd-grpc/src/proto/test.proto index ae27cf87..58b60453 100644 --- a/bd-grpc/src/proto/test.proto +++ b/bd-grpc/src/proto/test.proto @@ -11,6 +11,11 @@ message EchoResponse { string echo = 1; } +message EchoRepeatedResponse { + repeated string values = 1; +} + service Test { rpc Echo(EchoRequest) returns (EchoResponse); + rpc EchoRepeated(EchoRequest) returns (EchoRepeatedResponse); } From 279ecd99d8a1af31387d0fa369c98a713d5c3b5f Mon Sep 17 00:00:00 2001 From: Snow Pettersen Date: Tue, 17 Mar 2026 16:28:22 -0700 Subject: [PATCH 4/4] no need to restate default option --- bd-grpc/src/lib.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/bd-grpc/src/lib.rs b/bd-grpc/src/lib.rs index dd2b764c..59c4fdb1 100644 --- a/bd-grpc/src/lib.rs +++ b/bd-grpc/src/lib.rs @@ -527,7 +527,6 @@ fn json_parse_options() -> ParseOptions { fn json_print_options() -> PrintOptions { PrintOptions { proto_field_name: true, - always_output_default_values: false, ..Default::default() } }