Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 115 additions & 86 deletions codex-rs/app-server/tests/suite/v2/thread_unsubscribe.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
use anyhow::Context;
use anyhow::Result;
use app_test_support::McpProcess;
use app_test_support::create_final_assistant_message_sse_response;
use app_test_support::create_mock_responses_server_repeating_assistant;
use app_test_support::create_mock_responses_server_sequence_unchecked;
use app_test_support::create_shell_command_sse_response;
use app_test_support::to_response;
use codex_app_server_protocol::DynamicToolCallOutputContentItem;
use codex_app_server_protocol::DynamicToolCallParams;
use codex_app_server_protocol::DynamicToolCallResponse;
use codex_app_server_protocol::DynamicToolSpec;
use codex_app_server_protocol::ItemStartedNotification;
use codex_app_server_protocol::JSONRPCResponse;
use codex_app_server_protocol::RequestId;
use codex_app_server_protocol::ServerRequest;
use codex_app_server_protocol::ThreadItem;
use codex_app_server_protocol::ThreadLoadedListParams;
use codex_app_server_protocol::ThreadLoadedListResponse;
Expand All @@ -26,57 +27,15 @@ use codex_app_server_protocol::TurnStartParams;
use codex_app_server_protocol::TurnStartResponse;
use codex_app_server_protocol::UserInput as V2UserInput;
use core_test_support::responses;
use core_test_support::streaming_sse::StreamingSseChunk;
use core_test_support::streaming_sse::start_streaming_sse_server;
use pretty_assertions::assert_eq;
use serde_json::json;
use tempfile::TempDir;
use tokio::time::timeout;

const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);

async fn wait_for_responses_request_count_to_stabilize(
server: &wiremock::MockServer,
expected_count: usize,
settle_duration: std::time::Duration,
) -> Result<()> {
timeout(DEFAULT_READ_TIMEOUT, async {
let mut stable_since: Option<tokio::time::Instant> = None;
loop {
let requests = server
.received_requests()
.await
.context("failed to fetch received requests")?;
let responses_request_count = requests
.iter()
.filter(|request| {
request.method == "POST" && request.url.path().ends_with("/responses")
})
.count();

if responses_request_count > expected_count {
anyhow::bail!(
"expected exactly {expected_count} /responses requests, got {responses_request_count}"
);
}

if responses_request_count == expected_count {
match stable_since {
Some(stable_since) if stable_since.elapsed() >= settle_duration => {
return Ok::<(), anyhow::Error>(());
}
None => stable_since = Some(tokio::time::Instant::now()),
Some(_) => {}
}
} else {
stable_since = None;
}

tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
})
.await??;

Ok(())
}

#[tokio::test]
async fn thread_unsubscribe_keeps_thread_loaded_until_idle_timeout() -> Result<()> {
let server = create_mock_responses_server_repeating_assistant("Done").await;
Expand Down Expand Up @@ -128,43 +87,72 @@ async fn thread_unsubscribe_keeps_thread_loaded_until_idle_timeout() -> Result<(

#[tokio::test]
async fn thread_unsubscribe_during_turn_keeps_turn_running() -> Result<()> {
#[cfg(target_os = "windows")]
let shell_command = vec![
"powershell".to_string(),
"-Command".to_string(),
"Start-Sleep -Seconds 1".to_string(),
];
#[cfg(not(target_os = "windows"))]
let shell_command = vec!["sleep".to_string(), "1".to_string()];
let call_id = "deterministic-wait-call";
let tool_name = "deterministic_wait";
let tool_args = json!({});
let tool_call_arguments = serde_json::to_string(&tool_args)?;

let tmp = TempDir::new()?;
let codex_home = tmp.path().join("codex_home");
std::fs::create_dir(&codex_home)?;
let working_directory = tmp.path().join("workdir");
std::fs::create_dir(&working_directory)?;

let server = create_mock_responses_server_sequence_unchecked(vec![
create_shell_command_sse_response(
shell_command.clone(),
Some(&working_directory),
Some(10_000),
"call_sleep",
)?,
create_final_assistant_message_sse_response("Done")?,
let (server, mut completions) = start_streaming_sse_server(vec![
vec![StreamingSseChunk {
gate: None,
body: responses::sse(vec![
responses::ev_response_created("resp-1"),
responses::ev_function_call(call_id, tool_name, &tool_call_arguments),
responses::ev_completed("resp-1"),
]),
}],
vec![StreamingSseChunk {
gate: None,
body: responses::sse(vec![
responses::ev_response_created("resp-2"),
responses::ev_assistant_message("msg-1", "Done"),
responses::ev_completed("resp-2"),
]),
}],
])
.await;
create_config_toml(&codex_home, &server.uri())?;
let first_response_completed = completions.remove(0);
let final_response_completed = completions.remove(0);
create_config_toml(&codex_home, server.uri())?;

let mut mcp = McpProcess::new(&codex_home).await?;
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;

let thread_id = start_thread(&mut mcp).await?;
let thread_req = mcp
.send_thread_start_request(ThreadStartParams {
model: Some("mock-model".to_string()),
dynamic_tools: Some(vec![DynamicToolSpec {
name: tool_name.to_string(),
description: "Deterministic wait tool".to_string(),
input_schema: json!({
"type": "object",
"properties": {},
"additionalProperties": false,
}),
defer_loading: false,
}]),
..Default::default()
})
.await?;
let thread_resp: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(thread_req)),
)
.await??;
let ThreadStartResponse { thread, .. } = to_response::<ThreadStartResponse>(thread_resp)?;
let thread_id = thread.id;

let turn_req = mcp
.send_turn_start_request(TurnStartParams {
thread_id: thread_id.clone(),
input: vec![V2UserInput::Text {
text: "run sleep".to_string(),
text: "run deterministic tool".to_string(),
text_elements: Vec::new(),
}],
cwd: Some(working_directory),
Expand All @@ -180,9 +168,37 @@ async fn thread_unsubscribe_during_turn_keeps_turn_running() -> Result<()> {

timeout(
DEFAULT_READ_TIMEOUT,
wait_for_command_execution_item_started(&mut mcp),
server.wait_for_request_count(/*count*/ 1),
)
.await?;
timeout(DEFAULT_READ_TIMEOUT, first_response_completed).await??;

let started = timeout(
DEFAULT_READ_TIMEOUT,
wait_for_dynamic_tool_started(&mut mcp, call_id),
)
.await??;
assert_eq!(started.thread_id, thread_id);

let request = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_request_message(),
)
.await??;
let (request_id, params) = match request {
ServerRequest::DynamicToolCall { request_id, params } => (request_id, params),
other => panic!("expected DynamicToolCall request, got {other:?}"),
};
assert_eq!(
params,
DynamicToolCallParams {
thread_id: thread_id.clone(),
turn_id: started.turn_id,
call_id: call_id.to_string(),
tool: tool_name.to_string(),
arguments: tool_args,
}
);

let unsubscribe_id = mcp
.send_thread_unsubscribe_request(ThreadUnsubscribeParams {
Expand All @@ -197,21 +213,29 @@ async fn thread_unsubscribe_during_turn_keeps_turn_running() -> Result<()> {
let unsubscribe = to_response::<ThreadUnsubscribeResponse>(unsubscribe_resp)?;
assert_eq!(unsubscribe.status, ThreadUnsubscribeStatus::Unsubscribed);

assert!(
timeout(
std::time::Duration::from_millis(250),
mcp.read_stream_until_notification_message("thread/closed"),
)
.await
.is_err()
let closed_while_tool_call_blocked = timeout(
std::time::Duration::from_millis(250),
mcp.read_stream_until_notification_message("thread/closed"),
);
let closed_while_tool_call_blocked = closed_while_tool_call_blocked.await;
assert!(closed_while_tool_call_blocked.is_err());

let response = DynamicToolCallResponse {
content_items: vec![DynamicToolCallOutputContentItem::InputText {
text: "dynamic-ok".to_string(),
}],
success: true,
};
mcp.send_response(request_id, serde_json::to_value(response)?)
.await?;

wait_for_responses_request_count_to_stabilize(
&server,
/*expected_count*/ 2,
std::time::Duration::from_millis(200),
timeout(
DEFAULT_READ_TIMEOUT,
server.wait_for_request_count(/*count*/ 2),
)
.await?;
timeout(DEFAULT_READ_TIMEOUT, final_response_completed).await??;
server.shutdown().await;

Ok(())
}
Expand Down Expand Up @@ -350,15 +374,20 @@ async fn thread_unsubscribe_reports_not_subscribed_before_idle_unload() -> Resul
Ok(())
}

async fn wait_for_command_execution_item_started(mcp: &mut McpProcess) -> Result<()> {
async fn wait_for_dynamic_tool_started(
mcp: &mut McpProcess,
call_id: &str,
) -> Result<ItemStartedNotification> {
loop {
let started_notif = mcp
let notification = mcp
.read_stream_until_notification_message("item/started")
.await?;
let started_params = started_notif.params.context("item/started params")?;
let started: ItemStartedNotification = serde_json::from_value(started_params)?;
if let ThreadItem::CommandExecution { .. } = started.item {
return Ok(());
let Some(params) = notification.params else {
continue;
};
let started: ItemStartedNotification = serde_json::from_value(params)?;
if matches!(&started.item, ThreadItem::DynamicToolCall { id, .. } if id == call_id) {
return Ok(started);
}
}
}
Expand Down
16 changes: 16 additions & 0 deletions codex-rs/core/tests/common/streaming_sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpListener;
use tokio::sync::Mutex as TokioMutex;
use tokio::sync::Notify;
use tokio::sync::oneshot;

/// Streaming SSE chunk payload gated by a per-chunk signal.
Expand All @@ -20,6 +21,7 @@ pub struct StreamingSseChunk {
pub struct StreamingSseServer {
uri: String,
requests: Arc<TokioMutex<Vec<Vec<u8>>>>,
request_notify: Arc<Notify>,
shutdown: oneshot::Sender<()>,
task: tokio::task::JoinHandle<()>,
}
Expand All @@ -33,6 +35,15 @@ impl StreamingSseServer {
self.requests.lock().await.clone()
}

pub async fn wait_for_request_count(&self, count: usize) {
loop {
if self.requests.lock().await.len() >= count {
return;
}
self.request_notify.notified().await;
}
}

pub async fn shutdown(self) {
let _ = self.shutdown.send(());
let _ = self.task.await;
Expand Down Expand Up @@ -67,7 +78,9 @@ pub async fn start_streaming_sse_server(
completions: VecDeque::from(completion_senders),
}));
let requests = Arc::new(TokioMutex::new(Vec::new()));
let request_notify = Arc::new(Notify::new());
let requests_for_task = Arc::clone(&requests);
let request_notify_for_task = Arc::clone(&request_notify);
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();

let task = tokio::spawn(async move {
Expand All @@ -78,6 +91,7 @@ pub async fn start_streaming_sse_server(
let (mut stream, _) = accept_res.expect("accept streaming SSE connection");
let state = Arc::clone(&state);
let requests = Arc::clone(&requests_for_task);
let request_notify = Arc::clone(&request_notify_for_task);
tokio::spawn(async move {
let (request, body_prefix) = read_http_request(&mut stream).await;
let Some((method, path)) = parse_request_line(&request) else {
Expand Down Expand Up @@ -113,6 +127,7 @@ pub async fn start_streaming_sse_server(
}
};
requests.lock().await.push(body);
request_notify.notify_one();
let Some((chunks, completion)) = take_next_stream(&state).await else {
let _ = write_http_response(&mut stream, /*status*/ 500, "no responses queued", "text/plain").await;
return;
Expand Down Expand Up @@ -149,6 +164,7 @@ pub async fn start_streaming_sse_server(
StreamingSseServer {
uri,
requests,
request_notify,
shutdown: shutdown_tx,
task,
},
Expand Down
Loading