Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion code-rs/tui/src/chatwidget.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17019,8 +17019,22 @@ impl ChatWidget<'_> {
if self.rate_limit_fetch_inflight {
return false;
}
let now = Utc::now();
let reset_due = self
.rate_limit_primary_next_reset_at
.into_iter()
.chain(self.rate_limit_secondary_next_reset_at)
.any(|reset_at| {
now >= reset_at
&& self
.rate_limit_last_fetch_at
.is_none_or(|last_fetch| last_fetch < reset_at)
});
if reset_due {
return true;
}
match self.rate_limit_last_fetch_at {
Some(ts) => Utc::now() - ts > RATE_LIMIT_REFRESH_INTERVAL,
Some(ts) => now - ts > RATE_LIMIT_REFRESH_INTERVAL,
None => true,
}
}
Expand Down Expand Up @@ -32748,6 +32762,39 @@ use code_core::protocol::OrderMeta;
});
}

#[test]
fn limits_refreshes_when_reset_time_has_passed() {
let mut harness = ChatWidgetHarness::new();
harness.with_chat(|chat| {
chat.rate_limit_fetch_inflight = false;
let now = Utc::now();
let expired_reset = now - ChronoDuration::minutes(1);
chat.rate_limit_last_fetch_at = Some(expired_reset - ChronoDuration::seconds(1));
chat.rate_limit_primary_next_reset_at = Some(expired_reset);
chat.rate_limit_secondary_next_reset_at = Some(now + ChronoDuration::hours(1));

assert!(
chat.should_refresh_limits(),
"expired hourly reset should force refresh even inside normal interval"
);

chat.rate_limit_primary_next_reset_at = Some(now + ChronoDuration::hours(1));
chat.rate_limit_secondary_next_reset_at = Some(expired_reset);

assert!(
chat.should_refresh_limits(),
"expired weekly reset should force refresh even inside normal interval"
);

chat.rate_limit_last_fetch_at = Some(expired_reset + ChronoDuration::seconds(1));

assert!(
!chat.should_refresh_limits(),
"successful fetch after reset should not refresh repeatedly"
);
});
}

#[test]
fn apply_context_mode_selection_persists_disabled_override() {
let _runtime_guard = enter_test_runtime_guard();
Expand Down
110 changes: 94 additions & 16 deletions code-rs/tui/src/chatwidget/rate_limit_refresh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ use std::sync::{Arc, Mutex};
use anyhow::{Context, Result};
use code_core::auth::auth_for_stored_account;
use code_core::auth_accounts::{self, StoredAccount};
use code_core::{AuthManager, ModelClient, Prompt, ResponseEvent};
use code_core::{AuthManager, ModelClient, ModelProviderInfo, Prompt, ResponseEvent, WireApi};
use code_core::account_usage;
use code_core::config::Config;
use code_core::config_types::ReasoningEffort;
use code_core::debug_logger::DebugLogger;
use code_core::model_family::find_family_for_model;
use code_core::model_family::ModelFamily;
use code_core::protocol::{Event, EventMsg, RateLimitSnapshotEvent, TokenCountEvent};
use code_login::AuthMode;
use code_protocol::models::{ContentItem, ResponseItem};
Expand Down Expand Up @@ -142,20 +144,12 @@ fn run_refresh(

let client = build_model_client(&config, auth_mgr, debug_enabled)?;

let mut prompt = Prompt::default();
prompt.store = false;
prompt.user_instructions = config.user_instructions.clone();
prompt.base_instructions_override = config.base_instructions.clone();
prompt.input.push(ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "Yield immediately with only the message \"ok\"".to_string(),
}],
end_turn: None,
phase: None,
});
prompt.set_log_tag("tui/rate_limit_refresh");
let prompt = build_rate_limit_refresh_prompt(
&config.model,
&config.model_family,
config.user_instructions.clone(),
config.base_instructions.clone(),
);

let mut stream = client
.stream(&prompt)
Expand Down Expand Up @@ -236,6 +230,32 @@ fn run_refresh(
})
}

fn build_rate_limit_refresh_prompt(
model: &str,
fallback_family: &ModelFamily,
user_instructions: Option<String>,
base_instructions: Option<String>,
) -> Prompt {
let mut prompt = Prompt::default();
prompt.store = false;
let mut refresh_family = find_family_for_model(model).unwrap_or_else(|| fallback_family.clone());
refresh_family.prefer_websockets = false;
prompt.model_family_override = Some(refresh_family);
prompt.user_instructions = user_instructions;
prompt.base_instructions_override = base_instructions;
prompt.input.push(ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "Yield immediately with only the message \"ok\"".to_string(),
}],
end_turn: None,
phase: None,
});
prompt.set_log_tag("tui/rate_limit_refresh");
prompt
}

fn build_runtime() -> Result<Runtime> {
Ok(
tokio::runtime::Builder::new_multi_thread()
Expand All @@ -258,7 +278,7 @@ fn build_model_client(
Arc::new(config.clone()),
Some(auth_mgr),
None,
config.model_provider.clone(),
rate_limit_refresh_provider(&config.model_provider),
ReasoningEffort::Low,
config.model_reasoning_summary,
config.model_text_verbosity,
Expand All @@ -268,3 +288,61 @@ fn build_model_client(

Ok(client)
}

fn rate_limit_refresh_provider(provider: &ModelProviderInfo) -> ModelProviderInfo {
let mut provider = provider.clone();
if matches!(provider.wire_api, WireApi::ResponsesWebsocket) {
provider.wire_api = WireApi::Responses;
}
provider
}

#[cfg(test)]
mod tests {
use super::*;
use code_core::model_family::derive_default_model_family;

#[test]
fn rate_limit_refresh_prompt_forces_http_transport() {
let mut family = derive_default_model_family("gpt-5.5");
family.prefer_websockets = true;

let prompt = build_rate_limit_refresh_prompt("unknown-model", &family, None, None);

assert!(
!prompt
.model_family_override
.expect("refresh prompt should set model family")
.prefer_websockets,
"rate limit refresh depends on HTTP response headers"
);
}

#[test]
fn rate_limit_refresh_provider_uses_responses_http() {
let provider = ModelProviderInfo {
name: "test".to_string(),
base_url: Some("https://example.test/v1".to_string()),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
auth: None,
wire_api: WireApi::ResponsesWebsocket,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
websocket_connect_timeout_ms: None,
requires_openai_auth: true,
openrouter: None,
};

assert_eq!(
rate_limit_refresh_provider(&provider).wire_api,
WireApi::Responses,
"rate limit refresh needs HTTP response headers"
);
}
}