diff --git a/ginepro/src/balanced_channel.rs b/ginepro/src/balanced_channel.rs index a8d3dce..ffa1446 100644 --- a/ginepro/src/balanced_channel.rs +++ b/ginepro/src/balanced_channel.rs @@ -2,7 +2,7 @@ //! periodic service discovery. use crate::{ - service_probe::{GrpcServiceProbe, GrpcServiceProbeConfig}, + service_probe::{EndpointMiddleware, GrpcServiceProbe, GrpcServiceProbeConfig}, DnsResolver, LookupService, ServiceDefinition, }; use anyhow::Context as _; @@ -97,13 +97,14 @@ pub enum ResolutionStrategy { } /// Builder to configure and create a [`LoadBalancedChannel`]. -pub struct LoadBalancedChannelBuilder { +pub struct LoadBalancedChannelBuilder { service_definition: S, probe_interval: Option, resolution_strategy: ResolutionStrategy, timeout: Option, tls_config: Option, lookup_service: Pin>>>, + middleware: M, } impl LoadBalancedChannelBuilder @@ -117,7 +118,7 @@ where /// All the service endpoints of a [`ServiceDefinition`] will be /// constructed by resolving all ips from [`ServiceDefinition::hostname`], and /// using the portnumber [`ServiceDefinition::port`]. - pub fn new_with_service(service_definition: S) -> LoadBalancedChannelBuilder { + pub fn new_with_service(service_definition: S) -> Self { Self { service_definition, probe_interval: None, @@ -125,14 +126,22 @@ where tls_config: None, lookup_service: Box::pin(DnsResolver::from_system_config()), resolution_strategy: ResolutionStrategy::Lazy, + middleware: (), } } +} +impl LoadBalancedChannelBuilder +where + S: TryInto + 'static, + S::Error: Into> + Send + Sync, + M: EndpointMiddleware, +{ /// Set a custom [`LookupService`]. - pub fn lookup_service( + pub fn lookup_service( self, - lookup_service: T, - ) -> LoadBalancedChannelBuilder { + lookup_service: Lookup, + ) -> LoadBalancedChannelBuilder { LoadBalancedChannelBuilder { lookup_service: Box::pin(async { Ok(lookup_service) }), service_definition: self.service_definition, @@ -140,18 +149,13 @@ where tls_config: self.tls_config, timeout: self.timeout, resolution_strategy: self.resolution_strategy, + middleware: self.middleware, } } -} -impl LoadBalancedChannelBuilder -where - S: TryInto + 'static, - S::Error: Into> + Send + Sync, -{ /// Set the how often, the client should probe for changes to gRPC server endpoints. /// Default interval in seconds is 10. - pub fn dns_probe_interval(self, interval: Duration) -> LoadBalancedChannelBuilder { + pub fn dns_probe_interval(self, interval: Duration) -> Self { Self { probe_interval: Some(interval), ..self @@ -159,7 +163,7 @@ where } /// Set a timeout that will be applied to every new `Endpoint`. - pub fn timeout(self, timeout: Duration) -> LoadBalancedChannelBuilder { + pub fn timeout(self, timeout: Duration) -> Self { Self { timeout: Some(timeout), ..self @@ -175,10 +179,7 @@ where /// Instead, if [`ResolutionStrategy::Eager`] is set the domain name will be attempted resolved /// once before the [`LoadBalancedChannel`] is created, which ensures that the channel /// will have a non-empty of IPs on startup. If it fails the channel creation will also fail. - pub fn resolution_strategy( - self, - resolution_strategy: ResolutionStrategy, - ) -> LoadBalancedChannelBuilder { + pub fn resolution_strategy(self, resolution_strategy: ResolutionStrategy) -> Self { Self { resolution_strategy, ..self @@ -187,13 +188,29 @@ where /// Configure the channel to use tls. /// A `tls_config` MUST be specified to use the `HTTPS` scheme. - pub fn with_tls(self, tls_config: ClientTlsConfig) -> LoadBalancedChannelBuilder { + pub fn with_tls(self, tls_config: ClientTlsConfig) -> Self { Self { tls_config: Some(tls_config), ..self } } + /// Adds an endpoint middleware layer that lets you add custom configuration + pub fn with_endpoint_layer( + self, + layer: Layer, + ) -> LoadBalancedChannelBuilder { + LoadBalancedChannelBuilder { + lookup_service: self.lookup_service, + service_definition: self.service_definition, + probe_interval: self.probe_interval, + tls_config: self.tls_config, + timeout: self.timeout, + resolution_strategy: self.resolution_strategy, + middleware: (layer, self.middleware), + } + } + /// Construct a [`LoadBalancedChannel`] from the [`LoadBalancedChannelBuilder`] instance. pub async fn channel(self) -> Result { let (channel, sender) = Channel::balance_channel(GRPC_REPORT_ENDPOINTS_CHANNEL_SIZE); @@ -213,16 +230,15 @@ where .unwrap_or_else(|| Duration::from_secs(10)), }; - let tls_config = self.tls_config.map(|mut tls_config| { + let tls_config = self.tls_config.map(|tls_config| { // Since we resolve the hostname to an IP, which is not a valid DNS name, // we have to set the hostname explicitly on the tls config, // otherwise the IP will be set as the domain name and tls handshake will fail. - tls_config = tls_config.domain_name(config.service_definition.hostname()); - - tls_config + tls_config.domain_name(config.service_definition.hostname()) }); - let mut service_probe = GrpcServiceProbe::new_with_reporter(config, sender); + let mut service_probe = + GrpcServiceProbe::new_with_reporter(config, sender, self.middleware); if let Some(tls_config) = tls_config { service_probe = service_probe.with_tls(tls_config); diff --git a/ginepro/src/lib.rs b/ginepro/src/lib.rs index 6f07c88..3075b8e 100644 --- a/ginepro/src/lib.rs +++ b/ginepro/src/lib.rs @@ -121,6 +121,27 @@ //! } //! ``` //! +//! If needed, you can use the [`with_endpoint_layer`](LoadBalancedChannelBuilder::with_endpoint_layer) +//! method to add more configuration to the channel endpoints +//! +//! ```rust +//! #[tokio::main] +//! async fn main() { +//! use ginepro::LoadBalancedChannel; +//! use shared_proto::pb::tester_client::TesterClient; +//! use tonic::transport::Endpoint; +//! +//! // Create a load balanced channel with the default lookup implementation and a custom User-Agent. +//! let load_balanced_channel = LoadBalancedChannel::builder(("my.hostname", 5000)) +//! .with_endpoint_layer(|endpoint: Endpoint| endpoint.user_agent("my ginepro client").ok()) +//! .channel() +//! .await +//! .expect("failed to construct LoadBalancedChannel"); +//! +//! let tester_client = TesterClient::new(load_balanced_channel); +//! } +//! ``` +//! //! # Internals //! The tonic [`Channel`](tonic::transport::Channel) exposes the function //! [`balance_channel`](tonic::transport::Channel::balance_channel) which returnes a bounded channel through which @@ -137,3 +158,4 @@ pub use balanced_channel::*; pub use dns_resolver::*; pub use lookup_service::*; pub use service_definition::*; +pub use service_probe::EndpointMiddleware; diff --git a/ginepro/src/service_probe.rs b/ginepro/src/service_probe.rs index 7d056ec..b9664f3 100644 --- a/ginepro/src/service_probe.rs +++ b/ginepro/src/service_probe.rs @@ -13,6 +13,33 @@ pub enum ProbeError { ChangesetSenderClosed(#[source] anyhow::Error), } +/// A middleware to wrap an `Endpoint`. Useful for setting new endpoint configuration. +pub trait EndpointMiddleware: Send + Sync + 'static { + fn wrap(&self, endpoint: Endpoint) -> Option; +} + +impl EndpointMiddleware for (Head, Tail) +where + Head: EndpointMiddleware, + Tail: EndpointMiddleware, +{ + fn wrap(&self, endpoint: Endpoint) -> Option { + self.0.wrap(self.1.wrap(endpoint)?) + } +} + +impl EndpointMiddleware for () { + fn wrap(&self, endpoint: Endpoint) -> Option { + Some(endpoint) + } +} + +impl Option + Send + Sync + 'static> EndpointMiddleware for F { + fn wrap(&self, endpoint: Endpoint) -> Option { + (self)(endpoint) + } +} + /// [`GrpcServiceProbe`] looks up IP addresses associated with the configured `host_name` /// once every `probe_interval`. /// If a new IP address is discovered or an old one disappears it notifies the [`tonic`] gRPC client. @@ -27,9 +54,10 @@ pub enum ProbeError { /// and we have not instructed the removal of that server's address from the /// set of endpoints known to the tonic client. /// -pub struct GrpcServiceProbe +pub struct GrpcServiceProbe where Lookup: LookupService, + Middleware: EndpointMiddleware, { service_definition: ServiceDefinition, scheme: http::uri::Scheme, @@ -40,6 +68,7 @@ where endpoints: HashSet, endpoint_reporter: Sender>, tls_config: Option, + middleware: Middleware, } /// Config parameters to customize the behavior of `GrpcServiceProbe`. @@ -58,13 +87,14 @@ where pub endpoint_timeout: Option, } -impl GrpcServiceProbe { +impl GrpcServiceProbe { /// Construct `GrpcServiceProbe` with a `GrpcServiceProbeConfig` and /// the channel `endpoint_reporter` that will send endpoint changes. pub fn new_with_reporter( config: GrpcServiceProbeConfig, endpoint_reporter: Sender>, - ) -> GrpcServiceProbe { + middleware: Middleware, + ) -> Self { Self { service_definition: config.service_definition, dns_lookup: config.dns_lookup, @@ -74,18 +104,16 @@ impl GrpcServiceProbe { endpoint_reporter, scheme: http::uri::Scheme::HTTP, tls_config: None, + middleware, } } /// Enable tls for all endpoints. - pub fn with_tls(self, tls_config: ClientTlsConfig) -> GrpcServiceProbe { - Self { - tls_config: Some(tls_config), - scheme: http::uri::Scheme::HTTPS, - ..self - } + pub fn with_tls(mut self, tls_config: ClientTlsConfig) -> Self { + self.scheme = http::uri::Scheme::HTTPS; + self.tls_config = Some(tls_config); + self } - /// Start probing the provided `hostname` for IP address changes. /// The function will error if the receiving end of the tonic balance channel /// is closed, e.g, the client has been deconstructed. @@ -225,6 +253,6 @@ impl GrpcServiceProbe { endpoint = endpoint.timeout(*timeout); } - Some(endpoint) + self.middleware.wrap(endpoint) } } diff --git a/tests/tests/all/lookup.rs b/tests/tests/all/lookup.rs index dab8a7c..4ea2b00 100644 --- a/tests/tests/all/lookup.rs +++ b/tests/tests/all/lookup.rs @@ -34,6 +34,30 @@ impl Tester for TesterImpl { } } +#[derive(Clone)] +pub struct UserAgentTesterImpl; + +#[async_trait::async_trait] +impl Tester for UserAgentTesterImpl { + async fn test(&self, req: tonic::Request) -> Result, Status> { + let ua = req + .metadata() + .get("user-agent") + .ok_or_else(|| Status::new(tonic::Code::PermissionDenied, "no user agent supplied"))?; + + let ua = ua + .to_str() + .map_err(|_e| { + Status::new(tonic::Code::InvalidArgument, "non utf8 user agent supplied") + })? + .to_owned(); + + Ok(tonic::Response::new(Pong { + payload: Some(Payload::Raw(ua)), + })) + } +} + #[derive(Clone)] pub struct TestDnsResolver { pub ips: Arc>>, diff --git a/tests/tests/all/service_probe.rs b/tests/tests/all/service_probe.rs index 52378e7..4b8552e 100644 --- a/tests/tests/all/service_probe.rs +++ b/tests/tests/all/service_probe.rs @@ -1,5 +1,6 @@ use crate::lookup::TestDnsResolver; use crate::lookup::TesterImpl; +use crate::lookup::UserAgentTesterImpl; use ginepro::{LoadBalancedChannel, LoadBalancedChannelBuilder, LookupService, ServiceDefinition}; use shared_proto::pb::pong::Payload; use shared_proto::pb::tester_client::TesterClient; @@ -8,6 +9,7 @@ use std::sync::Arc; use std::{collections::HashSet, net::SocketAddr}; use std::{net::AddrParseError, time::Duration}; use tokio::sync::Mutex; +use tonic::transport::Endpoint; fn get_payload_raw(payload: Payload) -> String { match payload { @@ -248,15 +250,75 @@ async fn builder_and_resolve_shall_succeed_when_ips_are_returned() { } } - assert!( - LoadBalancedChannel::builder(ServiceDefinition::from_parts("test.com", 5000).unwrap(),) - .lookup_service(SucceedResolve) - .timeout(tokio::time::Duration::from_millis(500)) - .resolution_strategy(ginepro::ResolutionStrategy::Eager { - timeout: Duration::from_secs(20), - }) - .channel() + LoadBalancedChannel::builder(ServiceDefinition::from_parts("test.com", 5000).unwrap()) + .timeout(tokio::time::Duration::from_millis(500)) + .resolution_strategy(ginepro::ResolutionStrategy::Eager { + timeout: Duration::from_secs(20), + }) + .channel() + .await + .unwrap(); +} + +#[tokio::test] +async fn builder_with_middleware_layers() { + let uris = Arc::new(std::sync::Mutex::new(Vec::new())); + let uris2 = Arc::clone(&uris); + + let mut resolver = TestDnsResolver::default(); + let probe_interval = tokio::time::Duration::from_millis(3); + + let load_balanced_channel = LoadBalancedChannel::builder(("www.test.com", 5000)) + .lookup_service(resolver.clone()) + .dns_probe_interval(probe_interval) + .with_endpoint_layer(|endpoint: Endpoint| Some(endpoint.concurrency_limit(1))) + .with_endpoint_layer(move |endpoint: Endpoint| { + // record the uri so we can assert that all the layers are run + uris2.lock().unwrap().push(endpoint.uri().clone()); + Some(endpoint) + }) + .with_endpoint_layer(|endpoint: Endpoint| endpoint.user_agent("my ginepro client").ok()) + .channel() + .await + .unwrap(); + let mut client = TesterClient::new(load_balanced_channel); + + assert!(uris.lock().unwrap().is_empty()); // no URIs yet, no layers called + + // add a new server and check that the layers are run + { + resolver + .add_server_with_provided_impl("server2".to_string(), UserAgentTesterImpl) + .await; + + // Give time to the DNS probe to run + tokio::time::sleep(probe_interval * 3).await; + + assert_eq!(uris.lock().unwrap().len(), 1); // new URI registered, layers called + } + + // check that our endpoint actually has the user agent we configured + { + let res = client + .test(tonic::Request::new(Ping {})) .await - .is_ok() - ); + .expect("failed to call server"); + + assert!( + get_payload_raw(res.into_inner().payload.expect("no payload")) + .starts_with("my ginepro client") + ); + } + + // add a new server and check that the layers are run again + { + resolver + .add_server_with_provided_impl("server2".to_string(), UserAgentTesterImpl) + .await; + + // Give time to the DNS probe to run + tokio::time::sleep(probe_interval * 3).await; + + assert_eq!(uris.lock().unwrap().len(), 2); // new URI registered, layers called + } }