Skip to content

Commit be3bfce

Browse files
committed
refactor
1 parent baf9a5d commit be3bfce

File tree

2 files changed

+59
-9
lines changed

2 files changed

+59
-9
lines changed

src/channel.rs

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use {
22
etcd_client::{BalancedChannelBuilder, Channel, Client, ConnectOptions, Error},
33
std::{
44
collections::{HashMap, HashSet},
5+
future,
56
sync::{Arc, Mutex},
67
time::Duration,
78
},
@@ -31,7 +32,6 @@ pub struct ReliableChannelStats {
3132
pub struct ReliableChannelStatsRegistry {
3233
inner: Arc<Mutex<ReliableChannelStats>>,
3334
endpoint_status: Arc<Mutex<HashMap<String, EndpointStatus>>>,
34-
probe_notify: Arc<Notify>,
3535
call_counter: Arc<dashmap::DashMap<Uri, usize>>,
3636
}
3737

@@ -43,10 +43,6 @@ impl ReliableChannelStatsRegistry {
4343
.expect("reliable channel stats lock poisoned")
4444
}
4545

46-
pub fn trigger_probe(&self) {
47-
self.probe_notify.notify_one();
48-
}
49-
5046
pub fn call_counts_snapshot(&self) -> HashMap<String, usize> {
5147
self.call_counter
5248
.iter()
@@ -77,6 +73,25 @@ impl ReliableChannelStatsRegistry {
7773
}
7874
}
7975

76+
#[derive(Debug, Clone, Default)]
77+
pub struct ProbeTrigger {
78+
notify: Arc<Notify>,
79+
}
80+
81+
impl ProbeTrigger {
82+
pub fn new() -> Self {
83+
Self::default()
84+
}
85+
86+
pub fn trigger_probe(&self) {
87+
self.notify.notify_one();
88+
}
89+
90+
async fn notified(&self) {
91+
self.notify.notified().await;
92+
}
93+
}
94+
8095
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8196
pub enum EndpointStatus {
8297
Active,
@@ -94,6 +109,7 @@ pub struct ReliableBalancedChannelBuilder {
94109
pub probe_timeout: Duration,
95110
pub quarantine_retry_interval: Duration,
96111
stats_registry: ReliableChannelStatsRegistry,
112+
probe_trigger: Option<ProbeTrigger>,
97113
}
98114

99115
impl Default for ReliableBalancedChannelBuilder {
@@ -102,14 +118,35 @@ impl Default for ReliableBalancedChannelBuilder {
102118
probe_timeout: Duration::from_secs(5),
103119
quarantine_retry_interval: Duration::from_secs(15),
104120
stats_registry: ReliableChannelStatsRegistry::default(),
121+
probe_trigger: None,
105122
}
106123
}
107124
}
108125

109126
impl ReliableBalancedChannelBuilder {
127+
pub fn new_with_probe_trigger(probe_trigger: ProbeTrigger) -> Self {
128+
Self {
129+
probe_trigger: Some(probe_trigger),
130+
..Self::default()
131+
}
132+
}
133+
110134
pub fn stats_registry(&self) -> ReliableChannelStatsRegistry {
111135
self.stats_registry.clone()
112136
}
137+
138+
pub fn probe_trigger(&self) -> Option<ProbeTrigger> {
139+
self.probe_trigger.clone()
140+
}
141+
142+
pub fn set_probe_trigger(&mut self, probe_trigger: Option<ProbeTrigger>) {
143+
self.probe_trigger = probe_trigger;
144+
}
145+
146+
pub fn with_probe_trigger(mut self, probe_trigger: ProbeTrigger) -> Self {
147+
self.probe_trigger = Some(probe_trigger);
148+
self
149+
}
113150
}
114151

115152
impl BalancedChannelBuilder for ReliableBalancedChannelBuilder {
@@ -138,6 +175,7 @@ impl BalancedChannelBuilder for ReliableBalancedChannelBuilder {
138175
discover_updater,
139176
spy_quarantine_rx,
140177
self.stats_registry.clone(),
178+
self.probe_trigger.clone(),
141179
self.probe_timeout,
142180
self.quarantine_retry_interval,
143181
));
@@ -151,6 +189,7 @@ async fn run_endpoint_manager(
151189
discover_updater: EndpointUpdater,
152190
mut spy_quarantine_rx: mpsc::UnboundedReceiver<Uri>,
153191
stats_registry: ReliableChannelStatsRegistry,
192+
probe_trigger: Option<ProbeTrigger>,
154193
probe_timeout: Duration,
155194
quarantine_retry_interval: Duration,
156195
) {
@@ -186,7 +225,13 @@ async fn run_endpoint_manager(
186225
&stats_registry,
187226
);
188227
}
189-
_ = stats_registry.probe_notify.notified() => {
228+
_ = async {
229+
if let Some(trigger) = &probe_trigger {
230+
trigger.notified().await;
231+
} else {
232+
future::pending::<()>().await;
233+
}
234+
} => {
190235
sweep_active(
191236
&desired,
192237
&mut active,

tests/test_channel.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ use tonic::transport::{Endpoint, channel::Change};
55
use tower::{Service, ServiceExt};
66

77
use rust_etcd_utils::channel::{
8-
EndpointStatus, ReliableBalancedChannelBuilder, connect_with_reliable_balanced_channel,
8+
EndpointStatus, ProbeTrigger, ReliableBalancedChannelBuilder,
9+
connect_with_reliable_balanced_channel,
910
};
1011

1112
mod common;
@@ -233,6 +234,8 @@ async fn request_fails_when_mock_endpoint_goes_down_then_recovers_after_restart(
233234
let mut builder = ReliableBalancedChannelBuilder::default();
234235
builder.probe_timeout = Duration::from_millis(100);
235236
builder.quarantine_retry_interval = Duration::from_millis(150);
237+
let probe_trigger = ProbeTrigger::new();
238+
builder.set_probe_trigger(Some(probe_trigger.clone()));
236239
let stats_registry = builder.stats_registry();
237240

238241
let (mut channel, updater) = builder
@@ -307,7 +310,7 @@ async fn request_fails_when_mock_endpoint_goes_down_then_recovers_after_restart(
307310
);
308311

309312
// Phase 3: trigger a one-shot health probe and verify quarantine.
310-
stats_registry.trigger_probe();
313+
probe_trigger.trigger_probe();
311314
tokio::time::timeout(Duration::from_secs(3), async {
312315
loop {
313316
if stats_registry
@@ -487,6 +490,8 @@ async fn fair_share_then_failover_drains_to_1000_successful_requests() {
487490
let mut builder = ReliableBalancedChannelBuilder::default();
488491
builder.probe_timeout = Duration::from_millis(100);
489492
builder.quarantine_retry_interval = Duration::from_millis(150);
493+
let probe_trigger = ProbeTrigger::new();
494+
builder.set_probe_trigger(Some(probe_trigger.clone()));
490495
let stats_registry = builder.stats_registry();
491496

492497
let (mut channel, updater) = builder
@@ -590,7 +595,7 @@ async fn fair_share_then_failover_drains_to_1000_successful_requests() {
590595
.await;
591596
}
592597

593-
stats_registry.trigger_probe();
598+
probe_trigger.trigger_probe();
594599

595600
let mut successful_requests = first_phase_requests;
596601
let deadline = Instant::now() + Duration::from_secs(30);

0 commit comments

Comments
 (0)