diff --git a/grpc/src/client/name_resolution/dns/test.rs b/grpc/src/client/name_resolution/dns/test.rs index eb6f823a3..33bb3e6f3 100644 --- a/grpc/src/client/name_resolution/dns/test.rs +++ b/grpc/src/client/name_resolution/dns/test.rs @@ -51,8 +51,8 @@ use crate::rt::Runtime; use crate::rt::Sleep; use crate::rt::TaskHandle; use crate::rt::TcpOptions; -use crate::rt::default_runtime; use crate::rt::tokio::TokioRuntime; +use crate::rt::tracker::TrackedRuntime; const DEFAULT_TEST_SHORT_TIMEOUT: Duration = Duration::from_millis(10); @@ -178,9 +178,10 @@ pub(crate) async fn dns_basic() { let builder = global_registry().get("dns").unwrap(); let target = &"dns:///localhost:1234".parse().unwrap(); let (work_scheduler, mut work_rx) = TestWorkScheduler::new_pair(); + let (rt, waiter) = TrackedRuntime::new(TokioRuntime::default()); let opts = ResolverOptions { authority: "ignored".to_string(), - runtime: default_runtime(), + runtime: GrpcRuntime::new(rt), work_scheduler: work_scheduler.clone(), }; let mut resolver = builder.build(target, opts); @@ -192,6 +193,9 @@ pub(crate) async fn dns_basic() { // A successful endpoint update should be received. let update = update_rx.recv().await.unwrap(); assert!(update.endpoints.unwrap().len() > 1); + + drop(resolver); + waiter.wait_for_tasks().await; } #[tokio::test] @@ -200,9 +204,10 @@ pub(crate) async fn invalid_target() { let builder = global_registry().get("dns").unwrap(); let target = &"dns:///:1234".parse().unwrap(); let (work_scheduler, mut work_rx) = TestWorkScheduler::new_pair(); + let (rt, waiter) = TrackedRuntime::new(TokioRuntime::default()); let opts = ResolverOptions { authority: "ignored".to_string(), - runtime: default_runtime(), + runtime: GrpcRuntime::new(rt), work_scheduler: work_scheduler.clone(), }; let mut resolver = builder.build(target, opts); @@ -220,6 +225,9 @@ pub(crate) async fn invalid_target() { .unwrap() .contains(&target.to_string()) ); + + drop(resolver); + waiter.wait_for_tasks().await; } #[derive(Clone, Debug)] @@ -277,16 +285,17 @@ pub(crate) async fn dns_lookup_error() { let builder = global_registry().get("dns").unwrap(); let target = &"dns:///grpc.io:1234".parse().unwrap(); let (work_scheduler, mut work_rx) = TestWorkScheduler::new_pair(); - let runtime = FakeRuntime { + let fake_rt = FakeRuntime { inner: TokioRuntime::default(), dns: FakeDns { latency: Duration::from_secs(0), lookup_result: Err("test_error".to_string()), }, }; + let (rt, waiter) = TrackedRuntime::new(fake_rt); let opts = ResolverOptions { authority: "ignored".to_string(), - runtime: GrpcRuntime::new(runtime), + runtime: GrpcRuntime::new(rt), work_scheduler: work_scheduler.clone(), }; let mut resolver = builder.build(target, opts); @@ -298,22 +307,26 @@ pub(crate) async fn dns_lookup_error() { // An error endpoint update should be received. let update = update_rx.recv().await.unwrap(); assert!(update.endpoints.unwrap_err().contains("test_error")); + + drop(resolver); + waiter.wait_for_tasks().await; } #[tokio::test] pub(crate) async fn dns_lookup_timeout() { let (work_scheduler, mut work_rx) = TestWorkScheduler::new_pair(); - let runtime = FakeRuntime { + let fake_dns = FakeDns { + latency: Duration::from_secs(20), + lookup_result: Ok(Vec::new()), + }; + let fake_rt = FakeRuntime { inner: TokioRuntime::default(), - dns: FakeDns { - latency: Duration::from_secs(20), - lookup_result: Ok(Vec::new()), - }, + dns: fake_dns.clone(), }; - let dns_client = runtime.dns.clone(); + let (rt, waiter) = TrackedRuntime::new(fake_rt); let opts = ResolverOptions { authority: "ignored".to_string(), - runtime: GrpcRuntime::new(runtime), + runtime: GrpcRuntime::new(rt), work_scheduler: work_scheduler.clone(), }; let dns_opts = DnsOptions { @@ -323,7 +336,7 @@ pub(crate) async fn dns_lookup_timeout() { host: "grpc.io".to_string(), port: 1234, }; - let mut resolver = DnsResolver::new(Box::new(dns_client), opts, dns_opts); + let mut resolver = DnsResolver::new(Box::new(fake_dns), opts, dns_opts); // Wait for schedule work to be called. work_rx.recv().await.unwrap(); @@ -333,14 +346,18 @@ pub(crate) async fn dns_lookup_timeout() { // An error endpoint update should be received. let update = update_rx.recv().await.unwrap(); assert!(update.endpoints.unwrap_err().contains("Timed out")); + + drop(resolver); + waiter.wait_for_tasks().await; } #[tokio::test] pub(crate) async fn rate_limit() { let (work_scheduler, mut work_rx) = TestWorkScheduler::new_pair(); + let (rt, waiter) = TrackedRuntime::new(TokioRuntime::default()); let opts = ResolverOptions { authority: "ignored".to_string(), - runtime: default_runtime(), + runtime: GrpcRuntime::new(rt), work_scheduler: work_scheduler.clone(), }; let dns_client = opts @@ -376,14 +393,18 @@ pub(crate) async fn rate_limit() { } }; } + + drop(resolver); + waiter.wait_for_tasks().await; } #[tokio::test] pub(crate) async fn re_resolution_after_success() { let (work_scheduler, mut work_rx) = TestWorkScheduler::new_pair(); + let (rt, waiter) = TrackedRuntime::new(TokioRuntime::default()); let opts = ResolverOptions { authority: "ignored".to_string(), - runtime: default_runtime(), + runtime: GrpcRuntime::new(rt), work_scheduler: work_scheduler.clone(), }; let dns_opts = DnsOptions { @@ -413,14 +434,18 @@ pub(crate) async fn re_resolution_after_success() { resolver.work(&mut channel_controller); let update = update_rx.recv().await.unwrap(); assert!(update.endpoints.unwrap().len() > 1); + + drop(resolver); + waiter.wait_for_tasks().await; } #[tokio::test] pub(crate) async fn backoff_on_error() { let (work_scheduler, mut work_rx) = TestWorkScheduler::new_pair(); + let (rt, waiter) = TrackedRuntime::new(TokioRuntime::default()); let opts = ResolverOptions { authority: "ignored".to_string(), - runtime: default_runtime(), + runtime: GrpcRuntime::new(rt), work_scheduler: work_scheduler.clone(), }; let dns_opts = DnsOptions { @@ -472,4 +497,6 @@ pub(crate) async fn backoff_on_error() { println!("No event received from resolver."); } }; + drop(resolver); + waiter.wait_for_tasks().await; } diff --git a/grpc/src/client/transport/tonic/test.rs b/grpc/src/client/transport/tonic/test.rs index bfd6a216e..70fbc0ee0 100644 --- a/grpc/src/client/transport/tonic/test.rs +++ b/grpc/src/client/transport/tonic/test.rs @@ -88,6 +88,7 @@ use crate::metadata::AsciiMetadataKey; use crate::metadata::MetadataMap; use crate::rt::GrpcRuntime; use crate::rt::tokio::TokioRuntime; +use crate::rt::tracker::TrackedRuntime; #[derive(Debug)] struct MockCallCredentials { @@ -156,10 +157,11 @@ pub(crate) async fn tonic_transport_rpc() { authority: Authority::new("localhost".to_string(), None), handshake_info: ClientHandshakeInfo::default(), }; + let (rt, waiter) = TrackedRuntime::new(TokioRuntime::default()); let (conn, _sec_info, mut disconnection_listener) = builder .dyn_connect( addr.to_string(), - GrpcRuntime::new(TokioRuntime::default()), + GrpcRuntime::new(rt), &securty_opts, &config, ) @@ -224,6 +226,8 @@ pub(crate) async fn tonic_transport_rpc() { .unwrap(); assert_eq!(res, Ok(())); server_handle.await.unwrap(); + drop(conn); + waiter.wait_for_tasks().await; } #[tokio::test] @@ -673,10 +677,11 @@ async fn tonic_transport_invalid_base64_headers() { authority: Authority::new("localhost".to_string(), None), handshake_info: ClientHandshakeInfo::default(), }; + let (rt, waiter) = TrackedRuntime::new(TokioRuntime::default()); let (conn, _sec_info, _disconnection_listener) = builder .dyn_connect( addr.to_string(), - GrpcRuntime::new(TokioRuntime::default()), + GrpcRuntime::new(rt), &securty_opts, &config, ) @@ -715,6 +720,8 @@ async fn tonic_transport_invalid_base64_headers() { shutdown_notify.notify_one(); server_handle.await.unwrap(); + drop(conn); + waiter.wait_for_tasks().await; } #[tokio::test] @@ -748,10 +755,11 @@ async fn tonic_transport_recv_drop_cancels_send() { authority: Authority::new("localhost".to_string(), None), handshake_info: ClientHandshakeInfo::default(), }; + let (rt, waiter) = TrackedRuntime::new(TokioRuntime::default()); let (conn, _sec_info, _disconnection_listener) = builder .dyn_connect( addr.to_string(), - GrpcRuntime::new(TokioRuntime::default()), + GrpcRuntime::new(rt), &securty_opts, &config, ) @@ -781,6 +789,8 @@ async fn tonic_transport_recv_drop_cancels_send() { shutdown_notify.notify_one(); server_handle.await.unwrap(); + drop(conn); + waiter.wait_for_tasks().await; } struct WrappedEchoRequest(EchoRequest); diff --git a/grpc/src/rt/mod.rs b/grpc/src/rt/mod.rs index fb1b3166f..0cb3b66eb 100644 --- a/grpc/src/rt/mod.rs +++ b/grpc/src/rt/mod.rs @@ -43,6 +43,8 @@ use crate::private; pub(crate) mod hyper_wrapper; #[cfg(feature = "_runtime-tokio")] pub(crate) mod tokio; +#[cfg(test)] +pub(crate) mod tracker; pub type BoxFuture = Pin + Send>>; pub type BoxedTaskHandle = Box; diff --git a/grpc/src/rt/tracker.rs b/grpc/src/rt/tracker.rs new file mode 100644 index 000000000..84dc8c2c9 --- /dev/null +++ b/grpc/src/rt/tracker.rs @@ -0,0 +1,267 @@ +/* + * + * Copyright 2026 gRPC authors. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + * + */ + +use std::backtrace::Backtrace; +use std::collections::HashMap; +use std::future::Future; +use std::net::SocketAddr; +use std::path::PathBuf; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::Mutex; +use std::time::Duration; + +use crate::rt::BoxFuture; +use crate::rt::BoxedTaskHandle; +use crate::rt::DnsResolver; +use crate::rt::GrpcEndpoint; +use crate::rt::ResolverOptions; +use crate::rt::Runtime; +use crate::rt::Sleep; +use crate::rt::TcpOptions; +use crate::rt::UnixSocketOptions; + +const DEFAULT_TEST_DURATION: Duration = Duration::from_secs(10); + +#[derive(Debug)] +struct SharedInnerState { + tasks: HashMap, + next_id: u64, +} + +#[derive(Debug)] +struct SharedState { + inner: Mutex, + notify: tokio::sync::Notify, +} + +struct TaskGuard { + id: u64, + state: Arc, +} + +impl Drop for TaskGuard { + fn drop(&mut self) { + let mut inner = self.state.inner.lock().unwrap(); + inner.tasks.remove(&self.id); + if inner.tasks.is_empty() { + self.state.notify.notify_one(); + } + } +} + +/// A `Runtime` wrapper that tracks spawned tasks. +#[derive(Debug)] +pub(crate) struct TrackedRuntime { + inner: R, + state: Arc, +} + +/// A handle to wait for tasks tracked by `TrackedRuntime`. +pub(crate) struct TaskTracker { + wait_timeout: Duration, + state: Arc, + have_waited: bool, +} + +impl TrackedRuntime { + /// Creates a new tracked runtime and its associated tracker. + /// + /// Callers must call `wait_for_tasks` on the returned tracker at the end of + /// the test. + /// + /// ```rust + /// let (tracked_rt, tracker) = TrackedRuntime::new(rt); + /// tracked_rt.spawn(Box::pin(async { + /// // ... + /// })); + /// tracker.wait_for_tasks().await; + /// ``` + pub(crate) fn new(inner: R) -> (Self, TaskTracker) { + let state = Arc::new(SharedState { + inner: Mutex::new(SharedInnerState { + tasks: HashMap::new(), + next_id: 0, + }), + notify: tokio::sync::Notify::new(), + }); + ( + Self { + inner, + state: state.clone(), + }, + TaskTracker { + wait_timeout: DEFAULT_TEST_DURATION, + state, + have_waited: false, + }, + ) + } +} + +impl TaskTracker { + /// Waits for all tracked tasks to finish or until timeout. + /// + /// It waits for 10 seconds. If the tasks do not finish within this time, + /// it panics and prints the backtrace of all orphaned tasks. + /// + /// Callers MUST call this method before the `TaskTracker` is dropped. + /// Dropping the `TaskTracker` without calling this method will result in + /// a panic. + pub(crate) async fn wait_for_tasks(mut self) { + self.have_waited = true; + let notified = self.state.notify.notified(); + + if self.state.inner.lock().unwrap().tasks.is_empty() { + return; + }; + + if tokio::time::timeout(self.wait_timeout, notified) + .await + .is_ok() + { + return; + } + + let callsites: Vec = self + .state + .inner + .lock() + .unwrap() + .tasks + .values() + .map(|bt| format!("{}", bt)) + .collect(); + + if callsites.is_empty() { + // Tasks ended after the timeout expired. + return; + } + + panic!( + "TrackedRuntime: tasks did not end within timeout. Running tasks spawned at:\n{}", + callsites.join("\n\n---\n\n") + ); + } +} + +impl Runtime for TrackedRuntime { + fn spawn(&self, task: Pin + Send + 'static>>) -> BoxedTaskHandle { + let bt = Backtrace::force_capture(); + + let id = { + let mut inner = self.state.inner.lock().unwrap(); + let id = inner.next_id; + inner.next_id += 1; + inner.tasks.insert(id, bt); + id + }; + + let guard = TaskGuard { + id, + state: self.state.clone(), + }; + + let tracked_task = async move { + // Guard stays alive during await and is dropped when done or + // cancelled. + let _guard = guard; + task.await; + }; + + self.inner.spawn(Box::pin(tracked_task)) + } + fn get_dns_resolver(&self, opts: ResolverOptions) -> Result, String> { + self.inner.get_dns_resolver(opts) + } + + fn sleep(&self, duration: std::time::Duration) -> Pin> { + self.inner.sleep(duration) + } + + fn tcp_stream( + &self, + target: SocketAddr, + opts: TcpOptions, + ) -> BoxFuture, String>> { + self.inner.tcp_stream(target, opts) + } + + fn unix_stream( + &self, + path: PathBuf, + opts: UnixSocketOptions, + ) -> BoxFuture, String>> { + self.inner.unix_stream(path, opts) + } +} + +impl Drop for TaskTracker { + fn drop(&mut self) { + // Check if wait_for_tasks was called and that we are not already + // panicking to avoid double panics. + if !self.have_waited && !std::thread::panicking() { + panic!("TaskTracker was dropped without calling wait_for_tasks!"); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::rt::tokio::TokioRuntime; + + #[tokio::test] + async fn wait_success() { + let rt = TokioRuntime::default(); + let (tracked_rt, tracker) = TrackedRuntime::new(rt); + + tracked_rt.spawn(Box::pin(async { + tokio::time::sleep(Duration::from_millis(1)).await; + })); + + tracker.wait_for_tasks().await; + } + + #[tokio::test] + #[should_panic(expected = "TrackedRuntime: tasks did not end within timeout")] + async fn wait_timeout() { + let rt = TokioRuntime::default(); + let (tracked_rt, mut tracker) = TrackedRuntime::new(rt); + tracker.wait_timeout = Duration::from_millis(1); + + tracked_rt.spawn(Box::pin(async { + tokio::time::sleep(DEFAULT_TEST_DURATION).await; + })); + + tracker.wait_for_tasks().await; + } + + #[tokio::test] + #[should_panic(expected = "TaskTracker was dropped without calling wait_for_tasks!")] + async fn panic_on_drop() { + let rt = TokioRuntime::default(); + let (_tracked_rt, _tracker) = TrackedRuntime::new(rt); + } +}