Skip to content

Commit 36329a1

Browse files
authored
feat(inference): allow setting custom inference timeout (NVIDIA#672)
* feat(inference): add timeout * feat(inference): fix dynamic timeout change * feat(inference): update docs * feat(inference): fix formatting
1 parent 0815f82 commit 36329a1

13 files changed

Lines changed: 141 additions & 20 deletions

File tree

architecture/inference-routing.md

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,10 @@ File: `proto/inference.proto`
9292

9393
Key messages:
9494

95-
- `SetClusterInferenceRequest` -- `provider_name` + `model_id` + optional `no_verify` override, with verification enabled by default
96-
- `SetClusterInferenceResponse` -- `provider_name` + `model_id` + `version`
95+
- `SetClusterInferenceRequest` -- `provider_name` + `model_id` + `timeout_secs` + optional `no_verify` override, with verification enabled by default
96+
- `SetClusterInferenceResponse` -- `provider_name` + `model_id` + `timeout_secs` + `version`
9797
- `GetInferenceBundleResponse` -- `repeated ResolvedRoute routes` + `revision` + `generated_at_ms`
98-
- `ResolvedRoute` -- `name`, `base_url`, `protocols`, `api_key`, `model_id`, `provider_type`
98+
- `ResolvedRoute` -- `name`, `base_url`, `protocols`, `api_key`, `model_id`, `provider_type`, `timeout_secs`
9999

100100
## Data Plane (Sandbox)
101101

@@ -106,7 +106,7 @@ Files:
106106
- `crates/openshell-sandbox/src/lib.rs` -- inference context initialization, route refresh
107107
- `crates/openshell-sandbox/src/grpc_client.rs` -- `fetch_inference_bundle()`
108108

109-
In cluster mode, the sandbox starts a background refresh loop as soon as the inference context is created. The loop polls the gateway every 5 seconds by default (`OPENSHELL_ROUTE_REFRESH_INTERVAL_SECS` override) and uses the bundle revision hash to skip no-op cache writes.
109+
In cluster mode, the sandbox starts a background refresh loop as soon as the inference context is created. The loop polls the gateway every 5 seconds by default (`OPENSHELL_ROUTE_REFRESH_INTERVAL_SECS` override) and uses the bundle revision hash to skip no-op cache writes. The revision hash covers all route fields including `timeout_secs`, so any configuration change (provider, model, or timeout) triggers a cache update on the next poll.
110110

111111
### Interception flow
112112

@@ -143,7 +143,7 @@ If no pattern matches, the proxy returns `403 Forbidden` with `{"error": "connec
143143
### Route cache
144144

145145
- `InferenceContext` holds a `Router`, the pattern list, and an `Arc<RwLock<Vec<ResolvedRoute>>>` route cache.
146-
- In cluster mode, `spawn_route_refresh()` polls `GetInferenceBundle` every 30 seconds (`ROUTE_REFRESH_INTERVAL_SECS`). On failure, stale routes are kept.
146+
- In cluster mode, `spawn_route_refresh()` polls `GetInferenceBundle` every 5 seconds (`OPENSHELL_ROUTE_REFRESH_INTERVAL_SECS`). On failure, stale routes are kept.
147147
- In file mode (`--inference-routes`), routes load once at startup from YAML. No refresh task is spawned.
148148
- In cluster mode, an empty initial bundle still enables the inference context so the refresh task can pick up later configuration.
149149

@@ -209,9 +209,11 @@ File: `crates/openshell-router/src/mock.rs`
209209

210210
Routes with `mock://` scheme endpoints return canned responses without making HTTP requests. Mock responses are protocol-aware (OpenAI chat completion, OpenAI completion, Anthropic messages, or generic JSON). Mock routes include an `x-openshell-mock: true` response header.
211211

212-
### HTTP client
212+
### Per-request timeout
213213

214-
The router uses a `reqwest::Client` with a 60-second timeout. Timeouts and connection failures map to `RouterError::UpstreamUnavailable`.
214+
Each `ResolvedRoute` carries a `timeout` field (`Duration`). The `reqwest::Client` has no global timeout; instead, each outgoing request applies `.timeout(route.timeout)` on the request builder. When `timeout_secs` is `0` in the proto message, the default of 60 seconds is used (defined as `DEFAULT_ROUTE_TIMEOUT` in `config.rs`). Timeouts and connection failures map to `RouterError::UpstreamUnavailable`.
215+
216+
Timeout changes propagate dynamically to running sandboxes. The bundle revision hash includes `timeout_secs`, so when the timeout is updated via `openshell inference update --timeout`, the refresh loop detects the revision change and updates the route cache within one polling interval (5 seconds by default).
215217

216218
## Standalone Route File
217219

@@ -297,13 +299,16 @@ The system route is stored as a separate `InferenceRoute` record in the gateway
297299

298300
Cluster inference commands:
299301

300-
- `openshell inference set --provider <name> --model <id>` -- configures user-facing cluster inference
301-
- `openshell inference set --system --provider <name> --model <id>` -- configures system inference
302+
- `openshell inference set --provider <name> --model <id> [--timeout <secs>]` -- configures user-facing cluster inference
303+
- `openshell inference set --system --provider <name> --model <id> [--timeout <secs>]` -- configures system inference
304+
- `openshell inference update [--provider <name>] [--model <id>] [--timeout <secs>]` -- updates individual fields without resetting others
302305
- `openshell inference get` -- displays both user and system inference configuration
303306
- `openshell inference get --system` -- displays only the system inference configuration
304307

305308
The `--provider` flag references a provider record name (not a provider type). The provider must already exist in the cluster and have a supported inference type (`openai`, `anthropic`, or `nvidia`).
306309

310+
The `--timeout` flag sets the per-request timeout in seconds for upstream inference calls. When omitted or set to `0`, the default of 60 seconds applies. Timeout changes propagate to running sandboxes within the route refresh interval (5 seconds by default).
311+
307312
Inference writes verify by default. `--no-verify` is the explicit opt-out for endpoints that are not up yet.
308313

309314
## Provider Discovery

crates/openshell-cli/src/main.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,10 @@ enum InferenceCommands {
937937
/// Skip endpoint verification before saving the route.
938938
#[arg(long)]
939939
no_verify: bool,
940+
941+
/// Request timeout in seconds for inference calls (0 = default 60s).
942+
#[arg(long, default_value_t = 0)]
943+
timeout: u64,
940944
},
941945

942946
/// Update gateway-level inference configuration (partial update).
@@ -957,6 +961,10 @@ enum InferenceCommands {
957961
/// Skip endpoint verification before saving the route.
958962
#[arg(long)]
959963
no_verify: bool,
964+
965+
/// Request timeout in seconds for inference calls (0 = default 60s, unchanged if omitted).
966+
#[arg(long)]
967+
timeout: Option<u64>,
960968
},
961969

962970
/// Get gateway-level inference provider and model.
@@ -2026,10 +2034,11 @@ async fn main() -> Result<()> {
20262034
model,
20272035
system,
20282036
no_verify,
2037+
timeout,
20292038
} => {
20302039
let route_name = if system { "sandbox-system" } else { "" };
20312040
run::gateway_inference_set(
2032-
endpoint, &provider, &model, route_name, no_verify, &tls,
2041+
endpoint, &provider, &model, route_name, no_verify, timeout, &tls,
20332042
)
20342043
.await?;
20352044
}
@@ -2038,6 +2047,7 @@ async fn main() -> Result<()> {
20382047
model,
20392048
system,
20402049
no_verify,
2050+
timeout,
20412051
} => {
20422052
let route_name = if system { "sandbox-system" } else { "" };
20432053
run::gateway_inference_update(
@@ -2046,6 +2056,7 @@ async fn main() -> Result<()> {
20462056
model.as_deref(),
20472057
route_name,
20482058
no_verify,
2059+
timeout,
20492060
&tls,
20502061
)
20512062
.await?;

crates/openshell-cli/src/run.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3481,6 +3481,7 @@ pub async fn gateway_inference_set(
34813481
model_id: &str,
34823482
route_name: &str,
34833483
no_verify: bool,
3484+
timeout_secs: u64,
34843485
tls: &TlsOptions,
34853486
) -> Result<()> {
34863487
let progress = if std::io::stdout().is_terminal() {
@@ -3504,6 +3505,7 @@ pub async fn gateway_inference_set(
35043505
route_name: route_name.to_string(),
35053506
verify: false,
35063507
no_verify,
3508+
timeout_secs,
35073509
})
35083510
.await;
35093511

@@ -3525,6 +3527,7 @@ pub async fn gateway_inference_set(
35253527
println!(" {} {}", "Provider:".dimmed(), configured.provider_name);
35263528
println!(" {} {}", "Model:".dimmed(), configured.model_id);
35273529
println!(" {} {}", "Version:".dimmed(), configured.version);
3530+
print_timeout(configured.timeout_secs);
35283531
if configured.validation_performed {
35293532
println!(" {}", "Validated Endpoints:".dimmed());
35303533
for endpoint in configured.validated_endpoints {
@@ -3540,11 +3543,12 @@ pub async fn gateway_inference_update(
35403543
model_id: Option<&str>,
35413544
route_name: &str,
35423545
no_verify: bool,
3546+
timeout_secs: Option<u64>,
35433547
tls: &TlsOptions,
35443548
) -> Result<()> {
3545-
if provider_name.is_none() && model_id.is_none() {
3549+
if provider_name.is_none() && model_id.is_none() && timeout_secs.is_none() {
35463550
return Err(miette::miette!(
3547-
"at least one of --provider or --model must be specified"
3551+
"at least one of --provider, --model, or --timeout must be specified"
35483552
));
35493553
}
35503554

@@ -3561,6 +3565,7 @@ pub async fn gateway_inference_update(
35613565

35623566
let provider = provider_name.unwrap_or(&current.provider_name);
35633567
let model = model_id.unwrap_or(&current.model_id);
3568+
let timeout = timeout_secs.unwrap_or(current.timeout_secs);
35643569

35653570
let progress = if std::io::stdout().is_terminal() {
35663571
let spinner = ProgressBar::new_spinner();
@@ -3582,6 +3587,7 @@ pub async fn gateway_inference_update(
35823587
route_name: route_name.to_string(),
35833588
verify: false,
35843589
no_verify,
3590+
timeout_secs: timeout,
35853591
})
35863592
.await;
35873593

@@ -3603,6 +3609,7 @@ pub async fn gateway_inference_update(
36033609
println!(" {} {}", "Provider:".dimmed(), configured.provider_name);
36043610
println!(" {} {}", "Model:".dimmed(), configured.model_id);
36053611
println!(" {} {}", "Version:".dimmed(), configured.version);
3612+
print_timeout(configured.timeout_secs);
36063613
if configured.validation_performed {
36073614
println!(" {}", "Validated Endpoints:".dimmed());
36083615
for endpoint in configured.validated_endpoints {
@@ -3639,6 +3646,7 @@ pub async fn gateway_inference_get(
36393646
println!(" {} {}", "Provider:".dimmed(), configured.provider_name);
36403647
println!(" {} {}", "Model:".dimmed(), configured.model_id);
36413648
println!(" {} {}", "Version:".dimmed(), configured.version);
3649+
print_timeout(configured.timeout_secs);
36423650
} else {
36433651
// Show both routes by default.
36443652
print_inference_route(&mut client, "Gateway inference", "").await;
@@ -3666,6 +3674,7 @@ async fn print_inference_route(
36663674
println!(" {} {}", "Provider:".dimmed(), configured.provider_name);
36673675
println!(" {} {}", "Model:".dimmed(), configured.model_id);
36683676
println!(" {} {}", "Version:".dimmed(), configured.version);
3677+
print_timeout(configured.timeout_secs);
36693678
}
36703679
Err(e) if e.code() == Code::NotFound => {
36713680
println!("{}", format!("{label}:").cyan().bold());
@@ -3680,6 +3689,14 @@ async fn print_inference_route(
36803689
}
36813690
}
36823691

3692+
fn print_timeout(timeout_secs: u64) {
3693+
if timeout_secs == 0 {
3694+
println!(" {} {}s (default)", "Timeout:".dimmed(), 60);
3695+
} else {
3696+
println!(" {} {}s", "Timeout:".dimmed(), timeout_secs);
3697+
}
3698+
}
3699+
36833700
fn format_inference_status(status: Status) -> miette::Report {
36843701
let message = status.message().trim();
36853702

crates/openshell-router/src/backend.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ async fn send_backend_request(
149149
}
150150
Err(_) => body,
151151
};
152-
builder = builder.body(body);
152+
builder = builder.body(body).timeout(route.timeout);
153153

154154
builder.send().await.map_err(|e| {
155155
if e.is_timeout() {
@@ -468,6 +468,7 @@ mod tests {
468468
protocols: protocols.iter().map(|p| (*p).to_string()).collect(),
469469
auth,
470470
default_headers: vec![("anthropic-version".to_string(), "2023-06-01".to_string())],
471+
timeout: crate::config::DEFAULT_ROUTE_TIMEOUT,
471472
}
472473
}
473474

crates/openshell-router/src/config.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33

44
use serde::Deserialize;
55
use std::path::Path;
6+
use std::time::Duration;
67

78
pub use openshell_core::inference::AuthHeader;
89

910
use crate::RouterError;
1011

12+
pub const DEFAULT_ROUTE_TIMEOUT: Duration = Duration::from_secs(60);
13+
1114
#[derive(Debug, Clone, Deserialize)]
1215
pub struct RouterConfig {
1316
pub routes: Vec<RouteConfig>,
@@ -45,6 +48,8 @@ pub struct ResolvedRoute {
4548
pub auth: AuthHeader,
4649
/// Extra headers injected on every request (e.g. `anthropic-version`).
4750
pub default_headers: Vec<(String, String)>,
51+
/// Per-request timeout for proxied inference calls.
52+
pub timeout: Duration,
4853
}
4954

5055
impl std::fmt::Debug for ResolvedRoute {
@@ -57,6 +62,7 @@ impl std::fmt::Debug for ResolvedRoute {
5762
.field("protocols", &self.protocols)
5863
.field("auth", &self.auth)
5964
.field("default_headers", &self.default_headers)
65+
.field("timeout", &self.timeout)
6066
.finish()
6167
}
6268
}
@@ -129,6 +135,7 @@ impl RouteConfig {
129135
protocols,
130136
auth,
131137
default_headers,
138+
timeout: DEFAULT_ROUTE_TIMEOUT,
132139
})
133140
}
134141
}
@@ -256,6 +263,7 @@ routes:
256263
protocols: vec!["openai_chat_completions".to_string()],
257264
auth: AuthHeader::Bearer,
258265
default_headers: Vec::new(),
266+
timeout: DEFAULT_ROUTE_TIMEOUT,
259267
};
260268
let debug_output = format!("{route:?}");
261269
assert!(

crates/openshell-router/src/lib.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ mod backend;
55
pub mod config;
66
mod mock;
77

8-
use std::time::Duration;
9-
108
pub use backend::{
119
ProxyResponse, StreamingProxyResponse, ValidatedEndpoint, ValidationFailure,
1210
ValidationFailureKind, verify_backend_endpoint,
@@ -39,7 +37,6 @@ pub struct Router {
3937
impl Router {
4038
pub fn new() -> Result<Self, RouterError> {
4139
let client = reqwest::Client::builder()
42-
.timeout(Duration::from_secs(60))
4340
.build()
4441
.map_err(|e| RouterError::Internal(format!("failed to build HTTP client: {e}")))?;
4542
Ok(Self {

crates/openshell-router/src/mock.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ mod tests {
131131
protocols: protocols.iter().map(ToString::to_string).collect(),
132132
auth: crate::config::AuthHeader::Bearer,
133133
default_headers: Vec::new(),
134+
timeout: crate::config::DEFAULT_ROUTE_TIMEOUT,
134135
}
135136
}
136137

crates/openshell-router/tests/backend_integration.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ fn mock_candidates(base_url: &str) -> Vec<ResolvedRoute> {
1515
protocols: vec!["openai_chat_completions".to_string()],
1616
auth: AuthHeader::Bearer,
1717
default_headers: Vec::new(),
18+
timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT,
1819
}]
1920
}
2021

@@ -117,6 +118,7 @@ async fn proxy_no_compatible_route_returns_error() {
117118
protocols: vec!["anthropic_messages".to_string()],
118119
auth: AuthHeader::Custom("x-api-key"),
119120
default_headers: Vec::new(),
121+
timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT,
120122
}];
121123

122124
let err = router
@@ -178,6 +180,7 @@ async fn proxy_mock_route_returns_canned_response() {
178180
protocols: vec!["openai_chat_completions".to_string()],
179181
auth: AuthHeader::Bearer,
180182
default_headers: Vec::new(),
183+
timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT,
181184
}];
182185

183186
let body = serde_json::to_vec(&serde_json::json!({
@@ -312,6 +315,7 @@ async fn proxy_uses_x_api_key_for_anthropic_route() {
312315
protocols: vec!["anthropic_messages".to_string()],
313316
auth: AuthHeader::Custom("x-api-key"),
314317
default_headers: vec![("anthropic-version".to_string(), "2023-06-01".to_string())],
318+
timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT,
315319
}];
316320

317321
let body = serde_json::to_vec(&serde_json::json!({
@@ -370,6 +374,7 @@ async fn proxy_anthropic_does_not_send_bearer_auth() {
370374
protocols: vec!["anthropic_messages".to_string()],
371375
auth: AuthHeader::Custom("x-api-key"),
372376
default_headers: vec![("anthropic-version".to_string(), "2023-06-01".to_string())],
377+
timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT,
373378
}];
374379

375380
let response = router
@@ -414,6 +419,7 @@ async fn proxy_forwards_client_anthropic_version_header() {
414419
protocols: vec!["anthropic_messages".to_string()],
415420
auth: AuthHeader::Custom("x-api-key"),
416421
default_headers: vec![("anthropic-version".to_string(), "2023-06-01".to_string())],
422+
timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT,
417423
}];
418424

419425
let body = serde_json::to_vec(&serde_json::json!({

0 commit comments

Comments
 (0)