diff --git a/src/openhuman/inference/provider/reliable.rs b/src/openhuman/inference/provider/reliable.rs index 6c2c93d5c2..76aba618b2 100644 --- a/src/openhuman/inference/provider/reliable.rs +++ b/src/openhuman/inference/provider/reliable.rs @@ -13,6 +13,13 @@ fn is_non_retryable(err: &anyhow::Error) -> bool { if is_context_window_exceeded(err) { return true; } + let msg = err.to_string(); + // Session-expired is a user-auth-state boundary condition, not a + // transient provider outage. Retrying just burns attempts and delays + // the sign-in prompt. + if crate::core::observability::is_session_expired_message(&msg) { + return true; + } if let Some(reqwest_err) = err.downcast_ref::() { if let Some(status) = reqwest_err.status() { @@ -20,7 +27,6 @@ fn is_non_retryable(err: &anyhow::Error) -> bool { return status.is_client_error() && code != 429 && code != 408; } } - let msg = err.to_string(); for word in msg.split(|c: char| !c.is_ascii_digit()) { if let Ok(code) = word.parse::() { if (400..500).contains(&code) { @@ -72,6 +78,13 @@ fn is_stream_error_non_retryable(err: &StreamError) -> bool { false } StreamError::Provider(msg) => { + // Mirror the non-streaming classifier: session-expired is a + // user-auth-state boundary, not a transient provider outage — + // fail fast so the streaming caller can prompt sign-in instead + // of burning the retry budget. + if crate::core::observability::is_session_expired_message(msg) { + return true; + } let lower = msg.to_lowercase(); lower.contains("invalid api key") || lower.contains("unauthorized") diff --git a/src/openhuman/inference/provider/reliable_tests.rs b/src/openhuman/inference/provider/reliable_tests.rs index c683f07505..c7524c1ad2 100644 --- a/src/openhuman/inference/provider/reliable_tests.rs +++ b/src/openhuman/inference/provider/reliable_tests.rs @@ -216,6 +216,9 @@ fn non_retryable_detects_common_patterns() { assert!(is_non_retryable(&anyhow::anyhow!( "OpenAI Codex stream error: Your input exceeds the context window of this model." ))); + assert!(is_non_retryable(&anyhow::anyhow!( + "SESSION_EXPIRED: backend session not active — sign in to resume LLM work" + ))); } #[tokio::test] @@ -253,6 +256,168 @@ async fn context_window_error_aborts_retries_and_model_fallbacks() { assert_eq!(calls.load(Ordering::SeqCst), 1); } +#[tokio::test] +async fn session_expired_aborts_retries() { + let calls = Arc::new(AtomicUsize::new(0)); + let provider = ReliableProvider::new( + vec![( + "openhuman".into(), + Box::new(MockProvider { + calls: Arc::clone(&calls), + fail_until_attempt: usize::MAX, + response: "never", + error: "SESSION_EXPIRED: backend session not active — sign in to resume LLM work", + }), + )], + 3, + 1, + ); + + let err = provider + .simple_chat("hello", "reasoning-v1", 0.0) + .await + .expect_err("session-expired should fail fast"); + let msg = err.to_string(); + + assert_eq!( + calls.load(Ordering::SeqCst), + 1, + "session-expired must skip retry loop" + ); + assert!( + msg.contains("non_retryable"), + "aggregate should classify SESSION_EXPIRED as non_retryable: {msg}" + ); + assert!( + !msg.contains("attempt 2/4"), + "aggregate should contain only the first attempt for this provider: {msg}" + ); +} + +/// Streaming-path mock that emits a single configurable `StreamError::Provider` +/// then ends, and tracks how many times the stream was created (`stream_calls`) +/// and how many times the consumer polled it (`polls`). The latter is the +/// signal used by [`session_expired_aborts_retries_streaming`] to prove that +/// `is_stream_error_non_retryable` broke the retry loop after the first error +/// instead of polling for further attempts. +struct StreamingErrorMock { + stream_calls: Arc, + polls: Arc, + error: &'static str, +} + +#[async_trait] +impl Provider for StreamingErrorMock { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + anyhow::bail!(self.error) + } + + async fn chat_with_history( + &self, + _messages: &[ChatMessage], + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + anyhow::bail!(self.error) + } + + fn supports_streaming(&self) -> bool { + true + } + + fn stream_chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + _options: StreamOptions, + ) -> futures_util::stream::BoxStream<'static, StreamResult> { + use futures_util::{stream, StreamExt}; + self.stream_calls.fetch_add(1, Ordering::SeqCst); + let polls = Arc::clone(&self.polls); + let error = self.error.to_string(); + // `unfold` state: `sent` flips to true after the first poll. The + // counter bumps on every poll so the test can prove that the retry + // loop short-circuited after the first error (polls == 1) rather + // than continuing to drain (polls == 2). + stream::unfold(false, move |sent| { + let polls = Arc::clone(&polls); + let error = error.clone(); + async move { + polls.fetch_add(1, Ordering::SeqCst); + if sent { + None + } else { + Some((Err(StreamError::Provider(error)), true)) + } + } + }) + .boxed() + } +} + +#[tokio::test] +async fn session_expired_aborts_retries_streaming() { + use futures_util::StreamExt; + + let stream_calls = Arc::new(AtomicUsize::new(0)); + let polls = Arc::new(AtomicUsize::new(0)); + let provider = ReliableProvider::new( + vec![( + "openhuman".into(), + Box::new(StreamingErrorMock { + stream_calls: Arc::clone(&stream_calls), + polls: Arc::clone(&polls), + error: "SESSION_EXPIRED: backend session not active — sign in to resume LLM work", + }), + )], + 3, + 1, + ); + + let mut stream = provider.stream_chat_with_system( + None, + "hello", + "reasoning-v1", + 0.0, + StreamOptions::new(true), + ); + + // Drain the consumer-facing stream. ReliableProvider does NOT forward + // candidate errors — the consumer only sees a single terminal + // "All streaming providers/models failed" once retries are exhausted. + let mut terminal: Option = None; + while let Some(item) = stream.next().await { + if let Err(StreamError::Provider(msg)) = item { + terminal = Some(msg); + } + } + + assert_eq!( + stream_calls.load(Ordering::SeqCst), + 1, + "single candidate (one provider, one model) must build exactly one stream" + ); + assert_eq!( + polls.load(Ordering::SeqCst), + 1, + "session-expired must abort the streaming retry loop after the first poll; \ + a second poll means is_stream_error_non_retryable misclassified it" + ); + let terminal = terminal.expect("stream must surface a terminal aggregate error"); + assert!( + terminal.contains("All streaming providers/models failed"), + "expected aggregate failure terminal, got: {terminal}" + ); +} + #[tokio::test] async fn aggregated_error_marks_non_retryable_model_mismatch_with_details() { let calls = Arc::new(AtomicUsize::new(0));