diff --git a/interop/Cargo.toml b/interop/Cargo.toml index 83dfb2043..cbe19ea8a 100644 --- a/interop/Cargo.toml +++ b/interop/Cargo.toml @@ -23,7 +23,7 @@ http-body-util = "0.1" prost = "0.14" tokio = {version = "1.0", features = ["rt-multi-thread", "time", "macros"]} tokio-stream = "0.1" -tonic = {path = "../tonic", features = ["tls-ring"]} +tonic = {path = "../tonic", features = ["tls-ring", "gzip"]} tonic-prost = {path = "../tonic-prost"} tower = "0.5" tracing-subscriber = {version = "0.3"} @@ -35,6 +35,8 @@ protobuf = { version = "4.34.0-release" } tonic-protobuf = {path = "../tonic-protobuf"} grpc-protobuf = {path = "../grpc-protobuf"} rustls = { version = "0.23", default-features = false, features = ["ring"] } +base64 = "0.22" + [build-dependencies] tonic-prost-build = {path = "../tonic-prost-build"} diff --git a/interop/src/bin/client.rs b/interop/src/bin/client.rs index 7c2124123..de120d0d0 100644 --- a/interop/src/bin/client.rs +++ b/interop/src/bin/client.rs @@ -16,11 +16,28 @@ use tonic::transport::Certificate; use tonic::transport::ClientTlsConfig; use tonic::transport::Endpoint; +#[allow(dead_code)] #[derive(Debug)] struct Opts { use_tls: bool, test_case: Vec, codec: Codec, + server_host: String, + server_port: u16, + server_host_override: Option, + use_test_ca: bool, + default_service_account: Option, + oauth_scope: Option, + service_account_key_file: Option, + service_config_json: Option, + additional_metadata: Option, + google_c2p_universe_domain: Option, + soak_iterations: usize, + soak_max_failures: usize, + soak_per_iteration_max_acceptable_latency_ms: u32, + soak_overall_timeout_seconds: Option, + soak_min_time_ms_between_rpcs: u32, + soak_num_threads: usize, } #[derive(Debug)] @@ -50,6 +67,35 @@ impl Opts { test_case.split(',').map(Testcase::from_str).collect() })?, codec: pargs.value_from_str("--codec")?, + server_host: pargs + .opt_value_from_str("--server_host")? + .unwrap_or_else(|| "localhost".to_string()), + server_port: pargs.opt_value_from_str("--server_port")?.unwrap_or(10000), + server_host_override: pargs.opt_value_from_str("--server_host_override")?, + use_test_ca: match pargs.opt_value_from_str::<_, bool>("--use_test_ca") { + Ok(Some(val)) => val, + Ok(None) => true, + Err(_) => true, + }, + default_service_account: pargs.opt_value_from_str("--default_service_account")?, + oauth_scope: pargs.opt_value_from_str("--oauth_scope")?, + service_account_key_file: pargs.opt_value_from_str("--service_account_key_file")?, + service_config_json: pargs.opt_value_from_str("--service_config_json")?, + additional_metadata: pargs.opt_value_from_str("--additional_metadata")?, + google_c2p_universe_domain: pargs.opt_value_from_str("--google_c2p_universe_domain")?, + soak_iterations: pargs.opt_value_from_str("--soak_iterations")?.unwrap_or(10), + soak_max_failures: pargs + .opt_value_from_str("--soak_max_failures")? + .unwrap_or(0), + soak_per_iteration_max_acceptable_latency_ms: pargs + .opt_value_from_str("--soak_per_iteration_max_acceptable_latency_ms")? + .unwrap_or(1000), + soak_overall_timeout_seconds: pargs + .opt_value_from_str("--soak_overall_timeout_seconds")?, + soak_min_time_ms_between_rpcs: pargs + .opt_value_from_str("--soak_min_time_ms_between_rpcs")? + .unwrap_or(0), + soak_num_threads: pargs.opt_value_from_str("--soak_num_threads")?.unwrap_or(1), }) } } @@ -62,53 +108,102 @@ async fn main() -> Result<(), Box> { let test_cases = matches.test_case; + let additional_metadata = if let Some(ref am) = matches.additional_metadata { + let mut map = tonic::metadata::MetadataMap::new(); + for pair in am.split(';') { + if pair.is_empty() { + continue; + } + if let Some(colon_idx) = pair.find(':') { + let (key_str, val_str) = pair.split_at(colon_idx); + let val_str = &val_str[1..]; // strip the leading colon + let key_str = key_str.trim(); + let val_str = val_str.trim(); + + if key_str.ends_with("-bin") { + use base64::Engine; + let decoded_val = base64::engine::general_purpose::STANDARD.decode(val_str)?; + let key = tonic::metadata::BinaryMetadataKey::from_str(key_str)?; + let value = tonic::metadata::MetadataValue::from_bytes(&decoded_val); + map.insert_bin(key, value); + } else { + let key = tonic::metadata::AsciiMetadataKey::from_str(key_str)?; + let value = tonic::metadata::MetadataValue::try_from(val_str)?; + map.insert(key, value); + } + } + } + Some(map) + } else { + None + }; + let (mut client, mut unimplemented_client): ( Box, Box, ) = match matches.codec { Codec::Prost => { let scheme = if matches.use_tls { "https" } else { "http" }; - let mut endpoint = Endpoint::try_from(format!("{scheme}://localhost:10000"))? + let host = &matches.server_host; + let port = matches.server_port; + let mut endpoint = Endpoint::try_from(format!("{scheme}://{host}:{port}"))? .timeout(Duration::from_secs(5)) .concurrency_limit(30); if matches.use_tls { - let pem = std::fs::read_to_string("interop/data/ca.pem")?; - let ca = Certificate::from_pem(pem); - endpoint = endpoint.tls_config( - ClientTlsConfig::new() - .ca_certificate(ca) - .domain_name("foo.test.google.fr"), - )?; + let mut tls_config = ClientTlsConfig::new(); + if matches.use_test_ca { + let pem = std::fs::read_to_string("interop/data/ca.pem")?; + let ca = Certificate::from_pem(pem); + tls_config = tls_config.ca_certificate(ca); + } + let domain_name = matches + .server_host_override + .as_deref() + .unwrap_or("foo.test.google.fr"); + tls_config = tls_config.domain_name(domain_name); + endpoint = endpoint.tls_config(tls_config)?; } let channel = endpoint.connect().await?; + let interceptor = interop::client::MetadataInterceptor { + metadata: additional_metadata.unwrap_or_default(), + }; ( - Box::new(client_prost::TestClient::new(channel.clone())), - Box::new(client_prost::UnimplementedClient::new(channel)), + Box::new(client_prost::TestClient::new( + tonic::codegen::InterceptedService::new(channel.clone(), interceptor.clone()), + )), + Box::new(client_prost::UnimplementedClient::new( + tonic::codegen::InterceptedService::new(channel, interceptor), + )), ) } Codec::Protobuf => { + let host = &matches.server_host; + let port = matches.server_port; + let target_uri = format!("dns:///{host}:{port}"); + let channel = if matches.use_tls { let _ = rustls::crypto::ring::default_provider().install_default(); - let pem = std::fs::read_to_string("interop/data/ca.pem")?; - let root_certs = RootCertificates::from_pem(pem); - let creds = RustlsChannelCredendials::new( - GrpcClientTlsConfig::new() - .with_root_certificates_provider(StaticProvider::new(root_certs)), - )?; - let channel_options = - ChannelOptions::default().override_authority("test.test.google.fr"); - grpc::client::Channel::new( - "dns:///localhost:10000", - Arc::new(creds), - channel_options, - ) + let mut tls_config = GrpcClientTlsConfig::new(); + if matches.use_test_ca { + let pem = std::fs::read_to_string("interop/data/ca.pem")?; + let root_certs = RootCertificates::from_pem(pem); + tls_config = + tls_config.with_root_certificates_provider(StaticProvider::new(root_certs)); + } + let creds = RustlsChannelCredendials::new(tls_config)?; + let domain_name = matches + .server_host_override + .as_deref() + .unwrap_or("test.test.google.fr"); + let channel_options = ChannelOptions::default().override_authority(domain_name); + grpc::client::Channel::new(&target_uri, Arc::new(creds), channel_options) } else { grpc::client::Channel::new( - "dns:///localhost:10000", + &target_uri, Arc::new(LocalChannelCredentials::new()), ChannelOptions::default(), ) @@ -129,6 +224,13 @@ async fn main() -> Result<(), Box> { match test_case { Testcase::EmptyUnary => client.empty_unary(&mut test_results).await, + Testcase::CacheableUnary => client.cacheable_unary(&mut test_results).await, + Testcase::ClientCompressedUnary => { + client.client_compressed_unary(&mut test_results).await + } + Testcase::ServerCompressedUnary => { + client.server_compressed_unary(&mut test_results).await + } Testcase::LargeUnary => client.large_unary(&mut test_results).await, Testcase::ClientStreaming => client.client_streaming(&mut test_results).await, Testcase::ServerStreaming => client.server_streaming(&mut test_results).await, @@ -147,6 +249,13 @@ async fn main() -> Result<(), Box> { .await } Testcase::CustomMetadata => client.custom_metadata(&mut test_results).await, + Testcase::CancelAfterBegin => client.cancel_after_begin(&mut test_results).await, + Testcase::CancelAfterFirstResponse => { + client.cancel_after_first_response(&mut test_results).await + } + Testcase::TimeoutOnSleepingServer => { + client.timeout_on_sleeping_server(&mut test_results).await + } _ => unimplemented!(), } diff --git a/interop/src/bin/server.rs b/interop/src/bin/server.rs index b0a5d8e69..52db29524 100644 --- a/interop/src/bin/server.rs +++ b/interop/src/bin/server.rs @@ -7,6 +7,8 @@ use tonic::transport::{Identity, ServerTlsConfig}; struct Opts { use_tls: bool, codec: Codec, + port: u16, + address_type: AddressType, } #[derive(Debug)] @@ -27,12 +29,36 @@ impl FromStr for Codec { } } +#[derive(Debug, Clone, Copy)] +enum AddressType { + Ipv4, + Ipv6, + Ipv4Ipv6, +} + +impl FromStr for AddressType { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_uppercase().as_str() { + "IPV4" => Ok(AddressType::Ipv4), + "IPV6" => Ok(AddressType::Ipv6), + "IPV4_IPV6" => Ok(AddressType::Ipv4Ipv6), + _ => Err(format!("Invalid address type: {}", s)), + } + } +} + impl Opts { fn parse() -> Result { let mut pargs = pico_args::Arguments::from_env(); Ok(Self { use_tls: pargs.contains("--use_tls"), codec: pargs.value_from_str("--codec")?, + port: pargs.opt_value_from_str("--port")?.unwrap_or(10000), + address_type: pargs + .opt_value_from_str("--address_type")? + .unwrap_or(AddressType::Ipv4Ipv6), }) } } @@ -43,7 +69,12 @@ async fn main() -> std::result::Result<(), Box> { let matches = Opts::parse()?; - let addr = "127.0.0.1:10000".parse().unwrap(); + let host = match matches.address_type { + AddressType::Ipv4 => "127.0.0.1", + AddressType::Ipv6 => "[::1]", + AddressType::Ipv4Ipv6 => "[::]", + }; + let addr = format!("{host}:{}", matches.port).parse().unwrap(); let mut builder = Server::builder(); @@ -58,7 +89,9 @@ async fn main() -> std::result::Result<(), Box> { match matches.codec { Codec::Prost => { let test_service = - server_prost::TestServiceServer::new(server_prost::TestService::default()); + server_prost::TestServiceServer::new(server_prost::TestService::default()) + .accept_compressed(tonic::codec::CompressionEncoding::Gzip) + .send_compressed(tonic::codec::CompressionEncoding::Gzip); let unimplemented_service = server_prost::UnimplementedServiceServer::new( server_prost::UnimplementedService::default(), ); diff --git a/interop/src/client.rs b/interop/src/client.rs index 1e448d652..efac6b4d6 100644 --- a/interop/src/client.rs +++ b/interop/src/client.rs @@ -1,6 +1,30 @@ use crate::TestAssertion; use tonic::async_trait; +#[derive(Clone)] +pub struct MetadataInterceptor { + pub metadata: tonic::metadata::MetadataMap, +} + +impl tonic::service::Interceptor for MetadataInterceptor { + fn call( + &mut self, + mut request: tonic::Request<()>, + ) -> Result, tonic::Status> { + for key_and_val in self.metadata.iter() { + match key_and_val { + tonic::metadata::KeyAndValueRef::Ascii(key, val) => { + request.metadata_mut().insert(key.clone(), val.clone()); + } + tonic::metadata::KeyAndValueRef::Binary(key, val) => { + request.metadata_mut().insert_bin(key.clone(), val.clone()); + } + } + } + Ok(request) + } +} + #[async_trait] pub trait InteropTest: Send { async fn empty_unary(&mut self, assertions: &mut Vec); @@ -22,6 +46,18 @@ pub trait InteropTest: Send { async fn unimplemented_method(&mut self, assertions: &mut Vec); async fn custom_metadata(&mut self, assertions: &mut Vec); + + async fn cacheable_unary(&mut self, assertions: &mut Vec); + + async fn client_compressed_unary(&mut self, assertions: &mut Vec); + + async fn server_compressed_unary(&mut self, assertions: &mut Vec); + + async fn cancel_after_begin(&mut self, assertions: &mut Vec); + + async fn cancel_after_first_response(&mut self, assertions: &mut Vec); + + async fn timeout_on_sleeping_server(&mut self, assertions: &mut Vec); } #[async_trait] diff --git a/interop/src/client_prost.rs b/interop/src/client_prost.rs index 6790513c1..cfe8db990 100644 --- a/interop/src/client_prost.rs +++ b/interop/src/client_prost.rs @@ -9,8 +9,12 @@ use tonic::async_trait; use tonic::transport::Channel; use tonic::{Code, Request, Response, Status, metadata::MetadataValue}; -pub type TestClient = TestServiceClient; -pub type UnimplementedClient = UnimplementedServiceClient; +pub type TestClient = TestServiceClient< + tonic::codegen::InterceptedService, +>; +pub type UnimplementedClient = UnimplementedServiceClient< + tonic::codegen::InterceptedService, +>; const LARGE_REQ_SIZE: usize = 271_828; const LARGE_RSP_SIZE: i32 = 314_159; @@ -384,6 +388,312 @@ impl InteropTest for TestClient { format!("result={:?}", trailers.get_bin(key1)) )); } + + async fn cacheable_unary(&mut self, assertions: &mut Vec) { + let timestamp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos() + .to_string(); + let payload = Payload { + body: timestamp.into_bytes(), + ..Default::default() + }; + let req = SimpleRequest { + response_type: PayloadType::Compressable as i32, + payload: Some(payload), + ..Default::default() + }; + + let mut req1 = Request::new(req.clone()); + req1.metadata_mut() + .insert("x-user-ip", "1.2.3.4".parse().unwrap()); + + let result1 = self.cacheable_unary_call(req1).await; + + assertions.push(test_assert!( + "first call must be successful", + result1.is_ok(), + format!("result={:?}", result1) + )); + + let mut req2 = Request::new(req); + req2.metadata_mut() + .insert("x-user-ip", "1.2.3.4".parse().unwrap()); + let result2 = self.cacheable_unary_call(req2).await; + + assertions.push(test_assert!( + "second call must be successful", + result2.is_ok(), + format!("result={:?}", result2) + )); + + if let (Ok(res1), Ok(res2)) = (result1, result2) { + let body1 = res1.into_inner(); + let body2 = res2.into_inner(); + assertions.push(test_assert!( + "payload body of both responses is the same", + body1 == body2, + format!("body1={:?}, body2={:?}", body1, body2) + )); + } + } + + async fn client_compressed_unary(&mut self, assertions: &mut Vec) { + // 1. Probe + let req = SimpleRequest { + expect_compressed: Some(crate::pb::BoolValue { value: true }), + response_size: LARGE_RSP_SIZE, + payload: Some(crate::client_payload(LARGE_REQ_SIZE)), + ..Default::default() + }; + let result = self.unary_call(Request::new(req.clone())).await; + assertions.push(test_assert!( + "First call failed with INVALID_ARGUMENT status", + match &result { + Err(status) => status.code() == Code::InvalidArgument, + _ => false, + }, + format!("result={:?}", result) + )); + + // 2. Compressed + let mut compressed_client = self + .clone() + .send_compressed(tonic::codec::CompressionEncoding::Gzip); + let result = compressed_client + .unary_call(Request::new(req.clone())) + .await; + assertions.push(test_assert!( + "Second call (compressed) must be successful", + result.is_ok(), + format!("result={:?}", result) + )); + if let Ok(response) = result { + let body = response.into_inner(); + assertions.push(test_assert!( + "response payload body is 314159 bytes in size", + body.payload.as_ref().map_or(0, |p| p.body.len()) == LARGE_RSP_SIZE as usize, + format!( + "body.payload.len={:?}", + body.payload.as_ref().map(|p| p.body.len()) + ) + )); + } + + // 3. Uncompressed + let req = SimpleRequest { + expect_compressed: Some(crate::pb::BoolValue { value: false }), + response_size: LARGE_RSP_SIZE, + payload: Some(crate::client_payload(LARGE_REQ_SIZE)), + ..Default::default() + }; + let result = self.unary_call(Request::new(req)).await; + assertions.push(test_assert!( + "Third call (uncompressed) must be successful", + result.is_ok(), + format!("result={:?}", result) + )); + if let Ok(response) = result { + let body = response.into_inner(); + assertions.push(test_assert!( + "response payload body is 314159 bytes in size", + body.payload.as_ref().map_or(0, |p| p.body.len()) == LARGE_RSP_SIZE as usize, + format!( + "body.payload.len={:?}", + body.payload.as_ref().map(|p| p.body.len()) + ) + )); + } + } + + async fn server_compressed_unary(&mut self, assertions: &mut Vec) { + // 1. Request compressed response + let req = SimpleRequest { + response_compressed: Some(crate::pb::BoolValue { value: true }), + response_size: LARGE_RSP_SIZE, + payload: Some(crate::client_payload(LARGE_REQ_SIZE)), + ..Default::default() + }; + + let mut client = self + .clone() + .accept_compressed(tonic::codec::CompressionEncoding::Gzip); + + let result = client.unary_call(Request::new(req.clone())).await; + + assertions.push(test_assert!( + "Call with response_compressed=true must be successful", + result.is_ok(), + format!("result={:?}", result) + )); + + if let Ok(response) = result { + assertions.push(test_assert!( + "Response must have grpc-encoding: gzip", + response.metadata().get("grpc-encoding") + == Some(&tonic::metadata::MetadataValue::from_static("gzip")), + format!("metadata={:?}", response.metadata()) + )); + let body = response.into_inner(); + assertions.push(test_assert!( + "response payload body is 314159 bytes in size", + body.payload.as_ref().map_or(0, |p| p.body.len()) == LARGE_RSP_SIZE as usize, + format!( + "body.payload.len={:?}", + body.payload.as_ref().map(|p| p.body.len()) + ) + )); + } + + // 2. Request uncompressed response + let req = SimpleRequest { + response_compressed: Some(crate::pb::BoolValue { value: false }), + response_size: LARGE_RSP_SIZE, + payload: Some(crate::client_payload(LARGE_REQ_SIZE)), + ..Default::default() + }; + + let result = client.unary_call(Request::new(req)).await; + + assertions.push(test_assert!( + "Call with response_compressed=false must be successful", + result.is_ok(), + format!("result={:?}", result) + )); + + if let Ok(response) = result { + let body = response.into_inner(); + assertions.push(test_assert!( + "response payload body is 314159 bytes in size", + body.payload.as_ref().map_or(0, |p| p.body.len()) == LARGE_RSP_SIZE as usize, + format!( + "body.payload.len={:?}", + body.payload.as_ref().map(|p| p.body.len()) + ) + )); + } + } + + async fn cancel_after_begin(&mut self, assertions: &mut Vec) { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel::(); + let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx); + + let mut client = self.clone(); + + let handle = + tokio::spawn(async move { client.streaming_input_call(Request::new(stream)).await }); + + handle.abort(); + + let result = handle.await; + + assertions.push(test_assert!( + "Call must be cancelled", + match &result { + Err(e) => e.is_cancelled(), + _ => false, + }, + format!("result={:?}", result) + )); + + // Suppress unused variable warning for tx + drop(tx); + } + + async fn cancel_after_first_response(&mut self, assertions: &mut Vec) { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel::(); + tx.send(make_ping_pong_request(0)).unwrap(); + + let (signal_tx, mut signal_rx) = tokio::sync::mpsc::channel(1); + + let mut client = self.clone(); + + let handle = tokio::spawn(async move { + let response = client + .full_duplex_call(Request::new( + tokio_stream::wrappers::UnboundedReceiverStream::new(rx), + )) + .await?; + let mut stream = response.into_inner(); + let first_msg = stream.next().await; + + // Notify outside + signal_tx.send(first_msg).await.unwrap(); + + // Wait forever to be cancelled + std::future::pending::<()>().await; + + Ok::<_, Status>(()) + }); + + // Wait for signal + let first_msg = signal_rx.recv().await; + + let success = matches!(&first_msg, Some(Some(Ok(_)))); + assertions.push(test_assert!( + "Received first response", + success, + format!("first_msg={:?}", first_msg) + )); + + // Cancel the task + handle.abort(); + + let result = handle.await; + + assertions.push(test_assert!( + "Call must be cancelled", + match &result { + Err(e) => e.is_cancelled(), + _ => false, + }, + format!("result={:?}", result) + )); + + drop(tx); + } + + async fn timeout_on_sleeping_server(&mut self, assertions: &mut Vec) { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel::(); + + let mut req = make_ping_pong_request(0); + if let Some(param) = req.response_parameters.first_mut() { + param.interval_us = 100000; + } + tx.send(req).unwrap(); + + let mut client = self.clone(); + + let mut request = Request::new(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)); + request.set_timeout(std::time::Duration::from_millis(50)); + + let result = client.full_duplex_call(request).await; + + // For streaming calls, the timeout might occur during the stream poll, + // and Tonic might return it as a Status or it might be handled differently. + // But usually it returns Err(Status) with DeadlineExceeded. + + assertions.push(test_assert!( + "Initial call was successful", + result.is_ok(), + format!("result={:?}", result) + )); + + if let Ok(response) = result { + let mut stream = response.into_inner(); + let stream_result = + tokio::time::timeout(std::time::Duration::from_millis(50), stream.next()).await; + + assertions.push(test_assert!( + "Stream must time out (DEADLINE_EXCEEDED)", + stream_result.is_err(), + format!("stream_result={:?}", stream_result) + )); + } + + drop(tx); + } } #[async_trait] diff --git a/interop/src/client_protobuf.rs b/interop/src/client_protobuf.rs index e063be946..a86a19384 100644 --- a/interop/src/client_protobuf.rs +++ b/interop/src/client_protobuf.rs @@ -407,6 +407,100 @@ impl InteropTest for TestClient { format!("result={:?}", response_trailers.get_bin(key1)) )); } + + async fn cacheable_unary(&mut self, assertions: &mut Vec) { + let timestamp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos() + .to_string(); + let req = proto!(SimpleRequest { + response_type: PayloadType::Compressable, + payload: proto!(Payload { + body: timestamp.into_bytes(), + }), + }); + + let mut md1 = MetadataMap::new(); + md1.insert("x-user-ip", "1.2.3.4".parse().unwrap()); + let attacher1 = AttachHeadersInterceptor::new(md1); + + let result1 = self + .cacheable_unary_call(req.clone()) + .with_interceptor(attacher1) + .await; + + assertions.push(test_assert!( + "first call must be successful", + result1.is_ok(), + format!("result={:?}", result1) + )); + + let mut md2 = MetadataMap::new(); + md2.insert("x-user-ip", "1.2.3.4".parse().unwrap()); + let attacher2 = AttachHeadersInterceptor::new(md2); + + let result2 = self + .cacheable_unary_call(req) + .with_interceptor(attacher2) + .await; + + assertions.push(test_assert!( + "second call must be successful", + result2.is_ok(), + format!("result={:?}", result2) + )); + + if let (Ok(res1), Ok(res2)) = (result1, result2) { + let body1 = res1.payload().body(); + let body2 = res2.payload().body(); + assertions.push(test_assert!( + "payload body of both responses is the same", + body1 == body2, + format!("body1={:?}, body2={:?}", body1, body2) + )); + } + } + + async fn client_compressed_unary(&mut self, assertions: &mut Vec) { + assertions.push(test_assert!( + "client_compressed_unary is implemented for protobuf client", + false, + "Not implemented".to_string() + )); + } + + async fn server_compressed_unary(&mut self, assertions: &mut Vec) { + assertions.push(test_assert!( + "server_compressed_unary is implemented for protobuf client", + false, + "Not implemented".to_string() + )); + } + + async fn cancel_after_begin(&mut self, assertions: &mut Vec) { + assertions.push(test_assert!( + "cancel_after_begin is implemented for protobuf client", + false, + "Not implemented".to_string() + )); + } + + async fn cancel_after_first_response(&mut self, assertions: &mut Vec) { + assertions.push(test_assert!( + "cancel_after_first_response is implemented for protobuf client", + false, + "Not implemented".to_string() + )); + } + + async fn timeout_on_sleeping_server(&mut self, assertions: &mut Vec) { + assertions.push(test_assert!( + "timeout_on_sleeping_server is implemented for protobuf client", + false, + "Not implemented".to_string() + )); + } } #[async_trait] diff --git a/interop/src/server_prost.rs b/interop/src/server_prost.rs index 6a8d12228..6a2215b14 100644 --- a/interop/src/server_prost.rs +++ b/interop/src/server_prost.rs @@ -29,8 +29,20 @@ impl pb::test_service_server::TestService for TestService { } async fn unary_call(&self, request: Request) -> Result { + let is_compressed = request.metadata().get("grpc-encoding") + == Some(&tonic::metadata::MetadataValue::from_static("gzip")); + let req = request.into_inner(); + if let Some(expect_compressed) = req.expect_compressed { + if expect_compressed.value && !is_compressed { + return Err(Status::new( + Code::InvalidArgument, + "Requested compression but message was not compressed", + )); + } + } + if let Some(echo_status) = req.response_status { let status = Status::new(Code::from_i32(echo_status.code), echo_status.message); return Err(status); @@ -51,11 +63,24 @@ impl pb::test_service_server::TestService for TestService { ..Default::default() }; - Ok(Response::new(res)) + let mut response = Response::new(res); + let compress = req.response_compressed.map_or(false, |v| v.value); + if !compress { + response.disable_compression(); + } + Ok(response) } - async fn cacheable_unary_call(&self, _: Request) -> Result { - unimplemented!() + async fn cacheable_unary_call( + &self, + request: Request, + ) -> Result { + let req = request.into_inner(); + let res = SimpleResponse { + payload: req.payload, + ..Default::default() + }; + Ok(Response::new(res)) } type StreamingOutputCallStream = Stream;