Skip to content

Commit a87e178

Browse files
committed
Refresh stale agent tasks lazily
1 parent 42c6583 commit a87e178

7 files changed

Lines changed: 286 additions & 25 deletions

File tree

codex-rs/core/src/agent_identity.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ mod task_registration;
3434

3535
#[cfg(test)]
3636
pub(crate) use assertion::AgentAssertionEnvelope;
37+
pub(crate) use assertion::AgentTaskRuntimeMismatch;
3738
pub(crate) use task_registration::RegisteredAgentTask;
3839

3940
const AGENT_REGISTRATION_TIMEOUT: Duration = Duration::from_secs(15);
@@ -451,6 +452,9 @@ impl AgentIdentityBinding {
451452
}
452453

453454
fn from_auth(auth: &CodexAuth, forced_workspace_id: Option<String>) -> Option<Self> {
455+
// AgentAssertion is currently supported only for ChatGPT-backed Codex sessions. API-key
456+
// sessions keep using their API key until the registration service supports API-key
457+
// identity binding.
454458
if !auth.is_chatgpt_auth() {
455459
return None;
456460
}

codex-rs/core/src/agent_identity/assertion.rs

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,21 @@ use base64::engine::general_purpose::URL_SAFE_NO_PAD;
77
use ed25519_dalek::Signer as _;
88
use serde::Deserialize;
99
use serde::Serialize;
10+
use thiserror::Error;
1011
use tracing::debug;
1112

1213
use super::*;
1314

15+
#[derive(Debug, Error)]
16+
#[error(
17+
"agent task runtime {agent_runtime_id} does not match stored agent identity {stored_agent_runtime_id}"
18+
)]
19+
pub(crate) struct AgentTaskRuntimeMismatch {
20+
pub(crate) agent_runtime_id: String,
21+
pub(crate) task_id: String,
22+
pub(crate) stored_agent_runtime_id: String,
23+
}
24+
1425
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
1526
pub(crate) struct AgentAssertionEnvelope {
1627
pub(crate) agent_runtime_id: String,
@@ -31,12 +42,14 @@ impl AgentIdentityManager {
3142
let Some(stored_identity) = self.ensure_registered_identity().await? else {
3243
return Ok(None);
3344
};
34-
anyhow::ensure!(
35-
stored_identity.agent_runtime_id == agent_task.agent_runtime_id,
36-
"agent task runtime {} does not match stored agent identity {}",
37-
agent_task.agent_runtime_id,
38-
stored_identity.agent_runtime_id
39-
);
45+
if stored_identity.agent_runtime_id != agent_task.agent_runtime_id {
46+
return Err(AgentTaskRuntimeMismatch {
47+
agent_runtime_id: agent_task.agent_runtime_id.clone(),
48+
task_id: agent_task.task_id.clone(),
49+
stored_agent_runtime_id: stored_identity.agent_runtime_id,
50+
}
51+
.into());
52+
}
4053

4154
let timestamp = Utc::now().to_rfc3339_opts(SecondsFormat::Secs, true);
4255
let envelope = AgentAssertionEnvelope {
@@ -176,6 +189,39 @@ mod tests {
176189
.expect("signature should verify");
177190
}
178191

192+
#[tokio::test]
193+
async fn authorization_header_for_task_reports_runtime_mismatch() {
194+
let codex_home = tempfile::tempdir().expect("tempdir");
195+
let auth = make_chatgpt_auth(codex_home.path(), "account-123", Some("user-123"));
196+
let auth_manager = AuthManager::from_auth_for_testing(auth);
197+
let manager = AgentIdentityManager::new_for_tests(
198+
auth_manager,
199+
/*feature_enabled*/ true,
200+
"https://chatgpt.com/backend-api/".to_string(),
201+
SessionSource::Cli,
202+
);
203+
manager
204+
.seed_generated_identity_for_tests("agent-current")
205+
.await
206+
.expect("seed test identity");
207+
let agent_task = RegisteredAgentTask {
208+
agent_runtime_id: "agent-stale".to_string(),
209+
task_id: "task-123".to_string(),
210+
registered_at: "2026-03-23T12:00:00Z".to_string(),
211+
};
212+
213+
let error = manager
214+
.authorization_header_for_task(&agent_task)
215+
.await
216+
.expect_err("stale task should be reported");
217+
let mismatch = error
218+
.downcast_ref::<AgentTaskRuntimeMismatch>()
219+
.expect("runtime mismatch error");
220+
assert_eq!(mismatch.agent_runtime_id, "agent-stale");
221+
assert_eq!(mismatch.task_id, "task-123");
222+
assert_eq!(mismatch.stored_agent_runtime_id, "agent-current");
223+
}
224+
179225
fn make_chatgpt_auth(
180226
codex_home: &std::path::Path,
181227
account_id: &str,

codex-rs/core/src/client.rs

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ use std::sync::atomic::AtomicU64;
3232
use std::sync::atomic::Ordering;
3333

3434
use crate::agent_identity::AgentIdentityManager;
35+
use crate::agent_identity::AgentTaskRuntimeMismatch;
3536
use crate::agent_identity::RegisteredAgentTask;
3637
use codex_api::ApiError;
3738
use codex_api::CompactClient as ApiCompactClient;
@@ -159,7 +160,7 @@ struct ModelClientState {
159160
include_timing_metrics: bool,
160161
beta_features_header: Option<String>,
161162
disable_websockets: AtomicBool,
162-
cached_websocket_session: StdMutex<WebsocketSession>,
163+
cached_websocket_session: StdMutex<CachedWebsocketSession>,
163164
}
164165

165166
/// Resolved API client setup for a single request attempt.
@@ -244,6 +245,12 @@ struct WebsocketSession {
244245
connection_reused: StdMutex<bool>,
245246
}
246247

248+
#[derive(Debug, Default)]
249+
struct CachedWebsocketSession {
250+
agent_task: Option<RegisteredAgentTask>,
251+
websocket_session: WebsocketSession,
252+
}
253+
247254
impl WebsocketSession {
248255
fn set_connection_reused(&self, connection_reused: bool) {
249256
*self
@@ -360,7 +367,7 @@ impl ModelClient {
360367
include_timing_metrics,
361368
beta_features_header,
362369
disable_websockets: AtomicBool::new(false),
363-
cached_websocket_session: StdMutex::new(WebsocketSession::default()),
370+
cached_websocket_session: StdMutex::new(CachedWebsocketSession::default()),
364371
}),
365372
}
366373
}
@@ -377,18 +384,15 @@ impl ModelClient {
377384
&self,
378385
agent_task: Option<RegisteredAgentTask>,
379386
) -> ModelClientSession {
380-
let cache_websocket_session_on_drop = agent_task.is_none();
381-
let websocket_session = if agent_task.is_some() {
382-
drop(self.take_cached_websocket_session());
383-
WebsocketSession::default()
384-
} else {
385-
self.take_cached_websocket_session()
386-
};
387+
// WebSocket auth is bound to the task that opened the connection. Reuse only when the
388+
// cached connection was created for the same task, and drop mismatched taskless/task-scoped
389+
// sessions rather than mixing auth contexts.
390+
let websocket_session = self.take_cached_websocket_session(agent_task.as_ref());
387391
ModelClientSession {
388392
client: self.clone(),
389393
websocket_session,
390394
agent_task,
391-
cache_websocket_session_on_drop,
395+
cache_websocket_session_on_drop: true,
392396
turn_state: Arc::new(OnceLock::new()),
393397
}
394398
}
@@ -401,12 +405,12 @@ impl ModelClient {
401405
self.state
402406
.window_generation
403407
.store(window_generation, Ordering::Relaxed);
404-
self.store_cached_websocket_session(WebsocketSession::default());
408+
self.clear_cached_websocket_session();
405409
}
406410

407411
pub(crate) fn advance_window_generation(&self) {
408412
self.state.window_generation.fetch_add(1, Ordering::Relaxed);
409-
self.store_cached_websocket_session(WebsocketSession::default());
413+
self.clear_cached_websocket_session();
410414
}
411415

412416
fn current_window_id(&self) -> String {
@@ -415,21 +419,44 @@ impl ModelClient {
415419
format!("{conversation_id}:{window_generation}")
416420
}
417421

418-
fn take_cached_websocket_session(&self) -> WebsocketSession {
422+
fn take_cached_websocket_session(
423+
&self,
424+
agent_task: Option<&RegisteredAgentTask>,
425+
) -> WebsocketSession {
419426
let mut cached_websocket_session = self
420427
.state
421428
.cached_websocket_session
422429
.lock()
423430
.unwrap_or_else(std::sync::PoisonError::into_inner);
424-
std::mem::take(&mut *cached_websocket_session)
431+
if cached_websocket_session.agent_task.as_ref() == agent_task {
432+
return std::mem::take(&mut *cached_websocket_session).websocket_session;
433+
}
434+
435+
*cached_websocket_session = CachedWebsocketSession::default();
436+
WebsocketSession::default()
437+
}
438+
439+
fn store_cached_websocket_session(
440+
&self,
441+
agent_task: Option<RegisteredAgentTask>,
442+
websocket_session: WebsocketSession,
443+
) {
444+
*self
445+
.state
446+
.cached_websocket_session
447+
.lock()
448+
.unwrap_or_else(std::sync::PoisonError::into_inner) = CachedWebsocketSession {
449+
agent_task,
450+
websocket_session,
451+
};
425452
}
426453

427-
fn store_cached_websocket_session(&self, websocket_session: WebsocketSession) {
454+
fn clear_cached_websocket_session(&self) {
428455
*self
429456
.state
430457
.cached_websocket_session
431458
.lock()
432-
.unwrap_or_else(std::sync::PoisonError::into_inner) = websocket_session;
459+
.unwrap_or_else(std::sync::PoisonError::into_inner) = CachedWebsocketSession::default();
433460
}
434461

435462
pub(crate) fn force_http_fallback(
@@ -449,7 +476,7 @@ impl ModelClient {
449476
);
450477
}
451478

452-
self.store_cached_websocket_session(WebsocketSession::default());
479+
self.clear_cached_websocket_session();
453480
activated
454481
}
455482

@@ -727,6 +754,15 @@ impl ModelClient {
727754
.authorization_header_for_task(agent_task)
728755
.await
729756
.map_err(|err| {
757+
if let Some(mismatch) = err.downcast_ref::<AgentTaskRuntimeMismatch>() {
758+
debug!(
759+
agent_runtime_id = %mismatch.agent_runtime_id,
760+
task_id = %mismatch.task_id,
761+
stored_agent_runtime_id = %mismatch.stored_agent_runtime_id,
762+
"agent task no longer matches stored identity"
763+
);
764+
return CodexErr::AgentTaskStale;
765+
}
730766
CodexErr::Stream(
731767
format!("failed to build agent assertion authorization: {err}"),
732768
None,
@@ -883,12 +919,16 @@ impl Drop for ModelClientSession {
883919
let websocket_session = std::mem::take(&mut self.websocket_session);
884920
if self.cache_websocket_session_on_drop {
885921
self.client
886-
.store_cached_websocket_session(websocket_session);
922+
.store_cached_websocket_session(self.agent_task.clone(), websocket_session);
887923
}
888924
}
889925
}
890926

891927
impl ModelClientSession {
928+
pub(crate) fn agent_task(&self) -> Option<&RegisteredAgentTask> {
929+
self.agent_task.as_ref()
930+
}
931+
892932
pub(crate) fn disable_cached_websocket_session_on_drop(&mut self) {
893933
self.cache_websocket_session_on_drop = false;
894934
}

codex-rs/core/src/client_tests.rs

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use codex_model_provider_info::create_oss_provider_with_base_url;
3333
use codex_otel::SessionTelemetry;
3434
use codex_protocol::ThreadId;
3535
use codex_protocol::config_types::ReasoningSummary;
36+
use codex_protocol::error::CodexErr;
3637
use codex_protocol::models::ContentItem;
3738
use codex_protocol::models::ResponseItem;
3839
use codex_protocol::openai_models::ModelInfo;
@@ -393,6 +394,35 @@ async fn responses_http_uses_agent_assertion_when_agent_task_is_present() {
393394
assert_eq!(request.header("chatgpt-account-id"), None);
394395
}
395396

397+
#[tokio::test]
398+
async fn responses_http_reports_stale_agent_task_when_identity_changed() {
399+
let provider = create_oss_provider_with_base_url("https://example.com/v1", WireApi::Responses);
400+
let (_codex_home, client, mut agent_task, _stored_identity) =
401+
model_client_with_agent_task(provider).await;
402+
agent_task.agent_runtime_id = "agent-stale".to_string();
403+
let model_info = test_model_info();
404+
let session_telemetry = test_session_telemetry();
405+
let mut client_session = client.new_session_with_agent_task(Some(agent_task));
406+
407+
let error = match client_session
408+
.stream(
409+
&test_prompt("hello"),
410+
&model_info,
411+
&session_telemetry,
412+
/*effort*/ None,
413+
ReasoningSummary::Auto,
414+
/*service_tier*/ None,
415+
/*turn_metadata_header*/ None,
416+
)
417+
.await
418+
{
419+
Ok(_) => panic!("stale task should be reported before sending a request"),
420+
Err(error) => error,
421+
};
422+
423+
assert!(matches!(error, CodexErr::AgentTaskStale));
424+
}
425+
396426
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
397427
async fn websocket_agent_task_bypasses_cached_bearer_prewarm() {
398428
core_test_support::skip_if_no_network!();
@@ -469,3 +499,80 @@ async fn websocket_agent_task_bypasses_cached_bearer_prewarm() {
469499

470500
server.shutdown().await;
471501
}
502+
503+
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
504+
async fn websocket_agent_task_reuses_cached_connection_for_same_task() {
505+
core_test_support::skip_if_no_network!();
506+
507+
let server = responses::start_websocket_server(vec![vec![
508+
vec![
509+
responses::ev_response_created("resp-1"),
510+
responses::ev_completed("resp-1"),
511+
],
512+
vec![
513+
responses::ev_response_created("resp-2"),
514+
responses::ev_completed("resp-2"),
515+
],
516+
]])
517+
.await;
518+
let mut provider =
519+
create_oss_provider_with_base_url(&format!("{}/v1", server.uri()), WireApi::Responses);
520+
provider.supports_websockets = true;
521+
provider.websocket_connect_timeout_ms = Some(5_000);
522+
let (_codex_home, client, agent_task, stored_identity) =
523+
model_client_with_agent_task(provider).await;
524+
let model_info = test_model_info();
525+
let session_telemetry = test_session_telemetry();
526+
let prompt = test_prompt("hello");
527+
528+
{
529+
let mut first_session = client.new_session_with_agent_task(Some(agent_task.clone()));
530+
let mut stream = first_session
531+
.stream(
532+
&prompt,
533+
&model_info,
534+
&session_telemetry,
535+
/*effort*/ None,
536+
ReasoningSummary::Auto,
537+
/*service_tier*/ None,
538+
/*turn_metadata_header*/ None,
539+
)
540+
.await
541+
.expect("first agent task stream should succeed");
542+
drain_stream_to_completion(&mut stream)
543+
.await
544+
.expect("first agent task websocket stream should complete");
545+
}
546+
547+
let mut second_session = client.new_session_with_agent_task(Some(agent_task.clone()));
548+
let mut stream = second_session
549+
.stream(
550+
&prompt,
551+
&model_info,
552+
&session_telemetry,
553+
/*effort*/ None,
554+
ReasoningSummary::Auto,
555+
/*service_tier*/ None,
556+
/*turn_metadata_header*/ None,
557+
)
558+
.await
559+
.expect("second agent task stream should succeed");
560+
drain_stream_to_completion(&mut stream)
561+
.await
562+
.expect("second agent task websocket stream should complete");
563+
564+
let handshakes = server.handshakes();
565+
assert_eq!(handshakes.len(), 1);
566+
let agent_authorization = handshakes[0]
567+
.header("authorization")
568+
.expect("agent handshake should include authorization");
569+
assert_agent_assertion_header(
570+
&agent_authorization,
571+
&stored_identity,
572+
&agent_task.agent_runtime_id,
573+
&agent_task.task_id,
574+
);
575+
assert_eq!(server.single_connection().len(), 2);
576+
577+
server.shutdown().await;
578+
}

0 commit comments

Comments
 (0)