Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 40 additions & 24 deletions ginepro/src/balanced_channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//! periodic service discovery.

use crate::{
service_probe::{GrpcServiceProbe, GrpcServiceProbeConfig},
service_probe::{EndpointMiddleware, GrpcServiceProbe, GrpcServiceProbeConfig},
DnsResolver, LookupService, ServiceDefinition,
};
use anyhow::Context as _;
Expand Down Expand Up @@ -97,13 +97,14 @@ pub enum ResolutionStrategy {
}

/// Builder to configure and create a [`LoadBalancedChannel`].
pub struct LoadBalancedChannelBuilder<T, S> {
pub struct LoadBalancedChannelBuilder<T, S, M = ()> {
service_definition: S,
probe_interval: Option<Duration>,
resolution_strategy: ResolutionStrategy,
timeout: Option<Duration>,
tls_config: Option<ClientTlsConfig>,
lookup_service: Pin<Box<dyn Future<Output = Result<T, anyhow::Error>>>>,
middleware: M,
}

impl<S> LoadBalancedChannelBuilder<DnsResolver, S>
Expand All @@ -117,49 +118,52 @@ where
/// All the service endpoints of a [`ServiceDefinition`] will be
/// constructed by resolving all ips from [`ServiceDefinition::hostname`], and
/// using the portnumber [`ServiceDefinition::port`].
pub fn new_with_service(service_definition: S) -> LoadBalancedChannelBuilder<DnsResolver, S> {
pub fn new_with_service(service_definition: S) -> Self {
Self {
service_definition,
probe_interval: None,
timeout: None,
tls_config: None,
lookup_service: Box::pin(DnsResolver::from_system_config()),
resolution_strategy: ResolutionStrategy::Lazy,
middleware: (),
}
}
}

impl<T: LookupService + Send + Sync + 'static + Sized, S, M> LoadBalancedChannelBuilder<T, S, M>
where
S: TryInto<ServiceDefinition> + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync,
M: EndpointMiddleware,
{
/// Set a custom [`LookupService`].
pub fn lookup_service<T: LookupService + Send + Sync + 'static>(
pub fn lookup_service<Lookup: LookupService + Send + Sync + 'static>(
self,
lookup_service: T,
) -> LoadBalancedChannelBuilder<T, S> {
lookup_service: Lookup,
) -> LoadBalancedChannelBuilder<Lookup, S, M> {
LoadBalancedChannelBuilder {
lookup_service: Box::pin(async { Ok(lookup_service) }),
service_definition: self.service_definition,
probe_interval: self.probe_interval,
tls_config: self.tls_config,
timeout: self.timeout,
resolution_strategy: self.resolution_strategy,
middleware: self.middleware,
}
}
}

impl<T: LookupService + Send + Sync + 'static + Sized, S> LoadBalancedChannelBuilder<T, S>
where
S: TryInto<ServiceDefinition> + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync,
{
/// Set the how often, the client should probe for changes to gRPC server endpoints.
/// Default interval in seconds is 10.
pub fn dns_probe_interval(self, interval: Duration) -> LoadBalancedChannelBuilder<T, S> {
pub fn dns_probe_interval(self, interval: Duration) -> Self {
Self {
probe_interval: Some(interval),
..self
}
}

/// Set a timeout that will be applied to every new `Endpoint`.
pub fn timeout(self, timeout: Duration) -> LoadBalancedChannelBuilder<T, S> {
pub fn timeout(self, timeout: Duration) -> Self {
Self {
timeout: Some(timeout),
..self
Expand All @@ -175,10 +179,7 @@ where
/// Instead, if [`ResolutionStrategy::Eager`] is set the domain name will be attempted resolved
/// once before the [`LoadBalancedChannel`] is created, which ensures that the channel
/// will have a non-empty of IPs on startup. If it fails the channel creation will also fail.
pub fn resolution_strategy(
self,
resolution_strategy: ResolutionStrategy,
) -> LoadBalancedChannelBuilder<T, S> {
pub fn resolution_strategy(self, resolution_strategy: ResolutionStrategy) -> Self {
Self {
resolution_strategy,
..self
Expand All @@ -187,13 +188,29 @@ where

/// Configure the channel to use tls.
/// A `tls_config` MUST be specified to use the `HTTPS` scheme.
pub fn with_tls(self, tls_config: ClientTlsConfig) -> LoadBalancedChannelBuilder<T, S> {
pub fn with_tls(self, tls_config: ClientTlsConfig) -> Self {
Self {
tls_config: Some(tls_config),
..self
}
}

/// Adds an endpoint middleware layer that lets you add custom configuration
pub fn with_endpoint_layer<Layer: EndpointMiddleware>(
self,
layer: Layer,
) -> LoadBalancedChannelBuilder<T, S, (Layer, M)> {
LoadBalancedChannelBuilder {
lookup_service: self.lookup_service,
service_definition: self.service_definition,
probe_interval: self.probe_interval,
tls_config: self.tls_config,
timeout: self.timeout,
resolution_strategy: self.resolution_strategy,
middleware: (layer, self.middleware),
}
}

/// Construct a [`LoadBalancedChannel`] from the [`LoadBalancedChannelBuilder`] instance.
pub async fn channel(self) -> Result<LoadBalancedChannel, anyhow::Error> {
let (channel, sender) = Channel::balance_channel(GRPC_REPORT_ENDPOINTS_CHANNEL_SIZE);
Expand All @@ -213,16 +230,15 @@ where
.unwrap_or_else(|| Duration::from_secs(10)),
};

let tls_config = self.tls_config.map(|mut tls_config| {
let tls_config = self.tls_config.map(|tls_config| {
// Since we resolve the hostname to an IP, which is not a valid DNS name,
// we have to set the hostname explicitly on the tls config,
// otherwise the IP will be set as the domain name and tls handshake will fail.
tls_config = tls_config.domain_name(config.service_definition.hostname());

tls_config
tls_config.domain_name(config.service_definition.hostname())
});

let mut service_probe = GrpcServiceProbe::new_with_reporter(config, sender);
let mut service_probe =
GrpcServiceProbe::new_with_reporter(config, sender, self.middleware);

if let Some(tls_config) = tls_config {
service_probe = service_probe.with_tls(tls_config);
Expand Down
22 changes: 22 additions & 0 deletions ginepro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,27 @@
//! }
//! ```
//!
//! If needed, you can use the [`with_endpoint_layer`](LoadBalancedChannelBuilder::with_endpoint_layer)
//! method to add more configuration to the channel endpoints
//!
//! ```rust
//! #[tokio::main]
//! async fn main() {
//! use ginepro::LoadBalancedChannel;
//! use shared_proto::pb::tester_client::TesterClient;
//! use tonic::transport::Endpoint;
//!
//! // Create a load balanced channel with the default lookup implementation and a custom User-Agent.
//! let load_balanced_channel = LoadBalancedChannel::builder(("my.hostname", 5000))
//! .with_endpoint_layer(|endpoint: Endpoint| endpoint.user_agent("my ginepro client").ok())
//! .channel()
//! .await
//! .expect("failed to construct LoadBalancedChannel");
//!
//! let tester_client = TesterClient::new(load_balanced_channel);
//! }
//! ```
//!
//! # Internals
//! The tonic [`Channel`](tonic::transport::Channel) exposes the function
//! [`balance_channel`](tonic::transport::Channel::balance_channel) which returnes a bounded channel through which
Expand All @@ -137,3 +158,4 @@ pub use balanced_channel::*;
pub use dns_resolver::*;
pub use lookup_service::*;
pub use service_definition::*;
pub use service_probe::EndpointMiddleware;
50 changes: 39 additions & 11 deletions ginepro/src/service_probe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,33 @@ pub enum ProbeError {
ChangesetSenderClosed(#[source] anyhow::Error),
}

/// A middleware to wrap an `Endpoint`. Useful for setting new endpoint configuration.
pub trait EndpointMiddleware: Send + Sync + 'static {
fn wrap(&self, endpoint: Endpoint) -> Option<Endpoint>;
}

impl<Head, Tail> EndpointMiddleware for (Head, Tail)
where
Head: EndpointMiddleware,
Tail: EndpointMiddleware,
{
fn wrap(&self, endpoint: Endpoint) -> Option<Endpoint> {
self.0.wrap(self.1.wrap(endpoint)?)
}
}

impl EndpointMiddleware for () {
fn wrap(&self, endpoint: Endpoint) -> Option<Endpoint> {
Some(endpoint)
}
}

impl<F: Fn(Endpoint) -> Option<Endpoint> + Send + Sync + 'static> EndpointMiddleware for F {
fn wrap(&self, endpoint: Endpoint) -> Option<Endpoint> {
(self)(endpoint)
}
}

/// [`GrpcServiceProbe`] looks up IP addresses associated with the configured `host_name`
/// once every `probe_interval`.
/// If a new IP address is discovered or an old one disappears it notifies the [`tonic`] gRPC client.
Expand All @@ -27,9 +54,10 @@ pub enum ProbeError {
/// and we have not instructed the removal of that server's address from the
/// set of endpoints known to the tonic client.
///
pub struct GrpcServiceProbe<Lookup>
pub struct GrpcServiceProbe<Lookup, Middleware>
where
Lookup: LookupService,
Middleware: EndpointMiddleware,
{
service_definition: ServiceDefinition,
scheme: http::uri::Scheme,
Expand All @@ -40,6 +68,7 @@ where
endpoints: HashSet<SocketAddr>,
endpoint_reporter: Sender<Change<SocketAddr, Endpoint>>,
tls_config: Option<ClientTlsConfig>,
middleware: Middleware,
}

/// Config parameters to customize the behavior of `GrpcServiceProbe`.
Expand All @@ -58,13 +87,14 @@ where
pub endpoint_timeout: Option<tokio::time::Duration>,
}

impl<Lookup: LookupService> GrpcServiceProbe<Lookup> {
impl<Lookup: LookupService, Middleware: EndpointMiddleware> GrpcServiceProbe<Lookup, Middleware> {
/// Construct `GrpcServiceProbe` with a `GrpcServiceProbeConfig` and
/// the channel `endpoint_reporter` that will send endpoint changes.
pub fn new_with_reporter(
config: GrpcServiceProbeConfig<Lookup>,
endpoint_reporter: Sender<Change<SocketAddr, Endpoint>>,
) -> GrpcServiceProbe<Lookup> {
middleware: Middleware,
) -> Self {
Self {
service_definition: config.service_definition,
dns_lookup: config.dns_lookup,
Expand All @@ -74,18 +104,16 @@ impl<Lookup: LookupService> GrpcServiceProbe<Lookup> {
endpoint_reporter,
scheme: http::uri::Scheme::HTTP,
tls_config: None,
middleware,
}
}

/// Enable tls for all endpoints.
pub fn with_tls(self, tls_config: ClientTlsConfig) -> GrpcServiceProbe<Lookup> {
Self {
tls_config: Some(tls_config),
scheme: http::uri::Scheme::HTTPS,
..self
}
pub fn with_tls(mut self, tls_config: ClientTlsConfig) -> Self {
self.scheme = http::uri::Scheme::HTTPS;
self.tls_config = Some(tls_config);
self
}

/// Start probing the provided `hostname` for IP address changes.
/// The function will error if the receiving end of the tonic balance channel
/// is closed, e.g, the client has been deconstructed.
Expand Down Expand Up @@ -225,6 +253,6 @@ impl<Lookup: LookupService> GrpcServiceProbe<Lookup> {
endpoint = endpoint.timeout(*timeout);
}

Some(endpoint)
self.middleware.wrap(endpoint)
}
}
24 changes: 24 additions & 0 deletions tests/tests/all/lookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,30 @@ impl Tester for TesterImpl {
}
}

#[derive(Clone)]
pub struct UserAgentTesterImpl;

#[async_trait::async_trait]
impl Tester for UserAgentTesterImpl {
async fn test(&self, req: tonic::Request<Ping>) -> Result<tonic::Response<Pong>, Status> {
let ua = req
.metadata()
.get("user-agent")
.ok_or_else(|| Status::new(tonic::Code::PermissionDenied, "no user agent supplied"))?;

let ua = ua
.to_str()
.map_err(|_e| {
Status::new(tonic::Code::InvalidArgument, "non utf8 user agent supplied")
})?
.to_owned();

Ok(tonic::Response::new(Pong {
payload: Some(Payload::Raw(ua)),
}))
}
}

#[derive(Clone)]
pub struct TestDnsResolver {
pub ips: Arc<RwLock<HashMap<String, String>>>,
Expand Down
Loading