Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions bd-grpc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ log.workspace = true
mockall = { workspace = true, optional = true }
prometheus.workspace = true
protobuf.workspace = true
protobuf-json-mapping.workspace = true
serde.workspace = true
serde_json.workspace = true
snap.workspace = true
Expand Down
39 changes: 39 additions & 0 deletions bd-grpc/src/grpc_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::stats::EndpointStats;
use crate::{
CONNECT_PROTOCOL_VERSION,
CONTENT_TYPE,
CONTENT_TYPE_JSON,
CONTENT_TYPE_PROTO,
Code,
Error,
Expand Down Expand Up @@ -1026,6 +1027,44 @@ async fn connect_unary() {
);
}

#[tokio::test]
async fn unary_json_transcoding() {
let local_address = make_unary_server(Arc::new(EchoHandler::default()), |_| {}, None).await;
let client = reqwest::Client::builder().deflate(false).build().unwrap();
let address = AddressHelper::new(format!("http://{local_address}")).unwrap();
let response = client
.post(address.build(&service_method()).to_string())
.header(CONTENT_TYPE, CONTENT_TYPE_JSON)
.body("{\"echo\":\"json_echo\"}")
.send()
.await
.unwrap();
assert_eq!(response.status(), 200);
assert_eq!(
response.headers().get(CONTENT_TYPE).unwrap(),
CONTENT_TYPE_JSON
);
let body: serde_json::Value = serde_json::from_slice(&response.bytes().await.unwrap()).unwrap();
assert_eq!(body, serde_json::json!({ "echo": "json_echo" }));
}

#[tokio::test]
async fn server_streaming_json_request() {
let local_address = make_server_streaming_server(Arc::new(EchoHandler::default()), |_| {})
.await
.0;
let client = reqwest::Client::builder().deflate(false).build().unwrap();
let address = AddressHelper::new(format!("http://{local_address}")).unwrap();
let response = client
.post(address.build(&service_method()).to_string())
.header(CONTENT_TYPE, CONTENT_TYPE_JSON)
.body("{\"echo\":\"json_echo\"}")
.send()
.await
.unwrap();
assert_eq!(response.status(), 200);
}

#[tokio::test]
async fn connect_unary_error_stats() {
let stats = Helper::new();
Expand Down
39 changes: 39 additions & 0 deletions bd-grpc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ use http::{Extensions, HeaderMap};
use http_body::Frame;
use http_body_util::{BodyExt, StreamBody};
use protobuf::{Message, MessageFull};
use protobuf_json_mapping::{ParseOptions, PrintOptions};
use service::ServiceMethod;
use stats::{BandwidthStatsSummary, EndpointStats, StreamStats};
use status::Status;
Expand Down Expand Up @@ -486,6 +487,11 @@ async fn decode_request<Message: MessageFull>(
let message = if matches!(connect_protocol_type, Some(ConnectProtocolType::Unary)) {
Message::parse_from_tokio_bytes(&body_bytes)
.map_err(|e| Status::new(Code::InvalidArgument, format!("Invalid request: {e}")))?
} else if is_json_request_content_type(&parts.headers) {
let body_str = std::str::from_utf8(&body_bytes)
.map_err(|e| Status::new(Code::InvalidArgument, format!("Invalid request: {e}")))?;
protobuf_json_mapping::parse_from_str_with_options(body_str, &json_parse_options())
.map_err(|e| Status::new(Code::InvalidArgument, format!("Invalid request: {e}")))?
} else {
let mut grpc_decoder =
Decoder::<Message>::new(finalize_decompression(&parts.headers), OptimizeFor::Cpu);
Expand All @@ -504,6 +510,25 @@ async fn decode_request<Message: MessageFull>(
Ok((parts.headers, parts.extensions, message))
}

fn is_json_request_content_type(headers: &HeaderMap) -> bool {
headers
.get(CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.and_then(|value| value.split(';').next())
.is_some_and(|value| value.trim().eq_ignore_ascii_case(CONTENT_TYPE_JSON))
}

fn json_parse_options() -> ParseOptions {
ParseOptions::default()
}

fn json_print_options() -> PrintOptions {
PrintOptions {
proto_field_name: true,
..Default::default()
}
}

async fn unary_connect_handler<OutgoingType: MessageFull, IncomingType: MessageFull>(
headers: HeaderMap,
extensions: Extensions,
Expand All @@ -527,6 +552,8 @@ pub async fn unary_handler<OutgoingType: MessageFull, IncomingType: MessageFull>
) -> Result<Response> {
let (headers, extensions, message) =
decode_request::<OutgoingType>(request, validate_request, connect_protocol_type).await?;
let json_transcoding = connect_protocol_type.is_none() && is_json_request_content_type(&headers);

if matches!(connect_protocol_type, Some(ConnectProtocolType::Unary)) {
return unary_connect_handler(headers, extensions, message, handler).await;
}
Expand All @@ -538,6 +565,18 @@ pub async fn unary_handler<OutgoingType: MessageFull, IncomingType: MessageFull>

let response = handler.handle(headers, extensions, message).await?;

if json_transcoding {
let json =
protobuf_json_mapping::print_to_string_with_options(&response, &json_print_options())
.map_err(|e| Status::new(Code::Internal, format!("Failed to encode response: {e}")))?;
return Ok(
Response::builder()
.header(CONTENT_TYPE, CONTENT_TYPE_JSON)
.body(json.into())
.unwrap(),
);
}

let (tx, rx) = mpsc::channel::<std::result::Result<_, Infallible>>(2);

let mut encoder = Encoder::new(compression);
Expand Down
2 changes: 0 additions & 2 deletions bd-workspace-hack/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ winnow = { version = "0.7" }

[target.aarch64-apple-darwin.dependencies]
errno = { version = "0.3" }
getrandom-9fbad63c4bcf4a8f = { package = "getrandom", version = "0.4", default-features = false, features = ["std", "sys_rng"] }
hyper-util = { version = "0.1", default-features = false, features = ["client-proxy", "client-proxy-system"] }
libc = { version = "0.2", default-features = false, features = ["extra_traits"] }
miniz_oxide = { version = "0.8", default-features = false, features = ["simd", "with-alloc"] }
Expand All @@ -110,7 +109,6 @@ tower-http = { version = "0.6", features = ["compression-deflate", "decompressio

[target.aarch64-apple-darwin.build-dependencies]
errno = { version = "0.3" }
getrandom-9fbad63c4bcf4a8f = { package = "getrandom", version = "0.4", default-features = false, features = ["std", "sys_rng"] }
hyper-util = { version = "0.1", default-features = false, features = ["client-proxy", "client-proxy-system"] }
libc = { version = "0.2", default-features = false, features = ["extra_traits"] }
miniz_oxide = { version = "0.8", default-features = false, features = ["simd", "with-alloc"] }
Expand Down
Loading