diff --git a/codex-rs/codex-mcp/src/mcp_connection_manager.rs b/codex-rs/codex-mcp/src/mcp_connection_manager.rs index 1542c7b30b8..093f8dac9a2 100644 --- a/codex-rs/codex-mcp/src/mcp_connection_manager.rs +++ b/codex-rs/codex-mcp/src/mcp_connection_manager.rs @@ -50,8 +50,10 @@ use codex_protocol::protocol::McpStartupStatus; use codex_protocol::protocol::McpStartupUpdateEvent; use codex_protocol::protocol::SandboxPolicy; use codex_rmcp_client::ElicitationResponse; +use codex_rmcp_client::LocalStdioServerLauncher; use codex_rmcp_client::RmcpClient; use codex_rmcp_client::SendElicitation; +use codex_rmcp_client::StdioServerLauncher; use futures::future::BoxFuture; use futures::future::FutureExt; use futures::future::Shared; @@ -1499,7 +1501,8 @@ async fn make_rmcp_client( .map(|(key, value)| (key.into(), value.into())) .collect::>() }); - RmcpClient::new_stdio_client(command_os, args_os, env_os, &env_vars, cwd) + let launcher = Arc::new(LocalStdioServerLauncher) as Arc; + RmcpClient::new_stdio_client(command_os, args_os, env_os, &env_vars, cwd, launcher) .await .map_err(|err| StartupOutcomeError::from(anyhow!(err))) } diff --git a/codex-rs/rmcp-client/src/lib.rs b/codex-rs/rmcp-client/src/lib.rs index 86460ecc1f8..f02167b6f66 100644 --- a/codex-rs/rmcp-client/src/lib.rs +++ b/codex-rs/rmcp-client/src/lib.rs @@ -5,6 +5,7 @@ mod oauth; mod perform_oauth_login; mod program_resolver; mod rmcp_client; +mod stdio_server_launcher; mod utils; pub use auth_status::StreamableHttpOAuthDiscovery; @@ -29,3 +30,5 @@ pub use rmcp_client::ListToolsWithConnectorIdResult; pub use rmcp_client::RmcpClient; pub use rmcp_client::SendElicitation; pub use rmcp_client::ToolWithConnectorId; +pub use stdio_server_launcher::LocalStdioServerLauncher; +pub use stdio_server_launcher::StdioServerLauncher; diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index 415354fee4e..270a56ab2e2 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -4,7 +4,6 @@ use std::ffi::OsString; use std::future::Future; use std::io; use std::path::PathBuf; -use std::process::Stdio; use std::sync::Arc; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering; @@ -52,7 +51,6 @@ use rmcp::transport::StreamableHttpClientTransport; use rmcp::transport::auth::AuthClient; use rmcp::transport::auth::AuthError; use rmcp::transport::auth::OAuthState; -use rmcp::transport::child_process::TokioChildProcess; use rmcp::transport::streamable_http_client::AuthRequiredError; use rmcp::transport::streamable_http_client::StreamableHttpClient; use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; @@ -63,23 +61,22 @@ use serde::Serialize; use serde_json::Value; use sse_stream::Sse; use sse_stream::SseStream; -use tokio::io::AsyncBufReadExt; -use tokio::io::BufReader; -use tokio::process::Command; use tokio::sync::Mutex; use tokio::sync::watch; use tokio::time; -use tracing::info; use tracing::warn; use crate::elicitation_client_service::ElicitationClientService; use crate::load_oauth_tokens; use crate::oauth::OAuthPersistor; use crate::oauth::StoredOAuthTokens; -use crate::program_resolver; +use crate::stdio_server_launcher::LaunchedStdioServer; +use crate::stdio_server_launcher::LaunchedStdioServerTransport; +use crate::stdio_server_launcher::ProcessGroupGuard; +use crate::stdio_server_launcher::StdioServerCommand; +use crate::stdio_server_launcher::StdioServerLauncher; use crate::utils::apply_default_headers; use crate::utils::build_default_headers; -use crate::utils::create_env_for_mcp_server; use codex_config::types::OAuthCredentialsStoreMode; const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream"; @@ -307,9 +304,8 @@ impl StreamableHttpClient for StreamableHttpResponseClient { } enum PendingTransport { - ChildProcess { - transport: TokioChildProcess, - process_group_guard: Option, + Stdio { + server: LaunchedStdioServer, }, StreamableHttp { transport: StreamableHttpClientTransport, @@ -331,73 +327,11 @@ enum ClientState { }, } -#[cfg(unix)] -const PROCESS_GROUP_TERM_GRACE_PERIOD: Duration = Duration::from_secs(2); - -#[cfg(unix)] -struct ProcessGroupGuard { - process_group_id: u32, -} - -#[cfg(not(unix))] -struct ProcessGroupGuard; - -impl ProcessGroupGuard { - fn new(process_group_id: u32) -> Self { - #[cfg(unix)] - { - Self { process_group_id } - } - #[cfg(not(unix))] - { - let _ = process_group_id; - Self - } - } - - #[cfg(unix)] - fn maybe_terminate_process_group(&self) { - let process_group_id = self.process_group_id; - let should_escalate = - match codex_utils_pty::process_group::terminate_process_group(process_group_id) { - Ok(exists) => exists, - Err(error) => { - warn!("Failed to terminate MCP process group {process_group_id}: {error}"); - false - } - }; - if should_escalate { - std::thread::spawn(move || { - std::thread::sleep(PROCESS_GROUP_TERM_GRACE_PERIOD); - if let Err(error) = - codex_utils_pty::process_group::kill_process_group(process_group_id) - { - warn!("Failed to kill MCP process group {process_group_id}: {error}"); - } - }); - } - } - - #[cfg(not(unix))] - fn maybe_terminate_process_group(&self) {} -} - -impl Drop for ProcessGroupGuard { - fn drop(&mut self) { - if cfg!(unix) { - self.maybe_terminate_process_group(); - } - } -} - #[derive(Clone)] enum TransportRecipe { Stdio { - program: OsString, - args: Vec, - env: Option>, - env_vars: Vec, - cwd: Option, + command: StdioServerCommand, + launcher: Arc, }, StreamableHttp { server_name: String, @@ -574,13 +508,11 @@ impl RmcpClient { env: Option>, env_vars: &[String], cwd: Option, + launcher: Arc, ) -> io::Result { let transport_recipe = TransportRecipe::Stdio { - program, - args, - env, - env_vars: env_vars.to_vec(), - cwd, + command: StdioServerCommand::new(program, args, env, env_vars.to_vec(), cwd), + launcher, }; let transport = Self::create_pending_transport(&transport_recipe) .await @@ -954,60 +886,9 @@ impl RmcpClient { transport_recipe: &TransportRecipe, ) -> Result { match transport_recipe { - TransportRecipe::Stdio { - program, - args, - env, - env_vars, - cwd, - } => { - let program_name = program.to_string_lossy().into_owned(); - let envs = create_env_for_mcp_server(env.clone(), env_vars); - let resolved_program = program_resolver::resolve(program.clone(), &envs)?; - - let mut command = Command::new(resolved_program); - command - .kill_on_drop(true) - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .env_clear() - .envs(envs) - .args(args); - #[cfg(unix)] - command.process_group(0); - if let Some(cwd) = cwd { - command.current_dir(cwd); - } - - let (transport, stderr) = TokioChildProcess::builder(command) - .stderr(Stdio::piped()) - .spawn()?; - let process_group_guard = transport.id().map(ProcessGroupGuard::new); - - if let Some(stderr) = stderr { - tokio::spawn(async move { - let mut reader = BufReader::new(stderr).lines(); - loop { - match reader.next_line().await { - Ok(Some(line)) => { - info!("MCP server stderr ({program_name}): {line}"); - } - Ok(None) => break, - Err(error) => { - warn!( - "Failed to read MCP server stderr ({program_name}): {error}" - ); - break; - } - } - } - }); - } - - Ok(PendingTransport::ChildProcess { - transport, - process_group_guard, - }) + TransportRecipe::Stdio { command, launcher } => { + let server = launcher.launch(command.clone()).await?; + Ok(PendingTransport::Stdio { server }) } TransportRecipe::StreamableHttp { server_name, @@ -1104,14 +985,16 @@ impl RmcpClient { Option, )> { let (transport, oauth_persistor, process_group_guard) = match pending_transport { - PendingTransport::ChildProcess { - transport, - process_group_guard, - } => ( - service::serve_client(client_service, transport).boxed(), - None, - process_group_guard, - ), + PendingTransport::Stdio { server } => match server.transport { + LaunchedStdioServerTransport::Local { + transport, + process_group_guard, + } => ( + service::serve_client(client_service, transport).boxed(), + None, + process_group_guard, + ), + }, PendingTransport::StreamableHttp { transport } => ( service::serve_client(client_service, transport).boxed(), None, diff --git a/codex-rs/rmcp-client/src/stdio_server_launcher.rs b/codex-rs/rmcp-client/src/stdio_server_launcher.rs new file mode 100644 index 00000000000..7238a2d4171 --- /dev/null +++ b/codex-rs/rmcp-client/src/stdio_server_launcher.rs @@ -0,0 +1,242 @@ +//! Launch MCP stdio servers and return the transport rmcp should use. +//! +//! This module owns the "where does the server process run?" boundary for +//! stdio MCP servers. In this PR there is only the local launcher, which keeps +//! the existing behavior: the orchestrator starts the configured command and +//! rmcp talks to the child process through local stdin/stdout pipes. +//! +//! Later stack entries add an executor-backed launcher without changing +//! `RmcpClient`'s MCP lifecycle code. + +use std::collections::HashMap; +use std::ffi::OsString; +use std::io; +use std::path::PathBuf; +use std::process::Stdio; +#[cfg(unix)] +use std::thread::sleep; +#[cfg(unix)] +use std::thread::spawn; +#[cfg(unix)] +use std::time::Duration; + +#[cfg(unix)] +use codex_utils_pty::process_group::kill_process_group; +#[cfg(unix)] +use codex_utils_pty::process_group::terminate_process_group; +use futures::FutureExt; +use futures::future::BoxFuture; +use rmcp::transport::child_process::TokioChildProcess; +use tokio::io::AsyncBufReadExt; +use tokio::io::BufReader; +use tokio::process::Command; +use tracing::info; +use tracing::warn; + +use crate::program_resolver; +use crate::utils::create_env_for_mcp_server; + +// General purpose public code. + +/// Launches an MCP stdio server and returns the byte transport for rmcp. +/// +/// This trait is the boundary between MCP lifecycle code and process placement. +/// `RmcpClient` owns MCP operations such as `initialize` and `tools/list`; the +/// launcher owns starting the configured command and producing an rmcp transport +/// over the server's stdin/stdout bytes. +pub trait StdioServerLauncher: private::Sealed + Send + Sync { + /// Start the configured stdio server and return its rmcp-facing transport. + fn launch( + &self, + command: StdioServerCommand, + ) -> BoxFuture<'static, io::Result>; +} + +/// Command-line process shape shared by stdio server launchers. +#[derive(Clone)] +pub struct StdioServerCommand { + program: OsString, + args: Vec, + env: Option>, + env_vars: Vec, + cwd: Option, +} + +/// Opaque stdio server handle produced by a [`StdioServerLauncher`]. +/// +/// `RmcpClient` unwraps this only at the final `rmcp::service::serve_client` +/// boundary. Keeping the concrete variants private prevents callers from +/// depending on local-child-process implementation details. +pub struct LaunchedStdioServer { + pub(super) transport: LaunchedStdioServerTransport, +} + +pub(super) enum LaunchedStdioServerTransport { + Local { + transport: TokioChildProcess, + process_group_guard: Option, + }, +} + +impl StdioServerCommand { + /// Build the stdio process parameters before choosing where the process + /// runs. + pub(super) fn new( + program: OsString, + args: Vec, + env: Option>, + env_vars: Vec, + cwd: Option, + ) -> Self { + Self { + program, + args, + env, + env_vars, + cwd, + } + } +} + +// Local public implementation. + +/// Starts MCP stdio servers as local child processes. +/// +/// This is the existing behavior for local MCP servers: the orchestrator +/// process spawns the configured command and rmcp talks to the child's local +/// stdin/stdout pipes directly. +#[derive(Clone)] +pub struct LocalStdioServerLauncher; + +impl StdioServerLauncher for LocalStdioServerLauncher { + fn launch( + &self, + command: StdioServerCommand, + ) -> BoxFuture<'static, io::Result> { + async move { Self::launch_server(command) }.boxed() + } +} + +// Local private implementation. + +#[cfg(unix)] +const PROCESS_GROUP_TERM_GRACE_PERIOD: Duration = Duration::from_secs(2); + +#[cfg(unix)] +pub(super) struct ProcessGroupGuard { + process_group_id: u32, +} + +#[cfg(not(unix))] +pub(super) struct ProcessGroupGuard; + +mod private { + pub trait Sealed {} +} + +impl private::Sealed for LocalStdioServerLauncher {} + +impl LocalStdioServerLauncher { + fn launch_server(command: StdioServerCommand) -> io::Result { + let StdioServerCommand { + program, + args, + env, + env_vars, + cwd, + } = command; + let program_name = program.to_string_lossy().into_owned(); + let envs = create_env_for_mcp_server(env, &env_vars); + let resolved_program = + program_resolver::resolve(program, &envs).map_err(io::Error::other)?; + + let mut command = Command::new(resolved_program); + command + .kill_on_drop(true) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .env_clear() + .envs(envs) + .args(args); + #[cfg(unix)] + command.process_group(0); + if let Some(cwd) = cwd { + command.current_dir(cwd); + } + + let (transport, stderr) = TokioChildProcess::builder(command) + .stderr(Stdio::piped()) + .spawn()?; + let process_group_guard = transport.id().map(ProcessGroupGuard::new); + + if let Some(stderr) = stderr { + tokio::spawn(async move { + let mut reader = BufReader::new(stderr).lines(); + loop { + match reader.next_line().await { + Ok(Some(line)) => { + info!("MCP server stderr ({program_name}): {line}"); + } + Ok(None) => break, + Err(error) => { + warn!("Failed to read MCP server stderr ({program_name}): {error}"); + break; + } + } + } + }); + } + + Ok(LaunchedStdioServer { + transport: LaunchedStdioServerTransport::Local { + transport, + process_group_guard, + }, + }) + } +} + +impl ProcessGroupGuard { + fn new(process_group_id: u32) -> Self { + #[cfg(unix)] + { + Self { process_group_id } + } + #[cfg(not(unix))] + { + let _ = process_group_id; + Self + } + } + + #[cfg(unix)] + fn maybe_terminate_process_group(&self) { + let process_group_id = self.process_group_id; + let should_escalate = match terminate_process_group(process_group_id) { + Ok(exists) => exists, + Err(error) => { + warn!("Failed to terminate MCP process group {process_group_id}: {error}"); + false + } + }; + if should_escalate { + spawn(move || { + sleep(PROCESS_GROUP_TERM_GRACE_PERIOD); + if let Err(error) = kill_process_group(process_group_id) { + warn!("Failed to kill MCP process group {process_group_id}: {error}"); + } + }); + } + } + + #[cfg(not(unix))] + fn maybe_terminate_process_group(&self) {} +} + +impl Drop for ProcessGroupGuard { + fn drop(&mut self) { + if cfg!(unix) { + self.maybe_terminate_process_group(); + } + } +} diff --git a/codex-rs/rmcp-client/tests/process_group_cleanup.rs b/codex-rs/rmcp-client/tests/process_group_cleanup.rs index 5d8a80e1e7e..aad28ac0ac8 100644 --- a/codex-rs/rmcp-client/tests/process_group_cleanup.rs +++ b/codex-rs/rmcp-client/tests/process_group_cleanup.rs @@ -4,10 +4,12 @@ use std::collections::HashMap; use std::ffi::OsString; use std::fs; use std::path::Path; +use std::sync::Arc; use std::time::Duration; use anyhow::Context; use anyhow::Result; +use codex_rmcp_client::LocalStdioServerLauncher; use codex_rmcp_client::RmcpClient; fn process_exists(pid: u32) -> bool { @@ -78,6 +80,7 @@ async fn drop_kills_wrapper_process_group() -> Result<()> { )])), &[], /*cwd*/ None, + Arc::new(LocalStdioServerLauncher), ) .await?; diff --git a/codex-rs/rmcp-client/tests/resources.rs b/codex-rs/rmcp-client/tests/resources.rs index e41b9268a49..b23c34df138 100644 --- a/codex-rs/rmcp-client/tests/resources.rs +++ b/codex-rs/rmcp-client/tests/resources.rs @@ -1,9 +1,11 @@ use std::ffi::OsString; use std::path::PathBuf; +use std::sync::Arc; use std::time::Duration; use codex_rmcp_client::ElicitationAction; use codex_rmcp_client::ElicitationResponse; +use codex_rmcp_client::LocalStdioServerLauncher; use codex_rmcp_client::RmcpClient; use codex_utils_cargo_bin::CargoBinError; use futures::FutureExt as _; @@ -61,6 +63,7 @@ async fn rmcp_client_can_list_and_read_resources() -> anyhow::Result<()> { /*env*/ None, &[], /*cwd*/ None, + Arc::new(LocalStdioServerLauncher), ) .await?;