Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/openhuman/inference/provider/reliable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,20 @@ fn is_non_retryable(err: &anyhow::Error) -> bool {
if is_context_window_exceeded(err) {
return true;
}
let msg = err.to_string();
// Session-expired is a user-auth-state boundary condition, not a
// transient provider outage. Retrying just burns attempts and delays
// the sign-in prompt.
if crate::core::observability::is_session_expired_message(&msg) {
return true;
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>() {
if let Some(status) = reqwest_err.status() {
let code = status.as_u16();
return status.is_client_error() && code != 429 && code != 408;
}
}
let msg = err.to_string();
for word in msg.split(|c: char| !c.is_ascii_digit()) {
if let Ok(code) = word.parse::<u16>() {
if (400..500).contains(&code) {
Expand Down Expand Up @@ -72,6 +78,13 @@ fn is_stream_error_non_retryable(err: &StreamError) -> bool {
false
}
StreamError::Provider(msg) => {
// Mirror the non-streaming classifier: session-expired is a
// user-auth-state boundary, not a transient provider outage —
// fail fast so the streaming caller can prompt sign-in instead
// of burning the retry budget.
if crate::core::observability::is_session_expired_message(msg) {
return true;
}
let lower = msg.to_lowercase();
lower.contains("invalid api key")
|| lower.contains("unauthorized")
Expand Down
165 changes: 165 additions & 0 deletions src/openhuman/inference/provider/reliable_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ fn non_retryable_detects_common_patterns() {
assert!(is_non_retryable(&anyhow::anyhow!(
"OpenAI Codex stream error: Your input exceeds the context window of this model."
)));
assert!(is_non_retryable(&anyhow::anyhow!(
"SESSION_EXPIRED: backend session not active — sign in to resume LLM work"
)));
}

#[tokio::test]
Expand Down Expand Up @@ -253,6 +256,168 @@ async fn context_window_error_aborts_retries_and_model_fallbacks() {
assert_eq!(calls.load(Ordering::SeqCst), 1);
}

#[tokio::test]
async fn session_expired_aborts_retries() {
let calls = Arc::new(AtomicUsize::new(0));
let provider = ReliableProvider::new(
vec![(
"openhuman".into(),
Box::new(MockProvider {
calls: Arc::clone(&calls),
fail_until_attempt: usize::MAX,
response: "never",
error: "SESSION_EXPIRED: backend session not active — sign in to resume LLM work",
}),
)],
3,
1,
);

let err = provider
.simple_chat("hello", "reasoning-v1", 0.0)
.await
.expect_err("session-expired should fail fast");
let msg = err.to_string();

assert_eq!(
calls.load(Ordering::SeqCst),
1,
"session-expired must skip retry loop"
);
assert!(
msg.contains("non_retryable"),
"aggregate should classify SESSION_EXPIRED as non_retryable: {msg}"
);
assert!(
!msg.contains("attempt 2/4"),
"aggregate should contain only the first attempt for this provider: {msg}"
);
}

/// Streaming-path mock that emits a single configurable `StreamError::Provider`
/// then ends, and tracks how many times the stream was created (`stream_calls`)
/// and how many times the consumer polled it (`polls`). The latter is the
/// signal used by [`session_expired_aborts_retries_streaming`] to prove that
/// `is_stream_error_non_retryable` broke the retry loop after the first error
/// instead of polling for further attempts.
struct StreamingErrorMock {
stream_calls: Arc<AtomicUsize>,
polls: Arc<AtomicUsize>,
error: &'static str,
}

#[async_trait]
impl Provider for StreamingErrorMock {
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
anyhow::bail!(self.error)
}

async fn chat_with_history(
&self,
_messages: &[ChatMessage],
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
anyhow::bail!(self.error)
}

fn supports_streaming(&self) -> bool {
true
}

fn stream_chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
_options: StreamOptions,
) -> futures_util::stream::BoxStream<'static, StreamResult<StreamChunk>> {
use futures_util::{stream, StreamExt};
self.stream_calls.fetch_add(1, Ordering::SeqCst);
let polls = Arc::clone(&self.polls);
let error = self.error.to_string();
// `unfold` state: `sent` flips to true after the first poll. The
// counter bumps on every poll so the test can prove that the retry
// loop short-circuited after the first error (polls == 1) rather
// than continuing to drain (polls == 2).
stream::unfold(false, move |sent| {
let polls = Arc::clone(&polls);
let error = error.clone();
async move {
polls.fetch_add(1, Ordering::SeqCst);
if sent {
None
} else {
Some((Err(StreamError::Provider(error)), true))
}
}
})
.boxed()
}
}

#[tokio::test]
async fn session_expired_aborts_retries_streaming() {
use futures_util::StreamExt;

let stream_calls = Arc::new(AtomicUsize::new(0));
let polls = Arc::new(AtomicUsize::new(0));
let provider = ReliableProvider::new(
vec![(
"openhuman".into(),
Box::new(StreamingErrorMock {
stream_calls: Arc::clone(&stream_calls),
polls: Arc::clone(&polls),
error: "SESSION_EXPIRED: backend session not active — sign in to resume LLM work",
}),
)],
3,
1,
);

let mut stream = provider.stream_chat_with_system(
None,
"hello",
"reasoning-v1",
0.0,
StreamOptions::new(true),
);

// Drain the consumer-facing stream. ReliableProvider does NOT forward
// candidate errors — the consumer only sees a single terminal
// "All streaming providers/models failed" once retries are exhausted.
let mut terminal: Option<String> = None;
while let Some(item) = stream.next().await {
if let Err(StreamError::Provider(msg)) = item {
terminal = Some(msg);
}
}

assert_eq!(
stream_calls.load(Ordering::SeqCst),
1,
"single candidate (one provider, one model) must build exactly one stream"
);
assert_eq!(
polls.load(Ordering::SeqCst),
1,
"session-expired must abort the streaming retry loop after the first poll; \
a second poll means is_stream_error_non_retryable misclassified it"
);
let terminal = terminal.expect("stream must surface a terminal aggregate error");
assert!(
terminal.contains("All streaming providers/models failed"),
"expected aggregate failure terminal, got: {terminal}"
);
}

#[tokio::test]
async fn aggregated_error_marks_non_retryable_model_mismatch_with_details() {
let calls = Arc::new(AtomicUsize::new(0));
Expand Down
Loading