diff --git a/config/quickwit.yaml b/config/quickwit.yaml index 035bd2fb241..f9e29446665 100644 --- a/config/quickwit.yaml +++ b/config/quickwit.yaml @@ -49,6 +49,12 @@ version: 0.8 # # How often cert_path/key_path are polled for changes and hot-reloaded; an # # immediate reload can also be triggered with SIGHUP. Defaults to 5m. # cert_reload_interval: 5m +# # Maximum lifetime of a connection before the server sends an HTTP/2 GOAWAY and the +# # client reconnects. Disabled when unset. +# max_connection_age: 30m +# # Grace period after the GOAWAY before a still-draining connection is forcefully +# # closed. Requires `max_connection_age` to be set. +# max_connection_age_grace: 30s # # Optional plaintext health-check server. Disabled unless `listen_port` is set (or the # `QW_HEALTH_LISTEN_PORT` environment variable). It serves only `/health/livez` and @@ -75,6 +81,12 @@ version: 0.8 # # How often cert_path/key_path are polled for changes and hot-reloaded; an # # immediate reload can also be triggered with SIGHUP. Defaults to 5m. # cert_reload_interval: 5m +# # Maximum lifetime of an inbound connection before the server sends an HTTP/2 GOAWAY +# # and the peer reconnects. Disabled when unset. +# max_connection_age: 30m +# # Grace period after the GOAWAY before a still-draining connection is forcefully +# # closed. Requires `max_connection_age` to be set. +# max_connection_age_grace: 30s # # IP address advertised by the node, i.e. the IP address that peer nodes should use to connect to the node for RPCs. # The environment variable `QW_ADVERTISE_ADDRESS` can also be used to override this value. diff --git a/quickwit/Cargo.lock b/quickwit/Cargo.lock index 608c3204836..2d72d718683 100644 --- a/quickwit/Cargo.lock +++ b/quickwit/Cargo.lock @@ -8898,6 +8898,7 @@ dependencies = [ "bytesize", "datafusion", "futures-util", + "http-body-util", "hyper 1.9.0", "hyper-util", "itertools 0.14.0", diff --git a/quickwit/quickwit-config/resources/tests/node_config/quickwit.json b/quickwit/quickwit-config/resources/tests/node_config/quickwit.json index 1548273cb89..a60becaee7e 100644 --- a/quickwit/quickwit-config/resources/tests/node_config/quickwit.json +++ b/quickwit/quickwit-config/resources/tests/node_config/quickwit.json @@ -24,13 +24,30 @@ "extra_headers": { "x-header-1": "header-value-1", "x-header-2": "header-value-2" - } + }, + "tls": { + "cert_path": "/path/to/rest.crt", + "key_path": "/path/to/rest.key", + "ca_path": "/path/to/ca.crt", + "verify_client_cert": true + }, + "max_connection_age": "30m", + "max_connection_age_grace": "30s" }, "health": { "listen_port": 4444 }, "grpc": { - "max_message_size": "10 MB" + "max_message_size": "10 MB", + "tls": { + "cert_path": "/path/to/grpc.crt", + "key_path": "/path/to/grpc.key", + "ca_path": "/path/to/ca.crt", + "verify_client_cert": true, + "expected_name": "quickwit.local" + }, + "max_connection_age": "1h", + "max_connection_age_grace": "10s" }, "storage": { "azure": { diff --git a/quickwit/quickwit-config/resources/tests/node_config/quickwit.toml b/quickwit/quickwit-config/resources/tests/node_config/quickwit.toml index 3c97620f185..665d0f2a624 100644 --- a/quickwit/quickwit-config/resources/tests/node_config/quickwit.toml +++ b/quickwit/quickwit-config/resources/tests/node_config/quickwit.toml @@ -15,16 +15,33 @@ default_index_root_uri = "s3://quickwit-indexes" [rest] listen_port = 1111 +max_connection_age = "30m" +max_connection_age_grace = "30s" [rest.extra_headers] x-header-1 = "header-value-1" x-header-2 = "header-value-2" +[rest.tls] +cert_path = "/path/to/rest.crt" +key_path = "/path/to/rest.key" +ca_path = "/path/to/ca.crt" +verify_client_cert = true + [health] listen_port = 4444 [grpc] max_message_size = "10 MB" +max_connection_age = "1h" +max_connection_age_grace = "10s" + +[grpc.tls] +cert_path = "/path/to/grpc.crt" +key_path = "/path/to/grpc.key" +ca_path = "/path/to/ca.crt" +verify_client_cert = true +expected_name = "quickwit.local" [storage.azure] account = "quickwit-dev" diff --git a/quickwit/quickwit-config/resources/tests/node_config/quickwit.yaml b/quickwit/quickwit-config/resources/tests/node_config/quickwit.yaml index 7d4551b714a..46f988e447c 100644 --- a/quickwit/quickwit-config/resources/tests/node_config/quickwit.yaml +++ b/quickwit/quickwit-config/resources/tests/node_config/quickwit.yaml @@ -22,12 +22,27 @@ rest: extra_headers: x-header-1: header-value-1 x-header-2: header-value-2 + tls: + cert_path: /path/to/rest.crt + key_path: /path/to/rest.key + ca_path: /path/to/ca.crt + verify_client_cert: true + max_connection_age: 30m + max_connection_age_grace: 30s health: listen_port: 4444 grpc: max_message_size: 10 MB + tls: + cert_path: /path/to/grpc.crt + key_path: /path/to/grpc.key + ca_path: /path/to/ca.crt + verify_client_cert: true + expected_name: quickwit.local + max_connection_age: 1h + max_connection_age_grace: 10s storage: azure: diff --git a/quickwit/quickwit-config/src/node_config/mod.rs b/quickwit/quickwit-config/src/node_config/mod.rs index 1c60891cf58..bd937c956c7 100644 --- a/quickwit/quickwit-config/src/node_config/mod.rs +++ b/quickwit/quickwit-config/src/node_config/mod.rs @@ -52,6 +52,13 @@ pub struct RestConfig { pub extra_headers: HeaderMap, #[serde(default, rename = "tls")] pub tls_config: Option, + // See `GrpcConfig::max_connection_age`. Closes long-lived keep-alive connections so an updated + // TLS certificate is eventually presented. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_connection_age: Option, + // See `GrpcConfig::max_connection_age_grace`. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_connection_age_grace: Option, } /// Configuration for the optional plaintext health-check HTTP server. @@ -77,6 +84,14 @@ pub struct GrpcConfig { // keep alive ping request. #[serde(default, skip_serializing_if = "Option::is_none")] pub keep_alive: Option, + // Maximum lifetime of an inbound connection before the server sends an HTTP/2 GOAWAY and the + // peer reconnects. Disabled when unset. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_connection_age: Option, + // Grace period after the GOAWAY before a still-draining connection is forcefully closed. + // Requires `max_connection_age` to be set. Waits indefinitely when unset. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_connection_age_grace: Option, } fn default_http2_keep_alive_interval() -> HumanDuration { @@ -116,6 +131,10 @@ impl GrpcConfig { if let Some(tls_config) = &self.tls_config { tls_config.validate()?; } + ensure!( + !(self.max_connection_age_grace.is_some() && self.max_connection_age.is_none()), + "`grpc.max_connection_age_grace` requires `grpc.max_connection_age` to be set" + ); Ok(()) } } @@ -126,6 +145,8 @@ impl Default for GrpcConfig { max_message_size: Self::default_max_message_size(), tls_config: None, keep_alive: None, + max_connection_age: None, + max_connection_age_grace: None, } } } @@ -1065,19 +1086,37 @@ mod tests { fn test_grpc_config_validate() { let grpc_config = GrpcConfig { max_message_size: ByteSize::mb(1), - tls_config: None, - keep_alive: None, + ..Default::default() }; assert!(grpc_config.validate().is_ok()); let grpc_config = GrpcConfig { max_message_size: ByteSize::kb(1), - tls_config: None, - keep_alive: None, + ..Default::default() }; assert!(grpc_config.validate().is_err()); } + #[test] + fn test_grpc_config_validate_rejects_connection_age_grace_without_age() { + let grpc_config = GrpcConfig { + max_connection_age_grace: Some(HumanDuration::try_from("10s".to_string()).unwrap()), + ..Default::default() + }; + let error = grpc_config.validate().unwrap_err().to_string(); + assert!( + error.contains("requires `grpc.max_connection_age`"), + "unexpected error: {error}" + ); + + let grpc_config = GrpcConfig { + max_connection_age: Some(HumanDuration::try_from("1h".to_string()).unwrap()), + max_connection_age_grace: Some(HumanDuration::try_from("10s".to_string()).unwrap()), + ..Default::default() + }; + assert!(grpc_config.validate().is_ok()); + } + fn tls_config(reload_interval: &str) -> TlsConfig { TlsConfig { cert_path: "/path/to/server.crt".to_string(), @@ -1105,7 +1144,7 @@ mod tests { let grpc_config = GrpcConfig { max_message_size: ByteSize::mib(20), tls_config: Some(tls_config("0s")), - keep_alive: None, + ..Default::default() }; assert!(grpc_config.validate().is_err()); } diff --git a/quickwit/quickwit-config/src/node_config/serialize.rs b/quickwit/quickwit-config/src/node_config/serialize.rs index 8b53cb6cf74..71c1797a4e0 100644 --- a/quickwit/quickwit-config/src/node_config/serialize.rs +++ b/quickwit/quickwit-config/src/node_config/serialize.rs @@ -17,7 +17,7 @@ use std::net::{IpAddr, SocketAddr}; use std::str::FromStr; use std::time::Duration; -use anyhow::{Context, bail}; +use anyhow::{Context, bail, ensure}; use bytesize::ByteSize; use http::HeaderMap; use quickwit_common::fs::get_disk_size; @@ -31,6 +31,7 @@ use tracing::{info, warn}; use super::{GrpcConfig, HealthConfig, RestConfig}; use crate::config_value::ConfigValue; use crate::qw_env_vars::*; +use crate::serde_utils::HumanDuration; use crate::service::QuickwitService; use crate::storage_config::StorageConfigs; use crate::templating::render_config; @@ -457,6 +458,10 @@ struct RestConfigBuilder { pub extra_headers: HeaderMap, #[serde(default, rename = "tls")] pub tls_config: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_connection_age: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_connection_age_grace: Option, } impl RestConfigBuilder { @@ -475,11 +480,17 @@ impl RestConfigBuilder { if let Some(tls_config) = &self.tls_config { tls_config.validate()?; } + ensure!( + !(self.max_connection_age_grace.is_some() && self.max_connection_age.is_none()), + "`rest.max_connection_age_grace` requires `rest.max_connection_age` to be set" + ); let rest_config = RestConfig { listen_addr: SocketAddr::new(listen_ip, listen_port), cors_allow_origins: self.cors_allow_origins, extra_headers: self.extra_headers, tls_config: self.tls_config, + max_connection_age: self.max_connection_age, + max_connection_age_grace: self.max_connection_age_grace, }; Ok(rest_config) } @@ -550,6 +561,8 @@ pub fn node_config_for_tests_from_ports( cors_allow_origins: Vec::new(), extra_headers: HeaderMap::new(), tls_config: None, + max_connection_age: None, + max_connection_age_grace: None, }; NodeConfig { cluster_id: default_cluster_id().unwrap(), @@ -624,7 +637,46 @@ mod tests { config.rest_config.extra_headers.get("x-header-2").unwrap(), "header-value-2" ); + assert_eq!( + config.rest_config.tls_config, + Some(TlsConfig { + cert_path: "/path/to/rest.crt".to_string(), + key_path: "/path/to/rest.key".to_string(), + ca_path: "/path/to/ca.crt".to_string(), + expected_name: None, + verify_client_cert: true, + cert_reload_interval: HumanDuration::try_from("5m".to_string()).unwrap(), + }) + ); + assert_eq!( + config.rest_config.max_connection_age, + Some(HumanDuration::try_from("30m".to_string()).unwrap()) + ); + assert_eq!( + config.rest_config.max_connection_age_grace, + Some(HumanDuration::try_from("30s".to_string()).unwrap()) + ); + assert_eq!(config.grpc_config.max_message_size, ByteSize::mb(10)); + assert_eq!( + config.grpc_config.tls_config, + Some(TlsConfig { + cert_path: "/path/to/grpc.crt".to_string(), + key_path: "/path/to/grpc.key".to_string(), + ca_path: "/path/to/ca.crt".to_string(), + expected_name: Some("quickwit.local".to_string()), + verify_client_cert: true, + cert_reload_interval: HumanDuration::try_from("5m".to_string()).unwrap(), + }) + ); + assert_eq!( + config.grpc_config.max_connection_age, + Some(HumanDuration::try_from("1h".to_string()).unwrap()) + ); + assert_eq!( + config.grpc_config.max_connection_age_grace, + Some(HumanDuration::try_from("10s".to_string()).unwrap()) + ); assert_eq!( config diff --git a/quickwit/quickwit-integration-tests/Cargo.toml b/quickwit/quickwit-integration-tests/Cargo.toml index aeffa0bd1be..a4ed408c7b1 100644 --- a/quickwit/quickwit-integration-tests/Cargo.toml +++ b/quickwit/quickwit-integration-tests/Cargo.toml @@ -50,6 +50,7 @@ quickwit-parquet-engine = { workspace = true, optional = true } anyhow = { workspace = true } aws-sdk-sqs = { workspace = true } futures-util = { workspace = true } +http-body-util = { workspace = true } hyper = { workspace = true } hyper-util = { workspace = true } itertools = { workspace = true } diff --git a/quickwit/quickwit-integration-tests/src/tests/tls_tests.rs b/quickwit/quickwit-integration-tests/src/tests/tls_tests.rs index 05463145155..4b2848cd3fa 100644 --- a/quickwit/quickwit-integration-tests/src/tests/tls_tests.rs +++ b/quickwit/quickwit-integration-tests/src/tests/tls_tests.rs @@ -15,7 +15,10 @@ use std::net::SocketAddr; use std::time::Duration; -use hyper_util::rt::TokioExecutor; +use http_body_util::Empty; +use hyper::body::Bytes; +use hyper::client::conn::http2::SendRequest; +use hyper_util::rt::{TokioExecutor, TokioIo}; use quickwit_common::test_utils::wait_until_predicate; use quickwit_config::service::QuickwitService; use quickwit_config::{HumanDuration, TlsConfig}; @@ -282,6 +285,146 @@ async fn test_health_check_server_plaintext_with_mtls_rest() { sandbox.shutdown().await.unwrap(); } +/// Opens a long-lived HTTP/2 connection to `addr` over TLS, trusting `ca_path`. Returns the DER of +/// the leaf certificate the server presented, a `SendRequest` handle that must be kept alive to +/// keep the connection open, and a handle to the driving task that resolves once the server closes +/// the connection. +async fn open_http2_connection( + addr: SocketAddr, + ca_path: &str, +) -> ( + Vec, + SendRequest>, + tokio::task::JoinHandle<()>, +) { + let tls_config = TlsConfig { + ca_path: ca_path.to_string(), + ..fixture_tls_config() + }; + let client_config = quickwit_transport::make_tls_client_config(&tls_config).unwrap(); + let tls_connector = tokio_rustls::TlsConnector::from(client_config); + let tcp_stream = tokio::net::TcpStream::connect(addr).await.unwrap(); + let server_name = rustls::pki_types::ServerName::IpAddress(addr.ip().into()); + let tls_stream = tls_connector + .connect(server_name, tcp_stream) + .await + .unwrap(); + + let leaf_cert_der = { + let (_tcp_stream, connection) = tls_stream.get_ref(); + let peer_certs = connection + .peer_certificates() + .expect("the server must present a certificate"); + peer_certs[0].as_ref().to_vec() + }; + + let (send_request, connection) = hyper::client::conn::http2::handshake::<_, _, Empty>( + TokioExecutor::new(), + TokioIo::new(tls_stream), + ) + .await + .expect("HTTP/2 handshake should succeed"); + // The connection future resolves when the server closes the connection (e.g. after the max + // connection age elapses), even though this client never initiates a disconnect. + let connection_handle = tokio::spawn(async move { + let _ = connection.await; + }); + (leaf_cert_der, send_request, connection_handle) +} + +#[tokio::test] +async fn test_tls_rest_max_connection_age() { + quickwit_common::setup_logging_for_tests(); + // The node reads its certificate from a temp directory we can mutate to simulate a rotation. + let temp_dir = tempfile::tempdir().unwrap(); + let cert_path = temp_dir.path().join("server.crt"); + let key_path = temp_dir.path().join("server.key"); + std::fs::copy(format!("{TLS_FIXTURES_DIR}/server.crt"), &cert_path).unwrap(); + std::fs::copy(format!("{TLS_FIXTURES_DIR}/server.key"), &key_path).unwrap(); + let ca_path = format!("{TLS_FIXTURES_DIR}/ca.crt"); + + let mut sandbox_config = ClusterSandboxBuilder::default() + .add_node(QuickwitService::supported_services()) + .build_config() + .await; + sandbox_config.node_configs[0].0.rest_config.tls_config = Some(TlsConfig { + cert_path: cert_path.to_str().unwrap().to_string(), + key_path: key_path.to_str().unwrap().to_string(), + ca_path: ca_path.clone(), + // Long interval: only `reload_tls_cert()` reloads the cert here, never the periodic poll. + cert_reload_interval: HumanDuration::try_from("1h".to_string()).unwrap(), + ..fixture_tls_config() + }); + // Short max connection age: the server must send a GOAWAY and close a long-lived connection + // within this window, forcing the client to reconnect and re-handshake with the current cert. + sandbox_config.node_configs[0] + .0 + .rest_config + .max_connection_age = Some(HumanDuration::try_from("2s".to_string()).unwrap()); + let sandbox = sandbox_config.start().await; + let rest_addr = sandbox.node_configs[0].0.rest_config.listen_addr; + + // Open a long-lived connection and record the certificate it negotiated. + let (served_before, _send_request, connection_handle) = + open_http2_connection(rest_addr, &ca_path).await; + let server1_der = leaf_cert_der(&format!("{TLS_FIXTURES_DIR}/server.crt")); + assert_eq!(served_before, server1_der); + + // Rotate the certificate on disk and reload; the long-lived connection is still on the old one. + let server2_der = leaf_cert_der(&format!("{TLS_FIXTURES_DIR}/server2.crt")); + assert_ne!(server1_der, server2_der); + std::fs::copy(format!("{TLS_FIXTURES_DIR}/server2.crt"), &cert_path).unwrap(); + std::fs::copy(format!("{TLS_FIXTURES_DIR}/server2.key"), &key_path).unwrap(); + quickwit_serve::reload_tls_cert(); + + // The server closes the long-lived connection on its own once the max connection age elapses. + // An idle HTTP/2 connection without a max age would otherwise stay open indefinitely, so a + // bounded close demonstrates the feature. + tokio::time::timeout(Duration::from_secs(10), connection_handle) + .await + .expect("server should close the connection within the max connection age") + .expect("connection task should not panic"); + + // After reconnecting, the server presents the rotated certificate. + let (served_after, _send_request, _connection_handle) = + open_http2_connection(rest_addr, &ca_path).await; + assert_eq!(served_after, server2_der); + + sandbox.shutdown().await.unwrap(); +} + +#[tokio::test] +async fn test_tls_grpc_max_connection_age() { + quickwit_common::setup_logging_for_tests(); + let ca_path = format!("{TLS_FIXTURES_DIR}/ca.crt"); + + let mut sandbox_config = ClusterSandboxBuilder::default() + .add_node(QuickwitService::supported_services()) + .build_config() + .await; + // One-way TLS (no client-cert verification) so the bare HTTP/2 probe below, which presents no + // client identity, can complete the handshake. The max connection age behavior is independent + // of mTLS. + sandbox_config.node_configs[0].0.grpc_config.tls_config = Some(fixture_tls_config()); + sandbox_config.node_configs[0] + .0 + .grpc_config + .max_connection_age = Some(HumanDuration::try_from("2s".to_string()).unwrap()); + let sandbox = sandbox_config.start().await; + let grpc_addr = sandbox.node_configs[0].0.grpc_listen_addr; + + // The gRPC server (tonic, `h2` only) must close a long-lived connection once the max connection + // age elapses. + let (_served, _send_request, connection_handle) = + open_http2_connection(grpc_addr, &ca_path).await; + tokio::time::timeout(Duration::from_secs(10), connection_handle) + .await + .expect("gRPC server should close the connection within the max connection age") + .expect("connection task should not panic"); + + sandbox.shutdown().await.unwrap(); +} + #[tokio::test] async fn test_tls_rest_cert_hot_reload() { quickwit_common::setup_logging_for_tests(); diff --git a/quickwit/quickwit-serve/src/grpc.rs b/quickwit/quickwit-serve/src/grpc.rs index c0cac95a146..a44b24f44cd 100644 --- a/quickwit/quickwit-serve/src/grpc.rs +++ b/quickwit/quickwit-serve/src/grpc.rs @@ -59,6 +59,18 @@ pub(crate) async fn start_grpc_server( // no way to reload it without restarting the process. let mut server = Server::builder(); + // Bound the lifetime of inbound connections so a hot-reloaded TLS certificate is eventually + // presented on long-lived (otherwise idle-but-alive) inter-node connections: tonic sends an + // HTTP/2 GOAWAY after `max_connection_age`, the peer reconnects, and the fresh handshake picks + // up the current certificate. Without this, a reloaded certificate would only take effect on + // the next reconnection, which may never happen for a healthy connection. + if let Some(max_connection_age) = &grpc_config.max_connection_age { + server = server.max_connection_age(**max_connection_age); + + if let Some(max_connection_age_grace) = &grpc_config.max_connection_age_grace { + server = server.max_connection_age_grace(**max_connection_age_grace); + } + } let cluster_grpc_service = cluster_grpc_server(services.cluster.clone()); file_descriptor_sets.push(quickwit_proto::cluster::CLUSTER_PLANE_FILE_DESCRIPTOR_SET); diff --git a/quickwit/quickwit-serve/src/rest.rs b/quickwit/quickwit-serve/src/rest.rs index 39bfb0e2580..0cb9449c80e 100644 --- a/quickwit/quickwit-serve/src/rest.rs +++ b/quickwit/quickwit-serve/src/rest.rs @@ -16,25 +16,29 @@ use std::fmt::Formatter; use std::io; use std::pin::Pin; use std::sync::Arc; +use std::time::Duration; use futures_util::{Stream, StreamExt}; use hyper_util::rt::{TokioExecutor, TokioIo}; use hyper_util::server::conn::auto::Builder; +use hyper_util::server::graceful::GracefulConnection; use hyper_util::service::TowerToHyperService; use quickwit_common::tower::BoxFutureInfaillible; use quickwit_config::{disable_ingest_v1, enable_ingest_v2}; use quickwit_metrics::{counter, histogram, labels}; use quickwit_search::SearchService; use tokio::net::{TcpListener, TcpStream}; +use tokio::task::JoinSet; use tokio_rustls::TlsAcceptor; use tokio_rustls::server::TlsStream; use tokio_util::either::Either; +use tokio_util::sync::CancellationToken; use tower::ServiceBuilder; use tower_http::compression::CompressionLayer; use tower_http::compression::predicate::{NotForContentType, Predicate, SizeAbove}; use tower_http::cors::{AllowOrigin, CorsLayer}; use tower_http::trace::TraceLayer; -use tracing::{error, info}; +use tracing::{error, info, warn}; use warp::filters::log::Info; use warp::hyper::http::HeaderValue; use warp::hyper::{Method, StatusCode, http}; @@ -232,16 +236,27 @@ pub(crate) async fn start_rest_server( } else { None }; + let rest_config = &quickwit_services.node_config.rest_config; + // `max_connection_age_grace` without `max_connection_age` is rejected at config validation, so + // the grace is only carried when an age is present. + let max_connection_age_opt = + rest_config + .max_connection_age + .as_ref() + .map(|max_connection_age| MaxConnectionAge { + age: **max_connection_age, + grace: rest_config + .max_connection_age_grace + .as_ref() + .map(|max_connection_age_grace| **max_connection_age_grace), + }); serve_warp_routes( "REST", tcp_listener, rest_routes, - quickwit_services - .node_config - .rest_config - .cors_allow_origins - .clone(), + rest_config.cors_allow_origins.clone(), tls_acceptor_opt, + max_connection_age_opt, readiness_trigger, shutdown_signal, ) @@ -274,20 +289,36 @@ pub(crate) async fn start_health_check_server( health_check_routes, Vec::new(), None, + None, readiness_trigger, shutdown_signal, ) .await } +/// Bounds the lifetime of an accepted connection so a hot-reloaded TLS certificate eventually +/// reaches long-lived clients, which only pick up a new certificate when they reconnect. `grace` +/// is how long the connection may keep draining after the GOAWAY before it is forcefully closed; +/// `None` waits indefinitely. Grouping the two fields makes a grace-without-age combination +/// unrepresentable. +#[derive(Clone, Copy)] +struct MaxConnectionAge { + age: Duration, + grace: Option, +} + /// Serves a set of warp `routes` over `tcp_listener` until `shutdown_signal` resolves, optionally /// terminating TLS. Shared by the main REST server and the health-check server. +// `serve_warp_routes` wires together several independent concerns (routing, CORS, TLS, connection +// lifetime, readiness, shutdown); bundling them further would not aid readability. +#[allow(clippy::too_many_arguments)] async fn serve_warp_routes( server_name: &str, tcp_listener: TcpListener, routes: F, cors_allow_origins: Vec, tls_acceptor_opt: Option, + max_connection_age_opt: Option, readiness_trigger: BoxFutureInfaillible<()>, shutdown_signal: BoxFutureInfaillible<()>, ) -> anyhow::Result<()> @@ -324,7 +355,15 @@ where let service = TowerToHyperService::new(service); let server = Builder::new(TokioExecutor::new()); - let graceful = hyper_util::server::graceful::GracefulShutdown::new(); + // Triggers a graceful shutdown (HTTP/2 GOAWAY) on every live connection. Fired once on server + // shutdown; each connection also drains on its own when `max_connection_age` elapses. + let cancellation_token = CancellationToken::new(); + // Tracks in-flight connection tasks so we can wait for them to drain on shutdown. We do not use + // `hyper_util`'s `GracefulShutdown` helper because it takes ownership of each connection, which + // would prevent us from also triggering a per-connection `graceful_shutdown` when the + // connection's max age elapses (the `GracefulConnection` trait is sealed, so we cannot wrap + // it). + let mut connection_tasks: JoinSet<()> = JoinSet::new(); let mut shutdown_signal = std::pin::pin!(shutdown_signal); readiness_trigger.await; @@ -352,14 +391,18 @@ where } }; let serve_connection_fut = server - .serve_connection_with_upgrades(TokioIo::new(connection), service.clone()); - let serve_with_shutdown_fut = graceful.watch(serve_connection_fut.into_owned()); - tokio::spawn(async move { - if let Err(serve_error) = serve_with_shutdown_fut.await { - error!("failed to serve connection: {serve_error:#}"); - } - }); + .serve_connection_with_upgrades(TokioIo::new(connection), service.clone()) + .into_owned(); + let cancellation_token = cancellation_token.clone(); + connection_tasks.spawn(drain_connection( + serve_connection_fut, + cancellation_token, + max_connection_age_opt, + )); }, + // Reap finished connection tasks so the set does not grow without bound on a + // long-running server. Disabled while empty so the branch does not busy-loop. + _ = connection_tasks.join_next(), if !connection_tasks.is_empty() => {}, _ = &mut shutdown_signal => { info!("{server_name} server shutdown signal received"); break; @@ -367,12 +410,75 @@ where } } info!("shutting down {server_name} server"); - graceful.shutdown().await; + // Ask every live connection to drain, then wait for the tasks to finish. + cancellation_token.cancel(); + while connection_tasks.join_next().await.is_some() {} info!("{server_name} server successfully shut down"); Ok(()) } +/// Drives a single accepted connection to completion, sending an HTTP/2 GOAWAY and then waiting for +/// it to drain when either the connection's max age (`max_connection_age_opt`) elapses or a global +/// drain is requested via `cancellation_token`. When a grace period is configured, the connection +/// is forcefully closed (dropped) if it has not finished draining within that period. +/// +/// Bounding the connection lifetime is what lets a hot-reloaded TLS certificate eventually reach +/// long-lived clients: the new certificate is only presented on a fresh handshake, so the client +/// must reconnect to pick it up. +async fn drain_connection( + connection: C, + cancellation_token: CancellationToken, + max_connection_age_opt: Option, +) where + C: GracefulConnection, + C::Error: std::fmt::Display, +{ + let mut connection = std::pin::pin!(connection); + + // A connection without a max age only drains on global shutdown; model that as a never-firing + // timer so the `select!` arms stay uniform without recreating futures across iterations. + let max_age_sleep = match max_connection_age_opt { + Some(max_connection_age) => Either::Left(tokio::time::sleep(max_connection_age.age)), + None => Either::Right(std::future::pending::<()>()), + }; + + // Phase 1: serve until the connection ends on its own, its max age elapses, or a global drain + // is requested. + tokio::select! { + connection_res = connection.as_mut() => { + if let Err(serve_error) = connection_res { + error!("failed to serve connection: {serve_error:#}"); + } + return; + } + _ = max_age_sleep => {} + _ = cancellation_token.cancelled() => {} + } + + // Phase 2: we asked the peer to reconnect; send GOAWAY and let in-flight requests drain, + // bounded by the optional grace period. + connection.as_mut().graceful_shutdown(); + let grace_opt = match max_connection_age_opt { + Some(max_connection_age) => max_connection_age.grace, + None => None, + }; + let grace_sleep = match grace_opt { + Some(grace) => Either::Left(tokio::time::sleep(grace)), + None => Either::Right(std::future::pending::<()>()), + }; + tokio::select! { + connection_res = connection.as_mut() => { + if let Err(serve_error) = connection_res { + error!("failed to serve connection: {serve_error:#}"); + } + } + _ = grace_sleep => { + warn!("connection did not drain within the grace period; closing it forcefully"); + } + } +} + fn search_routes( search_service: Arc, ) -> impl Filter + Clone {