Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 35 additions & 35 deletions bd-grpc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,42 +9,42 @@ version = "1.0.0"
doctest = false

[dependencies]
anyhow.workspace = true
assert_matches.workspace = true
async-trait.workspace = true
axum.workspace = true
base64ct.workspace = true
bd-grpc-codec.path = "../bd-grpc-codec"
bd-log.path = "../bd-log"
bd-pgv.path = "../bd-pgv"
bd-server-stats.path = "../bd-server-stats"
bd-shutdown.path = "../bd-shutdown"
bd-stats-common.path = "../bd-stats-common"
bd-time.path = "../bd-time"
bd-workspace-hack.workspace = true
bytes.workspace = true
futures.workspace = true
http.workspace = true
http-body.workspace = true
http-body-util.workspace = true
hyper.workspace = true
hyper-util.workspace = true
log.workspace = true
mockall = { workspace = true, optional = true }
prometheus.workspace = true
protobuf.workspace = true
anyhow.workspace = true
assert_matches.workspace = true
async-trait.workspace = true
axum.workspace = true
base64ct.workspace = true
bd-grpc-codec.path = "../bd-grpc-codec"
bd-log.path = "../bd-log"
bd-pgv.path = "../bd-pgv"
bd-server-stats.path = "../bd-server-stats"
bd-shutdown.path = "../bd-shutdown"
bd-stats-common.path = "../bd-stats-common"
bd-time.path = "../bd-time"
bd-workspace-hack.workspace = true
bytes.workspace = true
futures.workspace = true
http.workspace = true
http-body.workspace = true
http-body-util.workspace = true
hyper.workspace = true
hyper-util.workspace = true
log.workspace = true
mockall = { workspace = true, optional = true }
prometheus.workspace = true
protobuf.workspace = true
protobuf-json-mapping.workspace = true
serde.workspace = true
serde_json.workspace = true
snap.workspace = true
thiserror.workspace = true
time.workspace = true
tokio.workspace = true
tokio-stream.workspace = true
tower.workspace = true
tower-http.workspace = true
unwrap-infallible.workspace = true
urlencoding.workspace = true
serde.workspace = true
serde_json.workspace = true
snap.workspace = true
thiserror.workspace = true
time.workspace = true
tokio.workspace = true
tokio-stream.workspace = true
tower.workspace = true
tower-http.workspace = true
unwrap-infallible.workspace = true
urlencoding.workspace = true

[dev-dependencies]
assert_matches.workspace = true
Expand Down
131 changes: 128 additions & 3 deletions bd-grpc/src/generated/proto/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,136 @@ impl ::protobuf::reflect::ProtobufValue for EchoResponse {
type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage<Self>;
}

// @@protoc_insertion_point(message:test.EchoRepeatedResponse)
#[derive(PartialEq,Clone,Default,Debug)]
pub struct EchoRepeatedResponse {
// message fields
// @@protoc_insertion_point(field:test.EchoRepeatedResponse.values)
pub values: ::std::vec::Vec<::std::string::String>,
// special fields
// @@protoc_insertion_point(special_field:test.EchoRepeatedResponse.special_fields)
pub special_fields: ::protobuf::SpecialFields,
}

impl<'a> ::std::default::Default for &'a EchoRepeatedResponse {
fn default() -> &'a EchoRepeatedResponse {
<EchoRepeatedResponse as ::protobuf::Message>::default_instance()
}
}

impl EchoRepeatedResponse {
pub fn new() -> EchoRepeatedResponse {
::std::default::Default::default()
}

fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData {
let mut fields = ::std::vec::Vec::with_capacity(1);
let mut oneofs = ::std::vec::Vec::with_capacity(0);
fields.push(::protobuf::reflect::rt::v2::make_vec_simpler_accessor::<_, _>(
"values",
|m: &EchoRepeatedResponse| { &m.values },
|m: &mut EchoRepeatedResponse| { &mut m.values },
));
::protobuf::reflect::GeneratedMessageDescriptorData::new_2::<EchoRepeatedResponse>(
"EchoRepeatedResponse",
fields,
oneofs,
)
}
}

impl ::protobuf::Message for EchoRepeatedResponse {
const NAME: &'static str = "EchoRepeatedResponse";

fn is_initialized(&self) -> bool {
true
}

fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> {
while let Some(tag) = is.read_raw_tag_or_eof()? {
match tag {
10 => {
self.values.push(is.read_string()?);
},
tag => {
::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?;
},
};
}
::std::result::Result::Ok(())
}

// Compute sizes of nested messages
#[allow(unused_variables)]
fn compute_size(&self) -> u64 {
let mut my_size = 0;
for value in &self.values {
my_size += ::protobuf::rt::string_size(1, &value);
};
my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields());
self.special_fields.cached_size().set(my_size as u32);
my_size
}

fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> {
for v in &self.values {
os.write_string(1, &v)?;
};
os.write_unknown_fields(self.special_fields.unknown_fields())?;
::std::result::Result::Ok(())
}

fn special_fields(&self) -> &::protobuf::SpecialFields {
&self.special_fields
}

fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields {
&mut self.special_fields
}

fn new() -> EchoRepeatedResponse {
EchoRepeatedResponse::new()
}

fn clear(&mut self) {
self.values.clear();
self.special_fields.clear();
}

fn default_instance() -> &'static EchoRepeatedResponse {
static instance: EchoRepeatedResponse = EchoRepeatedResponse {
values: ::std::vec::Vec::new(),
special_fields: ::protobuf::SpecialFields::new(),
};
&instance
}
}

impl ::protobuf::MessageFull for EchoRepeatedResponse {
fn descriptor() -> ::protobuf::reflect::MessageDescriptor {
static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new();
descriptor.get(|| file_descriptor().message_by_package_relative_name("EchoRepeatedResponse").unwrap()).clone()
}
}

impl ::std::fmt::Display for EchoRepeatedResponse {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
::protobuf::text_format::fmt(self, f)
}
}

impl ::protobuf::reflect::ProtobufValue for EchoRepeatedResponse {
type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage<Self>;
}

static file_descriptor_proto_data: &'static [u8] = b"\
\n\ntest.proto\x12\x04test\x1a\x17validate/validate.proto\"*\n\x0bEchoRe\
quest\x12\x1b\n\x04echo\x18\x01\x20\x01(\tR\x04echoB\x07\xfaB\x04r\x02\
\x10\x01\"\"\n\x0cEchoResponse\x12\x12\n\x04echo\x18\x01\x20\x01(\tR\x04\
echo25\n\x04Test\x12-\n\x04Echo\x12\x11.test.EchoRequest\x1a\x12.test.Ec\
hoResponseb\x06proto3\
echo\".\n\x14EchoRepeatedResponse\x12\x16\n\x06values\x18\x01\x20\x03(\t\
R\x06values2t\n\x04Test\x12-\n\x04Echo\x12\x11.test.EchoRequest\x1a\x12.\
test.EchoResponse\x12=\n\x0cEchoRepeated\x12\x11.test.EchoRequest\x1a\
\x1a.test.EchoRepeatedResponseb\x06proto3\
";

/// `FileDescriptorProto` object which was a source for this generated file
Expand All @@ -292,9 +416,10 @@ pub fn file_descriptor() -> &'static ::protobuf::reflect::FileDescriptor {
let generated_file_descriptor = generated_file_descriptor_lazy.get(|| {
let mut deps = ::std::vec::Vec::with_capacity(1);
deps.push(super::validate::file_descriptor().clone());
let mut messages = ::std::vec::Vec::with_capacity(2);
let mut messages = ::std::vec::Vec::with_capacity(3);
messages.push(EchoRequest::generated_message_descriptor_data());
messages.push(EchoResponse::generated_message_descriptor_data());
messages.push(EchoRepeatedResponse::generated_message_descriptor_data());
let mut enums = ::std::vec::Vec::with_capacity(0);
::protobuf::reflect::GeneratedFileDescriptor::new_generated(
file_descriptor_proto(),
Expand Down
102 changes: 98 additions & 4 deletions bd-grpc/src/grpc_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
use crate::client::{AddressHelper, Client};
use crate::compression::{Compression, ConnectSafeCompressionLayer};
use crate::connect_protocol::ConnectProtocolType;
use crate::generated::proto::test::{EchoRequest, EchoResponse};
use crate::generated::proto::test::{EchoRepeatedResponse, EchoRequest, EchoResponse};
use crate::stats::EndpointStats;
use crate::{
CONNECT_PROTOCOL_VERSION,
Expand All @@ -25,7 +25,7 @@ use crate::{
StreamingApi,
StreamingApiSender,
make_server_streaming_router,
make_unary_router,
make_unary_router_with_response_mutator,
new_grpc_response,
};
use assert_matches::assert_matches;
Expand Down Expand Up @@ -63,17 +63,31 @@ fn service_method() -> ServiceMethod<EchoRequest, EchoResponse> {
ServiceMethod::<EchoRequest, EchoResponse>::new("Test", "Echo")
}

fn repeated_service_method() -> ServiceMethod<EchoRequest, EchoRepeatedResponse> {
ServiceMethod::<EchoRequest, EchoRepeatedResponse>::new("Test", "EchoRepeated")
}

async fn make_unary_server(
handler: Arc<dyn Handler<EchoRequest, EchoResponse>>,
error_handler: impl Fn(&crate::Error) + Clone + Send + Sync + 'static,
endpoint_stats: Option<&EndpointStats>,
) -> SocketAddr {
let router = make_unary_router(
make_unary_server_with_response_mutator(handler, error_handler, endpoint_stats, None).await
}

async fn make_unary_server_with_response_mutator(
handler: Arc<dyn Handler<EchoRequest, EchoResponse>>,
error_handler: impl Fn(&crate::Error) + Clone + Send + Sync + 'static,
endpoint_stats: Option<&EndpointStats>,
response_mutator: Option<crate::UnaryResponseMutator<EchoResponse>>,
) -> SocketAddr {
let router = make_unary_router_with_response_mutator(
&service_method(),
handler,
error_handler,
endpoint_stats,
true,
response_mutator,
error_handler,
)
.layer(ConnectSafeCompressionLayer::new());
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
Expand Down Expand Up @@ -151,6 +165,10 @@ struct EchoHandler {
do_sleep: bool,
streaming_event_sender: Mutex<Option<mpsc::Receiver<StreamingTestEvent>>>,
}

#[derive(Default)]
struct EchoRepeatedHandler;

enum StreamingTestEvent {
Message(EchoResponse),
EndStreamOk,
Expand All @@ -176,6 +194,18 @@ impl Handler<EchoRequest, EchoResponse> for EchoHandler {
}
}

#[async_trait]
impl Handler<EchoRequest, EchoRepeatedResponse> for EchoRepeatedHandler {
async fn handle(
&self,
_headers: HeaderMap,
_extensions: Extensions,
_request: EchoRequest,
) -> Result<EchoRepeatedResponse> {
Ok(EchoRepeatedResponse::default())
}
}

#[async_trait]
impl ServerStreamingHandler<EchoResponse, EchoRequest> for EchoHandler {
async fn stream(
Expand Down Expand Up @@ -1048,6 +1078,70 @@ async fn unary_json_transcoding() {
assert_eq!(body, serde_json::json!({ "echo": "json_echo" }));
}

#[tokio::test]
async fn unary_json_transcoding_applies_response_mutator() {
let local_address = make_unary_server_with_response_mutator(
Arc::new(EchoHandler::default()),
|_| {},
None,
Some(Arc::new(|response: &mut EchoResponse| {
response.echo = "mutated".to_string();
Ok(())
})),
)
.await;
let client = reqwest::Client::builder().deflate(false).build().unwrap();
let address = AddressHelper::new(format!("http://{local_address}")).unwrap();
let response = client
.post(address.build(&service_method()).to_string())
.header(CONTENT_TYPE, CONTENT_TYPE_JSON)
.body("{\"echo\":\"json_echo\"}")
.send()
.await
.unwrap();
assert_eq!(response.status(), 200);
assert_eq!(
response.headers().get(CONTENT_TYPE).unwrap(),
CONTENT_TYPE_JSON
);
let body: serde_json::Value = serde_json::from_slice(&response.bytes().await.unwrap()).unwrap();
assert_eq!(body, serde_json::json!({ "echo": "mutated" }));
}

#[tokio::test]
async fn unary_json_transcoding_omits_empty_repeated_fields() {
let router = make_unary_router_with_response_mutator(
&repeated_service_method(),
Arc::new(EchoRepeatedHandler),
None,
true,
None,
|_| {},
)
.layer(ConnectSafeCompressionLayer::new());
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_address = listener.local_addr().unwrap();
let server = axum::serve(listener, router.into_make_service());
tokio::spawn(async { server.await.unwrap() });

let client = reqwest::Client::builder().deflate(false).build().unwrap();
let address = AddressHelper::new(format!("http://{local_address}")).unwrap();
let response = client
.post(address.build(&repeated_service_method()).to_string())
.header(CONTENT_TYPE, CONTENT_TYPE_JSON)
.body("{\"echo\":\"json_echo\"}")
.send()
.await
.unwrap();
assert_eq!(response.status(), 200);
assert_eq!(
response.headers().get(CONTENT_TYPE).unwrap(),
CONTENT_TYPE_JSON
);
let body: serde_json::Value = serde_json::from_slice(&response.bytes().await.unwrap()).unwrap();
assert_eq!(body, serde_json::json!({}));
}

#[tokio::test]
async fn server_streaming_json_request() {
let local_address = make_server_streaming_server(Arc::new(EchoHandler::default()), |_| {})
Expand Down
Loading
Loading