diff --git a/codex-rs/app-server/tests/suite/v2/thread_unsubscribe.rs b/codex-rs/app-server/tests/suite/v2/thread_unsubscribe.rs index 6aab3d186f9..1e051492cdd 100644 --- a/codex-rs/app-server/tests/suite/v2/thread_unsubscribe.rs +++ b/codex-rs/app-server/tests/suite/v2/thread_unsubscribe.rs @@ -1,14 +1,15 @@ -use anyhow::Context; use anyhow::Result; use app_test_support::McpProcess; -use app_test_support::create_final_assistant_message_sse_response; use app_test_support::create_mock_responses_server_repeating_assistant; -use app_test_support::create_mock_responses_server_sequence_unchecked; -use app_test_support::create_shell_command_sse_response; use app_test_support::to_response; +use codex_app_server_protocol::DynamicToolCallOutputContentItem; +use codex_app_server_protocol::DynamicToolCallParams; +use codex_app_server_protocol::DynamicToolCallResponse; +use codex_app_server_protocol::DynamicToolSpec; use codex_app_server_protocol::ItemStartedNotification; use codex_app_server_protocol::JSONRPCResponse; use codex_app_server_protocol::RequestId; +use codex_app_server_protocol::ServerRequest; use codex_app_server_protocol::ThreadItem; use codex_app_server_protocol::ThreadLoadedListParams; use codex_app_server_protocol::ThreadLoadedListResponse; @@ -26,57 +27,15 @@ use codex_app_server_protocol::TurnStartParams; use codex_app_server_protocol::TurnStartResponse; use codex_app_server_protocol::UserInput as V2UserInput; use core_test_support::responses; +use core_test_support::streaming_sse::StreamingSseChunk; +use core_test_support::streaming_sse::start_streaming_sse_server; use pretty_assertions::assert_eq; +use serde_json::json; use tempfile::TempDir; use tokio::time::timeout; const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); -async fn wait_for_responses_request_count_to_stabilize( - server: &wiremock::MockServer, - expected_count: usize, - settle_duration: std::time::Duration, -) -> Result<()> { - timeout(DEFAULT_READ_TIMEOUT, async { - let mut stable_since: Option = None; - loop { - let requests = server - .received_requests() - .await - .context("failed to fetch received requests")?; - let responses_request_count = requests - .iter() - .filter(|request| { - request.method == "POST" && request.url.path().ends_with("/responses") - }) - .count(); - - if responses_request_count > expected_count { - anyhow::bail!( - "expected exactly {expected_count} /responses requests, got {responses_request_count}" - ); - } - - if responses_request_count == expected_count { - match stable_since { - Some(stable_since) if stable_since.elapsed() >= settle_duration => { - return Ok::<(), anyhow::Error>(()); - } - None => stable_since = Some(tokio::time::Instant::now()), - Some(_) => {} - } - } else { - stable_since = None; - } - - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - } - }) - .await??; - - Ok(()) -} - #[tokio::test] async fn thread_unsubscribe_keeps_thread_loaded_until_idle_timeout() -> Result<()> { let server = create_mock_responses_server_repeating_assistant("Done").await; @@ -128,14 +87,10 @@ async fn thread_unsubscribe_keeps_thread_loaded_until_idle_timeout() -> Result<( #[tokio::test] async fn thread_unsubscribe_during_turn_keeps_turn_running() -> Result<()> { - #[cfg(target_os = "windows")] - let shell_command = vec![ - "powershell".to_string(), - "-Command".to_string(), - "Start-Sleep -Seconds 1".to_string(), - ]; - #[cfg(not(target_os = "windows"))] - let shell_command = vec!["sleep".to_string(), "1".to_string()]; + let call_id = "deterministic-wait-call"; + let tool_name = "deterministic_wait"; + let tool_args = json!({}); + let tool_call_arguments = serde_json::to_string(&tool_args)?; let tmp = TempDir::new()?; let codex_home = tmp.path().join("codex_home"); @@ -143,28 +98,61 @@ async fn thread_unsubscribe_during_turn_keeps_turn_running() -> Result<()> { let working_directory = tmp.path().join("workdir"); std::fs::create_dir(&working_directory)?; - let server = create_mock_responses_server_sequence_unchecked(vec![ - create_shell_command_sse_response( - shell_command.clone(), - Some(&working_directory), - Some(10_000), - "call_sleep", - )?, - create_final_assistant_message_sse_response("Done")?, + let (server, mut completions) = start_streaming_sse_server(vec![ + vec![StreamingSseChunk { + gate: None, + body: responses::sse(vec![ + responses::ev_response_created("resp-1"), + responses::ev_function_call(call_id, tool_name, &tool_call_arguments), + responses::ev_completed("resp-1"), + ]), + }], + vec![StreamingSseChunk { + gate: None, + body: responses::sse(vec![ + responses::ev_response_created("resp-2"), + responses::ev_assistant_message("msg-1", "Done"), + responses::ev_completed("resp-2"), + ]), + }], ]) .await; - create_config_toml(&codex_home, &server.uri())?; + let first_response_completed = completions.remove(0); + let final_response_completed = completions.remove(0); + create_config_toml(&codex_home, server.uri())?; let mut mcp = McpProcess::new(&codex_home).await?; timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??; - let thread_id = start_thread(&mut mcp).await?; + let thread_req = mcp + .send_thread_start_request(ThreadStartParams { + model: Some("mock-model".to_string()), + dynamic_tools: Some(vec![DynamicToolSpec { + name: tool_name.to_string(), + description: "Deterministic wait tool".to_string(), + input_schema: json!({ + "type": "object", + "properties": {}, + "additionalProperties": false, + }), + defer_loading: false, + }]), + ..Default::default() + }) + .await?; + let thread_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(thread_req)), + ) + .await??; + let ThreadStartResponse { thread, .. } = to_response::(thread_resp)?; + let thread_id = thread.id; let turn_req = mcp .send_turn_start_request(TurnStartParams { thread_id: thread_id.clone(), input: vec![V2UserInput::Text { - text: "run sleep".to_string(), + text: "run deterministic tool".to_string(), text_elements: Vec::new(), }], cwd: Some(working_directory), @@ -180,9 +168,37 @@ async fn thread_unsubscribe_during_turn_keeps_turn_running() -> Result<()> { timeout( DEFAULT_READ_TIMEOUT, - wait_for_command_execution_item_started(&mut mcp), + server.wait_for_request_count(/*count*/ 1), + ) + .await?; + timeout(DEFAULT_READ_TIMEOUT, first_response_completed).await??; + + let started = timeout( + DEFAULT_READ_TIMEOUT, + wait_for_dynamic_tool_started(&mut mcp, call_id), ) .await??; + assert_eq!(started.thread_id, thread_id); + + let request = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_request_message(), + ) + .await??; + let (request_id, params) = match request { + ServerRequest::DynamicToolCall { request_id, params } => (request_id, params), + other => panic!("expected DynamicToolCall request, got {other:?}"), + }; + assert_eq!( + params, + DynamicToolCallParams { + thread_id: thread_id.clone(), + turn_id: started.turn_id, + call_id: call_id.to_string(), + tool: tool_name.to_string(), + arguments: tool_args, + } + ); let unsubscribe_id = mcp .send_thread_unsubscribe_request(ThreadUnsubscribeParams { @@ -197,21 +213,29 @@ async fn thread_unsubscribe_during_turn_keeps_turn_running() -> Result<()> { let unsubscribe = to_response::(unsubscribe_resp)?; assert_eq!(unsubscribe.status, ThreadUnsubscribeStatus::Unsubscribed); - assert!( - timeout( - std::time::Duration::from_millis(250), - mcp.read_stream_until_notification_message("thread/closed"), - ) - .await - .is_err() + let closed_while_tool_call_blocked = timeout( + std::time::Duration::from_millis(250), + mcp.read_stream_until_notification_message("thread/closed"), ); + let closed_while_tool_call_blocked = closed_while_tool_call_blocked.await; + assert!(closed_while_tool_call_blocked.is_err()); + + let response = DynamicToolCallResponse { + content_items: vec![DynamicToolCallOutputContentItem::InputText { + text: "dynamic-ok".to_string(), + }], + success: true, + }; + mcp.send_response(request_id, serde_json::to_value(response)?) + .await?; - wait_for_responses_request_count_to_stabilize( - &server, - /*expected_count*/ 2, - std::time::Duration::from_millis(200), + timeout( + DEFAULT_READ_TIMEOUT, + server.wait_for_request_count(/*count*/ 2), ) .await?; + timeout(DEFAULT_READ_TIMEOUT, final_response_completed).await??; + server.shutdown().await; Ok(()) } @@ -350,15 +374,20 @@ async fn thread_unsubscribe_reports_not_subscribed_before_idle_unload() -> Resul Ok(()) } -async fn wait_for_command_execution_item_started(mcp: &mut McpProcess) -> Result<()> { +async fn wait_for_dynamic_tool_started( + mcp: &mut McpProcess, + call_id: &str, +) -> Result { loop { - let started_notif = mcp + let notification = mcp .read_stream_until_notification_message("item/started") .await?; - let started_params = started_notif.params.context("item/started params")?; - let started: ItemStartedNotification = serde_json::from_value(started_params)?; - if let ThreadItem::CommandExecution { .. } = started.item { - return Ok(()); + let Some(params) = notification.params else { + continue; + }; + let started: ItemStartedNotification = serde_json::from_value(params)?; + if matches!(&started.item, ThreadItem::DynamicToolCall { id, .. } if id == call_id) { + return Ok(started); } } } diff --git a/codex-rs/core/tests/common/streaming_sse.rs b/codex-rs/core/tests/common/streaming_sse.rs index 82edcd39d77..a86256ed5a9 100644 --- a/codex-rs/core/tests/common/streaming_sse.rs +++ b/codex-rs/core/tests/common/streaming_sse.rs @@ -7,6 +7,7 @@ use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; use tokio::net::TcpListener; use tokio::sync::Mutex as TokioMutex; +use tokio::sync::Notify; use tokio::sync::oneshot; /// Streaming SSE chunk payload gated by a per-chunk signal. @@ -20,6 +21,7 @@ pub struct StreamingSseChunk { pub struct StreamingSseServer { uri: String, requests: Arc>>>, + request_notify: Arc, shutdown: oneshot::Sender<()>, task: tokio::task::JoinHandle<()>, } @@ -33,6 +35,15 @@ impl StreamingSseServer { self.requests.lock().await.clone() } + pub async fn wait_for_request_count(&self, count: usize) { + loop { + if self.requests.lock().await.len() >= count { + return; + } + self.request_notify.notified().await; + } + } + pub async fn shutdown(self) { let _ = self.shutdown.send(()); let _ = self.task.await; @@ -67,7 +78,9 @@ pub async fn start_streaming_sse_server( completions: VecDeque::from(completion_senders), })); let requests = Arc::new(TokioMutex::new(Vec::new())); + let request_notify = Arc::new(Notify::new()); let requests_for_task = Arc::clone(&requests); + let request_notify_for_task = Arc::clone(&request_notify); let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); let task = tokio::spawn(async move { @@ -78,6 +91,7 @@ pub async fn start_streaming_sse_server( let (mut stream, _) = accept_res.expect("accept streaming SSE connection"); let state = Arc::clone(&state); let requests = Arc::clone(&requests_for_task); + let request_notify = Arc::clone(&request_notify_for_task); tokio::spawn(async move { let (request, body_prefix) = read_http_request(&mut stream).await; let Some((method, path)) = parse_request_line(&request) else { @@ -113,6 +127,7 @@ pub async fn start_streaming_sse_server( } }; requests.lock().await.push(body); + request_notify.notify_one(); let Some((chunks, completion)) = take_next_stream(&state).await else { let _ = write_http_response(&mut stream, /*status*/ 500, "no responses queued", "text/plain").await; return; @@ -149,6 +164,7 @@ pub async fn start_streaming_sse_server( StreamingSseServer { uri, requests, + request_notify, shutdown: shutdown_tx, task, },