From e80434bbfa50b37bcc0bb36f5bc5fb23777d370a Mon Sep 17 00:00:00 2001 From: Saurav Date: Thu, 9 Apr 2026 10:13:22 +0000 Subject: [PATCH 1/2] Add interceptor API --- grpc/src/server/interceptor.rs | 258 +++++++++++++++++++++++++++++++++ grpc/src/server/mod.rs | 2 + 2 files changed, 260 insertions(+) create mode 100644 grpc/src/server/interceptor.rs diff --git a/grpc/src/server/interceptor.rs b/grpc/src/server/interceptor.rs new file mode 100644 index 000000000..f73b6dc56 --- /dev/null +++ b/grpc/src/server/interceptor.rs @@ -0,0 +1,258 @@ +/* + * + * Copyright 2026 gRPC authors. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + * + */ + +use crate::client::CallOptions; +use crate::core::{RequestHeaders, Trailers}; +use crate::server::{Handle, RecvStream, SendStream}; + +/// A trait which allows intercepting an incoming RPC call to a [`Handle`] implementation. +#[trait_variant::make(Send)] +pub trait Intercept: Sync + 'static { + /// Intercepts an incoming call. + /// + /// Implementations can wrap `tx` and `rx` before passing them to `next`. + async fn intercept( + &self, + headers: RequestHeaders, + options: CallOptions, + tx: &mut impl SendStream, + rx: impl RecvStream + 'static, + next: &impl Handle, + ) -> Trailers; +} + +/// Wraps a [`Handle`] and an [`Intercept`] and implements [`Handle`] for the combination. +pub struct Intercepted { + handle: H, + intercept: I, +} + +impl Intercepted { + /// Creates a new `Intercepted` wrapper combining a handle and an interceptor. + pub fn new(handle: H, intercept: I) -> Self { + Self { handle, intercept } + } +} + +impl Handle for Intercepted +where + H: Handle + 'static, + I: Intercept + 'static, +{ + async fn handle( + &self, + headers: RequestHeaders, + options: CallOptions, + tx: &mut impl SendStream, + rx: impl RecvStream + 'static, + ) -> Trailers { + self.intercept + .intercept(headers, options, tx, rx, &self.handle) + .await + } +} + +/// Implements methods for combining [`Handle`] implementations with [`Intercept`] interceptors. +pub trait HandleExt: Handle + Sized { + /// Wraps this [`Handle`] with the given [`Intercept`] interceptor. + fn with_interceptor(self, interceptor: I) -> Intercepted + where + I: Intercept, + { + Intercepted::new(self, interceptor) + } +} + +impl HandleExt for T {} + +#[cfg(test)] +mod test { + use super::*; + use crate::client::CallOptions; + use crate::core::RequestHeaders; + use crate::core::{RecvMessage, ServerResponseStreamItem}; + use crate::server::SendOptions; + use std::sync::Arc; + use tokio::sync::Mutex; + + struct MockSendStream; + impl SendStream for MockSendStream { + async fn send<'a>( + &mut self, + _item: ServerResponseStreamItem<'a>, + _options: SendOptions, + ) -> Result<(), ()> { + Ok(()) + } + } + + struct MockRecvStream; + impl RecvStream for MockRecvStream { + async fn next(&mut self, _msg: &mut dyn RecvMessage) -> Option> { + None + } + } + + struct MockHandler { + called: Arc>, + } + + impl Handle for MockHandler { + async fn handle( + &self, + _headers: RequestHeaders, + _options: CallOptions, + _tx: &mut impl SendStream, + _rx: impl RecvStream + 'static, + ) -> Trailers { + let mut called = self.called.lock().await; + *called = true; + Trailers::new(Ok(())) + } + } + + struct MockInterceptor { + called: Arc>, + } + + impl Intercept for MockInterceptor { + async fn intercept( + &self, + headers: RequestHeaders, + options: CallOptions, + tx: &mut impl SendStream, + rx: impl RecvStream + 'static, + next: &impl Handle, + ) -> Trailers { + let mut called = self.called.lock().await; + *called = true; + drop(called); + next.handle(headers, options, tx, rx).await + } + } + + #[tokio::test] + async fn test_simple_interceptor() { + let handler_called = Arc::new(Mutex::new(false)); + let interceptor_called = Arc::new(Mutex::new(false)); + + let handler = MockHandler { + called: handler_called.clone(), + }; + let interceptor = MockInterceptor { + called: interceptor_called.clone(), + }; + + let chain = handler.with_interceptor(interceptor); + + let mut tx = MockSendStream; + let rx = MockRecvStream; + + chain + .handle( + RequestHeaders::default(), + CallOptions::default(), + &mut tx, + rx, + ) + .await; + + assert!(*interceptor_called.lock().await); + assert!(*handler_called.lock().await); + } + + #[tokio::test] + async fn test_interceptor_chaining_order() { + let order = Arc::new(Mutex::new(Vec::new())); + + struct OrderInterceptor { + id: usize, + order: Arc>>, + } + + impl Intercept for OrderInterceptor { + async fn intercept( + &self, + headers: RequestHeaders, + options: CallOptions, + tx: &mut impl SendStream, + rx: impl RecvStream + 'static, + next: &impl Handle, + ) -> Trailers { + let mut order = self.order.lock().await; + order.push(self.id); + drop(order); + next.handle(headers, options, tx, rx).await + } + } + + struct OrderHandler { + order: Arc>>, + } + + impl Handle for OrderHandler { + async fn handle( + &self, + _h: RequestHeaders, + _o: CallOptions, + _tx: &mut impl SendStream, + _rx: impl RecvStream + 'static, + ) -> Trailers { + let mut order = self.order.lock().await; + order.push(0); // 0 represents the handler + Trailers::new(Ok(())) + } + } + + let handler = OrderHandler { + order: order.clone(), + }; + let int1 = OrderInterceptor { + id: 1, + order: order.clone(), + }; + let int2 = OrderInterceptor { + id: 2, + order: order.clone(), + }; + + // This should run int1 first, then int2, then handler. + let chain = handler.with_interceptor(int2).with_interceptor(int1); + + let mut tx = MockSendStream; + let rx = MockRecvStream; + + chain + .handle( + RequestHeaders::default(), + CallOptions::default(), + &mut tx, + rx, + ) + .await; + + let final_order = order.lock().await; + assert_eq!(*final_order, vec![1, 2, 0]); + } +} diff --git a/grpc/src/server/mod.rs b/grpc/src/server/mod.rs index 486bdafda..01056ee2d 100644 --- a/grpc/src/server/mod.rs +++ b/grpc/src/server/mod.rs @@ -32,6 +32,8 @@ use crate::core::ServerResponseStreamItem; use crate::core::Trailers; use tokio::sync::oneshot; +pub(crate) mod interceptor; + pub struct Server { handler: Option>, } From 986af3c5644e5e21bbd6f578583b81cb24304a5d Mon Sep 17 00:00:00 2001 From: Saurav Date: Thu, 9 Apr 2026 13:03:43 +0000 Subject: [PATCH 2/2] feat(grpc): implement protocol validation interceptor for server streams Enforces strict gRPC stream state transitions and sequence validation on the server side to prevent malformed responses and improper polling. Key Changes: - **Send Stream Validation (`ServerSendStreamValidator`)**: Tracks state transitions (`Init` -> `HeadersSent` -> `MessagesSent` -> `Done`) to guarantee headers are sent exactly once and always precede messages, rejecting invalid operations. - **Receive Stream Validation (`ServerRecvStreamValidator`)**: Prevents erroneous polling by returning stable errors once the receive stream reaches EOF or a terminal error state. - **Preemptive Error Interception (`StreamValidationInterceptor`)**: Wraps streams with channel-aware validators to immediately preempt handler execution via `tokio::select!` upon detecting any protocol violation or underlying transport error, automatically returning an `Internal` status. --- grpc/src/server/handler_validation.rs | 598 ++++++++++++++++++++++++++ grpc/src/server/mod.rs | 22 +- 2 files changed, 616 insertions(+), 4 deletions(-) create mode 100644 grpc/src/server/handler_validation.rs diff --git a/grpc/src/server/handler_validation.rs b/grpc/src/server/handler_validation.rs new file mode 100644 index 000000000..3f74f7260 --- /dev/null +++ b/grpc/src/server/handler_validation.rs @@ -0,0 +1,598 @@ +/* + * + * Copyright 2026 gRPC authors. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + * + */ + +use crate::client::CallOptions; +use crate::core::{RecvMessage, RequestHeaders, ServerResponseStreamItem, Trailers}; +use crate::server::interceptor::Intercept; +use crate::server::{Handle, RecvStream, SendOptions, SendStream}; +use crate::{StatusCodeError, StatusError}; +use tokio::sync::mpsc::channel; + +struct ServerSendStreamValidator { + inner: S, + state: SendStreamState, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum SendStreamState { + Init, + HeadersSent, + MessagesSent, + Done, +} + +impl ServerSendStreamValidator { + fn new(inner: S) -> Self { + Self { + inner, + state: SendStreamState::Init, + } + } +} + +impl SendStream for ServerSendStreamValidator { + async fn send<'a>( + &mut self, + item: ServerResponseStreamItem<'a>, + options: SendOptions, + ) -> Result<(), ()> { + if self.state == SendStreamState::Done { + // Protocol error: Attempted to send an item on a completed or failed stream. + return Err(()); + } + + let next_state = match &item { + ServerResponseStreamItem::Headers(_) => match self.state { + SendStreamState::Init => SendStreamState::HeadersSent, + _ => { + // Protocol error: Received multiple headers frames. + self.state = SendStreamState::Done; + return Err(()); + } + }, + ServerResponseStreamItem::Message(_) => match self.state { + SendStreamState::HeadersSent | SendStreamState::MessagesSent => { + SendStreamState::MessagesSent + } + _ => { + // Protocol error: Attempted to send a message before headers. + self.state = SendStreamState::Done; + return Err(()); + } + }, + }; + + let res = self.inner.send(item, options).await; + match res { + Ok(()) => self.state = next_state, + Err(_) => { + self.state = SendStreamState::Done; + } + } + res + } +} + +struct ServerRecvStreamValidator { + inner: R, + done: bool, +} + +impl ServerRecvStreamValidator { + fn new(inner: R) -> Self { + Self { inner, done: false } + } +} + +impl RecvStream for ServerRecvStreamValidator { + async fn next(&mut self, msg: &mut dyn RecvMessage) -> Option> { + if self.done { + // Protocol error: Attempted to receive a message after reaching a terminal state (EOF or error). + return Some(Err(())); + } + + let res = self.inner.next(msg).await; + match res { + Some(Ok(())) => Some(Ok(())), + None => { + self.done = true; + None + } + Some(Err(())) => { + self.done = true; + Some(Err(())) + } + } + } +} + +struct ChannelAwareSendStreamValidator { + inner: ServerSendStreamValidator, + error_tx: tokio::sync::mpsc::Sender<()>, +} + +impl ChannelAwareSendStreamValidator { + fn new(inner: S, error_tx: tokio::sync::mpsc::Sender<()>) -> Self { + Self { + inner: ServerSendStreamValidator::new(inner), + error_tx, + } + } + + fn report_error(&self) { + let _ = self.error_tx.try_send(()); + } +} + +impl SendStream for ChannelAwareSendStreamValidator { + async fn send<'a>( + &mut self, + item: ServerResponseStreamItem<'a>, + options: SendOptions, + ) -> Result<(), ()> { + let res = self.inner.send(item, options).await; + if res.is_err() { + self.report_error(); + } + res + } +} + +struct ChannelAwareRecvStreamValidator { + inner: ServerRecvStreamValidator, + error_tx: tokio::sync::mpsc::Sender<()>, +} + +impl ChannelAwareRecvStreamValidator { + fn new(inner: R, error_tx: tokio::sync::mpsc::Sender<()>) -> Self { + Self { + inner: ServerRecvStreamValidator::new(inner), + error_tx, + } + } + + fn report_error(&self) { + let _ = self.error_tx.try_send(()); + } +} + +impl RecvStream for ChannelAwareRecvStreamValidator { + async fn next(&mut self, msg: &mut dyn RecvMessage) -> Option> { + let res = self.inner.next(msg).await; + if let Some(Err(())) = res { + self.report_error(); + } + res + } +} + +pub struct StreamValidationInterceptor; + +impl Intercept for StreamValidationInterceptor { + async fn intercept( + &self, + headers: RequestHeaders, + options: CallOptions, + tx: &mut impl SendStream, + rx: impl RecvStream + 'static, + next: &impl Handle, + ) -> Trailers { + let (error_tx, mut error_rx) = channel::<()>(1); + let mut wrapped_tx = ChannelAwareSendStreamValidator::new(tx, error_tx.clone()); + let wrapped_rx = ChannelAwareRecvStreamValidator::new(rx, error_tx); + + tokio::select! { + res = next.handle(headers, options, &mut wrapped_tx, wrapped_rx) => { + if error_rx.try_recv().is_ok() { + Trailers::new(Err(StatusError::new(StatusCodeError::Internal, "Stream validation error"))) + } else { + res + } + } + _ = error_rx.recv() => { + Trailers::new(Err(StatusError::new(StatusCodeError::Internal, "Stream validation error"))) + } + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::StatusCodeError; + use crate::client::CallOptions; + use crate::core::{ + RecvMessage, RequestHeaders, ResponseHeaders, SendMessage, ServerResponseStreamItem, + Trailers, + }; + use crate::server::SendOptions; + use crate::server::interceptor::HandleExt; + use bytes::{Buf, Bytes}; + + impl SendMessage for () { + fn encode(&self) -> Result, String> { + Ok(Box::new(Bytes::new())) + } + } + + struct NopRecvMessage; + impl RecvMessage for NopRecvMessage { + fn decode(&mut self, _data: &mut dyn Buf) -> Result<(), String> { + Ok(()) + } + } + + struct MockSendStream; + impl SendStream for MockSendStream { + async fn send<'a>( + &mut self, + _item: ServerResponseStreamItem<'a>, + _options: SendOptions, + ) -> Result<(), ()> { + Ok(()) + } + } + + struct FailingMockSendStream; + impl SendStream for FailingMockSendStream { + async fn send<'a>( + &mut self, + _item: ServerResponseStreamItem<'a>, + _options: SendOptions, + ) -> Result<(), ()> { + Err(()) + } + } + + struct ConfigurableMockRecvStream { + items: Vec>>, + index: usize, + } + + impl ConfigurableMockRecvStream { + fn new(items: Vec>>) -> Self { + Self { items, index: 0 } + } + } + + impl RecvStream for ConfigurableMockRecvStream { + async fn next(&mut self, _msg: &mut dyn RecvMessage) -> Option> { + if self.index < self.items.len() { + let res = self.items[self.index]; + self.index += 1; + res + } else { + None + } + } + } + + #[tokio::test] + async fn test_interceptor_successful_multi_message_streaming() { + struct StreamingHandler; + impl Handle for StreamingHandler { + async fn handle( + &self, + _headers: RequestHeaders, + _options: CallOptions, + tx: &mut impl SendStream, + mut rx: impl RecvStream + 'static, + ) -> Trailers { + let mut msg = NopRecvMessage; + let mut recv_count = 0; + while let Some(Ok(())) = rx.next(&mut msg).await { + recv_count += 1; + } + assert_eq!(recv_count, 3); + + tx.send( + ServerResponseStreamItem::Headers(ResponseHeaders::default()), + SendOptions::default(), + ) + .await + .unwrap(); + + tx.send( + ServerResponseStreamItem::Message(&()), + SendOptions::default(), + ) + .await + .unwrap(); + tx.send( + ServerResponseStreamItem::Message(&()), + SendOptions::default(), + ) + .await + .unwrap(); + + Trailers::new(Ok(())) + } + } + + let chain = StreamingHandler.with_interceptor(StreamValidationInterceptor); + let mut tx = MockSendStream; + // Stream providing 3 valid messages followed by EOF + let rx = + ConfigurableMockRecvStream::new(vec![Some(Ok(())), Some(Ok(())), Some(Ok(())), None]); + + let trailers = chain + .handle( + RequestHeaders::default(), + CallOptions::default(), + &mut tx, + rx, + ) + .await; + assert!(trailers.status().is_ok()); + } + + #[tokio::test] + async fn test_interceptor_successful_trailers_only_response() { + struct TrailersOnlyHandler; + impl Handle for TrailersOnlyHandler { + async fn handle( + &self, + _headers: RequestHeaders, + _options: CallOptions, + _tx: &mut impl SendStream, + _rx: impl RecvStream + 'static, + ) -> Trailers { + // Send no headers or messages; return trailers directly. + Trailers::new(Ok(())) + } + } + + let chain = TrailersOnlyHandler.with_interceptor(StreamValidationInterceptor); + let mut tx = MockSendStream; + let rx = ConfigurableMockRecvStream::new(vec![None]); + + let trailers = chain + .handle( + RequestHeaders::default(), + CallOptions::default(), + &mut tx, + rx, + ) + .await; + assert!(trailers.status().is_ok()); + } + + #[tokio::test] + async fn test_interceptor_sending_headers_twice() { + struct DoubleHeadersHandler; + impl Handle for DoubleHeadersHandler { + async fn handle( + &self, + _headers: RequestHeaders, + _options: CallOptions, + tx: &mut impl SendStream, + _rx: impl RecvStream + 'static, + ) -> Trailers { + // First headers frame should succeed + let _ = tx + .send( + ServerResponseStreamItem::Headers(ResponseHeaders::default()), + SendOptions::default(), + ) + .await; + + // Second headers frame violates protocol sequence; pure error-breaking loop/termination + if tx + .send( + ServerResponseStreamItem::Headers(ResponseHeaders::default()), + SendOptions::default(), + ) + .await + .is_err() + { + return Trailers::new(Ok(())); + } + + Trailers::new(Ok(())) + } + } + + let chain = DoubleHeadersHandler.with_interceptor(StreamValidationInterceptor); + let mut tx = MockSendStream; + let rx = ConfigurableMockRecvStream::new(vec![None]); + + let trailers = chain + .handle( + RequestHeaders::default(), + CallOptions::default(), + &mut tx, + rx, + ) + .await; + let err = trailers.status().as_ref().unwrap_err(); + assert_eq!(err.code(), StatusCodeError::Internal); + assert!(err.message().contains("Stream validation error")); + } + + #[tokio::test] + async fn test_interceptor_underlying_send_error() { + struct SendFailureHandler; + impl Handle for SendFailureHandler { + async fn handle( + &self, + _headers: RequestHeaders, + _options: CallOptions, + tx: &mut impl SendStream, + _rx: impl RecvStream + 'static, + ) -> Trailers { + // Valid sequence, but underlying transport fails; loop terminates purely on error + loop { + if tx + .send( + ServerResponseStreamItem::Headers(ResponseHeaders::default()), + SendOptions::default(), + ) + .await + .is_err() + { + break; + } + } + Trailers::new(Ok(())) + } + } + + let chain = SendFailureHandler.with_interceptor(StreamValidationInterceptor); + let mut tx = FailingMockSendStream; + let rx = ConfigurableMockRecvStream::new(vec![None]); + + let trailers = chain + .handle( + RequestHeaders::default(), + CallOptions::default(), + &mut tx, + rx, + ) + .await; + let err = trailers.status().as_ref().unwrap_err(); + assert_eq!(err.code(), StatusCodeError::Internal); + assert!(err.message().contains("Stream validation error")); + } + + #[tokio::test] + async fn test_interceptor_terminal_receive_error() { + struct ActiveRecvErrorHandler; + impl Handle for ActiveRecvErrorHandler { + async fn handle( + &self, + _headers: RequestHeaders, + _options: CallOptions, + _tx: &mut impl SendStream, + mut rx: impl RecvStream + 'static, + ) -> Trailers { + let mut msg = NopRecvMessage; + while let Some(Ok(())) = rx.next(&mut msg).await {} + Trailers::new(Ok(())) + } + } + + let chain = ActiveRecvErrorHandler.with_interceptor(StreamValidationInterceptor); + let mut tx = MockSendStream; + // Stream encounters terminal receive error actively + let rx = ConfigurableMockRecvStream::new(vec![Some(Err(()))]); + + let trailers = chain + .handle( + RequestHeaders::default(), + CallOptions::default(), + &mut tx, + rx, + ) + .await; + let err = trailers.status().as_ref().unwrap_err(); + assert_eq!(err.code(), StatusCodeError::Internal); + assert!(err.message().contains("Stream validation error")); + } + + #[tokio::test] + async fn test_interceptor_poll_after_done() { + struct DoneRecvStream; + impl RecvStream for DoneRecvStream { + async fn next(&mut self, _msg: &mut dyn RecvMessage) -> Option> { + None + } + } + struct PollAfterDoneHandler; + impl Handle for PollAfterDoneHandler { + async fn handle( + &self, + _h: RequestHeaders, + _o: CallOptions, + _tx: &mut impl SendStream, + mut rx: impl RecvStream + 'static, + ) -> Trailers { + let mut msg = NopRecvMessage; + assert!(rx.next(&mut msg).await.is_none()); + // Polling after None triggers validation error and preemption + let res = rx.next(&mut msg).await; + assert!(matches!(res, Some(Err(())))); + Trailers::new(Ok(())) + } + } + let chain = PollAfterDoneHandler.with_interceptor(StreamValidationInterceptor); + let mut tx = MockSendStream; + let rx = DoneRecvStream; + let trailers = chain + .handle( + RequestHeaders::default(), + CallOptions::default(), + &mut tx, + rx, + ) + .await; + let err = trailers.status().as_ref().unwrap_err(); + assert_eq!(err.code(), StatusCodeError::Internal); + assert!(err.message().contains("Stream validation error")); + } + + #[tokio::test] + async fn test_interceptor_send_message_before_headers() { + struct MessageBeforeHeadersHandler; + impl Handle for MessageBeforeHeadersHandler { + async fn handle( + &self, + _headers: RequestHeaders, + _options: CallOptions, + tx: &mut impl SendStream, + _rx: impl RecvStream + 'static, + ) -> Trailers { + // Invalid sequence: message before headers; pure error-breaking termination + loop { + if tx + .send( + ServerResponseStreamItem::Message(&()), + SendOptions::default(), + ) + .await + .is_err() + { + break; + } + } + Trailers::new(Ok(())) + } + } + + let chain = MessageBeforeHeadersHandler.with_interceptor(StreamValidationInterceptor); + let mut tx = MockSendStream; + let rx = ConfigurableMockRecvStream::new(vec![None]); + + let trailers = chain + .handle( + RequestHeaders::default(), + CallOptions::default(), + &mut tx, + rx, + ) + .await; + let err = trailers.status().as_ref().unwrap_err(); + assert_eq!(err.code(), StatusCodeError::Internal); + assert!(err.message().contains("Stream validation error")); + } +} diff --git a/grpc/src/server/mod.rs b/grpc/src/server/mod.rs index 01056ee2d..3d4f92622 100644 --- a/grpc/src/server/mod.rs +++ b/grpc/src/server/mod.rs @@ -32,6 +32,7 @@ use crate::core::ServerResponseStreamItem; use crate::core::Trailers; use tokio::sync::oneshot; +pub(crate) mod handler_validation; pub(crate) mod interceptor; pub struct Server { @@ -120,10 +121,23 @@ impl DynHandle for T { &self, headers: RequestHeaders, options: CallOptions, - mut tx: &mut dyn DynSendStream, + tx: &mut dyn DynSendStream, rx: BoxedRecvStream, ) -> Trailers { - self.handle(headers, options, &mut tx, rx).await + let mut tx_wrapper = SendStreamRef(tx); + self.handle(headers, options, &mut tx_wrapper, rx).await + } +} + +struct SendStreamRef<'a>(&'a mut dyn DynSendStream); + +impl<'a> SendStream for SendStreamRef<'a> { + async fn send<'b>( + &mut self, + item: ServerResponseStreamItem<'b>, + options: SendOptions, + ) -> Result<(), ()> { + self.0.dyn_send(item, options).await } } @@ -188,13 +202,13 @@ impl DynSendStream for T { } } -impl<'b> SendStream for &mut (dyn DynSendStream + 'b) { +impl SendStream for &mut T { async fn send<'a>( &mut self, item: ServerResponseStreamItem<'a>, options: SendOptions, ) -> Result<(), ()> { - (**self).dyn_send(item, options).await + (**self).send(item, options).await } }