diff --git a/codex-rs/app-server/tests/suite/v2/app_list.rs b/codex-rs/app-server/tests/suite/v2/app_list.rs index dbe61524f5e..f56f5b51c29 100644 --- a/codex-rs/app-server/tests/suite/v2/app_list.rs +++ b/codex-rs/app-server/tests/suite/v2/app_list.rs @@ -56,7 +56,7 @@ use tokio::net::TcpListener; use tokio::task::JoinHandle; use tokio::time::timeout; -const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10); +const DEFAULT_TIMEOUT: Duration = Duration::from_secs(20); #[tokio::test] async fn list_apps_returns_empty_when_connectors_disabled() -> Result<()> { diff --git a/codex-rs/codex-api/src/api_bridge.rs b/codex-rs/codex-api/src/api_bridge.rs index 0ad2b139795..740f7cb783d 100644 --- a/codex-rs/codex-api/src/api_bridge.rs +++ b/codex-rs/codex-api/src/api_bridge.rs @@ -178,13 +178,33 @@ struct UsageErrorBody { pub struct CoreAuthProvider { pub token: Option, pub account_id: Option, + authorization_header_override: Option, } impl CoreAuthProvider { + pub fn from_bearer_token(token: Option, account_id: Option) -> Self { + Self { + token, + account_id, + authorization_header_override: None, + } + } + + pub fn from_authorization_header_value( + authorization_header_value: Option, + account_id: Option, + ) -> Self { + Self { + token: None, + account_id, + authorization_header_override: authorization_header_value, + } + } + pub fn auth_header_attached(&self) -> bool { - self.token + self.authorization_header_value() .as_ref() - .is_some_and(|token| http::HeaderValue::from_str(&format!("Bearer {token}")).is_ok()) + .is_some_and(|value| http::HeaderValue::from_str(value).is_ok()) } pub fn auth_header_name(&self) -> Option<&'static str> { @@ -195,8 +215,20 @@ impl CoreAuthProvider { Self { token: token.map(str::to_string), account_id: account_id.map(str::to_string), + authorization_header_override: None, } } + + #[cfg(test)] + pub fn for_test_authorization_header( + authorization_header_value: Option<&str>, + account_id: Option<&str>, + ) -> Self { + Self::from_authorization_header_value( + authorization_header_value.map(str::to_string), + account_id.map(str::to_string), + ) + } } impl ApiAuthProvider for CoreAuthProvider { @@ -204,6 +236,12 @@ impl ApiAuthProvider for CoreAuthProvider { self.token.clone() } + fn authorization_header_value(&self) -> Option { + self.authorization_header_override + .clone() + .or_else(|| self.bearer_token().map(|token| format!("Bearer {token}"))) + } + fn account_id(&self) -> Option { self.account_id.clone() } diff --git a/codex-rs/codex-api/src/api_bridge_tests.rs b/codex-rs/codex-api/src/api_bridge_tests.rs index 71d3889915c..51c7c8d2fba 100644 --- a/codex-rs/codex-api/src/api_bridge_tests.rs +++ b/codex-rs/codex-api/src/api_bridge_tests.rs @@ -133,11 +133,27 @@ fn map_api_error_extracts_identity_auth_details_from_headers() { #[test] fn core_auth_provider_reports_when_auth_header_will_attach() { - let auth = CoreAuthProvider { - token: Some("access-token".to_string()), - account_id: None, - }; + let auth = CoreAuthProvider::from_bearer_token( + Some("access-token".to_string()), + /*account_id*/ None, + ); assert!(auth.auth_header_attached()); assert_eq!(auth.auth_header_name(), Some("authorization")); } + +#[test] +fn core_auth_provider_supports_non_bearer_authorization_headers() { + let auth = CoreAuthProvider::for_test_authorization_header( + Some("AgentAssertion opaque-token"), + /*account_id*/ None, + ); + + assert!(auth.auth_header_attached()); + assert_eq!(auth.auth_header_name(), Some("authorization")); + assert_eq!(auth.bearer_token(), None); + assert_eq!( + auth.authorization_header_value(), + Some("AgentAssertion opaque-token".to_string()) + ); +} diff --git a/codex-rs/codex-api/src/auth.rs b/codex-rs/codex-api/src/auth.rs index f649062db1f..4f27264a97a 100644 --- a/codex-rs/codex-api/src/auth.rs +++ b/codex-rs/codex-api/src/auth.rs @@ -9,14 +9,17 @@ use http::HeaderValue; /// reach this interface. pub trait AuthProvider: Send + Sync { fn bearer_token(&self) -> Option; + fn authorization_header_value(&self) -> Option { + self.bearer_token().map(|token| format!("Bearer {token}")) + } fn account_id(&self) -> Option { None } } pub(crate) fn add_auth_headers_to_header_map(auth: &A, headers: &mut HeaderMap) { - if let Some(token) = auth.bearer_token() - && let Ok(header) = HeaderValue::from_str(&format!("Bearer {token}")) + if let Some(authorization) = auth.authorization_header_value() + && let Ok(header) = HeaderValue::from_str(&authorization) { let _ = headers.insert(http::header::AUTHORIZATION, header); } diff --git a/codex-rs/codex-api/src/files.rs b/codex-rs/codex-api/src/files.rs index 6fad5b62f58..b1c9c31919f 100644 --- a/codex-rs/codex-api/src/files.rs +++ b/codex-rs/codex-api/src/files.rs @@ -5,6 +5,7 @@ use std::time::Duration; use crate::AuthProvider; use codex_client::build_reqwest_client_with_custom_ca; use reqwest::StatusCode; +use reqwest::header::AUTHORIZATION; use reqwest::header::CONTENT_LENGTH; use serde::Deserialize; use tokio::fs::File; @@ -260,8 +261,8 @@ fn authorized_request( let mut request = client .request(method, url) .timeout(OPENAI_FILE_REQUEST_TIMEOUT); - if let Some(token) = auth.bearer_token() { - request = request.bearer_auth(token); + if let Some(authorization) = auth.authorization_header_value() { + request = request.header(AUTHORIZATION, authorization); } if let Some(account_id) = auth.account_id() { request = request.header("chatgpt-account-id", account_id); @@ -307,6 +308,7 @@ mod tests { let server = MockServer::start().await; Mock::given(method("POST")) .and(path("/backend-api/files")) + .and(header("authorization", "Bearer token")) .and(header("chatgpt-account-id", "account_id")) .and(body_json(serde_json::json!({ "file_name": "hello.txt", @@ -367,4 +369,56 @@ mod tests { assert_eq!(uploaded.mime_type, Some("text/plain".to_string())); assert_eq!(finalize_attempts.load(Ordering::SeqCst), 2); } + + #[tokio::test] + async fn upload_local_file_uses_authorization_header_value() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/backend-api/files")) + .and(header("authorization", "AgentAssertion test-assertion")) + .and(body_json(serde_json::json!({ + "file_name": "hello.txt", + "file_size": 5, + "use_case": "codex", + }))) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(serde_json::json!({"file_id": "file_123", "upload_url": format!("{}/upload/file_123", server.uri())})), + ) + .mount(&server) + .await; + Mock::given(method("PUT")) + .and(path("/upload/file_123")) + .and(header("content-length", "5")) + .respond_with(ResponseTemplate::new(200)) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/backend-api/files/file_123/uploaded")) + .and(header("authorization", "AgentAssertion test-assertion")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "status": "success", + "download_url": format!("{}/download/file_123", server.uri()), + "file_name": "hello.txt", + "mime_type": "text/plain", + "file_size_bytes": 5 + }))) + .mount(&server) + .await; + + let base_url = base_url_for(&server); + let dir = TempDir::new().expect("temp dir"); + let path = dir.path().join("hello.txt"); + tokio::fs::write(&path, b"hello").await.expect("write file"); + let auth = CoreAuthProvider::for_test_authorization_header( + Some("AgentAssertion test-assertion"), + /*account_id*/ None, + ); + + let uploaded = upload_local_file(&base_url, &auth, &path) + .await + .expect("upload succeeds"); + + assert_eq!(uploaded.file_id, "file_123"); + } } diff --git a/codex-rs/core/src/agent_identity.rs b/codex-rs/core/src/agent_identity.rs index 1fb4509bf1f..5586ecd26ce 100644 --- a/codex-rs/core/src/agent_identity.rs +++ b/codex-rs/core/src/agent_identity.rs @@ -27,12 +27,15 @@ use tracing::debug; use tracing::info; use tracing::warn; +use crate::config::Config; + +mod assertion; mod task_registration; +#[cfg(test)] +pub(crate) use assertion::AgentAssertionEnvelope; pub(crate) use task_registration::RegisteredAgentTask; -use crate::config::Config; - const AGENT_REGISTRATION_TIMEOUT: Duration = Duration::from_secs(15); const AGENT_IDENTITY_BISCUIT_TIMEOUT: Duration = Duration::from_secs(15); @@ -335,7 +338,7 @@ impl AgentIdentityManager { } #[cfg(test)] - fn new_for_tests( + pub(crate) fn new_for_tests( auth_manager: Arc, feature_enabled: bool, chatgpt_base_url: String, @@ -349,6 +352,30 @@ impl AgentIdentityManager { ensure_lock: Arc::new(Mutex::new(())), } } + + #[cfg(test)] + pub(crate) async fn seed_generated_identity_for_tests( + &self, + agent_runtime_id: &str, + ) -> Result { + let (auth, binding) = self + .current_auth_binding() + .await + .context("test agent identity requires ChatGPT auth")?; + let key_material = generate_agent_key_material()?; + let stored_identity = StoredAgentIdentity { + binding_id: binding.binding_id.clone(), + chatgpt_account_id: binding.chatgpt_account_id.clone(), + chatgpt_user_id: binding.chatgpt_user_id, + agent_runtime_id: agent_runtime_id.to_string(), + private_key_pkcs8_base64: key_material.private_key_pkcs8_base64, + public_key_ssh: key_material.public_key_ssh, + registered_at: Utc::now().to_rfc3339_opts(SecondsFormat::Secs, true), + abom: self.abom.clone(), + }; + self.store_identity(&auth, &stored_identity)?; + Ok(stored_identity) + } } impl StoredAgentIdentity { @@ -579,7 +606,7 @@ mod tests { .and(path("/v1/agent/register")) .and(header("x-openai-authorization", "human-biscuit")) .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "agent_runtime_id": "agent_123", + "agent_runtime_id": "agent-123", }))) .expect(1) .mount(&server) @@ -605,7 +632,7 @@ mod tests { .unwrap() .expect("identity should be reused"); - assert_eq!(first.agent_runtime_id, "agent_123"); + assert_eq!(first.agent_runtime_id, "agent-123"); assert_eq!(first, second); assert_eq!(first.abom.agent_harness_id, "codex-cli"); assert_eq!(first.chatgpt_account_id, "account-123"); @@ -621,7 +648,7 @@ mod tests { .and(path("/v1/agent/register")) .and(header("x-openai-authorization", "human-biscuit")) .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "agent_runtime_id": "agent_456", + "agent_runtime_id": "agent-456", }))) .expect(1) .mount(&server) @@ -653,11 +680,11 @@ mod tests { .unwrap() .expect("identity should be registered"); - assert_eq!(stored.agent_runtime_id, "agent_456"); + assert_eq!(stored.agent_runtime_id, "agent-456"); let persisted = auth .get_agent_identity(&binding.chatgpt_account_id) .expect("stored identity"); - assert_eq!(persisted.agent_runtime_id, "agent_456"); + assert_eq!(persisted.agent_runtime_id, "agent-456"); } #[tokio::test] diff --git a/codex-rs/core/src/agent_identity/assertion.rs b/codex-rs/core/src/agent_identity/assertion.rs new file mode 100644 index 00000000000..df2bb468574 --- /dev/null +++ b/codex-rs/core/src/agent_identity/assertion.rs @@ -0,0 +1,176 @@ +use std::collections::BTreeMap; + +use anyhow::Context; +use anyhow::Result; +use base64::Engine as _; +use base64::engine::general_purpose::URL_SAFE_NO_PAD; +use ed25519_dalek::Signer as _; +use serde::Deserialize; +use serde::Serialize; +use tracing::debug; + +use super::*; + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub(crate) struct AgentAssertionEnvelope { + pub(crate) agent_runtime_id: String, + pub(crate) task_id: String, + pub(crate) timestamp: String, + pub(crate) signature: String, +} + +impl AgentIdentityManager { + pub(crate) async fn authorization_header_for_task( + &self, + agent_task: &RegisteredAgentTask, + ) -> Result> { + if !self.feature_enabled { + return Ok(None); + } + + let Some(stored_identity) = self.ensure_registered_identity().await? else { + return Ok(None); + }; + anyhow::ensure!( + stored_identity.agent_runtime_id == agent_task.agent_runtime_id, + "agent task runtime {} does not match stored agent identity {}", + agent_task.agent_runtime_id, + stored_identity.agent_runtime_id + ); + + let timestamp = Utc::now().to_rfc3339_opts(SecondsFormat::Secs, true); + let envelope = AgentAssertionEnvelope { + agent_runtime_id: agent_task.agent_runtime_id.clone(), + task_id: agent_task.task_id.clone(), + timestamp: timestamp.clone(), + signature: sign_agent_assertion_payload(&stored_identity, agent_task, ×tamp)?, + }; + let serialized_assertion = serialize_agent_assertion(&envelope)?; + debug!( + agent_runtime_id = %envelope.agent_runtime_id, + task_id = %envelope.task_id, + "attaching agent assertion authorization to downstream request" + ); + Ok(Some(format!("AgentAssertion {serialized_assertion}"))) + } +} + +fn sign_agent_assertion_payload( + stored_identity: &StoredAgentIdentity, + agent_task: &RegisteredAgentTask, + timestamp: &str, +) -> Result { + let signing_key = stored_identity.signing_key()?; + let payload = format!( + "{}:{}:{timestamp}", + agent_task.agent_runtime_id, agent_task.task_id + ); + Ok(BASE64_STANDARD.encode(signing_key.sign(payload.as_bytes()).to_bytes())) +} + +fn serialize_agent_assertion(envelope: &AgentAssertionEnvelope) -> Result { + let payload = serde_json::to_vec(&BTreeMap::from([ + ("agent_runtime_id", envelope.agent_runtime_id.as_str()), + ("signature", envelope.signature.as_str()), + ("task_id", envelope.task_id.as_str()), + ("timestamp", envelope.timestamp.as_str()), + ])) + .context("failed to serialize agent assertion envelope")?; + Ok(URL_SAFE_NO_PAD.encode(payload)) +} + +#[cfg(test)] +mod tests { + use base64::engine::general_purpose::URL_SAFE_NO_PAD; + use ed25519_dalek::Signature; + use ed25519_dalek::Verifier as _; + use pretty_assertions::assert_eq; + + use super::*; + + #[tokio::test] + async fn authorization_header_for_task_skips_when_feature_is_disabled() { + let auth_manager = + AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); + let manager = AgentIdentityManager::new_for_tests( + auth_manager, + /*feature_enabled*/ false, + "https://chatgpt.com/backend-api/".to_string(), + SessionSource::Cli, + ); + let agent_task = RegisteredAgentTask { + agent_runtime_id: "agent-123".to_string(), + task_id: "task-123".to_string(), + registered_at: "2026-03-23T12:00:00Z".to_string(), + }; + + assert_eq!( + manager + .authorization_header_for_task(&agent_task) + .await + .unwrap(), + None + ); + } + + #[tokio::test] + async fn authorization_header_for_task_serializes_signed_agent_assertion() { + let auth_manager = + AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); + let manager = AgentIdentityManager::new_for_tests( + auth_manager, + /*feature_enabled*/ true, + "https://chatgpt.com/backend-api/".to_string(), + SessionSource::Cli, + ); + let stored_identity = manager + .seed_generated_identity_for_tests("agent-123") + .await + .expect("seed test identity"); + let agent_task = RegisteredAgentTask { + agent_runtime_id: "agent-123".to_string(), + task_id: "task-123".to_string(), + registered_at: "2026-03-23T12:00:00Z".to_string(), + }; + + let header = manager + .authorization_header_for_task(&agent_task) + .await + .expect("build agent assertion") + .expect("header should exist"); + let token = header + .strip_prefix("AgentAssertion ") + .expect("agent assertion scheme"); + let payload = URL_SAFE_NO_PAD + .decode(token) + .expect("valid base64url payload"); + let envelope: AgentAssertionEnvelope = + serde_json::from_slice(&payload).expect("valid assertion envelope"); + + assert_eq!( + envelope, + AgentAssertionEnvelope { + agent_runtime_id: "agent-123".to_string(), + task_id: "task-123".to_string(), + timestamp: envelope.timestamp.clone(), + signature: envelope.signature.clone(), + } + ); + let signature_bytes = BASE64_STANDARD + .decode(&envelope.signature) + .expect("valid base64 signature"); + let signature = Signature::from_slice(&signature_bytes).expect("valid signature bytes"); + let signing_key = stored_identity.signing_key().expect("signing key"); + signing_key + .verifying_key() + .verify( + format!( + "{}:{}:{}", + envelope.agent_runtime_id, envelope.task_id, envelope.timestamp + ) + .as_bytes(), + &signature, + ) + .expect("signature should verify"); + } +} diff --git a/codex-rs/core/src/arc_monitor.rs b/codex-rs/core/src/arc_monitor.rs index 13f33f6c900..1f9504463bc 100644 --- a/codex-rs/core/src/arc_monitor.rs +++ b/codex-rs/core/src/arc_monitor.rs @@ -13,6 +13,7 @@ use codex_login::CodexAuth; use codex_login::default_client::build_reqwest_client; use codex_protocol::models::MessagePhase; use codex_protocol::models::ResponseItem; +use reqwest::header::AUTHORIZATION; const ARC_MONITOR_TIMEOUT: Duration = Duration::from_secs(30); const CODEX_ARC_MONITOR_ENDPOINT_OVERRIDE: &str = "CODEX_ARC_MONITOR_ENDPOINT_OVERRIDE"; @@ -109,13 +110,31 @@ pub(crate) async fn monitor_action( }, None => None, }; - let token = if let Some(token) = read_non_empty_env_var(CODEX_ARC_MONITOR_TOKEN) { - token + let (authorization_header_value, account_id) = if let Some(token) = + read_non_empty_env_var(CODEX_ARC_MONITOR_TOKEN) + { + ( + format!("Bearer {token}"), + auth.as_ref().and_then(CodexAuth::get_account_id), + ) + } else if let Some(authorization_header_value) = + match sess.authorization_header_for_current_agent_task().await { + Ok(authorization_header_value) => authorization_header_value, + Err(err) => { + warn!( + error = %err, + "skipping safety monitor because agent assertion authorization is unavailable" + ); + return ArcMonitorOutcome::Ok; + } + } + { + (authorization_header_value, None) } else { let Some(auth) = auth.as_ref() else { return ArcMonitorOutcome::Ok; }; - match auth.get_token() { + let token = match auth.get_token() { Ok(token) => token, Err(err) => { warn!( @@ -124,7 +143,8 @@ pub(crate) async fn monitor_action( ); return ArcMonitorOutcome::Ok; } - } + }; + (format!("Bearer {token}"), auth.get_account_id()) }; let url = read_non_empty_env_var(CODEX_ARC_MONITOR_ENDPOINT_OVERRIDE).unwrap_or_else(|| { @@ -147,8 +167,8 @@ pub(crate) async fn monitor_action( .post(&url) .timeout(ARC_MONITOR_TIMEOUT) .json(&body) - .bearer_auth(token); - if let Some(account_id) = auth.as_ref().and_then(CodexAuth::get_account_id) { + .header(AUTHORIZATION, authorization_header_value); + if let Some(account_id) = account_id { request = request.header("chatgpt-account-id", account_id); } diff --git a/codex-rs/core/src/arc_monitor_tests.rs b/codex-rs/core/src/arc_monitor_tests.rs index a7ba2399738..b7feda6fdab 100644 --- a/codex-rs/core/src/arc_monitor_tests.rs +++ b/codex-rs/core/src/arc_monitor_tests.rs @@ -9,17 +9,37 @@ use wiremock::MockServer; use wiremock::ResponseTemplate; use wiremock::matchers::body_json; use wiremock::matchers::header; +use wiremock::matchers::header_regex; use wiremock::matchers::method; use wiremock::matchers::path; use super::*; +use crate::agent_identity::AgentIdentityManager; +use crate::agent_identity::RegisteredAgentTask; use crate::codex::make_session_and_context; +use chrono::Utc; +use codex_login::AuthCredentialsStoreMode; +use codex_login::AuthDotJson; +use codex_login::AuthManager; +use codex_login::CodexAuth; +use codex_login::save_auth; +use codex_login::token_data::IdTokenInfo; +use codex_login::token_data::TokenData; use codex_protocol::models::ContentItem; use codex_protocol::models::LocalShellAction; use codex_protocol::models::LocalShellExecAction; use codex_protocol::models::LocalShellStatus; use codex_protocol::models::MessagePhase; use codex_protocol::models::ResponseItem; +use codex_protocol::protocol::SessionSource; +use tempfile::tempdir; + +const TEST_ID_TOKEN: &str = concat!( + "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.", + "eyJodHRwczovL2FwaS5vcGVuYWkuY29tL2F1dGgiOnsiY2hhdGdwdF91c2VyX2lk", + "IjpudWxsLCJjaGF0Z3B0X2FjY291bnRfaWQiOiJhY2NvdW50X2lkIn19.", + "c2ln", +); struct EnvVarGuard { key: &'static str, @@ -49,6 +69,57 @@ impl Drop for EnvVarGuard { } } +async fn install_cached_agent_task_auth( + session: &mut Session, + turn_context: &mut TurnContext, + chatgpt_base_url: String, +) { + let auth_dir = tempdir().expect("temp auth dir"); + let auth_json = AuthDotJson { + auth_mode: Some(codex_app_server_protocol::AuthMode::Chatgpt), + openai_api_key: None, + tokens: Some(TokenData { + id_token: IdTokenInfo { + email: None, + chatgpt_plan_type: None, + chatgpt_user_id: None, + chatgpt_account_id: Some("account_id".to_string()), + raw_jwt: TEST_ID_TOKEN.to_string(), + }, + access_token: "Access Token".to_string(), + refresh_token: "test".to_string(), + account_id: Some("account_id".to_string()), + }), + last_refresh: Some(Utc::now()), + agent_identity: None, + }; + save_auth(auth_dir.path(), &auth_json, AuthCredentialsStoreMode::File).expect("save test auth"); + let auth = CodexAuth::from_auth_storage(auth_dir.path(), AuthCredentialsStoreMode::File) + .expect("load test auth") + .expect("test auth"); + let auth_manager = AuthManager::from_auth_for_testing(auth); + let agent_identity_manager = Arc::new(AgentIdentityManager::new_for_tests( + Arc::clone(&auth_manager), + /*feature_enabled*/ true, + chatgpt_base_url, + SessionSource::Exec, + )); + let stored_identity = agent_identity_manager + .seed_generated_identity_for_tests("agent-123") + .await + .expect("seed test identity"); + session.services.auth_manager = Arc::clone(&auth_manager); + session.services.agent_identity_manager = agent_identity_manager; + turn_context.auth_manager = Some(auth_manager); + session + .cache_agent_task_for_tests(RegisteredAgentTask { + agent_runtime_id: stored_identity.agent_runtime_id, + task_id: "task-123".to_string(), + registered_at: "2026-04-15T00:00:00Z".to_string(), + }) + .await; +} + #[tokio::test] async fn build_arc_monitor_request_includes_relevant_history_and_null_policies() { let (session, mut turn_context) = make_session_and_context().await; @@ -247,6 +318,80 @@ async fn build_arc_monitor_request_includes_relevant_history_and_null_policies() ); } +#[tokio::test] +#[serial(arc_monitor_env)] +async fn monitor_action_uses_agent_assertion_for_cached_task() { + let server = MockServer::start().await; + let (mut session, mut turn_context) = make_session_and_context().await; + install_cached_agent_task_auth(&mut session, &mut turn_context, server.uri()).await; + + let mut config = (*turn_context.config).clone(); + config.chatgpt_base_url = server.uri(); + turn_context.config = Arc::new(config); + + session + .record_into_history( + &[ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "please run the tool".to_string(), + }], + end_turn: None, + phase: None, + }], + &turn_context, + ) + .await; + + Mock::given(method("POST")) + .and(path("/codex/safety/arc")) + .and(header_regex("authorization", r"^AgentAssertion .+")) + .and(body_json(serde_json::json!({ + "metadata": { + "codex_thread_id": session.conversation_id.to_string(), + "codex_turn_id": turn_context.sub_id.clone(), + "conversation_id": session.conversation_id.to_string(), + "protection_client_callsite": "normal", + }, + "messages": [{ + "role": "user", + "content": [{ + "type": "input_text", + "text": "please run the tool", + }], + }], + "policies": { + "developer": null, + "user": null, + }, + "action": { + "tool": "mcp_tool_call", + }, + }))) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "outcome": "ok", + "short_reason": "", + "rationale": "", + "risk_score": 1, + "risk_level": "low", + "evidence": [], + }))) + .expect(1) + .mount(&server) + .await; + + let outcome = monitor_action( + &session, + &turn_context, + serde_json::json!({ "tool": "mcp_tool_call" }), + "normal", + ) + .await; + + assert_eq!(outcome, ArcMonitorOutcome::Ok); +} + #[tokio::test] #[serial(arc_monitor_env)] async fn monitor_action_posts_expected_arc_request() { diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 16f743943a0..1ec7ca83bbd 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -31,6 +31,8 @@ use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicU64; use std::sync::atomic::Ordering; +use crate::agent_identity::AgentIdentityManager; +use crate::agent_identity::RegisteredAgentTask; use codex_api::ApiError; use codex_api::CompactClient as ApiCompactClient; use codex_api::CompactionInput as ApiCompactionInput; @@ -92,6 +94,7 @@ use tokio::sync::oneshot; use tokio::sync::oneshot::error::TryRecvError; use tokio_tungstenite::tungstenite::Error; use tokio_tungstenite::tungstenite::Message; +use tracing::debug; use tracing::instrument; use tracing::trace; use tracing::warn; @@ -144,6 +147,7 @@ pub(crate) const WEBSOCKET_CONNECT_TIMEOUT: Duration = #[derive(Debug)] struct ModelClientState { auth_manager: Option>, + agent_identity_manager: Option>, conversation_id: ThreadId, window_generation: AtomicU64, installation_id: String, @@ -211,6 +215,8 @@ pub struct ModelClient { pub struct ModelClientSession { client: ModelClient, websocket_session: WebsocketSession, + agent_task: Option, + cache_websocket_session_on_drop: bool, /// Turn state for sticky routing. /// /// This is an `OnceLock` that stores the turn state value received from the server @@ -306,6 +312,33 @@ impl ModelClient { enable_request_compression: bool, include_timing_metrics: bool, beta_features_header: Option, + ) -> Self { + Self::new_with_agent_identity_manager( + auth_manager, + /*agent_identity_manager*/ None, + conversation_id, + installation_id, + provider, + session_source, + model_verbosity, + enable_request_compression, + include_timing_metrics, + beta_features_header, + ) + } + + #[allow(clippy::too_many_arguments)] + pub(crate) fn new_with_agent_identity_manager( + auth_manager: Option>, + agent_identity_manager: Option>, + conversation_id: ThreadId, + installation_id: String, + provider: ModelProviderInfo, + session_source: SessionSource, + model_verbosity: Option, + enable_request_compression: bool, + include_timing_metrics: bool, + beta_features_header: Option, ) -> Self { let auth_manager = auth_manager_for_provider(auth_manager, &provider); let codex_api_key_env_enabled = auth_manager @@ -315,6 +348,7 @@ impl ModelClient { Self { state: Arc::new(ModelClientState { auth_manager, + agent_identity_manager, conversation_id, window_generation: AtomicU64::new(0), installation_id, @@ -336,9 +370,25 @@ impl ModelClient { /// This constructor does not perform network I/O itself; the session opens a websocket lazily /// when the first stream request is issued. pub fn new_session(&self) -> ModelClientSession { + self.new_session_with_agent_task(/*agent_task*/ None) + } + + pub(crate) fn new_session_with_agent_task( + &self, + agent_task: Option, + ) -> ModelClientSession { + let cache_websocket_session_on_drop = agent_task.is_none(); + let websocket_session = if agent_task.is_some() { + drop(self.take_cached_websocket_session()); + WebsocketSession::default() + } else { + self.take_cached_websocket_session() + }; ModelClientSession { client: self.clone(), - websocket_session: self.take_cached_websocket_session(), + websocket_session, + agent_task, + cache_websocket_session_on_drop, turn_state: Arc::new(OnceLock::new()), } } @@ -421,7 +471,7 @@ impl ModelClient { if prompt.input.is_empty() { return Ok(Vec::new()); } - let client_setup = self.current_client_setup().await?; + let client_setup = self.current_client_setup(/*agent_task*/ None).await?; let transport = ReqwestTransport::new(build_reqwest_client()); let request_telemetry = Self::build_request_telemetry( session_telemetry, @@ -485,7 +535,7 @@ impl ModelClient { ) -> Result { // Create the media call over HTTP first, then retain matching auth so realtime can attach // the server-side control WebSocket to the call id from that HTTP response. - let client_setup = self.current_client_setup().await?; + let client_setup = self.current_client_setup(/*agent_task*/ None).await?; let mut sideband_headers = extra_headers.clone(); sideband_headers.extend(sideband_websocket_auth_headers(&client_setup.api_auth)); let transport = ReqwestTransport::new(build_reqwest_client()); @@ -518,7 +568,7 @@ impl ModelClient { return Ok(Vec::new()); } - let client_setup = self.current_client_setup().await?; + let client_setup = self.current_client_setup(/*agent_task*/ None).await?; let transport = ReqwestTransport::new(build_reqwest_client()); let request_telemetry = Self::build_request_telemetry( session_telemetry, @@ -659,7 +709,10 @@ impl ModelClient { /// /// This centralizes setup used by both prewarm and normal request paths so they stay in /// lockstep when auth/provider resolution changes. - async fn current_client_setup(&self) -> Result { + async fn current_client_setup( + &self, + agent_task: Option<&RegisteredAgentTask>, + ) -> Result { let auth = match self.state.auth_manager.as_ref() { Some(manager) => manager.auth().await, None => None, @@ -668,7 +721,33 @@ impl ModelClient { .state .provider .to_api_provider(auth.as_ref().map(CodexAuth::auth_mode))?; - let api_auth = auth_provider_from_auth(auth.clone(), &self.state.provider)?; + let api_auth = match (agent_task, self.state.agent_identity_manager.as_ref()) { + (Some(agent_task), Some(agent_identity_manager)) => { + if let Some(authorization_header_value) = agent_identity_manager + .authorization_header_for_task(agent_task) + .await + .map_err(|err| { + CodexErr::Stream( + format!("failed to build agent assertion authorization: {err}"), + None, + ) + })? + { + debug!( + agent_runtime_id = %agent_task.agent_runtime_id, + task_id = %agent_task.task_id, + "using agent assertion authorization for downstream request" + ); + CoreAuthProvider::from_authorization_header_value( + Some(authorization_header_value), + /*account_id*/ None, + ) + } else { + auth_provider_from_auth(auth.clone(), &self.state.provider)? + } + } + _ => auth_provider_from_auth(auth.clone(), &self.state.provider)?, + }; Ok(CurrentClientSetup { auth, api_provider, @@ -802,12 +881,18 @@ impl ModelClient { impl Drop for ModelClientSession { fn drop(&mut self) { let websocket_session = std::mem::take(&mut self.websocket_session); - self.client - .store_cached_websocket_session(websocket_session); + if self.cache_websocket_session_on_drop { + self.client + .store_cached_websocket_session(websocket_session); + } } } impl ModelClientSession { + pub(crate) fn disable_cached_websocket_session_on_drop(&mut self) { + self.cache_websocket_session_on_drop = false; + } + pub(crate) fn reset_websocket_session(&mut self) { self.websocket_session.connection = None; self.websocket_session.last_request = None; @@ -1009,11 +1094,15 @@ impl ModelClientSession { return Ok(()); } - let client_setup = self.client.current_client_setup().await.map_err(|err| { - ApiError::Stream(format!( - "failed to build websocket prewarm client setup: {err}" - )) - })?; + let client_setup = self + .client + .current_client_setup(self.agent_task.as_ref()) + .await + .map_err(|err| { + ApiError::Stream(format!( + "failed to build websocket prewarm client setup: {err}" + )) + })?; let auth_context = AuthRequestTelemetryContext::new( client_setup.auth.as_ref().map(CodexAuth::auth_mode), &client_setup.api_auth, @@ -1167,7 +1256,10 @@ impl ModelClientSession { .map(AuthManager::unauthorized_recovery); let mut pending_retry = PendingUnauthorizedRetry::default(); loop { - let client_setup = self.client.current_client_setup().await?; + let client_setup = self + .client + .current_client_setup(self.agent_task.as_ref()) + .await?; let transport = ReqwestTransport::new(build_reqwest_client()); let request_auth_context = AuthRequestTelemetryContext::new( client_setup.auth.as_ref().map(CodexAuth::auth_mode), @@ -1256,7 +1348,10 @@ impl ModelClientSession { .map(AuthManager::unauthorized_recovery); let mut pending_retry = PendingUnauthorizedRetry::default(); loop { - let client_setup = self.client.current_client_setup().await?; + let client_setup = self + .client + .current_client_setup(self.agent_task.as_ref()) + .await?; let request_auth_context = AuthRequestTelemetryContext::new( client_setup.auth.as_ref().map(CodexAuth::auth_mode), &client_setup.api_auth, diff --git a/codex-rs/core/src/client_tests.rs b/codex-rs/core/src/client_tests.rs index 2fd5f04f951..5b6c99cf599 100644 --- a/codex-rs/core/src/client_tests.rs +++ b/codex-rs/core/src/client_tests.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use super::AuthRequestTelemetryContext; use super::ModelClient; use super::PendingUnauthorizedRetry; @@ -7,17 +9,36 @@ use super::X_CODEX_PARENT_THREAD_ID_HEADER; use super::X_CODEX_TURN_METADATA_HEADER; use super::X_CODEX_WINDOW_ID_HEADER; use super::X_OPENAI_SUBAGENT_HEADER; +use crate::Prompt; +use crate::ResponseEvent; +use crate::agent_identity::AgentAssertionEnvelope; +use crate::agent_identity::AgentIdentityManager; +use crate::agent_identity::RegisteredAgentTask; +use crate::agent_identity::StoredAgentIdentity; +use base64::Engine as _; +use base64::engine::general_purpose::URL_SAFE_NO_PAD; use codex_api::CoreAuthProvider; use codex_app_server_protocol::AuthMode; +use codex_login::AuthManager; +use codex_login::CodexAuth; +use codex_model_provider_info::ModelProviderInfo; use codex_model_provider_info::WireApi; use codex_model_provider_info::create_oss_provider_with_base_url; use codex_otel::SessionTelemetry; use codex_protocol::ThreadId; +use codex_protocol::config_types::ReasoningSummary; +use codex_protocol::models::ContentItem; +use codex_protocol::models::ResponseItem; use codex_protocol::openai_models::ModelInfo; use codex_protocol::protocol::SessionSource; use codex_protocol::protocol::SubAgentSource; +use core_test_support::responses; +use ed25519_dalek::Signature; +use ed25519_dalek::Verifier as _; +use futures::StreamExt; use pretty_assertions::assert_eq; use serde_json::json; +use tempfile::TempDir; fn test_model_client(session_source: SessionSource) -> ModelClient { let provider = create_oss_provider_with_base_url("https://example.com/v1", WireApi::Responses); @@ -79,6 +100,111 @@ fn test_session_telemetry() -> SessionTelemetry { ) } +fn test_prompt(text: &str) -> Prompt { + Prompt { + input: vec![ResponseItem::Message { + id: None, + role: "user".into(), + content: vec![ContentItem::InputText { + text: text.to_string(), + }], + end_turn: None, + phase: None, + }], + ..Prompt::default() + } +} + +async fn drain_stream_to_completion(stream: &mut crate::ResponseStream) -> anyhow::Result<()> { + while let Some(event) = stream.next().await { + if matches!(event?, ResponseEvent::Completed { .. }) { + break; + } + } + Ok(()) +} + +async fn model_client_with_agent_task( + provider: ModelProviderInfo, +) -> ( + TempDir, + ModelClient, + RegisteredAgentTask, + StoredAgentIdentity, +) { + let codex_home = tempfile::tempdir().expect("tempdir"); + let auth_manager = + AuthManager::from_auth_for_testing(CodexAuth::create_dummy_chatgpt_auth_for_testing()); + let agent_identity_manager = Arc::new(AgentIdentityManager::new_for_tests( + Arc::clone(&auth_manager), + /*feature_enabled*/ true, + "https://chatgpt.com/backend-api/".to_string(), + SessionSource::Cli, + )); + let stored_identity = agent_identity_manager + .seed_generated_identity_for_tests("agent-123") + .await + .expect("seed test identity"); + let agent_task = RegisteredAgentTask { + agent_runtime_id: stored_identity.agent_runtime_id.clone(), + task_id: "task-123".to_string(), + registered_at: "2026-03-23T12:00:00Z".to_string(), + }; + let client = ModelClient::new_with_agent_identity_manager( + Some(auth_manager), + Some(agent_identity_manager), + ThreadId::new(), + /*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(), + provider, + SessionSource::Cli, + /*model_verbosity*/ None, + /*enable_request_compression*/ false, + /*include_timing_metrics*/ false, + /*beta_features_header*/ None, + ); + (codex_home, client, agent_task, stored_identity) +} + +fn assert_agent_assertion_header( + authorization_header: &str, + stored_identity: &StoredAgentIdentity, + expected_agent_runtime_id: &str, + expected_task_id: &str, +) { + let token = authorization_header + .strip_prefix("AgentAssertion ") + .expect("agent assertion authorization scheme"); + let envelope: AgentAssertionEnvelope = serde_json::from_slice( + &URL_SAFE_NO_PAD + .decode(token) + .expect("base64url-encoded agent assertion"), + ) + .expect("valid agent assertion envelope"); + + assert_eq!(envelope.agent_runtime_id, expected_agent_runtime_id); + assert_eq!(envelope.task_id, expected_task_id); + + let signature = Signature::from_slice( + &base64::engine::general_purpose::STANDARD + .decode(&envelope.signature) + .expect("base64 signature"), + ) + .expect("signature bytes"); + stored_identity + .signing_key() + .expect("signing key") + .verifying_key() + .verify( + format!( + "{}:{}:{}", + envelope.agent_runtime_id, envelope.task_id, envelope.timestamp + ) + .as_bytes(), + &signature, + ) + .expect("signature should verify"); +} + #[test] fn build_subagent_headers_sets_other_subagent_label() { let client = test_model_client(SessionSource::SubAgent(SubAgentSource::Other( @@ -169,3 +295,130 @@ fn auth_request_telemetry_context_tracks_attached_auth_and_retry_phase() { assert_eq!(auth_context.recovery_mode, Some("managed")); assert_eq!(auth_context.recovery_phase, Some("refresh_token")); } + +#[tokio::test] +async fn responses_http_uses_agent_assertion_when_agent_task_is_present() { + core_test_support::skip_if_no_network!(); + + let server = responses::start_mock_server().await; + let request_recorder = responses::mount_sse_once( + &server, + responses::sse(vec![ + responses::ev_response_created("resp-1"), + responses::ev_completed("resp-1"), + ]), + ) + .await; + let provider = + create_oss_provider_with_base_url(&format!("{}/v1", server.uri()), WireApi::Responses); + let (_codex_home, client, agent_task, stored_identity) = + model_client_with_agent_task(provider).await; + let model_info = test_model_info(); + let session_telemetry = test_session_telemetry(); + let mut client_session = client.new_session_with_agent_task(Some(agent_task.clone())); + + let mut stream = client_session + .stream( + &test_prompt("hello"), + &model_info, + &session_telemetry, + /*effort*/ None, + ReasoningSummary::Auto, + /*service_tier*/ None, + /*turn_metadata_header*/ None, + ) + .await + .expect("stream request should succeed"); + drain_stream_to_completion(&mut stream) + .await + .expect("stream should complete"); + + let request = request_recorder.single_request(); + let authorization = request + .header("authorization") + .expect("authorization header should be present"); + assert_agent_assertion_header( + &authorization, + &stored_identity, + &agent_task.agent_runtime_id, + &agent_task.task_id, + ); + assert_eq!(request.header("chatgpt-account-id"), None); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn websocket_agent_task_bypasses_cached_bearer_prewarm() { + core_test_support::skip_if_no_network!(); + + let server = responses::start_websocket_server(vec![ + vec![vec![ + responses::ev_response_created("resp-prewarm"), + responses::ev_completed("resp-prewarm"), + ]], + vec![vec![ + responses::ev_response_created("resp-1"), + responses::ev_completed("resp-1"), + ]], + ]) + .await; + let mut provider = + create_oss_provider_with_base_url(&format!("{}/v1", server.uri()), WireApi::Responses); + provider.supports_websockets = true; + provider.websocket_connect_timeout_ms = Some(5_000); + let (_codex_home, client, agent_task, stored_identity) = + model_client_with_agent_task(provider).await; + let model_info = test_model_info(); + let session_telemetry = test_session_telemetry(); + let prompt = test_prompt("hello"); + + let mut prewarm_session = client.new_session(); + prewarm_session + .prewarm_websocket( + &prompt, + &model_info, + &session_telemetry, + /*effort*/ None, + ReasoningSummary::Auto, + /*service_tier*/ None, + /*turn_metadata_header*/ None, + ) + .await + .expect("bearer prewarm should succeed"); + drop(prewarm_session); + + let mut agent_task_session = client.new_session_with_agent_task(Some(agent_task.clone())); + let mut stream = agent_task_session + .stream( + &prompt, + &model_info, + &session_telemetry, + /*effort*/ None, + ReasoningSummary::Auto, + /*service_tier*/ None, + /*turn_metadata_header*/ None, + ) + .await + .expect("agent task stream should succeed"); + drain_stream_to_completion(&mut stream) + .await + .expect("agent task websocket stream should complete"); + + let handshakes = server.handshakes(); + assert_eq!(handshakes.len(), 2); + assert_eq!( + handshakes[0].header("authorization"), + Some("Bearer Access Token".to_string()) + ); + let agent_authorization = handshakes[1] + .header("authorization") + .expect("agent handshake should include authorization"); + assert_agent_assertion_header( + &agent_authorization, + &stored_identity, + &agent_task.agent_runtime_id, + &agent_task.task_id, + ); + assert_eq!(handshakes[1].header("chatgpt-account-id"), None); + + server.shutdown().await; +} diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 613f710ccc7..8a23baef208 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -2088,6 +2088,11 @@ impl Session { config.analytics_enabled, ) }); + let agent_identity_manager = Arc::new(AgentIdentityManager::new( + config.as_ref(), + Arc::clone(&auth_manager), + session_configuration.session_source.clone(), + )); let services = SessionServices { // Initialize the MCP connection manager with an uninitialized // instance. It will be replaced with one created via @@ -2110,11 +2115,7 @@ impl Session { hooks, rollout: Mutex::new(rollout_recorder), user_shell: Arc::new(default_shell), - agent_identity_manager: Arc::new(AgentIdentityManager::new( - config.as_ref(), - Arc::clone(&auth_manager), - session_configuration.session_source.clone(), - )), + agent_identity_manager: Arc::clone(&agent_identity_manager), shell_snapshot_tx, show_raw_agent_reasoning: config.show_raw_agent_reasoning, exec_policy, @@ -2131,8 +2132,9 @@ impl Session { network_proxy, network_approval: Arc::clone(&network_approval), state_db: state_db_ctx.clone(), - model_client: ModelClient::new( + model_client: ModelClient::new_with_agent_identity_manager( Some(Arc::clone(&auth_manager)), + Some(agent_identity_manager), conversation_id, installation_id, session_configuration.provider.clone(), @@ -6394,20 +6396,23 @@ pub(crate) async fn run_turn( })) .await; } - if let Err(error) = sess.ensure_agent_task_registered().await { - warn!(error = %error, "agent task registration failed"); - sess.send_event( - turn_context.as_ref(), - EventMsg::Error(ErrorEvent { - message: format!( - "Agent task registration failed. Please try again; Codex will attempt to register the task again on the next turn: {error}" - ), - codex_error_info: Some(CodexErrorInfo::Other), - }), - ) - .await; - return None; - } + let agent_task = match sess.ensure_agent_task_registered().await { + Ok(agent_task) => agent_task, + Err(error) => { + warn!(error = %error, "agent task registration failed"); + sess.send_event( + turn_context.as_ref(), + EventMsg::Error(ErrorEvent { + message: format!( + "Agent task registration failed. Please try again; Codex will attempt to register the task again on the next turn: {error}" + ), + codex_error_info: Some(CodexErrorInfo::Other), + }), + ) + .await; + return None; + } + }; if !skill_items.is_empty() { sess.record_conversation_items(&turn_context, &skill_items) @@ -6432,8 +6437,21 @@ pub(crate) async fn run_turn( // `ModelClientSession` is turn-scoped and caches WebSocket + sticky routing state, so we reuse // one instance across retries within this turn. - let mut client_session = - prewarmed_client_session.unwrap_or_else(|| sess.services.model_client.new_session()); + let mut prewarmed_client_session = prewarmed_client_session; + if agent_task.is_some() + && let Some(prewarmed_client_session) = prewarmed_client_session.as_mut() + { + prewarmed_client_session.disable_cached_websocket_session_on_drop(); + } + let mut client_session = if let Some(agent_task) = agent_task { + sess.services + .model_client + .new_session_with_agent_task(Some(agent_task)) + } else if let Some(prewarmed_client_session) = prewarmed_client_session.take() { + prewarmed_client_session + } else { + sess.services.model_client.new_session() + }; // Pending input is drained into history before building the next model request. // However, we defer that drain until after sampling in two cases: // 1. At the start of a turn, so the fresh user prompt in `input` gets sampled first. diff --git a/codex-rs/core/src/codex/agent_task_lifecycle.rs b/codex-rs/core/src/codex/agent_task_lifecycle.rs index b45c0f5b04a..6dc5c8e04d5 100644 --- a/codex-rs/core/src/codex/agent_task_lifecycle.rs +++ b/codex-rs/core/src/codex/agent_task_lifecycle.rs @@ -76,6 +76,11 @@ impl Session { agent_task } + #[cfg(test)] + pub(crate) async fn cache_agent_task_for_tests(&self, agent_task: RegisteredAgentTask) { + self.cache_agent_task(agent_task).await; + } + pub(super) async fn cached_agent_task_for_current_identity( &self, ) -> Option { @@ -109,6 +114,28 @@ impl Session { None } + pub(crate) async fn authorization_header_for_current_agent_task( + &self, + ) -> anyhow::Result> { + let Some(agent_task) = self.cached_agent_task_for_current_identity().await else { + return Ok(None); + }; + + let authorization_header_value = self + .services + .agent_identity_manager + .authorization_header_for_task(&agent_task) + .await?; + if authorization_header_value.is_some() { + debug!( + agent_runtime_id = %agent_task.agent_runtime_id, + task_id = %agent_task.task_id, + "using agent assertion authorization for current task request" + ); + } + Ok(authorization_header_value) + } + pub(super) async fn ensure_agent_task_registered( &self, ) -> anyhow::Result> { diff --git a/codex-rs/core/src/mcp_openai_file.rs b/codex-rs/core/src/mcp_openai_file.rs index 587dd4b7701..2bb10249d3e 100644 --- a/codex-rs/core/src/mcp_openai_file.rs +++ b/codex-rs/core/src/mcp_openai_file.rs @@ -40,9 +40,14 @@ pub(crate) async fn rewrite_mcp_tool_arguments_for_openai_files( let Some(value) = arguments.get(field_name) else { continue; }; - let Some(uploaded_value) = - rewrite_argument_value_for_openai_files(turn_context, auth.as_ref(), field_name, value) - .await? + let Some(uploaded_value) = rewrite_argument_value_for_openai_files( + sess, + turn_context, + auth.as_ref(), + field_name, + value, + ) + .await? else { continue; }; @@ -57,6 +62,7 @@ pub(crate) async fn rewrite_mcp_tool_arguments_for_openai_files( } async fn rewrite_argument_value_for_openai_files( + sess: &Session, turn_context: &TurnContext, auth: Option<&CodexAuth>, field_name: &str, @@ -65,6 +71,7 @@ async fn rewrite_argument_value_for_openai_files( match value { JsonValue::String(path_or_file_ref) => { let rewritten = build_uploaded_local_argument_value( + sess, turn_context, auth, field_name, @@ -81,6 +88,7 @@ async fn rewrite_argument_value_for_openai_files( return Ok(None); }; let rewritten = build_uploaded_local_argument_value( + sess, turn_context, auth, field_name, @@ -97,6 +105,7 @@ async fn rewrite_argument_value_for_openai_files( } async fn build_uploaded_local_argument_value( + sess: &Session, turn_context: &TurnContext, auth: Option<&CodexAuth>, field_name: &str, @@ -109,12 +118,20 @@ async fn build_uploaded_local_argument_value( "ChatGPT auth is required to upload local files for Codex Apps tools".to_string(), ); }; - let token_data = auth - .get_token_data() - .map_err(|error| format!("failed to read ChatGPT auth for file upload: {error}"))?; - let upload_auth = CoreAuthProvider { - token: Some(token_data.access_token), - account_id: token_data.account_id, + let upload_auth = if let Some(authorization_header_value) = sess + .authorization_header_for_current_agent_task() + .await + .map_err(|error| format!("failed to build agent assertion authorization: {error}"))? + { + CoreAuthProvider::from_authorization_header_value( + Some(authorization_header_value), + /*account_id*/ None, + ) + } else { + let token_data = auth + .get_token_data() + .map_err(|error| format!("failed to read ChatGPT auth for file upload: {error}"))?; + CoreAuthProvider::from_bearer_token(Some(token_data.access_token), token_data.account_id) }; let uploaded = upload_local_file( turn_context.config.chatgpt_base_url.trim_end_matches('/'), @@ -141,12 +158,81 @@ async fn build_uploaded_local_argument_value( #[cfg(test)] mod tests { use super::*; + use crate::agent_identity::AgentIdentityManager; + use crate::agent_identity::RegisteredAgentTask; use crate::codex::make_session_and_context; + use chrono::Utc; + use codex_login::AuthCredentialsStoreMode; + use codex_login::AuthDotJson; + use codex_login::AuthManager; + use codex_login::save_auth; + use codex_login::token_data::IdTokenInfo; + use codex_login::token_data::TokenData; + use codex_protocol::protocol::SessionSource; use codex_utils_absolute_path::AbsolutePathBuf; use pretty_assertions::assert_eq; use std::sync::Arc; use tempfile::tempdir; + const TEST_ID_TOKEN: &str = concat!( + "eyJhbGciOiJub25lIiwidHlwIjoiSldUIn0.", + "eyJodHRwczovL2FwaS5vcGVuYWkuY29tL2F1dGgiOnsiY2hhdGdwdF91c2VyX2lk", + "IjpudWxsLCJjaGF0Z3B0X2FjY291bnRfaWQiOiJhY2NvdW50X2lkIn19.", + "c2ln", + ); + + async fn install_cached_agent_task_auth( + session: &mut Session, + turn_context: &mut TurnContext, + chatgpt_base_url: String, + ) { + let auth_dir = tempdir().expect("temp auth dir"); + let auth_json = AuthDotJson { + auth_mode: Some(codex_app_server_protocol::AuthMode::Chatgpt), + openai_api_key: None, + tokens: Some(TokenData { + id_token: IdTokenInfo { + email: None, + chatgpt_plan_type: None, + chatgpt_user_id: None, + chatgpt_account_id: Some("account_id".to_string()), + raw_jwt: TEST_ID_TOKEN.to_string(), + }, + access_token: "Access Token".to_string(), + refresh_token: "test".to_string(), + account_id: Some("account_id".to_string()), + }), + last_refresh: Some(Utc::now()), + agent_identity: None, + }; + save_auth(auth_dir.path(), &auth_json, AuthCredentialsStoreMode::File) + .expect("save test auth"); + let auth = CodexAuth::from_auth_storage(auth_dir.path(), AuthCredentialsStoreMode::File) + .expect("load test auth") + .expect("test auth"); + let auth_manager = AuthManager::from_auth_for_testing(auth); + let agent_identity_manager = Arc::new(AgentIdentityManager::new_for_tests( + Arc::clone(&auth_manager), + /*feature_enabled*/ true, + chatgpt_base_url, + SessionSource::Exec, + )); + let stored_identity = agent_identity_manager + .seed_generated_identity_for_tests("agent-123") + .await + .expect("seed test identity"); + session.services.auth_manager = Arc::clone(&auth_manager); + session.services.agent_identity_manager = agent_identity_manager; + turn_context.auth_manager = Some(auth_manager); + session + .cache_agent_task_for_tests(RegisteredAgentTask { + agent_runtime_id: stored_identity.agent_runtime_id, + task_id: "task-123".to_string(), + registered_at: "2026-04-15T00:00:00Z".to_string(), + }) + .await; + } + #[tokio::test] async fn openai_file_argument_rewrite_requires_declared_file_params() { let (session, turn_context) = make_session_and_context().await; @@ -211,7 +297,7 @@ mod tests { .mount(&server) .await; - let (_, mut turn_context) = make_session_and_context().await; + let (session, mut turn_context) = make_session_and_context().await; let auth = CodexAuth::create_dummy_chatgpt_auth_for_testing(); let dir = tempdir().expect("temp dir"); let local_path = dir.path().join("file_report.csv"); @@ -225,6 +311,7 @@ mod tests { turn_context.config = Arc::new(config); let rewritten = build_uploaded_local_argument_value( + &session, &turn_context, Some(&auth), "file", @@ -292,7 +379,7 @@ mod tests { .mount(&server) .await; - let (_, mut turn_context) = make_session_and_context().await; + let (session, mut turn_context) = make_session_and_context().await; let auth = CodexAuth::create_dummy_chatgpt_auth_for_testing(); let dir = tempdir().expect("temp dir"); let local_path = dir.path().join("file_report.csv"); @@ -305,6 +392,7 @@ mod tests { config.chatgpt_base_url = format!("{}/backend-api", server.uri()); turn_context.config = Arc::new(config); let rewritten = rewrite_argument_value_for_openai_files( + &session, &turn_context, Some(&auth), "file", @@ -404,7 +492,7 @@ mod tests { .mount(&server) .await; - let (_, mut turn_context) = make_session_and_context().await; + let (session, mut turn_context) = make_session_and_context().await; let auth = CodexAuth::create_dummy_chatgpt_auth_for_testing(); let dir = tempdir().expect("temp dir"); tokio::fs::write(dir.path().join("one.csv"), b"one") @@ -419,6 +507,7 @@ mod tests { config.chatgpt_base_url = format!("{}/backend-api", server.uri()); turn_context.config = Arc::new(config); let rewritten = rewrite_argument_value_for_openai_files( + &session, &turn_context, Some(&auth), "files", @@ -470,4 +559,88 @@ mod tests { assert!(error.contains("failed to upload")); assert!(error.contains("file")); } + + #[tokio::test] + async fn build_uploaded_local_argument_value_uses_agent_assertion_for_cached_task() { + use wiremock::Mock; + use wiremock::MockServer; + use wiremock::ResponseTemplate; + use wiremock::matchers::body_json; + use wiremock::matchers::header_regex; + use wiremock::matchers::method; + use wiremock::matchers::path; + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/backend-api/files")) + .and(header_regex("authorization", r"^AgentAssertion .+")) + .and(body_json(serde_json::json!({ + "file_name": "file_report.csv", + "file_size": 5, + "use_case": "codex", + }))) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "file_id": "file_123", + "upload_url": format!("{}/upload/file_123", server.uri()), + }))) + .expect(1) + .mount(&server) + .await; + Mock::given(method("PUT")) + .and(path("/upload/file_123")) + .respond_with(ResponseTemplate::new(200)) + .expect(1) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/backend-api/files/file_123/uploaded")) + .and(header_regex("authorization", r"^AgentAssertion .+")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "status": "success", + "download_url": format!("{}/download/file_123", server.uri()), + "file_name": "file_report.csv", + "mime_type": "text/csv", + "file_size_bytes": 5, + }))) + .expect(1) + .mount(&server) + .await; + + let (mut session, mut turn_context) = make_session_and_context().await; + let auth = CodexAuth::create_dummy_chatgpt_auth_for_testing(); + let dir = tempdir().expect("temp dir"); + let local_path = dir.path().join("file_report.csv"); + tokio::fs::write(&local_path, b"hello") + .await + .expect("write local file"); + turn_context.cwd = AbsolutePathBuf::try_from(dir.path()).expect("absolute path"); + + let mut config = (*turn_context.config).clone(); + config.chatgpt_base_url = format!("{}/backend-api", server.uri()); + turn_context.config = Arc::new(config); + install_cached_agent_task_auth(&mut session, &mut turn_context, server.uri()).await; + + let rewritten = build_uploaded_local_argument_value( + &session, + &turn_context, + Some(&auth), + "file", + /*index*/ None, + "file_report.csv", + ) + .await + .expect("rewrite should upload the local file"); + + assert_eq!( + rewritten, + serde_json::json!({ + "download_url": format!("{}/download/file_123", server.uri()), + "file_id": "file_123", + "mime_type": "text/csv", + "file_name": "file_report.csv", + "uri": "sediment://file_123", + "file_size_bytes": 5, + }) + ); + } } diff --git a/codex-rs/login/src/api_bridge.rs b/codex-rs/login/src/api_bridge.rs index d8b9dbb77cb..85b02f6f169 100644 --- a/codex-rs/login/src/api_bridge.rs +++ b/codex-rs/login/src/api_bridge.rs @@ -8,29 +8,28 @@ pub fn auth_provider_from_auth( provider: &ModelProviderInfo, ) -> codex_protocol::error::Result { if let Some(api_key) = provider.api_key()? { - return Ok(CoreAuthProvider { - token: Some(api_key), - account_id: None, - }); + return Ok(CoreAuthProvider::from_bearer_token( + Some(api_key), + /*account_id*/ None, + )); } if let Some(token) = provider.experimental_bearer_token.clone() { - return Ok(CoreAuthProvider { - token: Some(token), - account_id: None, - }); + return Ok(CoreAuthProvider::from_bearer_token( + Some(token), + /*account_id*/ None, + )); } if let Some(auth) = auth { let token = auth.get_token()?; - Ok(CoreAuthProvider { - token: Some(token), - account_id: auth.get_account_id(), - }) + Ok(CoreAuthProvider::from_bearer_token( + Some(token), + auth.get_account_id(), + )) } else { - Ok(CoreAuthProvider { - token: None, - account_id: None, - }) + Ok(CoreAuthProvider::from_bearer_token( + /*token*/ None, /*account_id*/ None, + )) } }