Skip to content

Commit 8fa98c6

Browse files
authored
implement more bd-pgv validations (#431)
Also updates bd-grpc to return an Err when not all pgv validations are implemented for the request type in the router when pgv validation is enabled
1 parent 73f3a70 commit 8fa98c6

8 files changed

Lines changed: 1134 additions & 300 deletions

File tree

bd-grpc/src/generated/proto/test.rs

Lines changed: 130 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -390,14 +390,139 @@ impl ::protobuf::reflect::ProtobufValue for EchoRepeatedResponse {
390390
type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage<Self>;
391391
}
392392

393+
// @@protoc_insertion_point(message:test.UnsupportedRequest)
394+
#[derive(PartialEq,Clone,Default,Debug)]
395+
pub struct UnsupportedRequest {
396+
// message fields
397+
// @@protoc_insertion_point(field:test.UnsupportedRequest.echo)
398+
pub echo: ::std::string::String,
399+
// special fields
400+
// @@protoc_insertion_point(special_field:test.UnsupportedRequest.special_fields)
401+
pub special_fields: ::protobuf::SpecialFields,
402+
}
403+
404+
impl<'a> ::std::default::Default for &'a UnsupportedRequest {
405+
fn default() -> &'a UnsupportedRequest {
406+
<UnsupportedRequest as ::protobuf::Message>::default_instance()
407+
}
408+
}
409+
410+
impl UnsupportedRequest {
411+
pub fn new() -> UnsupportedRequest {
412+
::std::default::Default::default()
413+
}
414+
415+
fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData {
416+
let mut fields = ::std::vec::Vec::with_capacity(1);
417+
let mut oneofs = ::std::vec::Vec::with_capacity(0);
418+
fields.push(::protobuf::reflect::rt::v2::make_simpler_field_accessor::<_, _>(
419+
"echo",
420+
|m: &UnsupportedRequest| { &m.echo },
421+
|m: &mut UnsupportedRequest| { &mut m.echo },
422+
));
423+
::protobuf::reflect::GeneratedMessageDescriptorData::new_2::<UnsupportedRequest>(
424+
"UnsupportedRequest",
425+
fields,
426+
oneofs,
427+
)
428+
}
429+
}
430+
431+
impl ::protobuf::Message for UnsupportedRequest {
432+
const NAME: &'static str = "UnsupportedRequest";
433+
434+
fn is_initialized(&self) -> bool {
435+
true
436+
}
437+
438+
fn merge_from(&mut self, is: &mut ::protobuf::CodedInputStream<'_>) -> ::protobuf::Result<()> {
439+
while let Some(tag) = is.read_raw_tag_or_eof()? {
440+
match tag {
441+
10 => {
442+
self.echo = is.read_string()?;
443+
},
444+
tag => {
445+
::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?;
446+
},
447+
};
448+
}
449+
::std::result::Result::Ok(())
450+
}
451+
452+
// Compute sizes of nested messages
453+
#[allow(unused_variables)]
454+
fn compute_size(&self) -> u64 {
455+
let mut my_size = 0;
456+
if !self.echo.is_empty() {
457+
my_size += ::protobuf::rt::string_size(1, &self.echo);
458+
}
459+
my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields());
460+
self.special_fields.cached_size().set(my_size as u32);
461+
my_size
462+
}
463+
464+
fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> {
465+
if !self.echo.is_empty() {
466+
os.write_string(1, &self.echo)?;
467+
}
468+
os.write_unknown_fields(self.special_fields.unknown_fields())?;
469+
::std::result::Result::Ok(())
470+
}
471+
472+
fn special_fields(&self) -> &::protobuf::SpecialFields {
473+
&self.special_fields
474+
}
475+
476+
fn mut_special_fields(&mut self) -> &mut ::protobuf::SpecialFields {
477+
&mut self.special_fields
478+
}
479+
480+
fn new() -> UnsupportedRequest {
481+
UnsupportedRequest::new()
482+
}
483+
484+
fn clear(&mut self) {
485+
self.echo.clear();
486+
self.special_fields.clear();
487+
}
488+
489+
fn default_instance() -> &'static UnsupportedRequest {
490+
static instance: UnsupportedRequest = UnsupportedRequest {
491+
echo: ::std::string::String::new(),
492+
special_fields: ::protobuf::SpecialFields::new(),
493+
};
494+
&instance
495+
}
496+
}
497+
498+
impl ::protobuf::MessageFull for UnsupportedRequest {
499+
fn descriptor() -> ::protobuf::reflect::MessageDescriptor {
500+
static descriptor: ::protobuf::rt::Lazy<::protobuf::reflect::MessageDescriptor> = ::protobuf::rt::Lazy::new();
501+
descriptor.get(|| file_descriptor().message_by_package_relative_name("UnsupportedRequest").unwrap()).clone()
502+
}
503+
}
504+
505+
impl ::std::fmt::Display for UnsupportedRequest {
506+
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
507+
::protobuf::text_format::fmt(self, f)
508+
}
509+
}
510+
511+
impl ::protobuf::reflect::ProtobufValue for UnsupportedRequest {
512+
type RuntimeType = ::protobuf::reflect::rt::RuntimeTypeMessage<Self>;
513+
}
514+
393515
static file_descriptor_proto_data: &'static [u8] = b"\
394516
\n\ntest.proto\x12\x04test\x1a\x17validate/validate.proto\"*\n\x0bEchoRe\
395517
quest\x12\x1b\n\x04echo\x18\x01\x20\x01(\tR\x04echoB\x07\xfaB\x04r\x02\
396518
\x10\x01\"\"\n\x0cEchoResponse\x12\x12\n\x04echo\x18\x01\x20\x01(\tR\x04\
397519
echo\".\n\x14EchoRepeatedResponse\x12\x16\n\x06values\x18\x01\x20\x03(\t\
398-
R\x06values2t\n\x04Test\x12-\n\x04Echo\x12\x11.test.EchoRequest\x1a\x12.\
399-
test.EchoResponse\x12=\n\x0cEchoRepeated\x12\x11.test.EchoRequest\x1a\
400-
\x1a.test.EchoRepeatedResponseb\x06proto3\
520+
R\x06values\"1\n\x12UnsupportedRequest\x12\x1b\n\x04echo\x18\x01\x20\x01\
521+
(\tR\x04echoB\x07\xfaB\x04r\x02(\x012\xb5\x01\n\x04Test\x12-\n\x04Echo\
522+
\x12\x11.test.EchoRequest\x1a\x12.test.EchoResponse\x12=\n\x0cEchoRepeat\
523+
ed\x12\x11.test.EchoRequest\x1a\x1a.test.EchoRepeatedResponse\x12?\n\x0f\
524+
UnsupportedEcho\x12\x18.test.UnsupportedRequest\x1a\x12.test.EchoRespons\
525+
eb\x06proto3\
401526
";
402527

403528
/// `FileDescriptorProto` object which was a source for this generated file
@@ -416,10 +541,11 @@ pub fn file_descriptor() -> &'static ::protobuf::reflect::FileDescriptor {
416541
let generated_file_descriptor = generated_file_descriptor_lazy.get(|| {
417542
let mut deps = ::std::vec::Vec::with_capacity(1);
418543
deps.push(super::validate::file_descriptor().clone());
419-
let mut messages = ::std::vec::Vec::with_capacity(3);
544+
let mut messages = ::std::vec::Vec::with_capacity(4);
420545
messages.push(EchoRequest::generated_message_descriptor_data());
421546
messages.push(EchoResponse::generated_message_descriptor_data());
422547
messages.push(EchoRepeatedResponse::generated_message_descriptor_data());
548+
messages.push(UnsupportedRequest::generated_message_descriptor_data());
423549
let mut enums = ::std::vec::Vec::with_capacity(0);
424550
::protobuf::reflect::GeneratedFileDescriptor::new_generated(
425551
file_descriptor_proto(),

bd-grpc/src/grpc_test.rs

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
use crate::client::{AddressHelper, Client};
99
use crate::compression::{Compression, ConnectSafeCompressionLayer};
1010
use crate::connect_protocol::ConnectProtocolType;
11-
use crate::generated::proto::test::{EchoRepeatedResponse, EchoRequest, EchoResponse};
11+
use crate::generated::proto::test::{
12+
EchoRepeatedResponse,
13+
EchoRequest,
14+
EchoResponse,
15+
UnsupportedRequest,
16+
};
1217
use crate::stats::EndpointStats;
1318
use crate::{
1419
CONNECT_PROTOCOL_VERSION,
@@ -67,6 +72,10 @@ fn repeated_service_method() -> ServiceMethod<EchoRequest, EchoRepeatedResponse>
6772
ServiceMethod::<EchoRequest, EchoRepeatedResponse>::new("Test", "EchoRepeated")
6873
}
6974

75+
fn unsupported_service_method() -> ServiceMethod<UnsupportedRequest, EchoResponse> {
76+
ServiceMethod::<UnsupportedRequest, EchoResponse>::new("Test", "UnsupportedEcho")
77+
}
78+
7079
async fn make_unary_server(
7180
handler: Arc<dyn Handler<EchoRequest, EchoResponse>>,
7281
error_handler: impl Fn(&crate::Error) + Clone + Send + Sync + 'static,
@@ -89,6 +98,7 @@ async fn make_unary_server_with_response_mutator(
8998
response_mutator,
9099
error_handler,
91100
)
101+
.unwrap()
92102
.layer(ConnectSafeCompressionLayer::new());
93103
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
94104
let local_address = listener.local_addr().unwrap();
@@ -111,6 +121,7 @@ async fn make_server_streaming_server(
111121
true,
112122
None,
113123
)
124+
.unwrap()
114125
.layer(ConnectSafeCompressionLayer::new());
115126
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
116127
let local_address = listener.local_addr().unwrap();
@@ -169,6 +180,9 @@ struct EchoHandler {
169180
#[derive(Default)]
170181
struct EchoRepeatedHandler;
171182

183+
#[derive(Default)]
184+
struct UnsupportedHandler;
185+
172186
enum StreamingTestEvent {
173187
Message(EchoResponse),
174188
EndStreamOk,
@@ -206,6 +220,21 @@ impl Handler<EchoRequest, EchoRepeatedResponse> for EchoRepeatedHandler {
206220
}
207221
}
208222

223+
#[async_trait]
224+
impl Handler<UnsupportedRequest, EchoResponse> for UnsupportedHandler {
225+
async fn handle(
226+
&self,
227+
_headers: HeaderMap,
228+
_extensions: Extensions,
229+
request: UnsupportedRequest,
230+
) -> Result<EchoResponse> {
231+
Ok(EchoResponse {
232+
echo: request.echo,
233+
..Default::default()
234+
})
235+
}
236+
}
237+
209238
#[async_trait]
210239
impl ServerStreamingHandler<EchoResponse, EchoRequest> for EchoHandler {
211240
async fn stream(
@@ -247,6 +276,26 @@ impl ServerStreamingHandler<EchoResponse, EchoRequest> for EchoHandler {
247276
}
248277
}
249278

279+
#[async_trait]
280+
impl ServerStreamingHandler<EchoResponse, UnsupportedRequest> for UnsupportedHandler {
281+
async fn stream(
282+
&self,
283+
_headers: HeaderMap,
284+
_extensions: Extensions,
285+
request: UnsupportedRequest,
286+
sender: &mut StreamingApiSender<EchoResponse>,
287+
) -> Result<()> {
288+
sender
289+
.send(EchoResponse {
290+
echo: request.echo,
291+
..Default::default()
292+
})
293+
.await
294+
.unwrap();
295+
Ok(())
296+
}
297+
}
298+
250299
//
251300
// ErrorHandler
252301
//
@@ -996,6 +1045,53 @@ async fn connect_unary_error() {
9961045
);
9971046
}
9981047

1048+
#[test]
1049+
fn make_unary_router_rejects_unsupported_pgv_validation() {
1050+
assert_matches!(
1051+
make_unary_router_with_response_mutator(
1052+
&unsupported_service_method(),
1053+
Arc::new(UnsupportedHandler),
1054+
None,
1055+
true,
1056+
None,
1057+
|_| {},
1058+
),
1059+
Err(Error::ProtoValidation(bd_pgv::error::Error::ProtoValidation(message)))
1060+
if message == "not implemented: string rules max_bytes"
1061+
);
1062+
}
1063+
1064+
#[test]
1065+
fn make_unary_router_allows_unsupported_pgv_when_validation_disabled() {
1066+
assert!(
1067+
make_unary_router_with_response_mutator(
1068+
&unsupported_service_method(),
1069+
Arc::new(UnsupportedHandler),
1070+
None,
1071+
false,
1072+
None,
1073+
|_| {},
1074+
)
1075+
.is_ok()
1076+
);
1077+
}
1078+
1079+
#[test]
1080+
fn make_server_streaming_router_rejects_unsupported_pgv_validation() {
1081+
assert_matches!(
1082+
make_server_streaming_router(
1083+
&unsupported_service_method(),
1084+
Arc::new(UnsupportedHandler),
1085+
|_| {},
1086+
None,
1087+
true,
1088+
None,
1089+
),
1090+
Err(Error::ProtoValidation(bd_pgv::error::Error::ProtoValidation(message)))
1091+
if message == "not implemented: string rules max_bytes"
1092+
);
1093+
}
1094+
9991095
#[tokio::test]
10001096
async fn connect_unary() {
10011097
let local_address = make_unary_server(Arc::new(EchoHandler::default()), |_| {}, None).await;
@@ -1118,6 +1214,7 @@ async fn unary_json_transcoding_omits_empty_repeated_fields() {
11181214
None,
11191215
|_| {},
11201216
)
1217+
.unwrap()
11211218
.layer(ConnectSafeCompressionLayer::new());
11221219
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
11231220
let local_address = listener.local_addr().unwrap();

0 commit comments

Comments
 (0)