Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion codex-rs/codex-mcp/src/mcp_connection_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ use codex_protocol::protocol::McpStartupStatus;
use codex_protocol::protocol::McpStartupUpdateEvent;
use codex_protocol::protocol::SandboxPolicy;
use codex_rmcp_client::ElicitationResponse;
use codex_rmcp_client::LocalStdioServerLauncher;
use codex_rmcp_client::RmcpClient;
use codex_rmcp_client::SendElicitation;
use codex_rmcp_client::StdioServerLauncher;
use futures::future::BoxFuture;
use futures::future::FutureExt;
use futures::future::Shared;
Expand Down Expand Up @@ -1499,7 +1501,8 @@ async fn make_rmcp_client(
.map(|(key, value)| (key.into(), value.into()))
.collect::<HashMap<_, _>>()
});
RmcpClient::new_stdio_client(command_os, args_os, env_os, &env_vars, cwd)
let launcher = Arc::new(LocalStdioServerLauncher) as Arc<dyn StdioServerLauncher>;
RmcpClient::new_stdio_client(command_os, args_os, env_os, &env_vars, cwd, launcher)
.await
.map_err(|err| StartupOutcomeError::from(anyhow!(err)))
}
Expand Down
3 changes: 3 additions & 0 deletions codex-rs/rmcp-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod oauth;
mod perform_oauth_login;
mod program_resolver;
mod rmcp_client;
mod stdio_server_launcher;
mod utils;

pub use auth_status::StreamableHttpOAuthDiscovery;
Expand All @@ -29,3 +30,5 @@ pub use rmcp_client::ListToolsWithConnectorIdResult;
pub use rmcp_client::RmcpClient;
pub use rmcp_client::SendElicitation;
pub use rmcp_client::ToolWithConnectorId;
pub use stdio_server_launcher::LocalStdioServerLauncher;
pub use stdio_server_launcher::StdioServerLauncher;
167 changes: 25 additions & 142 deletions codex-rs/rmcp-client/src/rmcp_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use std::ffi::OsString;
use std::future::Future;
use std::io;
use std::path::PathBuf;
use std::process::Stdio;
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
Expand Down Expand Up @@ -52,7 +51,6 @@ use rmcp::transport::StreamableHttpClientTransport;
use rmcp::transport::auth::AuthClient;
use rmcp::transport::auth::AuthError;
use rmcp::transport::auth::OAuthState;
use rmcp::transport::child_process::TokioChildProcess;
use rmcp::transport::streamable_http_client::AuthRequiredError;
use rmcp::transport::streamable_http_client::StreamableHttpClient;
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
Expand All @@ -63,23 +61,22 @@ use serde::Serialize;
use serde_json::Value;
use sse_stream::Sse;
use sse_stream::SseStream;
use tokio::io::AsyncBufReadExt;
use tokio::io::BufReader;
use tokio::process::Command;
use tokio::sync::Mutex;
use tokio::sync::watch;
use tokio::time;
use tracing::info;
use tracing::warn;

use crate::elicitation_client_service::ElicitationClientService;
use crate::load_oauth_tokens;
use crate::oauth::OAuthPersistor;
use crate::oauth::StoredOAuthTokens;
use crate::program_resolver;
use crate::stdio_server_launcher::LaunchedStdioServer;
use crate::stdio_server_launcher::LaunchedStdioServerTransport;
use crate::stdio_server_launcher::ProcessGroupGuard;
use crate::stdio_server_launcher::StdioServerCommand;
use crate::stdio_server_launcher::StdioServerLauncher;
use crate::utils::apply_default_headers;
use crate::utils::build_default_headers;
use crate::utils::create_env_for_mcp_server;
use codex_config::types::OAuthCredentialsStoreMode;

const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream";
Expand Down Expand Up @@ -307,9 +304,8 @@ impl StreamableHttpClient for StreamableHttpResponseClient {
}

enum PendingTransport {
ChildProcess {
transport: TokioChildProcess,
process_group_guard: Option<ProcessGroupGuard>,
Stdio {
server: LaunchedStdioServer,
},
StreamableHttp {
transport: StreamableHttpClientTransport<StreamableHttpResponseClient>,
Expand All @@ -331,73 +327,11 @@ enum ClientState {
},
}

#[cfg(unix)]
const PROCESS_GROUP_TERM_GRACE_PERIOD: Duration = Duration::from_secs(2);

#[cfg(unix)]
struct ProcessGroupGuard {
process_group_id: u32,
}

#[cfg(not(unix))]
struct ProcessGroupGuard;

impl ProcessGroupGuard {
fn new(process_group_id: u32) -> Self {
#[cfg(unix)]
{
Self { process_group_id }
}
#[cfg(not(unix))]
{
let _ = process_group_id;
Self
}
}

#[cfg(unix)]
fn maybe_terminate_process_group(&self) {
let process_group_id = self.process_group_id;
let should_escalate =
match codex_utils_pty::process_group::terminate_process_group(process_group_id) {
Ok(exists) => exists,
Err(error) => {
warn!("Failed to terminate MCP process group {process_group_id}: {error}");
false
}
};
if should_escalate {
std::thread::spawn(move || {
std::thread::sleep(PROCESS_GROUP_TERM_GRACE_PERIOD);
if let Err(error) =
codex_utils_pty::process_group::kill_process_group(process_group_id)
{
warn!("Failed to kill MCP process group {process_group_id}: {error}");
}
});
}
}

#[cfg(not(unix))]
fn maybe_terminate_process_group(&self) {}
}

impl Drop for ProcessGroupGuard {
fn drop(&mut self) {
if cfg!(unix) {
self.maybe_terminate_process_group();
}
}
}

#[derive(Clone)]
enum TransportRecipe {
Stdio {
program: OsString,
args: Vec<OsString>,
env: Option<HashMap<OsString, OsString>>,
env_vars: Vec<String>,
cwd: Option<PathBuf>,
command: StdioServerCommand,
launcher: Arc<dyn StdioServerLauncher>,
},
StreamableHttp {
server_name: String,
Expand Down Expand Up @@ -574,13 +508,11 @@ impl RmcpClient {
env: Option<HashMap<OsString, OsString>>,
env_vars: &[String],
cwd: Option<PathBuf>,
launcher: Arc<dyn StdioServerLauncher>,
) -> io::Result<Self> {
let transport_recipe = TransportRecipe::Stdio {
program,
args,
env,
env_vars: env_vars.to_vec(),
cwd,
command: StdioServerCommand::new(program, args, env, env_vars.to_vec(), cwd),
launcher,
};
let transport = Self::create_pending_transport(&transport_recipe)
.await
Expand Down Expand Up @@ -954,60 +886,9 @@ impl RmcpClient {
transport_recipe: &TransportRecipe,
) -> Result<PendingTransport> {
match transport_recipe {
TransportRecipe::Stdio {
program,
args,
env,
env_vars,
cwd,
} => {
let program_name = program.to_string_lossy().into_owned();
let envs = create_env_for_mcp_server(env.clone(), env_vars);
let resolved_program = program_resolver::resolve(program.clone(), &envs)?;

let mut command = Command::new(resolved_program);
command
.kill_on_drop(true)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.env_clear()
.envs(envs)
.args(args);
#[cfg(unix)]
command.process_group(0);
if let Some(cwd) = cwd {
command.current_dir(cwd);
}

let (transport, stderr) = TokioChildProcess::builder(command)
.stderr(Stdio::piped())
.spawn()?;
let process_group_guard = transport.id().map(ProcessGroupGuard::new);

if let Some(stderr) = stderr {
tokio::spawn(async move {
let mut reader = BufReader::new(stderr).lines();
loop {
match reader.next_line().await {
Ok(Some(line)) => {
info!("MCP server stderr ({program_name}): {line}");
}
Ok(None) => break,
Err(error) => {
warn!(
"Failed to read MCP server stderr ({program_name}): {error}"
);
break;
}
}
}
});
}

Ok(PendingTransport::ChildProcess {
transport,
process_group_guard,
})
TransportRecipe::Stdio { command, launcher } => {
let server = launcher.launch(command.clone()).await?;
Ok(PendingTransport::Stdio { server })
}
TransportRecipe::StreamableHttp {
server_name,
Expand Down Expand Up @@ -1104,14 +985,16 @@ impl RmcpClient {
Option<ProcessGroupGuard>,
)> {
let (transport, oauth_persistor, process_group_guard) = match pending_transport {
PendingTransport::ChildProcess {
transport,
process_group_guard,
} => (
service::serve_client(client_service, transport).boxed(),
None,
process_group_guard,
),
PendingTransport::Stdio { server } => match server.transport {
LaunchedStdioServerTransport::Local {
transport,
process_group_guard,
} => (
service::serve_client(client_service, transport).boxed(),
None,
process_group_guard,
),
},
PendingTransport::StreamableHttp { transport } => (
service::serve_client(client_service, transport).boxed(),
None,
Expand Down
Loading
Loading