diff --git a/src/devtools.rs b/src/devtools.rs index af7bb0e..3e9a3c3 100644 --- a/src/devtools.rs +++ b/src/devtools.rs @@ -5,6 +5,44 @@ use tokio::process::{Child, Command}; use tokio::sync::Mutex; use crate::browser::DetectedBrowser; +use crate::state::{DirectToolsConfig, ExternalMcpServer}; + +pub const BROWSER_MCP_SERVER_NAME: &str = "browser"; + +pub fn chrome_devtools_mcp_args(selected_browser: Option<&DetectedBrowser>) -> Vec { + let mut args = vec!["-y".to_string(), "chrome-devtools-mcp@latest".to_string()]; + if let Some(browser) = selected_browser { + if browser.remote_debug_active { + if let Some(target) = browser.remote_debug_target.as_deref() { + if target == "pipe" { + args.push("--executablePath".to_string()); + args.push(browser.path.clone()); + } else { + args.push("--browserUrl".to_string()); + args.push(format!("http://{target}")); + } + } else { + args.push("--executablePath".to_string()); + args.push(browser.path.clone()); + } + } else { + args.push("--executablePath".to_string()); + args.push(browser.path.clone()); + } + } + args +} + +pub fn chrome_devtools_mcp_server(selected_browser: Option<&DetectedBrowser>) -> ExternalMcpServer { + ExternalMcpServer { + command: Some("npx".to_string()), + args: chrome_devtools_mcp_args(selected_browser), + lifecycle: "eager".to_string(), + direct_tools: Some(DirectToolsConfig::Enabled(true)), + unprefixed_tools: true, + ..ExternalMcpServer::default() + } +} /// A running chrome-devtools-mcp child process with stdin/stdout JSON-RPC bridge. pub struct DevtoolsBridge { @@ -16,27 +54,12 @@ pub struct DevtoolsBridge { impl DevtoolsBridge { /// Spawn `npx chrome-devtools-mcp@latest` and set up stdio bridge. + #[allow(dead_code)] pub async fn start( selected_browser: Option<&DetectedBrowser>, ) -> Result>, String> { let mut command = Command::new("npx"); - command.args(["-y", "chrome-devtools-mcp@latest"]); - - if let Some(browser) = selected_browser { - if browser.remote_debug_active { - if let Some(target) = browser.remote_debug_target.as_deref() { - if target == "pipe" { - command.args(["--executablePath", &browser.path]); - } else { - command.args(["--browserUrl", &format!("http://{target}")]); - } - } else { - command.args(["--executablePath", &browser.path]); - } - } else { - command.args(["--executablePath", &browser.path]); - } - } + command.args(chrome_devtools_mcp_args(selected_browser)); let mut child = command .stdin(std::process::Stdio::piped()) @@ -152,3 +175,56 @@ impl DevtoolsBridge { let _ = self.child.kill().await; } } + +#[cfg(test)] +mod tests { + use super::*; + + fn browser_with_target(target: Option<&str>) -> DetectedBrowser { + DetectedBrowser { + name: "Chromium".to_string(), + binary: "chromium".to_string(), + path: "/usr/bin/chromium".to_string(), + remote_debugging: true, + remote_debug_hint: "--remote-debugging-port=".to_string(), + mcp_supported: true, + support_note: "Chromium (supported)".to_string(), + remote_debug_active: target.is_some(), + remote_debug_target: target.map(str::to_string), + remote_debug_pid: Some(42), + } + } + + #[test] + fn chrome_devtools_mcp_server_uses_browser_url_for_remote_debug_target() { + let browser = browser_with_target(Some("127.0.0.1:9222")); + let server = chrome_devtools_mcp_server(Some(&browser)); + assert_eq!(server.command.as_deref(), Some("npx")); + assert_eq!(server.lifecycle, "eager"); + assert!(server.unprefixed_tools); + assert_eq!( + server.args, + vec![ + "-y".to_string(), + "chrome-devtools-mcp@latest".to_string(), + "--browserUrl".to_string(), + "http://127.0.0.1:9222".to_string(), + ] + ); + } + + #[test] + fn chrome_devtools_mcp_server_uses_executable_path_without_remote_target() { + let browser = browser_with_target(None); + let server = chrome_devtools_mcp_server(Some(&browser)); + assert_eq!( + server.args, + vec![ + "-y".to_string(), + "chrome-devtools-mcp@latest".to_string(), + "--executablePath".to_string(), + "/usr/bin/chromium".to_string(), + ] + ); + } +} diff --git a/src/external_mcp.rs b/src/external_mcp.rs new file mode 100644 index 0000000..32dfd9a --- /dev/null +++ b/src/external_mcp.rs @@ -0,0 +1,3039 @@ +//! Downstream MCP gateway support for CatDesk. +//! +//! CatDesk remains the ChatGPT-facing MCP server while this module acts as an +//! MCP client for configured downstream stdio servers. The public surface keeps +//! the ChatGPT tool list compact by exposing one proxy tool named `mcp` by +//! default, while TOML opt-in direct tools can expose selected downstream tools +//! as top-level CatDesk tools. + +use crate::state::{DirectToolsConfig, ExternalMcpConfig, ExternalMcpServer}; +use reqwest::header::{ACCEPT, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue}; +use serde_json::{Value, json}; +use std::collections::{HashMap, HashSet}; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter}; +use tokio::process::{Child, ChildStdin, Command}; +use tokio::sync::{Mutex, oneshot}; + +pub const EXTERNAL_MCP_TOOL_NAME: &str = "mcp"; +const PROTOCOL_VERSION: &str = "2025-03-26"; +#[cfg(not(test))] +const REQUEST_TIMEOUT: Duration = Duration::from_secs(120); +#[cfg(test)] +const TEST_REQUEST_TIMEOUT: Duration = Duration::from_secs(2); +const MAX_HTTP_RESPONSE_BYTES: usize = 10 * 1024 * 1024; + +/// Result payload returned by the CatDesk `mcp` proxy tool. +pub struct ExternalMcpProxyOutput { + pub text: String, + pub structured: Value, +} + +/// Cached metadata for one downstream MCP tool. +#[derive(Clone, Debug)] +pub struct ExternalToolMeta { + pub server_name: String, + pub original_name: String, + pub exposed_name: String, + pub title: Option, + pub description: String, + pub input_schema: Value, + pub annotations: Value, +} + +/// Successful downstream tool invocation. +#[derive(Debug)] +pub struct ExternalMcpCallResult { + pub server_name: String, + pub original_name: String, + pub exposed_name: String, + pub result: Value, +} + +/// Successful downstream resource read. +#[derive(Debug)] +pub struct ExternalMcpReadResourceResult { + pub server_name: String, + pub uri: String, + pub result: Value, +} + +#[derive(Debug, Default)] +struct MetadataRefreshReport { + failures: Vec, +} + +impl MetadataRefreshReport { + fn has_failures(&self) -> bool { + !self.failures.is_empty() + } + + fn failure_summary(&self) -> String { + self.failures.join("; ") + } +} + +struct ExternalMcpConnection { + client: ExternalMcpClient, + last_used: Instant, + in_flight: usize, +} + +fn normalized_lifecycle(server: &ExternalMcpServer) -> String { + let lifecycle = server.lifecycle.trim().to_ascii_lowercase(); + match lifecycle.as_str() { + "eager" => "eager".to_string(), + "keep-alive" | "keep_alive" => "keep-alive".to_string(), + _ => "lazy".to_string(), + } +} + +fn server_is_keep_alive(server: &ExternalMcpServer) -> bool { + normalized_lifecycle(server) == "keep-alive" +} + +/// Owns configured downstream MCP servers, live connections, and cached tool metadata. +pub struct ExternalMcpManager { + config: ExternalMcpConfig, + workspace_root: PathBuf, + connections: HashMap, + tool_metadata: HashMap>, +} + +impl Default for ExternalMcpManager { + fn default() -> Self { + Self::new(ExternalMcpConfig::default()) + } +} + +impl ExternalMcpManager { + pub fn new(config: ExternalMcpConfig) -> Self { + Self::with_workspace( + config, + std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")), + ) + } + + pub fn with_workspace(config: ExternalMcpConfig, workspace_root: PathBuf) -> Self { + Self { + config, + workspace_root, + connections: HashMap::new(), + tool_metadata: HashMap::new(), + } + } + + pub fn from_workspace_and_app_config( + workspace_root: &str, + app_config: ExternalMcpConfig, + ) -> Self { + Self::with_workspace(app_config, PathBuf::from(workspace_root)) + } + + pub fn configured_server_names(&self) -> Vec { + sorted_keys(&self.config.mcp_servers) + } + + pub fn eager_server_names(&self) -> Vec { + let mut names = self + .config + .mcp_servers + .iter() + .filter_map(|(name, server)| { + let lifecycle = server.lifecycle.trim().to_ascii_lowercase(); + matches!(lifecycle.as_str(), "eager" | "keep-alive" | "keep_alive") + .then_some(name.clone()) + }) + .collect::>(); + names.sort(); + names + } + + pub fn status_payload(&mut self) -> Value { + let _ = self.reap_idle_connections(); + let mut servers = Vec::new(); + let mut connected_count = 0u64; + let mut tool_count = 0u64; + for name in self.configured_server_names() { + let connected = self.connections.contains_key(&name); + if connected { + connected_count = connected_count.saturating_add(1); + } + let tools = self.tool_metadata.get(&name).cloned().unwrap_or_default(); + tool_count = tool_count.saturating_add(tools.len() as u64); + let server = self.config.mcp_servers.get(&name); + let lifecycle = server + .map(normalized_lifecycle) + .unwrap_or_else(|| "lazy".to_string()); + let keep_alive = server.is_some_and(server_is_keep_alive); + let transport = server.map(server_transport_name).unwrap_or("stdio"); + let headers = server + .map(|server| redacted_headers_payload(&server.headers)) + .unwrap_or_default(); + servers.push(json!({ + "name": name, + "transport": transport, + "headers": headers, + "lifecycle": lifecycle, + "keepAlive": keep_alive, + "connected": connected, + "toolCount": tools.len(), + "directToolsEnabled": self.server_has_any_direct_tools(&name), + "tools": tools.into_iter().map(|meta| tool_meta_payload(&meta)).collect::>(), + })); + } + json!({ + "toolName": EXTERNAL_MCP_TOOL_NAME, + "action": "status", + "serverCount": self.config.mcp_servers.len(), + "connectedCount": connected_count, + "toolCount": tool_count, + "idleTimeoutMinutes": self.config.settings.idle_timeout, + "message": if self.config.mcp_servers.is_empty() { + "No downstream MCP servers configured. Add [mcp.mcpServers.] entries to ~/.catdesk/config.toml." + } else { + "External MCP gateway ready" + }, + "servers": servers, + }) + } + + pub fn tui_status_snapshot( + &mut self, + failed_server_count: usize, + browser_gateway_enabled: bool, + ) -> crate::state::ExternalMcpTuiStatus { + let _ = self.reap_idle_connections(); + let tool_count = self.tool_metadata.values().map(Vec::len).sum::(); + crate::state::ExternalMcpTuiStatus { + configured_server_count: self.config.mcp_servers.len(), + connected_server_count: self.connections.len(), + failed_server_count, + tool_count, + browser_gateway_enabled, + } + } + + pub async fn proxy(&mut self, arguments: &Value) -> Result { + self.reap_idle_connections()?; + match ProxyAction::from_arguments(arguments)? { + ProxyAction::Status => Ok(ExternalMcpProxyOutput { + text: "MCP gateway status".to_string(), + structured: self.status_payload(), + }), + ProxyAction::Call { tool, args, server } => { + let call = self.call_tool(&tool, args, server.as_deref()).await?; + let structured = json!({ + "toolName": EXTERNAL_MCP_TOOL_NAME, + "action": "call", + "server": call.server_name, + "tool": call.exposed_name, + "downstreamTool": call.original_name, + "downstreamToolCallCount": 1, + "result": call.result, + }); + Ok(ExternalMcpProxyOutput { + text: format!( + "called {}:{} via MCP gateway", + call.server_name, call.original_name + ), + structured, + }) + } + ProxyAction::Connect { server } => { + let structured = self.connect(&server).await?; + Ok(ExternalMcpProxyOutput { + text: format!("connected downstream MCP server: {server}"), + structured, + }) + } + ProxyAction::Disconnect { server } => { + let structured = self.disconnect(&server).await?; + Ok(ExternalMcpProxyOutput { + text: format!("disconnected downstream MCP server: {server}"), + structured, + }) + } + ProxyAction::Describe { query, server } => { + let refresh = self.refresh_metadata_for_lookup(server.as_deref()).await?; + let matches = self.describe_tools(&query, server.as_deref())?; + let mut text = if matches.is_empty() { + format!("no downstream MCP tool matched: {query}") + } else { + format!("described {} downstream MCP tool(s)", matches.len()) + }; + let partial = refresh.has_failures(); + let refresh_failures = refresh.failures; + if partial { + text = format!( + "{text}; partial metadata refresh with {} failure(s)", + refresh_failures.len() + ); + } + Ok(ExternalMcpProxyOutput { + text, + structured: json!({ + "toolName": EXTERNAL_MCP_TOOL_NAME, + "action": "describe", + "query": query, + "server": server, + "partial": partial, + "refreshFailures": refresh_failures, + "matches": matches, + }), + }) + } + ProxyAction::Search { query, server } => { + let refresh = self.refresh_metadata_for_lookup(server.as_deref()).await?; + let matches = self.search_tools(&query, server.as_deref()); + let partial = refresh.has_failures(); + let refresh_failures = refresh.failures; + let mut text = format!("found {} downstream MCP tool(s)", matches.len()); + if partial { + text = format!( + "{text}; partial metadata refresh with {} failure(s)", + refresh_failures.len() + ); + } + Ok(ExternalMcpProxyOutput { + text, + structured: json!({ + "toolName": EXTERNAL_MCP_TOOL_NAME, + "action": "search", + "query": query, + "server": server, + "partial": partial, + "refreshFailures": refresh_failures, + "matches": matches, + }), + }) + } + ProxyAction::Server { server } => { + let structured = self.connect(&server).await?; + Ok(ExternalMcpProxyOutput { + text: format!("listed downstream MCP server: {server}"), + structured, + }) + } + ProxyAction::ReadResource { uri, server } => { + let result = self.read_resource(&uri, server.as_deref()).await?; + Ok(ExternalMcpProxyOutput { + text: format!( + "read downstream MCP resource {} from {}", + result.uri, result.server_name + ), + structured: json!({ + "toolName": EXTERNAL_MCP_TOOL_NAME, + "action": "readResource", + "server": result.server_name, + "uri": result.uri, + "result": result.result, + }), + }) + } + ProxyAction::ListResources { server } => { + let resources = self.list_resources(server.as_deref()).await?; + Ok(ExternalMcpProxyOutput { + text: format!("listed {} downstream MCP resource(s)", resources.len()), + structured: json!({ + "toolName": EXTERNAL_MCP_TOOL_NAME, + "action": "resources", + "server": server, + "resources": resources, + }), + }) + } + } + } + + pub async fn connect(&mut self, server_name: &str) -> Result { + if self.connections.contains_key(server_name) { + return Ok(self.server_payload(server_name)); + } + let server = self + .config + .mcp_servers + .get(server_name) + .cloned() + .ok_or_else(|| format!("unknown downstream MCP server: {server_name}"))?; + let mut client = if let Some(url) = server + .url + .as_deref() + .filter(|value| !value.trim().is_empty()) + { + ExternalMcpClient::start_http(url, &server.headers)? + } else { + let command = server + .command + .as_deref() + .filter(|value| !value.trim().is_empty()) + .ok_or_else(|| { + format!( + "downstream MCP server `{server_name}` requires stdio command or HTTP url configuration" + ) + })?; + let cwd = resolve_server_cwd(&self.workspace_root, server.cwd.as_deref()); + ExternalMcpClient::start_stdio(command, &server.args, cwd.as_deref(), &server.env) + .await? + }; + client.initialize().await?; + let raw_tools = client.list_tools().await?; + let metadata = raw_tools + .into_iter() + .filter_map(|tool| tool_meta_from_json(server_name, &server, tool)) + .collect::>(); + self.tool_metadata.insert(server_name.to_string(), metadata); + self.connections.insert( + server_name.to_string(), + ExternalMcpConnection { + client, + last_used: Instant::now(), + in_flight: 0, + }, + ); + Ok(self.server_payload(server_name)) + } + + pub async fn disconnect(&mut self, server_name: &str) -> Result { + if !self.config.mcp_servers.contains_key(server_name) { + return Err(format!("unknown downstream MCP server: {server_name}")); + } + let disconnected = if let Some(mut connection) = self.connections.remove(server_name) { + connection.client.stop().await; + true + } else { + false + }; + Ok(json!({ + "toolName": EXTERNAL_MCP_TOOL_NAME, + "action": "disconnect", + "server": server_name, + "disconnected": disconnected, + "connected": self.connections.contains_key(server_name), + })) + } + + pub async fn shutdown_all(&mut self) -> Value { + let names = self.connections.keys().cloned().collect::>(); + let mut disconnected = Vec::new(); + for name in names { + if let Some(mut connection) = self.connections.remove(&name) { + connection.client.stop().await; + disconnected.push(name); + } + } + disconnected.sort(); + json!({ + "toolName": EXTERNAL_MCP_TOOL_NAME, + "action": "shutdown", + "disconnected": disconnected, + "connectedCount": self.connections.len(), + }) + } + + pub fn reap_idle_connections(&mut self) -> Result, String> { + let timeout_minutes = self.config.settings.idle_timeout; + if timeout_minutes == 0 { + return Ok(Vec::new()); + } + let timeout = Duration::from_secs(timeout_minutes.saturating_mul(60)); + let now = Instant::now(); + let mut reaped = Vec::new(); + for (name, connection) in &self.connections { + let Some(server) = self.config.mcp_servers.get(name) else { + continue; + }; + if server_is_keep_alive(server) || connection.in_flight > 0 { + continue; + } + if now.duration_since(connection.last_used) >= timeout { + reaped.push(name.clone()); + } + } + for name in &reaped { + self.connections.remove(name); + } + reaped.sort(); + Ok(reaped) + } + + #[cfg(test)] + pub(crate) fn mark_connection_idle_for_test(&mut self, server_name: &str, idle_for: Duration) { + if let Some(connection) = self.connections.get_mut(server_name) { + connection.last_used = Instant::now() - idle_for; + } + } + + #[cfg(test)] + pub fn connected_server_count(&self) -> usize { + self.connections.len() + } + + pub async fn call_tool( + &mut self, + tool_name: &str, + arguments: Value, + server_hint: Option<&str>, + ) -> Result { + let refresh = self.refresh_metadata_for_lookup(server_hint).await?; + let meta = match self.resolve_tool(tool_name, server_hint) { + Ok(meta) => meta, + Err(error) if server_hint.is_none() && refresh.has_failures() => { + return Err(format!( + "{error}; metadata refresh failures: {}", + refresh.failure_summary() + )); + } + Err(error) => return Err(error), + }; + if !self.connections.contains_key(&meta.server_name) { + self.connect(&meta.server_name).await?; + } + let connection = self + .connections + .get_mut(&meta.server_name) + .ok_or_else(|| format!("downstream MCP server disappeared: {}", meta.server_name))?; + connection.in_flight = connection.in_flight.saturating_add(1); + let result = connection + .client + .call_tool(&meta.original_name, arguments) + .await; + connection.in_flight = connection.in_flight.saturating_sub(1); + connection.last_used = Instant::now(); + Ok(ExternalMcpCallResult { + server_name: meta.server_name, + original_name: meta.original_name, + exposed_name: meta.exposed_name, + result: result?, + }) + } + + pub async fn call_direct_tool( + &mut self, + tool_name: &str, + arguments: Value, + read_only: bool, + ) -> Result, String> { + self.refresh_metadata_for_direct_tools().await?; + let direct_names = self + .direct_tool_metadata_for_mode(read_only) + .into_iter() + .map(|meta| meta.exposed_name) + .collect::>(); + if !direct_names.contains(tool_name) { + return Ok(None); + } + self.call_tool(tool_name, arguments, None).await.map(Some) + } + + pub async fn direct_tool_descriptors(&mut self, read_only: bool) -> Result, String> { + self.refresh_metadata_for_direct_tools().await?; + let mut descriptors = self + .direct_tool_metadata_for_mode(read_only) + .into_iter() + .map(|meta| { + json!({ + "name": meta.exposed_name, + "title": format!("MCP: {}", meta.original_name), + "description": direct_tool_description(&meta), + "inputSchema": meta.input_schema, + "annotations": direct_tool_annotations(&meta) + }) + }) + .collect::>(); + descriptors.sort_by(|left, right| { + let left = left.get("name").and_then(Value::as_str).unwrap_or_default(); + let right = right + .get("name") + .and_then(Value::as_str) + .unwrap_or_default(); + left.cmp(right) + }); + Ok(descriptors) + } + + pub async fn list_resources( + &mut self, + server_hint: Option<&str>, + ) -> Result, String> { + let server_names = self.resolve_server_names_for_resource_action(server_hint)?; + let mut resources = Vec::new(); + for server_name in server_names { + self.connect(&server_name).await?; + let connection = self + .connections + .get_mut(&server_name) + .ok_or_else(|| format!("downstream MCP server disappeared: {server_name}"))?; + let listed = connection.client.list_resources().await?; + connection.last_used = Instant::now(); + resources.extend( + listed + .into_iter() + .map(|resource| resource_payload(&server_name, resource)), + ); + } + resources.sort_by(|left, right| { + let left_server = left + .get("server") + .and_then(Value::as_str) + .unwrap_or_default(); + let right_server = right + .get("server") + .and_then(Value::as_str) + .unwrap_or_default(); + let left_uri = left.get("uri").and_then(Value::as_str).unwrap_or_default(); + let right_uri = right.get("uri").and_then(Value::as_str).unwrap_or_default(); + (left_server, left_uri).cmp(&(right_server, right_uri)) + }); + Ok(resources) + } + + pub async fn read_resource( + &mut self, + uri: &str, + server_hint: Option<&str>, + ) -> Result { + if uri.trim().is_empty() { + return Err("resource must be a non-empty string".to_string()); + } + let server_names = self.resolve_server_names_for_resource_action(server_hint)?; + let mut failures = Vec::new(); + for server_name in server_names { + self.connect(&server_name).await?; + let connection = self + .connections + .get_mut(&server_name) + .ok_or_else(|| format!("downstream MCP server disappeared: {server_name}"))?; + match connection.client.read_resource(uri).await { + Ok(result) => { + connection.last_used = Instant::now(); + return Ok(ExternalMcpReadResourceResult { + server_name, + uri: uri.to_string(), + result, + }); + } + Err(error) => failures.push(format!("{server_name}: {error}")), + } + } + Err(format!( + "failed to read downstream MCP resource `{uri}`: {}", + failures.join("; ") + )) + } + + async fn refresh_metadata_for_lookup( + &mut self, + server_hint: Option<&str>, + ) -> Result { + if let Some(server_name) = server_hint { + self.connect(server_name).await?; + return Ok(MetadataRefreshReport::default()); + } + let names = self.configured_server_names(); + let mut failures = Vec::new(); + for name in names { + if self.tool_metadata.contains_key(&name) { + continue; + } + if let Err(error) = self.connect(&name).await { + failures.push(format!("{name}: {error}")); + } + } + Ok(MetadataRefreshReport { failures }) + } + + async fn refresh_metadata_for_direct_tools(&mut self) -> Result<(), String> { + let names = self + .configured_server_names() + .into_iter() + .filter(|name| self.server_has_any_direct_tools(name)) + .collect::>(); + let mut failures = Vec::new(); + for name in names { + if self.tool_metadata.contains_key(&name) { + continue; + } + if let Err(error) = self.connect(&name).await { + failures.push(format!("{name}: {error}")); + } + } + if failures.is_empty() { + Ok(()) + } else { + Err(format!( + "failed to refresh direct downstream MCP tools: {}", + failures.join("; ") + )) + } + } + + fn resolve_tool( + &self, + tool_name: &str, + server_hint: Option<&str>, + ) -> Result { + let mut matches = self + .all_tool_metadata(server_hint) + .into_iter() + .filter(|meta| meta.exposed_name == tool_name || meta.original_name == tool_name) + .collect::>(); + if matches.len() == 1 { + return Ok(matches.remove(0)); + } + if matches.is_empty() { + return Err(format!("unknown downstream MCP tool: {tool_name}")); + } + let labels = matches + .iter() + .map(|meta| format!("{}:{}", meta.server_name, meta.original_name)) + .collect::>() + .join(", "); + Err(format!( + "ambiguous downstream MCP tool `{tool_name}`; pass server. Matches: {labels}" + )) + } + + fn search_tools(&self, query: &str, server_hint: Option<&str>) -> Vec { + let query = query.trim().to_ascii_lowercase(); + self.all_tool_metadata(server_hint) + .into_iter() + .filter(|meta| { + query.is_empty() + || meta.exposed_name.to_ascii_lowercase().contains(&query) + || meta.original_name.to_ascii_lowercase().contains(&query) + || meta.description.to_ascii_lowercase().contains(&query) + || meta.server_name.to_ascii_lowercase().contains(&query) + }) + .map(|meta| tool_meta_payload(&meta)) + .collect() + } + + fn describe_tools(&self, query: &str, server_hint: Option<&str>) -> Result, String> { + if query.trim().is_empty() { + return Err("describe must be a non-empty string".to_string()); + } + let exact = self + .all_tool_metadata(server_hint) + .into_iter() + .filter(|meta| meta.exposed_name == query || meta.original_name == query) + .collect::>(); + let matches = if exact.is_empty() { + self.search_tools(query, server_hint) + } else { + exact + .into_iter() + .map(|meta| tool_meta_payload(&meta)) + .collect() + }; + Ok(matches) + } + + fn all_tool_metadata(&self, server_hint: Option<&str>) -> Vec { + let mut metas = Vec::new(); + for (server_name, tools) in &self.tool_metadata { + if server_hint.is_some_and(|hint| hint != server_name) { + continue; + } + metas.extend(tools.iter().cloned()); + } + metas.sort_by(|a, b| a.exposed_name.cmp(&b.exposed_name)); + metas + } + + fn direct_tool_metadata(&self) -> Vec { + let mut metas = Vec::new(); + for (server_name, tools) in &self.tool_metadata { + let Some(server) = self.config.mcp_servers.get(server_name) else { + continue; + }; + metas.extend( + tools + .iter() + .filter(|meta| self.tool_is_direct_for_server(server_name, server, meta)) + .cloned(), + ); + } + metas.sort_by(|left, right| left.exposed_name.cmp(&right.exposed_name)); + metas + } + + fn direct_tool_metadata_for_mode(&self, read_only: bool) -> Vec { + self.direct_tool_metadata() + .into_iter() + .filter(|meta| !read_only || tool_meta_is_read_only(meta)) + .collect() + } + + fn server_has_any_direct_tools(&self, server_name: &str) -> bool { + self.config + .mcp_servers + .get(server_name) + .is_some_and(|server| match &server.direct_tools { + Some(DirectToolsConfig::Enabled(value)) => *value, + Some(DirectToolsConfig::Names(names)) => !names.is_empty(), + None => self.config.settings.direct_tools, + }) + } + + pub fn direct_tool_name_candidate(&self, tool_name: &str) -> bool { + self.config.mcp_servers.iter().any(|(server_name, server)| { + let prefixed_any_name = + tool_name.starts_with(&format!("{}_", sanitize_identifier(server_name))); + match &server.direct_tools { + Some(DirectToolsConfig::Enabled(value)) => *value && prefixed_any_name, + Some(DirectToolsConfig::Names(names)) => names.iter().any(|name| { + tool_name == name || tool_name == exposed_tool_name(server_name, server, name) + }), + None => self.config.settings.direct_tools && prefixed_any_name, + } + }) + } + + #[cfg(test)] + pub(crate) fn set_cached_tools_for_test(&mut self, server_name: &str, tools: Vec) { + let Some(server) = self.config.mcp_servers.get(server_name) else { + return; + }; + let metadata = tools + .into_iter() + .filter_map(|tool| tool_meta_from_json(server_name, server, tool)) + .collect::>(); + self.tool_metadata.insert(server_name.to_string(), metadata); + } + + fn tool_is_direct_for_server( + &self, + server_name: &str, + server: &ExternalMcpServer, + meta: &ExternalToolMeta, + ) -> bool { + match &server.direct_tools { + Some(DirectToolsConfig::Enabled(value)) => *value, + Some(DirectToolsConfig::Names(names)) => names.iter().any(|name| { + name == &meta.original_name + || name == &meta.exposed_name + || exposed_tool_name(server_name, server, name) == meta.exposed_name + }), + None => self.config.settings.direct_tools, + } + } + + fn resolve_server_names_for_resource_action( + &self, + server_hint: Option<&str>, + ) -> Result, String> { + if let Some(server_name) = server_hint { + if self.config.mcp_servers.contains_key(server_name) { + return Ok(vec![server_name.to_string()]); + } + return Err(format!("unknown downstream MCP server: {server_name}")); + } + Ok(self.configured_server_names()) + } + + fn server_payload(&self, server_name: &str) -> Value { + let tools = self + .tool_metadata + .get(server_name) + .cloned() + .unwrap_or_default() + .into_iter() + .map(|meta| tool_meta_payload(&meta)) + .collect::>(); + let server = self.config.mcp_servers.get(server_name); + json!({ + "toolName": EXTERNAL_MCP_TOOL_NAME, + "action": "server", + "server": server_name, + "transport": server.map(server_transport_name).unwrap_or("stdio"), + "headers": server.map(|server| redacted_headers_payload(&server.headers)).unwrap_or_default(), + "connected": self.connections.contains_key(server_name), + "directToolsEnabled": self.server_has_any_direct_tools(server_name), + "toolCount": tools.len(), + "tools": tools, + }) + } +} + +struct ExternalMcpClient { + transport: ExternalMcpClientTransport, + next_id: u64, +} + +struct ExternalMcpStdioTransport { + child: Child, + stdin: BufWriter, + pending: Arc>>>, +} + +struct ExternalMcpHttpTransport { + client: reqwest::Client, + url: String, + headers: HeaderMap, + session_id: Option, +} + +enum ExternalMcpClientTransport { + Stdio(ExternalMcpStdioTransport), + Http(ExternalMcpHttpTransport), +} + +impl Drop for ExternalMcpClient { + fn drop(&mut self) { + if let ExternalMcpClientTransport::Stdio(transport) = &mut self.transport { + let _ = transport.child.start_kill(); + } + } +} + +impl ExternalMcpClient { + async fn start_stdio( + command: &str, + args: &[String], + cwd: Option<&Path>, + env: &HashMap, + ) -> Result { + let mut cmd = Command::new(command); + cmd.args(args) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::null()); + if let Some(cwd) = cwd { + cmd.current_dir(cwd); + } + for (key, value) in env { + cmd.env(key, value); + } + let mut child = cmd + .spawn() + .map_err(|error| format!("spawn {command}: {error}"))?; + let child_stdin = child + .stdin + .take() + .ok_or_else(|| format!("{command} exposed no stdin"))?; + let child_stdout = child + .stdout + .take() + .ok_or_else(|| format!("{command} exposed no stdout"))?; + let pending: Arc>>> = + Arc::new(Mutex::new(HashMap::new())); + let pending_reader = pending.clone(); + tokio::spawn(async move { + let mut reader = BufReader::new(child_stdout); + let mut line = String::new(); + let mut close_reason = "downstream MCP server closed stdout".to_string(); + loop { + line.clear(); + match reader.read_line(&mut line).await { + Ok(0) => break, + Err(error) => { + close_reason = format!("downstream MCP stdout read error: {error}"); + break; + } + Ok(_) => { + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + let message = match serde_json::from_str::(trimmed) { + Ok(message) => message, + Err(error) => { + close_reason = + format!("downstream MCP returned malformed JSON: {error}"); + break; + } + }; + let Some(id) = message.get("id") else { + continue; + }; + let mut pending = pending_reader.lock().await; + if let Some(tx) = pending.remove(&id_key(id)) { + let _ = tx.send(message); + } + } + } + } + + let mut pending = pending_reader.lock().await; + for (_, tx) in pending.drain() { + let _ = tx.send(json!({ + "jsonrpc": "2.0", + "error": {"code": -32000, "message": close_reason}, + })); + } + }); + Ok(Self { + transport: ExternalMcpClientTransport::Stdio(ExternalMcpStdioTransport { + child, + stdin: BufWriter::new(child_stdin), + pending, + }), + next_id: 1, + }) + } + + fn start_http(url: &str, headers: &HashMap) -> Result { + let resolved_headers = resolve_http_headers(headers)?; + Ok(Self { + transport: ExternalMcpClientTransport::Http(ExternalMcpHttpTransport { + client: reqwest::Client::new(), + url: url.trim().to_string(), + headers: resolved_headers, + session_id: None, + }), + next_id: 1, + }) + } + + async fn initialize(&mut self) -> Result<(), String> { + self.request( + "initialize", + json!({ + "protocolVersion": PROTOCOL_VERSION, + "capabilities": {}, + "clientInfo": { + "name": "catdesk-mcp-gateway", + "version": env!("CARGO_PKG_VERSION"), + } + }), + ) + .await?; + self.notify("notifications/initialized", json!({})).await + } + + async fn list_tools(&mut self) -> Result, String> { + let mut tools = Vec::new(); + let mut cursor: Option = None; + loop { + let params = cursor + .as_ref() + .map(|value| json!({ "cursor": value })) + .unwrap_or_else(|| json!({})); + let result = self.request("tools/list", params).await?; + if let Some(entries) = result.get("tools").and_then(Value::as_array) { + tools.extend(entries.iter().cloned()); + } + cursor = result + .get("nextCursor") + .and_then(Value::as_str) + .map(str::to_string); + if cursor.is_none() { + break; + } + } + Ok(tools) + } + + async fn call_tool(&mut self, name: &str, arguments: Value) -> Result { + self.request( + "tools/call", + json!({ + "name": name, + "arguments": arguments, + }), + ) + .await + } + + async fn list_resources(&mut self) -> Result, String> { + let mut resources = Vec::new(); + let mut cursor: Option = None; + loop { + let params = cursor + .as_ref() + .map(|value| json!({ "cursor": value })) + .unwrap_or_else(|| json!({})); + let result = self.request("resources/list", params).await?; + if let Some(entries) = result.get("resources").and_then(Value::as_array) { + resources.extend(entries.iter().cloned()); + } + cursor = result + .get("nextCursor") + .and_then(Value::as_str) + .map(str::to_string); + if cursor.is_none() { + break; + } + } + Ok(resources) + } + + async fn read_resource(&mut self, uri: &str) -> Result { + self.request("resources/read", json!({ "uri": uri })).await + } + + async fn request(&mut self, method: &str, params: Value) -> Result { + let id = json!(self.next_id); + self.next_id = self.next_id.saturating_add(1); + let request = json!({ + "jsonrpc": "2.0", + "id": id, + "method": method, + "params": params, + }); + match &mut self.transport { + ExternalMcpClientTransport::Stdio(transport) => { + let line = serde_json::to_string(&request).map_err(|error| error.to_string())?; + let (tx, rx) = oneshot::channel(); + { + let mut pending = transport.pending.lock().await; + pending.insert(id_key(&id), tx); + } + if let Err(error) = transport.stdin.write_all(line.as_bytes()).await { + transport.pending.lock().await.remove(&id_key(&id)); + return Err(format!("stdin write: {error}")); + } + if let Err(error) = transport.stdin.write_all(b"\n").await { + transport.pending.lock().await.remove(&id_key(&id)); + return Err(format!("stdin write newline: {error}")); + } + if let Err(error) = transport.stdin.flush().await { + transport.pending.lock().await.remove(&id_key(&id)); + return Err(format!("stdin flush: {error}")); + } + let request_timeout = current_request_timeout(); + let response = tokio::time::timeout(request_timeout, rx) + .await + .map_err(|_| { + format!( + "downstream MCP request timed out after {}s", + request_timeout.as_secs() + ) + })? + .map_err(|_| "downstream MCP response channel closed".to_string())?; + parse_json_rpc_result(response) + } + ExternalMcpClientTransport::Http(transport) => transport.request(request).await, + } + } + + async fn stop(&mut self) { + if let ExternalMcpClientTransport::Stdio(transport) = &mut self.transport { + let _ = transport.child.kill().await; + } + } + + async fn notify(&mut self, method: &str, params: Value) -> Result<(), String> { + let notification = json!({ + "jsonrpc": "2.0", + "method": method, + "params": params, + }); + match &mut self.transport { + ExternalMcpClientTransport::Stdio(transport) => { + let line = + serde_json::to_string(¬ification).map_err(|error| error.to_string())?; + transport + .stdin + .write_all(line.as_bytes()) + .await + .map_err(|error| format!("stdin write: {error}"))?; + transport + .stdin + .write_all(b"\n") + .await + .map_err(|error| format!("stdin write newline: {error}"))?; + transport + .stdin + .flush() + .await + .map_err(|error| format!("stdin flush: {error}")) + } + ExternalMcpClientTransport::Http(transport) => transport.notify(notification).await, + } + } +} + +impl ExternalMcpHttpTransport { + async fn notify(&mut self, notification: Value) -> Result<(), String> { + let mut builder = self + .client + .post(&self.url) + .header(ACCEPT, "application/json, text/event-stream") + .header(CONTENT_TYPE, "application/json") + .header("MCP-Protocol-Version", PROTOCOL_VERSION) + .headers(self.headers.clone()) + .json(¬ification); + if let Some(session_id) = &self.session_id { + builder = builder.header("Mcp-Session-Id", session_id); + } + let request_timeout = current_request_timeout(); + let response = tokio::time::timeout(request_timeout, builder.send()) + .await + .map_err(|_| { + format!( + "downstream MCP HTTP notification timed out after {}s", + request_timeout.as_secs() + ) + })? + .map_err(|error| format!("downstream MCP HTTP notification failed: {error}"))?; + if !response.status().is_success() { + let status = response.status(); + let text = response.text().await.unwrap_or_default(); + return Err(format!( + "downstream MCP HTTP notification error {}: {}", + status.as_u16(), + text + )); + } + Ok(()) + } + + async fn request(&mut self, request: Value) -> Result { + let mut builder = self + .client + .post(&self.url) + .header(ACCEPT, "application/json, text/event-stream") + .header(CONTENT_TYPE, "application/json") + .header("MCP-Protocol-Version", PROTOCOL_VERSION) + .headers(self.headers.clone()) + .json(&request); + if let Some(session_id) = &self.session_id { + builder = builder.header("Mcp-Session-Id", session_id); + } + let request_timeout = current_request_timeout(); + let response = tokio::time::timeout(request_timeout, builder.send()) + .await + .map_err(|_| { + format!( + "downstream MCP HTTP request timed out after {}s", + request_timeout.as_secs() + ) + })? + .map_err(|error| format!("downstream MCP HTTP request failed: {error}"))?; + if let Some(session_id) = response + .headers() + .get("mcp-session-id") + .and_then(|value| value.to_str().ok()) + .filter(|value| !value.trim().is_empty()) + { + self.session_id = Some(session_id.to_string()); + } + let status = response.status(); + let content_type = parse_content_type_header(response.headers().get(CONTENT_TYPE)); + let text = tokio::time::timeout(request_timeout, response.text()) + .await + .map_err(|_| { + format!( + "downstream MCP HTTP response read timed out after {}s", + request_timeout.as_secs() + ) + })? + .map_err(|error| format!("downstream MCP HTTP response read failed: {error}"))?; + if text.len() > MAX_HTTP_RESPONSE_BYTES { + return Err(format!( + "downstream MCP HTTP response exceeded {} bytes", + MAX_HTTP_RESPONSE_BYTES + )); + } + if !status.is_success() { + return Err(format!( + "downstream MCP HTTP error {}: {}", + status.as_u16(), + text + )); + } + let response = parse_http_response_body(&text, content_type.as_deref())?; + parse_json_rpc_result(response) + } +} + +#[derive(Debug, PartialEq)] +enum ProxyAction { + Status, + Call { + tool: String, + args: Value, + server: Option, + }, + Connect { + server: String, + }, + Disconnect { + server: String, + }, + Describe { + query: String, + server: Option, + }, + Search { + query: String, + server: Option, + }, + Server { + server: String, + }, + ListResources { + server: Option, + }, + ReadResource { + uri: String, + server: Option, + }, +} + +impl ProxyAction { + fn from_arguments(arguments: &Value) -> Result { + let server = optional_string_owned(arguments, "server")?; + let tool = optional_string_owned(arguments, "tool")?; + let connect = optional_string_owned(arguments, "connect")?; + let disconnect = optional_string_owned(arguments, "disconnect")?; + let describe = optional_string_owned(arguments, "describe")?; + let search = optional_string_owned(arguments, "search")?; + let resource = optional_string_owned(arguments, "resource")?; + let resources = optional_bool(arguments, "resources")?; + let args_present = arguments.get("args").is_some(); + let primary_count = [ + tool.is_some(), + connect.is_some(), + disconnect.is_some(), + describe.is_some(), + search.is_some(), + resource.is_some(), + resources, + ] + .into_iter() + .filter(|present| *present) + .count(); + + if primary_count > 1 { + return Err( + "mcp proxy accepts exactly one action among tool, connect, disconnect, describe, search, resource, and resources" + .to_string(), + ); + } + if args_present && tool.is_none() { + return Err("args can only be used with tool calls".to_string()); + } + if let Some(tool) = tool { + return Ok(Self::Call { + tool, + args: parse_proxy_args(arguments)?, + server, + }); + } + if let Some(connect) = connect { + return Ok(Self::Connect { server: connect }); + } + if let Some(disconnect) = disconnect { + return Ok(Self::Disconnect { server: disconnect }); + } + if let Some(describe) = describe { + return Ok(Self::Describe { + query: describe, + server, + }); + } + if let Some(search) = search { + return Ok(Self::Search { + query: search, + server, + }); + } + if let Some(resource) = resource { + return Ok(Self::ReadResource { + uri: resource, + server, + }); + } + if resources { + return Ok(Self::ListResources { server }); + } + if let Some(server) = server { + return Ok(Self::Server { server }); + } + Ok(Self::Status) + } +} + +pub fn parse_proxy_args(arguments: &Value) -> Result { + let Some(raw) = arguments.get("args") else { + return Ok(json!({})); + }; + if raw.is_object() { + return Ok(raw.clone()); + } + let Some(raw) = raw.as_str() else { + return Err("args must be a JSON object or JSON object string".to_string()); + }; + if raw.trim().is_empty() { + return Ok(json!({})); + } + let parsed: Value = + serde_json::from_str(raw).map_err(|error| format!("invalid args JSON: {error}"))?; + if !parsed.is_object() { + return Err("args must be a JSON object or parse to a JSON object".to_string()); + } + Ok(parsed) +} + +pub fn sanitize_exposed_tool_name(server_name: &str, tool_name: &str) -> String { + let server = sanitize_identifier(server_name); + let tool = sanitize_identifier(tool_name); + if server.is_empty() { + tool + } else if tool.is_empty() { + server + } else { + format!("{server}_{tool}") + } +} + +fn exposed_tool_name(server_name: &str, server: &ExternalMcpServer, tool_name: &str) -> String { + if server.unprefixed_tools { + sanitize_identifier(tool_name) + } else { + sanitize_exposed_tool_name(server_name, tool_name) + } +} + +fn sanitize_identifier(value: &str) -> String { + let mut output = String::new(); + let mut last_was_separator = false; + for ch in value.chars() { + if ch.is_ascii_alphanumeric() { + output.push(ch.to_ascii_lowercase()); + last_was_separator = false; + } else if !last_was_separator { + output.push('_'); + last_was_separator = true; + } + } + output.trim_matches('_').to_string() +} + +fn optional_string<'a>(arguments: &'a Value, name: &str) -> Result, String> { + match arguments.get(name) { + Some(value) => value + .as_str() + .filter(|value| !value.trim().is_empty()) + .map(Some) + .ok_or_else(|| format!("{name} must be a non-empty string")), + None => Ok(None), + } +} + +fn optional_string_owned(arguments: &Value, name: &str) -> Result, String> { + Ok(optional_string(arguments, name)?.map(str::to_string)) +} + +fn optional_bool(arguments: &Value, name: &str) -> Result { + match arguments.get(name) { + Some(value) => value + .as_bool() + .ok_or_else(|| format!("{name} must be a boolean")), + None => Ok(false), + } +} + +fn id_key(value: &Value) -> String { + match value { + Value::String(text) => format!("s:{text}"), + _ => format!("j:{value}"), + } +} + +fn tool_meta_from_json( + server_name: &str, + server: &ExternalMcpServer, + tool: Value, +) -> Option { + let object = tool.as_object()?; + let original_name = object.get("name")?.as_str()?.to_string(); + let exposed_name = exposed_tool_name(server_name, server, &original_name); + let excludes = server + .exclude_tools + .iter() + .map(|value| value.as_str()) + .collect::>(); + if excludes.contains(original_name.as_str()) || excludes.contains(exposed_name.as_str()) { + return None; + } + let title = object + .get("title") + .and_then(Value::as_str) + .map(str::to_string); + let description = object + .get("description") + .and_then(Value::as_str) + .unwrap_or_default() + .to_string(); + let input_schema = object + .get("inputSchema") + .cloned() + .unwrap_or_else(|| json!({ "type": "object", "properties": {} })); + let annotations = object + .get("annotations") + .cloned() + .unwrap_or_else(|| json!({})); + Some(ExternalToolMeta { + server_name: server_name.to_string(), + original_name, + exposed_name, + title, + description, + input_schema, + annotations, + }) +} + +fn tool_meta_payload(meta: &ExternalToolMeta) -> Value { + json!({ + "server": meta.server_name, + "name": meta.exposed_name, + "downstreamName": meta.original_name, + "title": meta.title, + "description": meta.description, + "inputSchema": meta.input_schema, + "annotations": meta.annotations, + "callExample": { + "tool": meta.exposed_name, + "server": meta.server_name, + "args": "{}" + } + }) +} + +fn server_transport_name(server: &ExternalMcpServer) -> &'static str { + if server + .url + .as_deref() + .is_some_and(|value| !value.trim().is_empty()) + { + "http" + } else { + "stdio" + } +} + +fn redacted_headers_payload(headers: &HashMap) -> Value { + let mut entries = serde_json::Map::new(); + for name in sorted_keys(headers) { + entries.insert(name, json!("")); + } + Value::Object(entries) +} + +fn resolve_http_headers(headers: &HashMap) -> Result { + let mut output = HeaderMap::new(); + for (name, value) in headers { + let name = HeaderName::from_bytes(name.as_bytes()) + .map_err(|error| format!("invalid HTTP header name `{name}`: {error}"))?; + let value = resolve_env_interpolations(value)?; + let value = HeaderValue::from_str(&value) + .map_err(|error| format!("invalid HTTP header value for `{name}`: {error}"))?; + output.insert(name, value); + } + Ok(output) +} + +fn resolve_env_interpolations(value: &str) -> Result { + let mut output = String::new(); + let mut remaining = value; + while let Some(start) = remaining.find("${") { + output.push_str(&remaining[..start]); + let after_start = &remaining[start + 2..]; + let Some(end) = after_start.find('}') else { + output.push_str(&remaining[start..]); + return Ok(output); + }; + let name = &after_start[..end]; + if name.is_empty() { + return Err("empty environment variable interpolation".to_string()); + } + let resolved = std::env::var(name) + .map_err(|_| format!("missing environment variable for HTTP header: {name}"))?; + output.push_str(&resolved); + remaining = &after_start[end + 1..]; + } + output.push_str(remaining); + Ok(output) +} + +fn parse_http_response_body(text: &str, content_type: Option<&str>) -> Result { + let is_event_stream = content_type + .map(|value| value.to_ascii_lowercase().contains("text/event-stream")) + .unwrap_or_else(|| { + text.lines() + .any(|line| line.trim_start().starts_with("data:")) + }); + if is_event_stream { + return parse_sse_json_rpc_event(text); + } + serde_json::from_str(text) + .map_err(|error| format!("downstream MCP HTTP returned invalid JSON: {error}")) +} + +fn parse_sse_json_rpc_event(text: &str) -> Result { + let mut data_lines = Vec::new(); + let mut saw_data = false; + let mut last_error = None; + for line in text.lines().chain(std::iter::once("")) { + let line = line.trim_end_matches('\r'); + if let Some(data) = line.strip_prefix("data:") { + data_lines.push(data.trim_start()); + saw_data = true; + continue; + } + if line.is_empty() && !data_lines.is_empty() { + let payload = data_lines.join("\n"); + data_lines.clear(); + match serde_json::from_str::(&payload) { + Ok(value) if value.get("result").is_some() || value.get("error").is_some() => { + return Ok(value); + } + Ok(_) => {} + Err(error) => last_error = Some(error.to_string()), + } + } + } + if let Some(error) = last_error { + return Err(format!( + "downstream MCP HTTP SSE returned invalid JSON: {error}" + )); + } + if saw_data { + Err("downstream MCP HTTP SSE response contained no JSON-RPC response event".to_string()) + } else { + Err("downstream MCP HTTP SSE response contained no data event".to_string()) + } +} + +fn parse_content_type_header(value: Option<&HeaderValue>) -> Option { + value + .and_then(|value| value.to_str().ok()) + .map(str::to_string) +} + +fn parse_json_rpc_result(response: Value) -> Result { + if let Some(error) = response.get("error") { + return Err(format!("downstream MCP error: {error}")); + } + Ok(response.get("result").cloned().unwrap_or(Value::Null)) +} + +fn tool_meta_is_read_only(meta: &ExternalToolMeta) -> bool { + meta.annotations + .get("readOnlyHint") + .and_then(Value::as_bool) + .unwrap_or(false) +} + +fn direct_tool_annotations(meta: &ExternalToolMeta) -> Value { + if meta.annotations.is_object() + && !meta + .annotations + .as_object() + .is_some_and(|object| object.is_empty()) + { + meta.annotations.clone() + } else { + json!({ "readOnlyHint": false, "openWorldHint": true, "destructiveHint": true }) + } +} + +fn direct_tool_description(meta: &ExternalToolMeta) -> String { + if meta.description.trim().is_empty() { + format!( + "Direct tool from downstream MCP server `{}`. Original tool: `{}`.", + meta.server_name, meta.original_name + ) + } else { + format!( + "{}\n\nDownstream MCP server: `{}`. Original tool: `{}`.", + meta.description, meta.server_name, meta.original_name + ) + } +} + +fn resource_payload(server_name: &str, resource: Value) -> Value { + let mut object = resource.as_object().cloned().unwrap_or_default(); + object.insert("server".to_string(), json!(server_name)); + Value::Object(object) +} + +fn current_request_timeout() -> Duration { + #[cfg(test)] + { + TEST_REQUEST_TIMEOUT + } + #[cfg(not(test))] + { + REQUEST_TIMEOUT + } +} + +fn sorted_keys(map: &HashMap) -> Vec { + let mut keys = map.keys().cloned().collect::>(); + keys.sort(); + keys +} + +fn resolve_server_cwd(workspace_root: &Path, cwd: Option<&str>) -> Option { + let cwd = cwd?.trim(); + if cwd.is_empty() { + return None; + } + let path = PathBuf::from(cwd); + if path.is_absolute() { + Some(path) + } else { + Some(workspace_root.join(path)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_server(command: &str, args: &[&str]) -> ExternalMcpServer { + ExternalMcpServer { + unprefixed_tools: false, + command: Some(command.to_string()), + args: args.iter().map(|arg| arg.to_string()).collect(), + env: HashMap::new(), + cwd: None, + url: None, + headers: HashMap::new(), + lifecycle: "lazy".to_string(), + direct_tools: None, + exclude_tools: Vec::new(), + } + } + + fn mock_server_script() -> &'static str { + r#" +import json +import sys + +TOOLS_PAGE_1 = [ + { + "name": "echo", + "description": "Echo a message", + "inputSchema": { + "type": "object", + "properties": {"message": {"type": "string"}}, + }, + } +] +TOOLS_PAGE_2 = [ + { + "name": "status", + "description": "Read status", + "inputSchema": {"type": "object", "properties": {}}, + } +] +RESOURCES_PAGE_1 = [ + {"uri": "mock://alpha", "name": "Alpha", "description": "Alpha resource"} +] +RESOURCES_PAGE_2 = [ + {"uri": "mock://beta", "name": "Beta", "description": "Beta resource"} +] + +for line in sys.stdin: + message = json.loads(line) + method = message.get("method") + request_id = message.get("id") + if request_id is None: + continue + params = message.get("params", {}) + if method == "initialize": + result = { + "protocolVersion": "2025-03-26", + "capabilities": {"tools": {"listChanged": False}, "resources": {"listChanged": False}}, + "serverInfo": {"name": "mock", "version": "1.0.0"}, + } + elif method == "tools/list": + if params.get("cursor") == "tools-page-2": + result = {"tools": TOOLS_PAGE_2} + else: + result = {"tools": TOOLS_PAGE_1, "nextCursor": "tools-page-2"} + elif method == "tools/call": + name = params.get("name") + args = params.get("arguments", {}) + if name == "echo": + result = { + "content": [{"type": "text", "text": args.get("message", "")}], + "isError": False, + } + elif name == "status": + result = { + "content": [{"type": "text", "text": "ok"}], + "isError": False, + } + else: + print(json.dumps({ + "jsonrpc": "2.0", + "id": request_id, + "error": {"code": -32602, "message": f"unknown tool: {name}"}, + }), flush=True) + continue + elif method == "resources/list": + if params.get("cursor") == "resources-page-2": + result = {"resources": RESOURCES_PAGE_2} + else: + result = {"resources": RESOURCES_PAGE_1, "nextCursor": "resources-page-2"} + elif method == "resources/read": + uri = params.get("uri") + result = { + "contents": [{ + "uri": uri, + "mimeType": "text/plain", + "text": f"content for {uri}", + }] + } + else: + print(json.dumps({ + "jsonrpc": "2.0", + "id": request_id, + "error": {"code": -32601, "message": f"unknown method: {method}"}, + }), flush=True) + continue + print(json.dumps({"jsonrpc": "2.0", "id": request_id, "result": result}), flush=True) +"# + } + + fn unique_workspace(prefix: &str) -> PathBuf { + let unique = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_nanos(); + std::env::temp_dir().join(format!("{prefix}-{unique}")) + } + + fn write_mock_server(workspace_root: &Path) -> PathBuf { + std::fs::create_dir_all(workspace_root).expect("create workspace"); + let server_path = workspace_root.join("mock_mcp_server.py"); + std::fs::write(&server_path, mock_server_script()).expect("write mock server"); + server_path + } + + fn mock_manager_with_server( + server: ExternalMcpServer, + workspace_root: &Path, + ) -> ExternalMcpManager { + let mut servers = HashMap::new(); + servers.insert("mock".to_string(), server); + ExternalMcpManager::with_workspace( + ExternalMcpConfig { + mcp_servers: servers, + ..ExternalMcpConfig::default() + }, + workspace_root.to_path_buf(), + ) + } + + fn mock_stdio_server(server_path: &Path) -> ExternalMcpServer { + ExternalMcpServer { + command: Some("python3".to_string()), + args: vec!["-u".to_string(), server_path.to_string_lossy().into_owned()], + ..ExternalMcpServer::default() + } + } + + #[test] + fn sanitize_exposed_tool_names_prefixes_server_and_tool() { + assert_eq!( + sanitize_exposed_tool_name("Chrome DevTools", "take-screenshot"), + "chrome_devtools_take_screenshot" + ); + assert_eq!( + sanitize_exposed_tool_name("serena", "read_file"), + "serena_read_file" + ); + } + + #[test] + fn manager_status_reports_configured_servers() { + let mut servers = HashMap::new(); + servers.insert( + "serena".to_string(), + test_server("serena-mcp-server", &["--project", "."]), + ); + let config = ExternalMcpConfig { + mcp_servers: servers, + ..ExternalMcpConfig::default() + }; + let mut manager = ExternalMcpManager::new(config); + let status = manager.status_payload(); + + assert_eq!(status.get("serverCount").and_then(Value::as_u64), Some(1)); + assert_eq!( + status + .get("servers") + .and_then(Value::as_array) + .and_then(|servers| servers.first()) + .and_then(|server| server.get("name")) + .and_then(Value::as_str), + Some("serena") + ); + } + + #[test] + fn parse_proxy_args_accepts_json_object_string() { + let args = + parse_proxy_args(&json!({"args":"{\"path\":\"src/main.rs\"}"})).expect("parse args"); + assert_eq!( + args.get("path").and_then(Value::as_str), + Some("src/main.rs") + ); + } + + #[test] + fn parse_proxy_args_accepts_json_object() { + let args = + parse_proxy_args(&json!({"args":{"path":"src/main.rs"}})).expect("parse object args"); + assert_eq!( + args.get("path").and_then(Value::as_str), + Some("src/main.rs") + ); + } + + #[test] + fn parse_proxy_args_rejects_json_arrays() { + let error = parse_proxy_args(&json!({"args":"[]"})).expect_err("array should fail"); + assert!(error.contains("JSON object")); + } + + #[test] + fn proxy_action_rejects_multiple_actions() { + let error = ProxyAction::from_arguments(&json!({ + "tool": "mock_echo", + "search": "echo", + "args": "{}" + })) + .expect_err("multiple actions should fail"); + assert!(error.contains("exactly one action")); + } + + #[test] + fn proxy_action_rejects_args_without_tool() { + let error = ProxyAction::from_arguments(&json!({ + "search": "echo", + "args": "{}" + })) + .expect_err("args without tool should fail"); + assert!(error.contains("args can only be used with tool")); + } + + #[test] + fn proxy_action_rejects_invalid_args_json() { + let error = ProxyAction::from_arguments(&json!({ + "tool": "mock_echo", + "args": "{not json}" + })) + .expect_err("invalid JSON should fail"); + assert!(error.contains("invalid args JSON")); + } + + #[test] + fn parses_sse_json_rpc_http_response_body() { + let payload = + "event: message\ndata: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"ok\":true}}\n\n"; + let parsed = parse_http_response_body(payload, Some("text/event-stream")) + .expect("parse SSE JSON-RPC body"); + assert_eq!( + parsed + .get("result") + .and_then(|result| result.get("ok")) + .and_then(Value::as_bool), + Some(true) + ); + } + + #[test] + fn parses_sse_json_rpc_response_after_notification_event() { + let payload = concat!( + "event: message\n", + "data: {\"jsonrpc\":\"2.0\",\"method\":\"notifications/progress\",\"params\":{}}\n\n", + "event: message\n", + "data: {\"jsonrpc\":\"2.0\",\"id\":2,\"result\":{\"ok\":true}}\n\n", + ); + let parsed = parse_http_response_body(payload, Some("text/event-stream")) + .expect("parse SSE response after notification"); + assert_eq!( + parsed + .get("result") + .and_then(|result| result.get("ok")) + .and_then(Value::as_bool), + Some(true) + ); + } + + #[tokio::test] + async fn direct_tool_descriptors_read_only_filters_annotations() { + let mut servers = HashMap::new(); + servers.insert( + "mock".to_string(), + ExternalMcpServer { + command: Some("mock".to_string()), + direct_tools: Some(DirectToolsConfig::Enabled(true)), + ..ExternalMcpServer::default() + }, + ); + let mut manager = ExternalMcpManager::new(ExternalMcpConfig { + mcp_servers: servers, + ..ExternalMcpConfig::default() + }); + manager.set_cached_tools_for_test( + "mock", + vec![ + json!({ + "name": "safe", + "description": "Safe read-only tool", + "inputSchema": {"type":"object", "properties": {}}, + "annotations": {"readOnlyHint": true, "openWorldHint": false, "destructiveHint": false} + }), + json!({ + "name": "unsafe", + "description": "Unsafe tool", + "inputSchema": {"type":"object", "properties": {}}, + "annotations": {"readOnlyHint": false, "destructiveHint": true} + }), + json!({ + "name": "unknown", + "description": "Missing annotations", + "inputSchema": {"type":"object", "properties": {}} + }), + ], + ); + + let read_only = manager + .direct_tool_descriptors(true) + .await + .expect("read-only descriptors"); + let names = read_only + .iter() + .filter_map(|tool| tool.get("name").and_then(Value::as_str)) + .collect::>(); + assert_eq!(names, vec!["mock_safe"]); + assert_eq!( + read_only + .first() + .and_then(|tool| tool.get("annotations")) + .and_then(|annotations| annotations.get("readOnlyHint")) + .and_then(Value::as_bool), + Some(true) + ); + } + + #[test] + fn tool_metadata_skips_excluded_tools() { + let mut server = test_server("mock", &[]); + server.exclude_tools = vec!["danger".to_string()]; + let danger = tool_meta_from_json( + "mock", + &server, + json!({"name":"danger", "inputSchema": {"type":"object"}}), + ); + let safe = tool_meta_from_json( + "mock", + &server, + json!({"name":"read", "description":"Read something"}), + ); + assert!(danger.is_none()); + assert_eq!(safe.expect("safe tool").exposed_name, "mock_read"); + } + + #[tokio::test] + async fn manager_connects_and_calls_stdio_server() { + let workspace_root = unique_workspace("catdesk-external-mcp-call"); + let server_path = write_mock_server(&workspace_root); + let mut manager = + mock_manager_with_server(mock_stdio_server(&server_path), &workspace_root); + + let status = manager.connect("mock").await.expect("connect mock server"); + assert_eq!(status.get("toolCount").and_then(Value::as_u64), Some(2)); + + let call = manager + .call_tool("mock_echo", json!({"message":"hello"}), None) + .await + .expect("call downstream tool"); + assert_eq!(call.server_name, "mock"); + assert_eq!(call.original_name, "echo"); + assert_eq!(call.exposed_name, "mock_echo"); + assert_eq!( + call.result + .get("content") + .and_then(Value::as_array) + .and_then(|content| content.first()) + .and_then(|entry| entry.get("text")) + .and_then(Value::as_str), + Some("hello") + ); + + let _ = std::fs::remove_dir_all(workspace_root); + } + + #[tokio::test] + async fn proxy_search_returns_partial_results_when_one_server_fails() { + let workspace_root = unique_workspace("catdesk-external-mcp-partial-search"); + let server_path = write_mock_server(&workspace_root); + let mut servers = HashMap::new(); + servers.insert("mock".to_string(), mock_stdio_server(&server_path)); + servers.insert( + "broken".to_string(), + test_server("catdesk-missing-mcp-command-for-test", &[]), + ); + let mut manager = ExternalMcpManager::with_workspace( + ExternalMcpConfig { + mcp_servers: servers, + ..ExternalMcpConfig::default() + }, + workspace_root.clone(), + ); + + let output = manager + .proxy(&json!({"search": "echo"})) + .await + .expect("partial search should succeed"); + + assert!(output.text.contains("partial metadata refresh")); + assert_eq!( + output.structured.get("partial").and_then(Value::as_bool), + Some(true) + ); + assert_eq!( + output + .structured + .get("matches") + .and_then(Value::as_array) + .and_then(|matches| matches.first()) + .and_then(|tool| tool.get("name")) + .and_then(Value::as_str), + Some("mock_echo") + ); + let refresh_failures = output + .structured + .get("refreshFailures") + .and_then(Value::as_array) + .expect("refresh failures should be present"); + assert_eq!(refresh_failures.len(), 1); + assert!( + refresh_failures + .first() + .and_then(Value::as_str) + .is_some_and(|failure| failure.contains("broken:")) + ); + + let _ = std::fs::remove_dir_all(workspace_root); + } + + #[tokio::test] + async fn proxy_search_with_server_hint_keeps_refresh_failure_strict() { + let workspace_root = unique_workspace("catdesk-external-mcp-strict-search"); + std::fs::create_dir_all(&workspace_root).expect("create workspace"); + let mut servers = HashMap::new(); + servers.insert( + "broken".to_string(), + test_server("catdesk-missing-mcp-command-for-test", &[]), + ); + let mut manager = ExternalMcpManager::with_workspace( + ExternalMcpConfig { + mcp_servers: servers, + ..ExternalMcpConfig::default() + }, + workspace_root.clone(), + ); + + let error = match manager + .proxy(&json!({"search": "echo", "server": "broken"})) + .await + { + Ok(_) => panic!("server-scoped search should fail for the target server"), + Err(error) => error, + }; + + assert!(error.contains("spawn catdesk-missing-mcp-command-for-test")); + + let _ = std::fs::remove_dir_all(workspace_root); + } + + #[tokio::test] + async fn manager_lists_and_reads_downstream_resources() { + let workspace_root = unique_workspace("catdesk-external-mcp-resources"); + let server_path = write_mock_server(&workspace_root); + let mut manager = + mock_manager_with_server(mock_stdio_server(&server_path), &workspace_root); + + let resources = manager + .list_resources(Some("mock")) + .await + .expect("list resources"); + assert_eq!(resources.len(), 2); + assert_eq!( + resources + .iter() + .find(|resource| resource.get("uri").and_then(Value::as_str) == Some("mock://beta")) + .and_then(|resource| resource.get("server")) + .and_then(Value::as_str), + Some("mock") + ); + + let read = manager + .read_resource("mock://alpha", Some("mock")) + .await + .expect("read resource"); + assert_eq!(read.server_name, "mock"); + assert_eq!(read.uri, "mock://alpha"); + assert_eq!( + read.result + .get("contents") + .and_then(Value::as_array) + .and_then(|contents| contents.first()) + .and_then(|entry| entry.get("text")) + .and_then(Value::as_str), + Some("content for mock://alpha") + ); + + let _ = std::fs::remove_dir_all(workspace_root); + } + + #[tokio::test] + async fn proxy_lists_and_reads_resources() { + let workspace_root = unique_workspace("catdesk-external-mcp-proxy-resources"); + let server_path = write_mock_server(&workspace_root); + let mut manager = + mock_manager_with_server(mock_stdio_server(&server_path), &workspace_root); + + let list = manager + .proxy(&json!({"resources": true, "server": "mock"})) + .await + .expect("proxy list resources"); + assert_eq!( + list.structured.get("action").and_then(Value::as_str), + Some("resources") + ); + assert_eq!( + list.structured + .get("resources") + .and_then(Value::as_array) + .map(Vec::len), + Some(2) + ); + + let read = manager + .proxy(&json!({"resource": "mock://beta", "server": "mock"})) + .await + .expect("proxy read resource"); + assert_eq!( + read.structured.get("action").and_then(Value::as_str), + Some("readResource") + ); + assert_eq!( + read.structured + .get("result") + .and_then(|result| result.get("contents")) + .and_then(Value::as_array) + .and_then(|contents| contents.first()) + .and_then(|entry| entry.get("text")) + .and_then(Value::as_str), + Some("content for mock://beta") + ); + + let _ = std::fs::remove_dir_all(workspace_root); + } + + #[test] + fn tool_metadata_can_expose_unprefixed_tools_for_browser_gateway() { + let server = ExternalMcpServer { + command: Some("npx".to_string()), + unprefixed_tools: true, + direct_tools: Some(DirectToolsConfig::Enabled(true)), + ..ExternalMcpServer::default() + }; + let meta = tool_meta_from_json( + "browser", + &server, + json!({"name":"take-screenshot", "inputSchema": {"type":"object"}}), + ) + .expect("tool metadata"); + assert_eq!(meta.exposed_name, "take_screenshot"); + } + + #[test] + fn tui_status_snapshot_reports_gateway_counts() { + let mut servers = HashMap::new(); + servers.insert( + "mock".to_string(), + ExternalMcpServer { + command: Some("mock".to_string()), + ..ExternalMcpServer::default() + }, + ); + let mut manager = ExternalMcpManager::new(ExternalMcpConfig { + mcp_servers: servers, + ..ExternalMcpConfig::default() + }); + manager.set_cached_tools_for_test( + "mock", + vec![json!({"name":"read", "inputSchema": {"type":"object"}})], + ); + let snapshot = manager.tui_status_snapshot(2, true); + assert_eq!(snapshot.configured_server_count, 1); + assert_eq!(snapshot.connected_server_count, 0); + assert_eq!(snapshot.failed_server_count, 2); + assert_eq!(snapshot.tool_count, 1); + assert!(snapshot.browser_gateway_enabled); + } + + #[test] + fn direct_tool_name_candidate_matches_global_and_server_forms() { + let mut servers = HashMap::new(); + servers.insert( + "global".to_string(), + ExternalMcpServer { + command: Some("mock".to_string()), + ..ExternalMcpServer::default() + }, + ); + servers.insert( + "server".to_string(), + ExternalMcpServer { + command: Some("mock".to_string()), + direct_tools: Some(DirectToolsConfig::Enabled(true)), + ..ExternalMcpServer::default() + }, + ); + servers.insert( + "allow".to_string(), + ExternalMcpServer { + command: Some("mock".to_string()), + direct_tools: Some(DirectToolsConfig::Names(vec!["echo".to_string()])), + ..ExternalMcpServer::default() + }, + ); + let manager = ExternalMcpManager::new(ExternalMcpConfig { + settings: crate::state::ExternalMcpSettings { + direct_tools: true, + ..Default::default() + }, + mcp_servers: servers, + }); + + assert!(manager.direct_tool_name_candidate("global_echo")); + assert!(manager.direct_tool_name_candidate("server_status")); + assert!(manager.direct_tool_name_candidate("allow_echo")); + assert!(!manager.direct_tool_name_candidate("allow_status")); + } + + #[tokio::test] + async fn direct_tool_descriptors_keep_duplicate_names_unambiguous() { + let mut servers = HashMap::new(); + servers.insert( + "alpha".to_string(), + ExternalMcpServer { + command: Some("mock".to_string()), + direct_tools: Some(DirectToolsConfig::Enabled(true)), + ..ExternalMcpServer::default() + }, + ); + servers.insert( + "beta".to_string(), + ExternalMcpServer { + command: Some("mock".to_string()), + direct_tools: Some(DirectToolsConfig::Enabled(true)), + ..ExternalMcpServer::default() + }, + ); + let mut manager = ExternalMcpManager::new(ExternalMcpConfig { + mcp_servers: servers, + ..ExternalMcpConfig::default() + }); + let echo_tool = json!({ + "name": "echo", + "description": "Echo a message", + "inputSchema": {"type": "object", "properties": {}} + }); + manager.set_cached_tools_for_test("alpha", vec![echo_tool.clone()]); + manager.set_cached_tools_for_test("beta", vec![echo_tool]); + + let descriptors = manager + .direct_tool_descriptors(false) + .await + .expect("direct descriptors"); + let names = descriptors + .iter() + .filter_map(|tool| tool.get("name").and_then(Value::as_str)) + .collect::>(); + assert_eq!(names, vec!["alpha_echo", "beta_echo"]); + } + + #[tokio::test] + async fn original_tool_name_requires_server_when_ambiguous() { + let mut servers = HashMap::new(); + servers.insert( + "alpha".to_string(), + ExternalMcpServer { + command: Some("mock".to_string()), + ..ExternalMcpServer::default() + }, + ); + servers.insert( + "beta".to_string(), + ExternalMcpServer { + command: Some("mock".to_string()), + ..ExternalMcpServer::default() + }, + ); + let mut manager = ExternalMcpManager::new(ExternalMcpConfig { + mcp_servers: servers, + ..ExternalMcpConfig::default() + }); + let echo_tool = json!({ + "name": "echo", + "description": "Echo a message", + "inputSchema": {"type": "object", "properties": {}} + }); + manager.set_cached_tools_for_test("alpha", vec![echo_tool.clone()]); + manager.set_cached_tools_for_test("beta", vec![echo_tool]); + + let error = manager + .call_tool("echo", json!({}), None) + .await + .expect_err("ambiguous original tool name should fail"); + assert!(error.contains("ambiguous downstream MCP tool `echo`")); + assert!(error.contains("alpha:echo")); + assert!(error.contains("beta:echo")); + } + + #[tokio::test] + async fn manager_connects_and_calls_http_server_with_headers_and_session() { + use axum::extract::State; + use axum::http::{HeaderMap as AxumHeaderMap, HeaderValue as AxumHeaderValue}; + use axum::response::IntoResponse; + use axum::routing::post; + use axum::{Json, Router}; + + async fn handle_http_mcp( + State(requests): State>>>, + headers: AxumHeaderMap, + Json(message): Json, + ) -> impl IntoResponse { + let method = message + .get("method") + .and_then(Value::as_str) + .unwrap_or_default() + .to_string(); + requests.lock().await.push(json!({ + "method": method, + "authorization": headers.get("authorization").and_then(|value| value.to_str().ok()), + "xApiKey": headers.get("x-api-key").and_then(|value| value.to_str().ok()), + "session": headers.get("mcp-session-id").and_then(|value| value.to_str().ok()), + "accept": headers.get("accept").and_then(|value| value.to_str().ok()), + "protocol": headers.get("mcp-protocol-version").and_then(|value| value.to_str().ok()), + })); + let id = message.get("id").cloned().unwrap_or(Value::Null); + let result = match method.as_str() { + "initialize" => json!({ + "protocolVersion": PROTOCOL_VERSION, + "capabilities": {"tools": {"listChanged": false}}, + "serverInfo": {"name": "http-mock", "version": "1.0.0"}, + }), + "tools/list" => json!({ + "tools": [{ + "name": "echo", + "description": "Echo over HTTP", + "inputSchema": {"type":"object", "properties": {"message": {"type":"string"}}}, + }] + }), + "tools/call" => { + let text = message + .get("params") + .and_then(|params| params.get("arguments")) + .and_then(|args| args.get("message")) + .and_then(Value::as_str) + .unwrap_or_default(); + json!({"content": [{"type":"text", "text": text}], "isError": false}) + } + _ => json!({}), + }; + let mut response = + Json(json!({"jsonrpc":"2.0", "id": id, "result": result})).into_response(); + response + .headers_mut() + .insert("mcp-session-id", AxumHeaderValue::from_static("session-1")); + response + } + + let requests = Arc::new(Mutex::new(Vec::new())); + let app = Router::new() + .route("/mcp", post(handle_http_mcp)) + .with_state(requests.clone()); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind mock HTTP server"); + let addr = listener.local_addr().expect("local addr"); + tokio::spawn(async move { + let _ = axum::serve(listener, app).await; + }); + + let mut headers = HashMap::new(); + headers.insert("Authorization".to_string(), "Bearer test-token".to_string()); + headers.insert("X-Api-Key".to_string(), "${PATH}".to_string()); + let mut servers = HashMap::new(); + servers.insert( + "http".to_string(), + ExternalMcpServer { + url: Some(format!("http://{addr}/mcp")), + headers, + ..ExternalMcpServer::default() + }, + ); + let mut manager = ExternalMcpManager::new(ExternalMcpConfig { + mcp_servers: servers, + ..ExternalMcpConfig::default() + }); + + let status = manager.connect("http").await.expect("connect HTTP server"); + assert_eq!( + status.get("transport").and_then(Value::as_str), + Some("http") + ); + assert_eq!(status.get("toolCount").and_then(Value::as_u64), Some(1)); + let call = manager + .call_tool("http_echo", json!({"message":"hello http"}), None) + .await + .expect("call HTTP downstream tool"); + assert_eq!( + call.result + .get("content") + .and_then(Value::as_array) + .and_then(|content| content.first()) + .and_then(|entry| entry.get("text")) + .and_then(Value::as_str), + Some("hello http") + ); + + let requests = requests.lock().await; + assert!(requests.iter().any(|request| { + request.get("authorization").and_then(Value::as_str) == Some("Bearer test-token") + && request.get("xApiKey").and_then(Value::as_str) + == std::env::var("PATH").ok().as_deref() + })); + assert!(requests.iter().any(|request| { + request.get("method").and_then(Value::as_str) == Some("notifications/initialized") + && request.get("session").and_then(Value::as_str) == Some("session-1") + })); + assert!(requests.iter().any(|request| { + request.get("method").and_then(Value::as_str) == Some("tools/list") + && request.get("session").and_then(Value::as_str) == Some("session-1") + })); + assert!(requests.iter().any(|request| { + request.get("accept").and_then(Value::as_str) + == Some("application/json, text/event-stream") + && request.get("protocol").and_then(Value::as_str) == Some(PROTOCOL_VERSION) + })); + } + + #[tokio::test] + async fn http_transport_accepts_sse_json_rpc_responses() { + use axum::extract::State; + use axum::http::HeaderValue as AxumHeaderValue; + use axum::response::{IntoResponse, Response}; + use axum::routing::post; + use axum::{Json, Router}; + + async fn handle_sse_http_mcp( + State(requests): State>>>, + Json(message): Json, + ) -> Response { + let method = message + .get("method") + .and_then(Value::as_str) + .unwrap_or_default() + .to_string(); + requests.lock().await.push(json!({"method": method})); + let id = message.get("id").cloned().unwrap_or(Value::Null); + let result = match method.as_str() { + "initialize" => json!({ + "protocolVersion": PROTOCOL_VERSION, + "capabilities": {"tools": {"listChanged": false}}, + "serverInfo": {"name": "sse-mock", "version": "1.0.0"}, + }), + "tools/list" => json!({ + "tools": [{ + "name": "echo", + "description": "Echo through SSE", + "inputSchema": {"type":"object", "properties": {"message": {"type":"string"}}}, + }] + }), + "tools/call" => { + json!({"content": [{"type":"text", "text": "sse-ok"}], "isError": false}) + } + _ => json!({}), + }; + let body = format!( + "event: message\ndata: {}\n\n", + json!({"jsonrpc":"2.0", "id": id, "result": result}) + ); + let mut response = body.into_response(); + response.headers_mut().insert( + "content-type", + AxumHeaderValue::from_static("text/event-stream"), + ); + response.headers_mut().insert( + "mcp-session-id", + AxumHeaderValue::from_static("sse-session"), + ); + response + } + + let requests = Arc::new(Mutex::new(Vec::new())); + let app = Router::new() + .route("/mcp", post(handle_sse_http_mcp)) + .with_state(requests.clone()); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind SSE HTTP server"); + let addr = listener.local_addr().expect("local addr"); + tokio::spawn(async move { + let _ = axum::serve(listener, app).await; + }); + + let mut servers = HashMap::new(); + servers.insert( + "sse".to_string(), + ExternalMcpServer { + url: Some(format!("http://{addr}/mcp")), + ..ExternalMcpServer::default() + }, + ); + let mut manager = ExternalMcpManager::new(ExternalMcpConfig { + mcp_servers: servers, + ..ExternalMcpConfig::default() + }); + manager + .connect("sse") + .await + .expect("connect SSE HTTP server"); + let call = manager + .call_tool("sse_echo", json!({"message":"ignored"}), None) + .await + .expect("call SSE HTTP tool"); + assert_eq!( + call.result + .get("content") + .and_then(Value::as_array) + .and_then(|content| content.first()) + .and_then(|entry| entry.get("text")) + .and_then(Value::as_str), + Some("sse-ok") + ); + let requests = requests.lock().await; + assert!( + requests + .iter() + .any(|request| request.get("method").and_then(Value::as_str) + == Some("notifications/initialized")) + ); + } + + #[tokio::test] + async fn http_transport_reports_json_rpc_and_http_errors() { + use axum::extract::State; + use axum::http::StatusCode; + use axum::routing::post; + use axum::{Json, Router}; + + async fn handle_http_error_mcp( + State(mode): State, + Json(message): Json, + ) -> (StatusCode, Json) { + let id = message.get("id").cloned().unwrap_or(Value::Null); + if mode == "http" { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"message":"boom"})), + ); + } + ( + StatusCode::OK, + Json(json!({ + "jsonrpc":"2.0", + "id": id, + "error": {"code": -32001, "message": "json rpc boom"}, + })), + ) + } + + async fn spawn_error_server(mode: &str) -> String { + let app = Router::new() + .route("/mcp", post(handle_http_error_mcp)) + .with_state(mode.to_string()); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind error HTTP server"); + let addr = listener.local_addr().expect("local addr"); + tokio::spawn(async move { + let _ = axum::serve(listener, app).await; + }); + format!("http://{addr}/mcp") + } + + for (name, mode, expected) in [ + ("json", "json", "downstream MCP error"), + ("http", "http", "downstream MCP HTTP error 500"), + ] { + let mut servers = HashMap::new(); + servers.insert( + name.to_string(), + ExternalMcpServer { + url: Some(spawn_error_server(mode).await), + ..ExternalMcpServer::default() + }, + ); + let mut manager = ExternalMcpManager::new(ExternalMcpConfig { + mcp_servers: servers, + ..ExternalMcpConfig::default() + }); + let error = manager + .connect(name) + .await + .expect_err("HTTP error should fail"); + assert!(error.contains(expected), "unexpected error: {error}"); + } + } + + #[test] + fn http_header_interpolation_reports_missing_environment_variables() { + let mut headers = HashMap::new(); + headers.insert( + "Authorization".to_string(), + "Bearer ${CATDESK_MISSING_ENV_FOR_TEST}".to_string(), + ); + let error = resolve_http_headers(&headers).expect_err("missing env should fail"); + assert!(error.contains("CATDESK_MISSING_ENV_FOR_TEST")); + } + + #[test] + fn status_payload_redacts_http_headers() { + let mut headers = HashMap::new(); + headers.insert("Authorization".to_string(), "Bearer secret".to_string()); + let mut servers = HashMap::new(); + servers.insert( + "http".to_string(), + ExternalMcpServer { + url: Some("http://127.0.0.1:3000/mcp".to_string()), + headers, + ..ExternalMcpServer::default() + }, + ); + let mut manager = ExternalMcpManager::new(ExternalMcpConfig { + mcp_servers: servers, + ..ExternalMcpConfig::default() + }); + let status = manager.status_payload(); + assert_eq!( + status + .get("servers") + .and_then(Value::as_array) + .and_then(|servers| servers.first()) + .and_then(|server| server.get("headers")) + .and_then(|headers| headers.get("Authorization")) + .and_then(Value::as_str), + Some("") + ); + } + + #[test] + fn status_reports_empty_gateway_message_and_lifecycle_fields() { + let mut empty_manager = ExternalMcpManager::new(ExternalMcpConfig::default()); + let empty_status = empty_manager.status_payload(); + assert_eq!( + empty_status.get("message").and_then(Value::as_str), + Some( + "No downstream MCP servers configured. Add [mcp.mcpServers.] entries to ~/.catdesk/config.toml." + ) + ); + + let mut servers = HashMap::new(); + servers.insert( + "keep".to_string(), + ExternalMcpServer { + command: Some("mock".to_string()), + lifecycle: "keep_alive".to_string(), + ..ExternalMcpServer::default() + }, + ); + let mut manager = ExternalMcpManager::new(ExternalMcpConfig { + mcp_servers: servers, + ..ExternalMcpConfig::default() + }); + let status = manager.status_payload(); + let server = status + .get("servers") + .and_then(Value::as_array) + .and_then(|servers| servers.first()) + .expect("missing server status"); + assert_eq!( + server.get("lifecycle").and_then(Value::as_str), + Some("keep-alive") + ); + assert_eq!(server.get("keepAlive").and_then(Value::as_bool), Some(true)); + assert_eq!( + status.get("idleTimeoutMinutes").and_then(Value::as_u64), + Some(10) + ); + } + + #[tokio::test] + async fn proxy_disconnect_removes_connected_server() { + let workspace_root = unique_workspace("catdesk-external-mcp-disconnect"); + let server_path = write_mock_server(&workspace_root); + let mut manager = + mock_manager_with_server(mock_stdio_server(&server_path), &workspace_root); + + manager.connect("mock").await.expect("connect mock server"); + assert_eq!(manager.connected_server_count(), 1); + let output = manager + .proxy(&json!({"disconnect": "mock"})) + .await + .expect("disconnect through proxy"); + assert_eq!( + output.structured.get("action").and_then(Value::as_str), + Some("disconnect") + ); + assert_eq!( + output + .structured + .get("disconnected") + .and_then(Value::as_bool), + Some(true) + ); + assert_eq!(manager.connected_server_count(), 0); + + let _ = std::fs::remove_dir_all(workspace_root); + } + + #[tokio::test] + async fn idle_reaping_disconnects_lazy_servers_and_keeps_keep_alive_servers() { + let workspace_root = unique_workspace("catdesk-external-mcp-idle"); + let server_path = write_mock_server(&workspace_root); + let mut lazy_server = mock_stdio_server(&server_path); + lazy_server.lifecycle = "lazy".to_string(); + let mut keep_server = mock_stdio_server(&server_path); + keep_server.lifecycle = "keep-alive".to_string(); + + let mut servers = HashMap::new(); + servers.insert("lazy".to_string(), lazy_server); + servers.insert("keep".to_string(), keep_server); + let mut manager = ExternalMcpManager::with_workspace( + ExternalMcpConfig { + mcp_servers: servers, + ..ExternalMcpConfig::default() + }, + workspace_root.clone(), + ); + + manager.connect("lazy").await.expect("connect lazy server"); + manager.connect("keep").await.expect("connect keep server"); + manager.mark_connection_idle_for_test("lazy", Duration::from_secs(11 * 60)); + manager.mark_connection_idle_for_test("keep", Duration::from_secs(11 * 60)); + let reaped = manager + .reap_idle_connections() + .expect("reap idle connections"); + + assert_eq!(reaped, vec!["lazy".to_string()]); + assert_eq!(manager.connected_server_count(), 1); + assert!( + manager + .status_payload() + .get("servers") + .and_then(Value::as_array) + .and_then(|servers| servers + .iter() + .find(|server| server.get("name").and_then(Value::as_str) == Some("keep"))) + .and_then(|server| server.get("connected")) + .and_then(Value::as_bool) + .unwrap_or(false) + ); + + let _ = std::fs::remove_dir_all(workspace_root); + } + + #[tokio::test] + async fn shutdown_all_disconnects_all_servers() { + let workspace_root = unique_workspace("catdesk-external-mcp-shutdown"); + let server_path = write_mock_server(&workspace_root); + let mut manager = + mock_manager_with_server(mock_stdio_server(&server_path), &workspace_root); + + manager.connect("mock").await.expect("connect mock server"); + assert_eq!(manager.connected_server_count(), 1); + let output = manager.shutdown_all().await; + assert_eq!( + output.get("action").and_then(Value::as_str), + Some("shutdown") + ); + assert_eq!( + output.get("connectedCount").and_then(Value::as_u64), + Some(0) + ); + assert_eq!(manager.connected_server_count(), 0); + + let _ = std::fs::remove_dir_all(workspace_root); + } + + #[tokio::test] + async fn describe_output_includes_proxy_call_example() { + let mut servers = HashMap::new(); + servers.insert( + "mock".to_string(), + ExternalMcpServer { + command: Some("mock".to_string()), + ..ExternalMcpServer::default() + }, + ); + let mut manager = ExternalMcpManager::new(ExternalMcpConfig { + mcp_servers: servers, + ..ExternalMcpConfig::default() + }); + manager.set_cached_tools_for_test( + "mock", + vec![json!({ + "name": "echo", + "description": "Echo a message", + "inputSchema": {"type": "object", "properties": {"message": {"type": "string"}}} + })], + ); + + let output = manager + .proxy(&json!({"describe": "mock_echo"})) + .await + .expect("describe direct tool"); + let call_example = output + .structured + .get("matches") + .and_then(Value::as_array) + .and_then(|matches| matches.first()) + .and_then(|tool| tool.get("callExample")) + .expect("missing call example"); + assert_eq!( + call_example.get("tool").and_then(Value::as_str), + Some("mock_echo") + ); + assert_eq!( + call_example.get("server").and_then(Value::as_str), + Some("mock") + ); + } + + #[tokio::test] + async fn downstream_json_rpc_error_is_reported() { + let workspace_root = unique_workspace("catdesk-external-mcp-downstream-error"); + let server_path = write_mock_server(&workspace_root); + let mut manager = + mock_manager_with_server(mock_stdio_server(&server_path), &workspace_root); + + let error = manager + .call_tool("missing", json!({}), Some("mock")) + .await + .expect_err("missing downstream tool should fail"); + assert!( + error.contains("unknown downstream MCP tool: missing") + || error.contains("downstream MCP error") + ); + + let _ = std::fs::remove_dir_all(workspace_root); + } + + #[tokio::test] + async fn malformed_downstream_json_response_is_reported_as_timeout() { + let workspace_root = unique_workspace("catdesk-external-mcp-malformed"); + std::fs::create_dir_all(&workspace_root).expect("create workspace"); + let server_path = workspace_root.join("malformed_mcp_server.py"); + std::fs::write( + &server_path, + r#" +import sys +for line in sys.stdin: + print("not-json", flush=True) + break +"#, + ) + .expect("write malformed server"); + let mut manager = + mock_manager_with_server(mock_stdio_server(&server_path), &workspace_root); + + let error = manager + .connect("mock") + .await + .expect_err("malformed response should fail"); + assert!( + error.contains("malformed JSON") + || error.contains("timed out") + || error.contains("closed") + ); + + let _ = std::fs::remove_dir_all(workspace_root); + } + + #[tokio::test] + async fn downstream_process_exit_before_response_is_reported() { + let workspace_root = unique_workspace("catdesk-external-mcp-exit"); + std::fs::create_dir_all(&workspace_root).expect("create workspace"); + let server_path = workspace_root.join("exit_mcp_server.py"); + std::fs::write(&server_path, "import sys\nsys.exit(0)\n").expect("write exit server"); + let mut manager = + mock_manager_with_server(mock_stdio_server(&server_path), &workspace_root); + + let error = manager + .connect("mock") + .await + .expect_err("early exit should fail"); + assert!( + error.contains("closed") + || error.contains("timed out") + || error.contains("Broken pipe") + ); + + let _ = std::fs::remove_dir_all(workspace_root); + } + + #[tokio::test] + async fn direct_tool_descriptors_obey_global_and_server_allowlists() { + let workspace_root = unique_workspace("catdesk-external-mcp-direct"); + let server_path = write_mock_server(&workspace_root); + let mut server = mock_stdio_server(&server_path); + server.direct_tools = Some(DirectToolsConfig::Names(vec!["echo".to_string()])); + let mut manager = mock_manager_with_server(server, &workspace_root); + + let descriptors = manager + .direct_tool_descriptors(false) + .await + .expect("direct descriptors"); + let names = descriptors + .iter() + .filter_map(|tool| tool.get("name").and_then(Value::as_str)) + .collect::>(); + assert_eq!(names, vec!["mock_echo"]); + + let call = manager + .call_direct_tool("mock_echo", json!({"message":"direct"}), false) + .await + .expect("direct call result") + .expect("direct tool matched"); + assert_eq!(call.original_name, "echo"); + assert_eq!( + call.result + .get("content") + .and_then(Value::as_array) + .and_then(|content| content.first()) + .and_then(|entry| entry.get("text")) + .and_then(Value::as_str), + Some("direct") + ); + + let no_match = manager + .call_direct_tool("mock_status", json!({}), false) + .await + .expect("direct call lookup"); + assert!(no_match.is_none()); + + let _ = std::fs::remove_dir_all(workspace_root); + } +} diff --git a/src/main.rs b/src/main.rs index e887327..9796322 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,11 +2,13 @@ mod binagotchy_gen; mod browser; mod command; mod devtools; +mod external_mcp; mod macos_terminal; mod mascot; mod mcp; mod ngrok; mod server; +mod skills; mod state; mod theme; mod workspace_tools; @@ -2396,42 +2398,113 @@ async fn start_services( } } - // Start MCP HTTP server - let devtools_bridge = if mode.browser_enabled() { + // Browser tools are injected into the generic external MCP gateway below. + let devtools_bridge = None; + if mode.browser_enabled() { + let mut app = state.lock().await; if selected_browser.is_none() { - state.lock().await.log( + app.log( "ERROR", "Browser mode requires selecting a supported Chromium browser".into(), ); - None } else { - state - .lock() - .await - .log("INFO", "Starting chrome-devtools-mcp...".into()); - match DevtoolsBridge::start(selected_browser.as_ref()).await { - Ok(bridge) => { - let mut app = state.lock().await; + app.log( + "INFO", + "Browser tools will be served through the external MCP gateway".into(), + ); + } + } + + let external_mcp = { + let workspace_root = { + let app = state.lock().await; + app.workspace_root.clone() + }; + let mut app_mcp_config = state::load_app_config() + .map(|config| config.mcp) + .unwrap_or_default(); + let browser_gateway_enabled = mode.browser_enabled() && selected_browser.is_some(); + if browser_gateway_enabled { + app_mcp_config.mcp_servers.insert( + crate::devtools::BROWSER_MCP_SERVER_NAME.to_string(), + crate::devtools::chrome_devtools_mcp_server(selected_browser.as_ref()), + ); + } + let mut manager = crate::external_mcp::ExternalMcpManager::from_workspace_and_app_config( + &workspace_root, + app_mcp_config, + ); + let configured_servers = manager.configured_server_names(); + let initial_status = manager.tui_status_snapshot(0, browser_gateway_enabled); + { + let mut app = state.lock().await; + app.external_mcp_status = initial_status; + if configured_servers.is_empty() { + app.log( + "INFO", + "External MCP gateway ready with no configured downstream servers".into(), + ); + } else { + app.log( + "INFO", + format!( + "External MCP gateway configured servers: {}", + configured_servers.join(", ") + ), + ); + } + } + Arc::new(Mutex::new(manager)) + }; + + let browser_gateway_enabled = mode.browser_enabled() && selected_browser.is_some(); + let mut external_mcp_failed_count = 0usize; + let eager_server_names = { external_mcp.lock().await.eager_server_names() }; + for server_name in eager_server_names { + let result = { + let mut manager = external_mcp.lock().await; + manager.connect(&server_name).await + }; + let mut app = state.lock().await; + match result { + Ok(_) => { + if server_name == crate::devtools::BROWSER_MCP_SERVER_NAME { app.devtools_running = true; - app.log("INFO", "chrome-devtools-mcp started".into()); - Some(bridge) } - Err(e) => { - let mut app = state.lock().await; - app.log("ERROR", format!("chrome-devtools-mcp: {e}")); - None + app.log( + "INFO", + format!("External MCP server connected: {server_name}"), + ); + } + Err(error) => { + external_mcp_failed_count = external_mcp_failed_count.saturating_add(1); + if server_name == crate::devtools::BROWSER_MCP_SERVER_NAME { + app.devtools_running = false; } + app.log( + "WARN", + format!("External MCP server {server_name} connection failed: {error}"), + ); } } - } else { - None - }; + let snapshot = { + let mut manager = external_mcp.lock().await; + manager.tui_status_snapshot(external_mcp_failed_count, browser_gateway_enabled) + }; + app.external_mcp_status = snapshot; + } let mcp_path = { let app = state.lock().await; app.mcp_path() }; - let router = server::router(state.clone(), devtools_bridge.clone(), mcp_path, ui_events); + let router = server::router( + state.clone(), + devtools_bridge.clone(), + Some(external_mcp.clone()), + mcp_path, + ui_events, + ); let listener = match tokio::net::TcpListener::bind(format!("0.0.0.0:{port}")).await { Ok(l) => l, Err(e) => { @@ -2443,8 +2516,10 @@ async fn start_services( } }; + let external_mcp_for_shutdown = external_mcp.clone(); let handle = tokio::spawn(async move { let _ = axum::serve(listener, router).await; + let _ = external_mcp_for_shutdown.lock().await.shutdown_all().await; }); { @@ -2790,6 +2865,7 @@ fn draw_ui( "N/A" } }; + let external_mcp_summary = app.external_mcp_status.render_summary(); let mcp_url: String = app.public_mcp_url().unwrap_or_else(|| "--".into()); let browser_summary = browser::format_browser_names(&app.detected_browsers); let remote_support_summary = browser::format_remote_debug_names(&app.detected_browsers); @@ -2891,7 +2967,7 @@ fn draw_ui( ), ]), Line::from(vec![ - status_label("DevTools:"), + status_label("Browser MCP:"), Span::styled( devtools_status, Style::default().fg(if app.devtools_running { @@ -2901,6 +2977,17 @@ fn draw_ui( }), ), ]), + Line::from(vec![ + status_label("MCP gateway:"), + Span::styled( + external_mcp_summary, + Style::default().fg(if app.external_mcp_status.failed_server_count == 0 { + palette.success_fg + } else { + palette.warning_fg + }), + ), + ]), Line::from(vec![ status_label("MCP Server URL:"), Span::styled( diff --git a/src/mcp.rs b/src/mcp.rs index d1d19df..6a3c5bc 100644 --- a/src/mcp.rs +++ b/src/mcp.rs @@ -10,7 +10,9 @@ use tokio::sync::Mutex; use crate::command; use crate::devtools::DevtoolsBridge; +use crate::external_mcp::{EXTERNAL_MCP_TOOL_NAME, ExternalMcpManager}; use crate::mascot; +use crate::skills; use crate::state::{ AgentsPathMode, Mode, ShowDetailMode, TokenStatsLayout, ToolMode, app_config_path, load_app_config, user_home_dir, @@ -147,6 +149,7 @@ pub async fn handle_request( tool_mode: ToolMode, set_catdesk_as_co_author: bool, devtools: &Option>>, + external_mcp: &Option>>, ) -> Option { match req.method.as_str() { "initialize" => { @@ -172,7 +175,7 @@ pub async fn handle_request( Some(handle_initialize(req)) } m if m.starts_with("notifications/") => None, - "tools/list" => Some(handle_tools_list(req, mode, tool_mode, devtools).await), + "tools/list" => Some(handle_tools_list(req, mode, tool_mode, devtools, external_mcp).await), "tools/call" => Some( handle_tools_call( req, @@ -182,6 +185,7 @@ pub async fn handle_request( tool_mode, set_catdesk_as_co_author, devtools, + external_mcp, ) .await, ), @@ -324,6 +328,7 @@ async fn handle_tools_list( mode: Mode, tool_mode: ToolMode, devtools: &Option>>, + external_mcp: &Option>>, ) -> JsonRpcResponse { let mut tools: Vec = Vec::new(); @@ -395,6 +400,8 @@ async fn handle_tools_list( "annotations": { "readOnlyHint": true, "openWorldHint": false, "destructiveHint": false } })); + tools.extend(skill_tool_descriptors()); + if tool_mode.write_tools_enabled() { tools.push(json!({ "name": "write", @@ -444,6 +451,49 @@ async fn handle_tools_list( } } + if external_mcp.is_some() && tool_mode.write_tools_enabled() { + tools.push(json!({ + "name": EXTERNAL_MCP_TOOL_NAME, + "title": "MCP gateway", + "description": "Search, describe, connect to, and call tools from configured downstream MCP servers.", + "inputSchema": { + "type": "object", + "properties": { + "tool": { "type": "string", "description": "Downstream MCP tool name to call. Use the exposed server-prefixed name when possible." }, + "args": { + "description": "Downstream tool arguments as a JSON object. JSON object strings are accepted for compatibility.", + "oneOf": [ + { "type": "object", "additionalProperties": true }, + { "type": "string" } + ] + }, + "connect": { "type": "string", "description": "Downstream MCP server name to connect and list." }, + "describe": { "type": "string", "description": "Downstream MCP tool name to describe." }, + "search": { "type": "string", "description": "Search downstream tools by name, server, or description." }, + "resource": { "type": "string", "description": "Downstream MCP resource URI to read." }, + "resources": { "type": "boolean", "description": "List downstream MCP resources." }, + "server": { "type": "string", "description": "Server filter or disambiguation." } + } + }, + "annotations": { "readOnlyHint": false, "openWorldHint": true, "destructiveHint": true } + })); + if let Some(manager) = external_mcp { + let mut manager = manager.lock().await; + if let Ok(direct_tools) = manager.direct_tool_descriptors(tool_mode.read_only()).await { + tools.extend(direct_tools); + } + } + } + + if external_mcp.is_some() && tool_mode.read_only() { + if let Some(manager) = external_mcp { + let mut manager = manager.lock().await; + if let Ok(direct_tools) = manager.direct_tool_descriptors(true).await { + tools.extend(direct_tools); + } + } + } + // Browser tools — get from devtools bridge if mode.browser_enabled() { if let Some(bridge) = devtools { @@ -474,6 +524,7 @@ async fn handle_tools_call( tool_mode: ToolMode, set_catdesk_as_co_author: bool, devtools: &Option>>, + external_mcp: &Option>>, ) -> JsonRpcResponse { let params = &req.params; let tool_name = params @@ -486,8 +537,16 @@ async fn handle_tools_call( let before_snapshot = collect_watched_snapshot(&watch_targets, workspace_root); let mut response = { + if tool_name == EXTERNAL_MCP_TOOL_NAME { + if tool_mode.read_only() { + read_only_blocked_response(req, &tool_name) + } else if let Some(manager) = external_mcp { + handle_external_mcp_proxy(req, manager).await + } else { + tool_error_response(req, format!("Unknown tool: {tool_name}")) + } // Local computer tools - if mode.computer_enabled() { + } else if mode.computer_enabled() { if tool_name == "run_command" { if tool_mode.run_command_enabled() { handle_run_command(req, workspace_root, set_catdesk_as_co_author).await @@ -507,6 +566,10 @@ async fn handle_tools_call( ), "read" => handle_read_file(req, workspace_root), "search" => handle_search_text(req, workspace_root), + "list_skills" => handle_list_skills(req, workspace_root), + "search_skills" => handle_search_skills(req, workspace_root), + "read_skill" => handle_read_skill(req, workspace_root), + "read_skill_resource" => handle_read_skill_resource(req, workspace_root), _ => { if tool_mode.write_tools_enabled() { match tool_name.as_str() { @@ -514,7 +577,16 @@ async fn handle_tools_call( "edit" => handle_edit_file(req, workspace_root), "delete" => handle_delete_path(req, workspace_root), _ => { - if mode.browser_enabled() { + if let Some(response) = try_external_mcp_direct_tool( + req, + &tool_name, + tool_mode, + external_mcp, + ) + .await + { + response + } else if mode.browser_enabled() { forward_to_devtools(req, &tool_name, tool_mode, devtools) .await } else { @@ -525,8 +597,23 @@ async fn handle_tools_call( } } } - } else if tool_mode.read_only() && is_local_destructive_tool(&tool_name) { - read_only_blocked_response(req, &tool_name) + } else if tool_mode.read_only() { + if let Some(response) = try_external_mcp_direct_tool( + req, + &tool_name, + tool_mode, + external_mcp, + ) + .await + { + response + } else if is_local_destructive_tool(&tool_name) { + read_only_blocked_response(req, &tool_name) + } else if mode.browser_enabled() { + forward_to_devtools(req, &tool_name, tool_mode, devtools).await + } else { + tool_error_response(req, format!("Unknown tool: {tool_name}")) + } } else if mode.browser_enabled() { forward_to_devtools(req, &tool_name, tool_mode, devtools).await } else { @@ -930,6 +1017,74 @@ fn tool_response( JsonRpcResponse::success(req.id.clone(), result) } +async fn handle_external_mcp_proxy( + req: &JsonRpcRequest, + manager: &Arc>, +) -> JsonRpcResponse { + let args = tool_arguments(req); + let mut manager = manager.lock().await; + match manager.proxy(&args).await { + Ok(output) => tool_success_response_with_structured(req, output.text, output.structured), + Err(error) => tool_error_response(req, error), + } +} + +async fn try_external_mcp_direct_tool( + req: &JsonRpcRequest, + tool_name: &str, + tool_mode: ToolMode, + external_mcp: &Option>>, +) -> Option { + let manager = external_mcp.as_ref()?; + let arguments = tool_arguments(req); + let mut manager = manager.lock().await; + if tool_mode.read_only() { + match manager.call_direct_tool(tool_name, arguments, true).await { + Ok(Some(call)) => { + return Some(tool_success_response_with_structured( + req, + format!( + "called {}:{} via MCP gateway", + call.server_name, call.original_name + ), + json!({ + "toolName": tool_name, + "server": call.server_name, + "downstreamTool": call.original_name, + "downstreamToolCallCount": 1, + "result": call.result, + }), + )); + } + Ok(None) => { + if manager.direct_tool_name_candidate(tool_name) { + return Some(read_only_blocked_response(req, tool_name)); + } + return None; + } + Err(error) => return Some(tool_error_response(req, error)), + } + } + match manager.call_direct_tool(tool_name, arguments, false).await { + Ok(Some(call)) => Some(tool_success_response_with_structured( + req, + format!( + "called {}:{} via MCP gateway", + call.server_name, call.original_name + ), + json!({ + "toolName": tool_name, + "server": call.server_name, + "downstreamTool": call.original_name, + "downstreamToolCallCount": 1, + "result": call.result, + }), + )), + Ok(None) => None, + Err(error) => Some(tool_error_response(req, error)), + } +} + fn tool_message_structured(req: &JsonRpcRequest, message: String, is_error: bool) -> Value { json!({ "toolName": tool_name_from_request(req), @@ -1135,6 +1290,144 @@ fn widget_path_strings(path: &Path) -> (String, String) { ) } +fn skill_tool_descriptors() -> Vec { + vec![ + json!({ + "name": "list_skills", + "title": "List skills", + "description": "List all available local CatDesk skills. Skills are directories with SKILL.md under configured skill roots.", + "inputSchema": {"type": "object", "properties": {}}, + "annotations": { "readOnlyHint": true, "openWorldHint": false, "destructiveHint": false } + }), + json!({ + "name": "search_skills", + "title": "Search skills", + "description": "Search available skills by task, domain, trigger phrase, skill name, description, and SKILL.md body.", + "inputSchema": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Task, domain, trigger phrase, or keyword to search for"} + }, + "required": ["query"] + }, + "annotations": { "readOnlyHint": true, "openWorldHint": false, "destructiveHint": false } + }), + json!({ + "name": "read_skill", + "title": "Read skill", + "description": "Read the complete SKILL.md for one local CatDesk skill.", + "inputSchema": { + "type": "object", + "properties": { + "id": {"type": "string", "description": "Skill id, usually the skill directory name"} + }, + "required": ["id"] + }, + "annotations": { "readOnlyHint": true, "openWorldHint": false, "destructiveHint": false } + }), + json!({ + "name": "read_skill_resource", + "title": "Read skill resource", + "description": "Read a template, reference file, or other text resource inside one skill directory.", + "inputSchema": { + "type": "object", + "properties": { + "skill_id": {"type": "string", "description": "Skill id, usually the skill directory name"}, + "path": {"type": "string", "description": "Resource path relative to the skill directory"} + }, + "required": ["skill_id", "path"] + }, + "annotations": { "readOnlyHint": true, "openWorldHint": false, "destructiveHint": false } + }), + ] +} + +fn handle_list_skills(req: &JsonRpcRequest, workspace_root: &str) -> JsonRpcResponse { + let workspace_root = Path::new(workspace_root); + match skills::list_skills(workspace_root) { + Ok(skill_list) => { + let structured = skills::skill_summaries_payload(&skill_list); + let text = render_skill_list(&skill_list); + tool_success_response_with_structured(req, text, structured) + } + Err(error) => tool_error_response(req, error), + } +} + +fn handle_search_skills(req: &JsonRpcRequest, workspace_root: &str) -> JsonRpcResponse { + let arguments = tool_arguments(req); + let query = match required_string_argument(&arguments, "query") { + Ok(value) => value, + Err(error) => return tool_error_response(req, error), + }; + match skills::search_skills(Path::new(workspace_root), query) { + Ok(skill_list) => { + let mut structured = skills::skill_summaries_payload(&skill_list); + if let Some(object) = structured.as_object_mut() { + object.insert("query".to_string(), json!(query)); + object.insert("toolName".to_string(), json!("search_skills")); + } + let text = render_skill_list(&skill_list); + tool_success_response_with_structured(req, text, structured) + } + Err(error) => tool_error_response(req, error), + } +} + +fn handle_read_skill(req: &JsonRpcRequest, workspace_root: &str) -> JsonRpcResponse { + let arguments = tool_arguments(req); + let skill_id = match required_string_argument(&arguments, "id") { + Ok(value) => value, + Err(error) => return tool_error_response(req, error), + }; + match skills::read_skill(Path::new(workspace_root), skill_id) { + Ok(document) => { + let structured = json!({ + "toolName": "read_skill", + "skill": skills::skill_summary_payload(&document.summary), + "content": document.content, + }); + tool_success_response_with_structured(req, document.content, structured) + } + Err(error) => tool_error_response(req, error), + } +} + +fn handle_read_skill_resource(req: &JsonRpcRequest, workspace_root: &str) -> JsonRpcResponse { + let arguments = tool_arguments(req); + let skill_id = match required_string_argument(&arguments, "skill_id") { + Ok(value) => value, + Err(error) => return tool_error_response(req, error), + }; + let path = match required_string_argument(&arguments, "path") { + Ok(value) => value, + Err(error) => return tool_error_response(req, error), + }; + match skills::read_skill_resource(Path::new(workspace_root), skill_id, path) { + Ok(resource) => { + let structured = json!({ + "toolName": "read_skill_resource", + "skillId": resource.skill_id, + "path": resource.path, + "content": resource.content, + }); + tool_success_response_with_structured(req, resource.content, structured) + } + Err(error) => tool_error_response(req, error), + } +} + +fn render_skill_list(skill_list: &[skills::SkillSummary]) -> String { + if skill_list.is_empty() { + return "No skills found.".to_string(); + } + skill_list + .iter() + .map(|skill| format!("{} - {}\n{}", skill.id, skill.name, skill.description)) + .collect::>() + .join("\n\n") +} + fn catdesk_instruction_text( workspace_root: &str, mode: Mode, @@ -1662,9 +1955,16 @@ fn current_token_stats_layout() -> TokenStatsLayout { } fn current_show_detail_mode() -> ShowDetailMode { - load_app_config() - .map(|config| config.show_detail_mode) - .unwrap_or_default() + #[cfg(test)] + { + ShowDetailMode::Expanded + } + #[cfg(not(test))] + { + crate::state::load_app_config() + .map(|config| config.show_detail_mode) + .unwrap_or_default() + } } fn attach_widget_changed_files( @@ -2521,7 +2821,10 @@ fn diff_changed_files(before: &WatchedSnapshot, after: &WatchedSnapshot) -> Vec< } fn is_local_destructive_tool(tool_name: &str) -> bool { - matches!(tool_name, "run_command" | "write" | "edit" | "delete") + matches!( + tool_name, + "run_command" | "write" | "edit" | "delete" | EXTERNAL_MCP_TOOL_NAME + ) } fn tool_is_read_only(tool: &Value) -> bool { @@ -2815,8 +3118,89 @@ fn handle_delete_path(req: &JsonRpcRequest, workspace_root: &str) -> JsonRpcResp #[cfg(test)] mod tests { use super::*; + use crate::state::{DirectToolsConfig, ExternalMcpConfig, ExternalMcpServer}; use uuid::Uuid; + fn external_mcp_for_test() -> Option>> { + let mut servers = HashMap::new(); + servers.insert( + "mock".to_string(), + ExternalMcpServer { + command: Some("mock-mcp-server".to_string()), + ..ExternalMcpServer::default() + }, + ); + Some(Arc::new(Mutex::new(ExternalMcpManager::new( + ExternalMcpConfig { + mcp_servers: servers, + ..ExternalMcpConfig::default() + }, + )))) + } + + fn external_mcp_with_cached_direct_tool_for_test() -> Option>> { + let mut servers = HashMap::new(); + servers.insert( + "mock".to_string(), + ExternalMcpServer { + command: Some("mock-mcp-server".to_string()), + direct_tools: Some(DirectToolsConfig::Names(vec!["echo".to_string()])), + ..ExternalMcpServer::default() + }, + ); + let mut manager = ExternalMcpManager::new(ExternalMcpConfig { + mcp_servers: servers, + ..ExternalMcpConfig::default() + }); + manager.set_cached_tools_for_test( + "mock", + vec![json!({ + "name": "echo", + "description": "Echo a message", + "inputSchema": { + "type": "object", + "properties": {"message": {"type": "string"}} + } + })], + ); + Some(Arc::new(Mutex::new(manager))) + } + + fn external_mcp_with_cached_read_only_direct_tool_for_test() + -> Option>> { + let mut servers = HashMap::new(); + servers.insert( + "mock".to_string(), + ExternalMcpServer { + command: Some("mock-mcp-server".to_string()), + direct_tools: Some(DirectToolsConfig::Enabled(true)), + ..ExternalMcpServer::default() + }, + ); + let mut manager = ExternalMcpManager::new(ExternalMcpConfig { + mcp_servers: servers, + ..ExternalMcpConfig::default() + }); + manager.set_cached_tools_for_test( + "mock", + vec![ + json!({ + "name": "safe", + "description": "Safe read-only direct tool", + "inputSchema": {"type": "object", "properties": {}}, + "annotations": {"readOnlyHint": true, "openWorldHint": false, "destructiveHint": false} + }), + json!({ + "name": "unsafe", + "description": "Unsafe direct tool", + "inputSchema": {"type": "object", "properties": {}}, + "annotations": {"readOnlyHint": false, "destructiveHint": true} + }), + ], + ); + Some(Arc::new(Mutex::new(manager))) + } + fn resources_read_request(uri: &str) -> JsonRpcRequest { JsonRpcRequest { jsonrpc: "2.0".into(), @@ -2879,7 +3263,8 @@ mod tests { params: json!({}), }; - let response = handle_tools_list(&req, Mode::Both, ToolMode::MultiTools, &None).await; + let response = + handle_tools_list(&req, Mode::Both, ToolMode::MultiTools, &None, &None).await; let names = response .result .as_ref() @@ -2897,6 +3282,10 @@ mod tests { "catdesk_instruction", "read", "search", + "list_skills", + "search_skills", + "read_skill", + "read_skill_resource", "write", "edit", "delete", @@ -2913,7 +3302,8 @@ mod tests { params: json!({}), }; - let response = handle_tools_list(&req, Mode::Both, ToolMode::MultiTools, &None).await; + let response = + handle_tools_list(&req, Mode::Both, ToolMode::MultiTools, &None, &None).await; let tools = response .result .as_ref() @@ -2950,7 +3340,41 @@ mod tests { params: json!({}), }; - let response = handle_tools_list(&req, Mode::Both, ToolMode::ReadOnly, &None).await; + let response = handle_tools_list(&req, Mode::Both, ToolMode::ReadOnly, &None, &None).await; + let names = response + .result + .as_ref() + .and_then(|result| result.get("tools")) + .and_then(Value::as_array) + .expect("missing tools") + .iter() + .filter_map(|tool| tool.get("name").and_then(Value::as_str)) + .collect::>(); + + assert_eq!( + names, + vec![ + "catdesk_instruction", + "read", + "search", + "list_skills", + "search_skills", + "read_skill", + "read_skill_resource", + ] + ); + } + + #[tokio::test] + async fn tools_list_exposes_skills_tools_in_read_only_mode() { + let req = JsonRpcRequest { + jsonrpc: "2.0".into(), + id: Some(json!("req-tools-list")), + method: "tools/list".into(), + params: json!({}), + }; + + let response = handle_tools_list(&req, Mode::Both, ToolMode::ReadOnly, &None, &None).await; let names = response .result .as_ref() @@ -2961,7 +3385,511 @@ mod tests { .filter_map(|tool| tool.get("name").and_then(Value::as_str)) .collect::>(); - assert_eq!(names, vec!["catdesk_instruction", "read", "search"]); + assert!(names.contains(&"list_skills")); + assert!(names.contains(&"search_skills")); + assert!(names.contains(&"read_skill")); + assert!(names.contains(&"read_skill_resource")); + } + + #[tokio::test] + async fn list_skills_tool_returns_workspace_skills() { + let workspace_root = + std::env::temp_dir().join(format!("catdesk-mcp-list-skills-{}", Uuid::new_v4())); + let skill_root = workspace_root.join(".catdesk/skills/slides"); + std::fs::create_dir_all(&skill_root).expect("create skill root"); + std::fs::write( + skill_root.join("SKILL.md"), + "---\nname: Slides\ndescription: Create slide decks.\n---\nUse this skill for presentations.\n", + ) + .expect("write skill"); + let req = tool_call_request("list_skills", json!({})); + let workspace_root_str = workspace_root.to_string_lossy().into_owned(); + + let response = handle_tools_call( + &req, + &workspace_root_str, + 1, + Mode::Both, + ToolMode::ReadOnly, + false, + &None, + &None, + ) + .await; + let skills = response + .result + .as_ref() + .and_then(|result| result.get("structuredContent")) + .and_then(|structured| structured.get("skills")) + .and_then(Value::as_array) + .expect("missing skills"); + assert!( + skills + .iter() + .any(|skill| skill.get("id").and_then(Value::as_str) == Some("slides")) + ); + + let _ = std::fs::remove_dir_all(workspace_root); + } + + #[tokio::test] + async fn read_skill_resource_tool_rejects_traversal() { + let workspace_root = + std::env::temp_dir().join(format!("catdesk-mcp-skill-resource-{}", Uuid::new_v4())); + let skill_root = workspace_root.join(".catdesk/skills/slides"); + std::fs::create_dir_all(&skill_root).expect("create skill root"); + std::fs::write( + skill_root.join("SKILL.md"), + "---\nname: Slides\ndescription: Create decks.\n---\nUse this skill for slides.\n", + ) + .expect("write skill"); + let req = tool_call_request( + "read_skill_resource", + json!({"skill_id": "slides", "path": "../secret.txt"}), + ); + let workspace_root_str = workspace_root.to_string_lossy().into_owned(); + + let response = handle_tools_call( + &req, + &workspace_root_str, + 1, + Mode::Both, + ToolMode::ReadOnly, + false, + &None, + &None, + ) + .await; + let is_error = response + .result + .as_ref() + .and_then(|result| result.get("isError")) + .and_then(Value::as_bool); + assert_eq!(is_error, Some(true)); + + let _ = std::fs::remove_dir_all(workspace_root); + } + + #[tokio::test] + async fn tools_list_exposes_mcp_proxy_when_external_manager_present() { + let req = JsonRpcRequest { + jsonrpc: "2.0".into(), + id: Some(json!("req-tools-list")), + method: "tools/list".into(), + params: json!({}), + }; + let external_mcp = external_mcp_for_test(); + + let response = + handle_tools_list(&req, Mode::Both, ToolMode::MultiTools, &None, &external_mcp).await; + let names = response + .result + .as_ref() + .and_then(|result| result.get("tools")) + .and_then(Value::as_array) + .expect("missing tools") + .iter() + .filter_map(|tool| tool.get("name").and_then(Value::as_str)) + .collect::>(); + + assert!(names.contains(&EXTERNAL_MCP_TOOL_NAME)); + } + + #[tokio::test] + async fn browser_mode_exposes_unprefixed_gateway_tools_without_devtools_bridge() { + let req = JsonRpcRequest { + jsonrpc: "2.0".into(), + id: Some(json!("req-browser-tools-list")), + method: "tools/list".into(), + params: json!({}), + }; + let mut servers = HashMap::new(); + servers.insert( + "browser".to_string(), + ExternalMcpServer { + command: Some("npx".to_string()), + unprefixed_tools: true, + direct_tools: Some(DirectToolsConfig::Enabled(true)), + ..ExternalMcpServer::default() + }, + ); + let mut manager = ExternalMcpManager::new(ExternalMcpConfig { + mcp_servers: servers, + ..ExternalMcpConfig::default() + }); + manager.set_cached_tools_for_test( + "browser", + vec![json!({ + "name": "take-screenshot", + "description": "Take a browser screenshot", + "inputSchema": {"type":"object", "properties": {}} + })], + ); + let external_mcp = Some(Arc::new(Mutex::new(manager))); + + let response = handle_tools_list( + &req, + Mode::Browser, + ToolMode::MultiTools, + &None, + &external_mcp, + ) + .await; + let names = response + .result + .as_ref() + .and_then(|result| result.get("tools")) + .and_then(Value::as_array) + .expect("missing tools") + .iter() + .filter_map(|tool| tool.get("name").and_then(Value::as_str)) + .collect::>(); + + assert!(names.contains(&"take_screenshot")); + } + + #[tokio::test] + async fn tools_list_exposes_direct_downstream_tools_when_enabled() { + let req = JsonRpcRequest { + jsonrpc: "2.0".into(), + id: Some(json!("req-tools-list")), + method: "tools/list".into(), + params: json!({}), + }; + let external_mcp = external_mcp_with_cached_direct_tool_for_test(); + + let response = + handle_tools_list(&req, Mode::Both, ToolMode::MultiTools, &None, &external_mcp).await; + let tools = response + .result + .as_ref() + .and_then(|result| result.get("tools")) + .and_then(Value::as_array) + .expect("missing tools"); + let direct_tool = tools + .iter() + .find(|tool| tool.get("name").and_then(Value::as_str) == Some("mock_echo")) + .expect("missing direct tool"); + assert_eq!( + direct_tool + .get("inputSchema") + .and_then(|schema| schema.get("properties")) + .and_then(|properties| properties.get("message")) + .and_then(|property| property.get("type")) + .and_then(Value::as_str), + Some("string") + ); + } + + #[tokio::test] + async fn direct_downstream_tool_call_routes_through_external_manager() { + let workspace_root = + std::env::temp_dir().join(format!("catdesk-mcp-direct-tool-{}", Uuid::new_v4())); + std::fs::create_dir_all(&workspace_root).expect("create workspace"); + let server_path = workspace_root.join("mock_mcp_server.py"); + std::fs::write( + &server_path, + r#" +import json +import sys + +for line in sys.stdin: + message = json.loads(line) + request_id = message.get("id") + if request_id is None: + continue + method = message.get("method") + if method == "initialize": + result = { + "protocolVersion": "2025-03-26", + "capabilities": {"tools": {"listChanged": False}}, + "serverInfo": {"name": "mock", "version": "1.0.0"}, + } + elif method == "tools/list": + result = {"tools": [{ + "name": "echo", + "description": "Echo a message", + "inputSchema": {"type": "object", "properties": {"message": {"type": "string"}}}, + }]} + elif method == "tools/call": + args = message.get("params", {}).get("arguments", {}) + result = {"content": [{"type": "text", "text": args.get("message", "")}], "isError": False} + else: + print(json.dumps({"jsonrpc": "2.0", "id": request_id, "error": {"code": -32601, "message": "unknown method"}}), flush=True) + continue + print(json.dumps({"jsonrpc": "2.0", "id": request_id, "result": result}), flush=True) +"#, + ) + .expect("write mock server"); + + let mut servers = HashMap::new(); + servers.insert( + "mock".to_string(), + ExternalMcpServer { + command: Some("python3".to_string()), + args: vec!["-u".to_string(), server_path.to_string_lossy().into_owned()], + direct_tools: Some(DirectToolsConfig::Names(vec!["echo".to_string()])), + ..ExternalMcpServer::default() + }, + ); + let external_mcp = Some(Arc::new(Mutex::new(ExternalMcpManager::with_workspace( + ExternalMcpConfig { + mcp_servers: servers, + ..ExternalMcpConfig::default() + }, + workspace_root.clone(), + )))); + let req = tool_call_request("mock_echo", json!({ "message": "hello direct" })); + let workspace_root_str = workspace_root.to_string_lossy().into_owned(); + + let response = handle_tools_call( + &req, + &workspace_root_str, + 0xff, + Mode::Both, + ToolMode::MultiTools, + false, + &None, + &external_mcp, + ) + .await; + let structured = response + .result + .as_ref() + .and_then(|result| result.get("structuredContent")) + .expect("missing structured content"); + + assert_eq!( + structured.get("toolName").and_then(Value::as_str), + Some("mock_echo") + ); + assert_eq!( + structured + .get("downstreamToolCallCount") + .and_then(Value::as_u64), + Some(1) + ); + assert_eq!( + structured + .get("result") + .and_then(|result| result.get("content")) + .and_then(Value::as_array) + .and_then(|content| content.first()) + .and_then(|entry| entry.get("text")) + .and_then(Value::as_str), + Some("hello direct") + ); + + let _ = std::fs::remove_dir_all(workspace_root); + } + + #[tokio::test] + async fn read_only_tools_list_exposes_only_annotated_downstream_direct_tools() { + let req = JsonRpcRequest { + jsonrpc: "2.0".into(), + id: Some(json!("req-tools-list")), + method: "tools/list".into(), + params: json!({}), + }; + let external_mcp = external_mcp_with_cached_read_only_direct_tool_for_test(); + + let response = + handle_tools_list(&req, Mode::Both, ToolMode::ReadOnly, &None, &external_mcp).await; + let names = response + .result + .as_ref() + .and_then(|result| result.get("tools")) + .and_then(Value::as_array) + .expect("missing tools") + .iter() + .filter_map(|tool| tool.get("name").and_then(Value::as_str)) + .collect::>(); + + assert!(names.contains(&"mock_safe")); + assert!(!names.contains(&"mock_unsafe")); + } + + #[tokio::test] + async fn read_only_direct_downstream_tool_call_routes_when_annotated_read_only() { + let workspace_root = std::env::temp_dir().join(format!( + "catdesk-mcp-read-only-direct-tool-{}", + Uuid::new_v4() + )); + std::fs::create_dir_all(&workspace_root).expect("create workspace"); + let server_path = workspace_root.join("mock_mcp_server.py"); + std::fs::write( + &server_path, + r#" +import json +import sys + +for line in sys.stdin: + message = json.loads(line) + request_id = message.get("id") + if request_id is None: + continue + method = message.get("method") + if method == "initialize": + result = { + "protocolVersion": "2025-03-26", + "capabilities": {"tools": {"listChanged": False}}, + "serverInfo": {"name": "mock", "version": "1.0.0"}, + } + elif method == "tools/list": + result = {"tools": [{ + "name": "status", + "description": "Read status", + "inputSchema": {"type": "object", "properties": {}}, + "annotations": {"readOnlyHint": True, "openWorldHint": False, "destructiveHint": False}, + }]} + elif method == "tools/call": + result = {"content": [{"type": "text", "text": "ok"}], "isError": False} + else: + print(json.dumps({"jsonrpc": "2.0", "id": request_id, "error": {"code": -32601, "message": "unknown method"}}), flush=True) + continue + print(json.dumps({"jsonrpc": "2.0", "id": request_id, "result": result}), flush=True) +"#, + ) + .expect("write mock server"); + + let mut servers = HashMap::new(); + servers.insert( + "mock".to_string(), + ExternalMcpServer { + command: Some("python3".to_string()), + args: vec!["-u".to_string(), server_path.to_string_lossy().into_owned()], + direct_tools: Some(DirectToolsConfig::Names(vec!["status".to_string()])), + ..ExternalMcpServer::default() + }, + ); + let external_mcp = Some(Arc::new(Mutex::new(ExternalMcpManager::with_workspace( + ExternalMcpConfig { + mcp_servers: servers, + ..ExternalMcpConfig::default() + }, + workspace_root.clone(), + )))); + let workspace_root_str = workspace_root.to_string_lossy().into_owned(); + let req = tool_call_request("mock_status", json!({})); + + let response = handle_tools_call( + &req, + &workspace_root_str, + 0xff, + Mode::Both, + ToolMode::ReadOnly, + false, + &None, + &external_mcp, + ) + .await; + let structured = response + .result + .as_ref() + .and_then(|result| result.get("structuredContent")) + .expect("missing structured content"); + assert_eq!( + structured + .get("downstreamToolCallCount") + .and_then(Value::as_u64), + Some(1) + ); + assert_eq!( + structured + .get("result") + .and_then(|result| result.get("content")) + .and_then(Value::as_array) + .and_then(|content| content.first()) + .and_then(|entry| entry.get("text")) + .and_then(Value::as_str), + Some("ok") + ); + + let _ = std::fs::remove_dir_all(workspace_root); + } + + #[tokio::test] + async fn direct_downstream_tool_call_is_blocked_in_read_only_mode() { + let external_mcp = external_mcp_with_cached_direct_tool_for_test(); + let workspace_root = + std::env::temp_dir().join(format!("catdesk-mcp-direct-read-only-{}", Uuid::new_v4())); + std::fs::create_dir_all(&workspace_root).expect("create workspace"); + let workspace_root_str = workspace_root.to_string_lossy().into_owned(); + let req = tool_call_request("mock_echo", json!({ "message": "blocked" })); + + let response = handle_tools_call( + &req, + &workspace_root_str, + 0xff, + Mode::Both, + ToolMode::ReadOnly, + false, + &None, + &external_mcp, + ) + .await; + + assert_eq!( + response + .result + .as_ref() + .and_then(|result| result.get("isError")) + .and_then(Value::as_bool), + Some(true) + ); + assert_eq!( + result_text(&response), + "Tool 'mock_echo' is disabled in read-only mode" + ); + + let _ = std::fs::remove_dir_all(workspace_root); + } + + #[tokio::test] + async fn mcp_proxy_status_returns_configured_servers() { + let external_mcp = external_mcp_for_test(); + let workspace_root = + std::env::temp_dir().join(format!("catdesk-mcp-gateway-{}", Uuid::new_v4())); + std::fs::create_dir_all(&workspace_root).expect("create workspace"); + let workspace_root_str = workspace_root.to_string_lossy().into_owned(); + let req = tool_call_request(EXTERNAL_MCP_TOOL_NAME, json!({})); + + let response = handle_tools_call( + &req, + &workspace_root_str, + 0xff, + Mode::Both, + ToolMode::MultiTools, + false, + &None, + &external_mcp, + ) + .await; + let structured = response + .result + .as_ref() + .and_then(|result| result.get("structuredContent")) + .expect("missing structured content"); + + assert_eq!( + structured.get("action").and_then(Value::as_str), + Some("status") + ); + assert_eq!( + structured.get("serverCount").and_then(Value::as_u64), + Some(1) + ); + assert_eq!( + structured + .get("servers") + .and_then(Value::as_array) + .and_then(|servers| servers.first()) + .and_then(|server| server.get("name")) + .and_then(Value::as_str), + Some("mock") + ); + + let _ = std::fs::remove_dir_all(workspace_root); } #[tokio::test] @@ -2973,7 +3901,8 @@ mod tests { params: json!({}), }; - let response = handle_tools_list(&req, Mode::Both, ToolMode::MultiTools, &None).await; + let response = + handle_tools_list(&req, Mode::Both, ToolMode::MultiTools, &None, &None).await; let search_tool = response .result .as_ref() @@ -3030,6 +3959,7 @@ mod tests { ToolMode::MultiTools, false, &None, + &None, ) .await; @@ -3072,6 +4002,7 @@ mod tests { ToolMode::MultiTools, false, &None, + &None, ) .await; @@ -3104,6 +4035,7 @@ mod tests { ToolMode::MultiTools, false, &None, + &None, ) .await; @@ -3154,6 +4086,7 @@ mod tests { ToolMode::MultiTools, false, &None, + &None, ) .await; @@ -3253,6 +4186,7 @@ mod tests { ToolMode::MultiTools, false, &None, + &None, ) .await; @@ -3325,6 +4259,7 @@ mod tests { ToolMode::MultiTools, false, &None, + &None, ) .await; @@ -3404,6 +4339,7 @@ mod tests { ToolMode::MultiTools, false, &None, + &None, ) .await; @@ -3453,6 +4389,7 @@ mod tests { ToolMode::MultiTools, false, &None, + &None, ) .await; @@ -3528,6 +4465,7 @@ mod tests { ToolMode::MultiTools, false, &None, + &None, ) .await; @@ -3601,6 +4539,7 @@ mod tests { ToolMode::MultiTools, false, &None, + &None, ) .await; @@ -3684,6 +4623,7 @@ mod tests { ToolMode::MultiTools, false, &None, + &None, ) .await; @@ -3746,6 +4686,7 @@ mod tests { ToolMode::MultiTools, false, &None, + &None, ) .await; @@ -3782,6 +4723,7 @@ mod tests { ToolMode::MultiTools, false, &None, + &None, ) .await; @@ -3817,6 +4759,7 @@ mod tests { ToolMode::MultiTools, false, &None, + &None, ) .await; diff --git a/src/server.rs b/src/server.rs index 0a05263..cf5a49f 100644 --- a/src/server.rs +++ b/src/server.rs @@ -12,6 +12,7 @@ use std::sync::Arc; use tokio::sync::{Mutex, mpsc::UnboundedSender}; use crate::devtools::DevtoolsBridge; +use crate::external_mcp::ExternalMcpManager; use crate::mcp::{self, JsonRpcRequest, WIDGET_PAYLOAD_META_KEY}; use crate::state::{ AgentsPathMode, FlowDirection, ServerUiEvent, SharedState, ShowDetailMode, TokenStatsLayout, @@ -26,6 +27,7 @@ const STATELESS_FLOW_LABEL: &str = "stateless"; struct ServerState { app: SharedState, devtools: Option>>, + external_mcp: Option>>, ui_events: UnboundedSender, } @@ -33,12 +35,14 @@ struct ServerState { pub fn router( app_state: SharedState, devtools: Option>>, + external_mcp: Option>>, mcp_path: String, ui_events: UnboundedSender, ) -> Router { let state = ServerState { app: app_state, devtools, + external_mcp, ui_events, }; Router::new() @@ -226,6 +230,54 @@ fn attach_history_usage(result: &mut Option, usage_totals: &UsageTotals) } } +fn external_mcp_failure_log_summary( + req: &JsonRpcRequest, + response: &mcp::JsonRpcResponse, +) -> Option { + let result = response.result.as_ref()?; + if result.get("isError").and_then(Value::as_bool) != Some(true) { + return None; + } + let tool_name = req + .params + .get("name") + .and_then(Value::as_str) + .unwrap_or_default(); + if tool_name.is_empty() { + return None; + } + let structured = result.get("structuredContent").and_then(Value::as_object); + let is_external = tool_name == crate::external_mcp::EXTERNAL_MCP_TOOL_NAME + || structured + .and_then(|content| content.get("downstreamTool")) + .is_some(); + if !is_external { + return None; + } + let action = req + .params + .get("arguments") + .and_then(Value::as_object) + .map(|args| { + [ + "tool", + "connect", + "disconnect", + "describe", + "search", + "resource", + "resources", + ] + .into_iter() + .find(|key| args.contains_key(*key)) + .unwrap_or("status") + }) + .unwrap_or("call"); + Some(format!( + "External MCP tool failure: tool={tool_name} action={action} args=" + )) +} + // ── GET / — health ────────────────────────────────────────── async fn health(State(s): State) -> Json { @@ -932,6 +984,36 @@ mod tests { ); } + #[test] + fn external_mcp_failure_log_summary_redacts_arguments() { + let req = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: Some(json!(1)), + method: "tools/call".to_string(), + params: json!({ + "name": "mcp", + "arguments": { + "tool": "danger", + "args": "{\"secret\":\"value\"}" + } + }), + }; + let response = mcp::JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: Some(json!(1)), + result: Some(json!({ + "isError": true, + "structuredContent": {"toolName": "mcp", "message": "failed"} + })), + error: None, + }; + let message = external_mcp_failure_log_summary(&req, &response).expect("log summary"); + assert!(message.contains("tool=mcp")); + assert!(message.contains("action=tool")); + assert!(message.contains("args=")); + assert!(!message.contains("secret")); + } + #[tokio::test] async fn post_mcp_accumulates_usage_from_widget_payload_meta() { let workspace_root = unique_temp_path("catdesk-post-mcp-workspace"); @@ -952,6 +1034,7 @@ mod tests { let server_state = ServerState { app: app_state.clone(), devtools: None, + external_mcp: None, ui_events: ui_tx, }; @@ -1098,6 +1181,7 @@ async fn post_mcp(State(s): State, body_bytes: Bytes) -> Response, body_bytes: Bytes) -> Response Result, String> { + let mut seen = HashSet::new(); + let mut skills = Vec::new(); + for root in skill_roots(workspace_root) { + let Ok(entries) = std::fs::read_dir(&root) else { + continue; + }; + for entry in entries.flatten() { + let path = entry.path(); + if !path.is_dir() || !path.join(SKILL_FILE).is_file() { + continue; + } + let Some(id) = path + .file_name() + .and_then(|name| name.to_str()) + .map(str::to_string) + else { + continue; + }; + if !seen.insert(id.clone()) { + continue; + } + skills.push(load_skill_summary(id, path)?); + } + } + skills.sort_by(|left, right| left.id.cmp(&right.id)); + Ok(skills) +} + +pub fn search_skills(workspace_root: &Path, query: &str) -> Result, String> { + let query = query.trim(); + if query.is_empty() { + return list_skills(workspace_root); + } + let query_lower = query.to_ascii_lowercase(); + let mut scored = Vec::new(); + for summary in list_skills(workspace_root)? { + let content = read_limited_text(&summary.root.join(SKILL_FILE))?.to_ascii_lowercase(); + let mut haystack = format!("{}\n{}\n{}", summary.id, summary.name, summary.description) + .to_ascii_lowercase(); + haystack.push('\n'); + haystack.push_str(&content); + if !haystack.contains(&query_lower) { + continue; + } + let score = skill_match_score(&summary, &content, &query_lower); + scored.push((score, summary)); + } + scored.sort_by(|(left_score, left), (right_score, right)| { + right_score + .cmp(left_score) + .then_with(|| left.id.cmp(&right.id)) + }); + Ok(scored.into_iter().map(|(_, summary)| summary).collect()) +} + +pub fn read_skill(workspace_root: &Path, skill_id: &str) -> Result { + let summary = find_skill(workspace_root, skill_id)?; + let content = read_limited_text(&summary.root.join(SKILL_FILE))?; + Ok(SkillDocument { summary, content }) +} + +pub fn read_skill_resource( + workspace_root: &Path, + skill_id: &str, + resource_path: &str, +) -> Result { + let summary = find_skill(workspace_root, skill_id)?; + let safe_path = safe_relative_path(resource_path)?; + if safe_path == PathBuf::from(SKILL_FILE) { + return Err( + "read_skill_resource is for resource files; use read_skill for SKILL.md".to_string(), + ); + } + let path = summary.root.join(&safe_path); + let canonical_root = summary + .root + .canonicalize() + .map_err(|error| format!("failed to resolve skill root: {error}"))?; + let canonical_path = path + .canonicalize() + .map_err(|error| format!("skill resource not found: {error}"))?; + if !canonical_path.starts_with(&canonical_root) { + return Err("skill resource path escapes the skill directory".to_string()); + } + let content = read_limited_text(&canonical_path)?; + Ok(SkillResource { + skill_id: summary.id, + path: safe_path.to_string_lossy().replace('\\', "/"), + content, + }) +} + +pub fn skill_summaries_payload(skills: &[SkillSummary]) -> Value { + json!({ + "skillCount": skills.len(), + "skills": skills.iter().map(skill_summary_payload).collect::>() + }) +} + +pub fn skill_summary_payload(summary: &SkillSummary) -> Value { + json!({ + "id": summary.id, + "name": summary.name, + "description": summary.description, + }) +} + +pub fn skill_roots(workspace_root: &Path) -> Vec { + let mut roots = Vec::new(); + if let Ok(value) = std::env::var("CATDESK_SKILLS_DIR") { + for part in std::env::split_paths(&value) { + push_existing_dir(&mut roots, part); + } + } + push_existing_dir(&mut roots, workspace_root.join(".catdesk").join("skills")); + push_existing_dir(&mut roots, workspace_root.join("skills")); + if let Some(home) = std::env::var_os("HOME") { + push_existing_dir( + &mut roots, + PathBuf::from(home).join(".catdesk").join("skills"), + ); + } + push_existing_dir(&mut roots, PathBuf::from("/home/oai/skills")); + roots +} + +fn push_existing_dir(roots: &mut Vec, path: PathBuf) { + if path.is_dir() && !roots.iter().any(|existing| existing == &path) { + roots.push(path); + } +} + +fn find_skill(workspace_root: &Path, skill_id: &str) -> Result { + list_skills(workspace_root)? + .into_iter() + .find(|skill| skill.id == skill_id) + .ok_or_else(|| format!("skill not found: {skill_id}")) +} + +fn load_skill_summary(id: String, root: PathBuf) -> Result { + let content = read_limited_text(&root.join(SKILL_FILE))?; + let frontmatter = parse_skill_frontmatter(&content) + .ok_or_else(|| format!("skill {id} is missing required YAML frontmatter"))?; + Ok(SkillSummary { + id, + name: frontmatter.name, + description: frontmatter.description, + root, + }) +} + +fn read_limited_text(path: &Path) -> Result { + let metadata = path + .metadata() + .map_err(|error| format!("failed to read metadata for {}: {error}", path.display()))?; + if metadata.len() > MAX_SKILL_TEXT_BYTES { + return Err(format!( + "skill file exceeds {} bytes: {}", + MAX_SKILL_TEXT_BYTES, + path.display() + )); + } + std::fs::read_to_string(path) + .map_err(|error| format!("failed to read {}: {error}", path.display())) +} + +#[derive(Debug, PartialEq, Eq)] +struct SkillFrontmatter { + name: String, + description: String, +} + +fn parse_skill_frontmatter(content: &str) -> Option { + let mut lines = content.lines(); + if lines.next()?.trim() != "---" { + return None; + } + let mut fields = HashMap::::new(); + let mut multiline_key: Option = None; + let mut multiline_value = String::new(); + for line in lines { + let trimmed_end = line.trim_end(); + if trimmed_end == "---" { + flush_multiline_field(&mut fields, &mut multiline_key, &mut multiline_value); + let name = fields.remove("name")?.trim().to_string(); + let description = fields.remove("description")?.trim().to_string(); + if name.is_empty() || description.is_empty() { + return None; + } + return Some(SkillFrontmatter { name, description }); + } + if let Some(key) = multiline_key.as_ref() { + if starts_unclosed_quote(&multiline_value) + || line.starts_with(' ') + || line.starts_with('\t') + || trimmed_end.is_empty() + { + if !multiline_value.is_empty() { + multiline_value.push('\n'); + } + multiline_value.push_str(trimmed_end.trim()); + continue; + } + let key = key.clone(); + fields.insert(key, unquote_yaml_scalar(multiline_value.trim())); + multiline_key = None; + multiline_value.clear(); + } + let Some((key, value)) = trimmed_end.split_once(':') else { + continue; + }; + let key = key.trim().to_string(); + let value = value.trim(); + if value == "|" || value == ">" { + multiline_key = Some(key); + multiline_value.clear(); + } else if starts_unclosed_quote(value) { + multiline_key = Some(key); + multiline_value.push_str(value); + } else { + fields.insert(key, unquote_yaml_scalar(value)); + } + } + None +} + +fn flush_multiline_field( + fields: &mut HashMap, + multiline_key: &mut Option, + multiline_value: &mut String, +) { + if let Some(key) = multiline_key.take() { + fields.insert(key, unquote_yaml_scalar(multiline_value.trim())); + multiline_value.clear(); + } +} + +fn starts_unclosed_quote(value: &str) -> bool { + let trimmed = value.trim(); + if trimmed.len() < 2 { + return false; + } + let first = trimmed.as_bytes()[0]; + if first != b'\'' && first != b'"' { + return false; + } + trimmed.as_bytes()[trimmed.len() - 1] != first +} + +fn unquote_yaml_scalar(value: &str) -> String { + let trimmed = value.trim(); + if trimmed.len() >= 2 { + let bytes = trimmed.as_bytes(); + if (bytes[0] == b'"' && bytes[trimmed.len() - 1] == b'"') + || (bytes[0] == b'\'' && bytes[trimmed.len() - 1] == b'\'') + { + return trimmed[1..trimmed.len() - 1].to_string(); + } + } + trimmed.to_string() +} + +fn skill_match_score(summary: &SkillSummary, content_lower: &str, query_lower: &str) -> i32 { + let mut score = 0; + if summary.id.eq_ignore_ascii_case(query_lower) { + score += 100; + } + if summary.name.to_ascii_lowercase().contains(query_lower) { + score += 40; + } + if summary + .description + .to_ascii_lowercase() + .contains(query_lower) + { + score += 25; + } + if content_lower.contains(query_lower) { + score += 5; + } + score +} + +fn safe_relative_path(resource_path: &str) -> Result { + let trimmed = resource_path.trim(); + if trimmed.is_empty() { + return Err("resource path is required".to_string()); + } + let path = Path::new(trimmed); + if path.is_absolute() { + return Err("resource path must be relative to the skill directory".to_string()); + } + let mut output = PathBuf::new(); + for component in path.components() { + match component { + Component::Normal(part) => output.push(part), + Component::CurDir => {} + Component::ParentDir => { + return Err("resource path cannot contain parent directory traversal".to_string()); + } + Component::RootDir | Component::Prefix(_) => { + return Err("resource path must be relative to the skill directory".to_string()); + } + } + } + if output.as_os_str().is_empty() { + return Err("resource path is required".to_string()); + } + Ok(output) +} + +#[cfg(test)] +mod tests { + use super::*; + use uuid::Uuid; + + fn temp_workspace() -> PathBuf { + let path = std::env::temp_dir().join(format!("catdesk-skills-{}", Uuid::new_v4())); + std::fs::create_dir_all(&path).expect("create temp workspace"); + path + } + + fn write_skill(workspace: &Path, id: &str, skill_md: &str) { + let root = workspace.join(".catdesk").join("skills").join(id); + std::fs::create_dir_all(&root).expect("create skill root"); + std::fs::write(root.join(SKILL_FILE), skill_md).expect("write skill file"); + } + + fn skill_doc(name: &str, description: &str, body: &str) -> String { + format!("---\nname: {name}\ndescription: {description}\n---\n{body}\n") + } + + #[test] + fn parse_skill_frontmatter_reads_required_name_and_description() { + let parsed = parse_skill_frontmatter( + "---\nname: Docs\ndescription: Use when drafting documents.\n---\n# Body\n", + ) + .expect("parse frontmatter"); + assert_eq!(parsed.name, "Docs"); + assert_eq!(parsed.description, "Use when drafting documents."); + } + + #[test] + fn parse_skill_frontmatter_rejects_missing_description() { + assert!(parse_skill_frontmatter("---\nname: Docs\n---\n").is_none()); + } + + #[test] + fn parse_skill_frontmatter_reads_multiline_description() { + let parsed = parse_skill_frontmatter( + "---\nname: Docs\ndescription: |\n Use when drafting documents.\n Handles reports.\n---\n", + ) + .expect("parse frontmatter"); + assert_eq!( + parsed.description, + "Use when drafting documents.\nHandles reports." + ); + } + + #[test] + fn parse_skill_frontmatter_reads_quoted_multiline_description() { + let parsed = parse_skill_frontmatter( + "---\nname: spreadsheets\ndescription: \"Create and edit spreadsheets.\nThis skill applies when workbook work is requested.\"\n---\n", + ) + .expect("parse frontmatter"); + assert_eq!( + parsed.description, + "Create and edit spreadsheets.\nThis skill applies when workbook work is requested." + ); + } + + #[test] + fn list_skills_reads_workspace_skill_dirs() { + let workspace = temp_workspace(); + write_skill( + &workspace, + "slides", + &skill_doc( + "Slides", + "Create slide decks.", + "Use this skill for presentations.", + ), + ); + let skills = list_skills(&workspace).expect("list skills"); + assert!(skills.iter().any(|skill| skill.id == "slides")); + let _ = std::fs::remove_dir_all(workspace); + } + + #[test] + fn search_skills_matches_description_and_body() { + let workspace = temp_workspace(); + write_skill( + &workspace, + "docs", + &skill_doc( + "Documents", + "Create polished reports.", + "Use for writing reports.", + ), + ); + let skills = search_skills(&workspace, "reports").expect("search skills"); + assert_eq!(skills.first().map(|skill| skill.id.as_str()), Some("docs")); + let _ = std::fs::remove_dir_all(workspace); + } + + #[test] + fn read_skill_returns_full_skill_markdown() { + let workspace = temp_workspace(); + write_skill( + &workspace, + "pdf", + &skill_doc("PDFs", "Render and inspect PDFs.", "Use this for PDFs."), + ); + let skill = read_skill(&workspace, "pdf").expect("read skill"); + assert_eq!(skill.summary.name, "PDFs"); + assert_eq!(skill.summary.description, "Render and inspect PDFs."); + assert!(skill.content.contains("Use this for PDFs.")); + let _ = std::fs::remove_dir_all(workspace); + } + + #[test] + fn read_skill_resource_rejects_parent_traversal() { + let workspace = temp_workspace(); + write_skill( + &workspace, + "pdf", + &skill_doc("PDFs", "Render PDFs.", "Use this for PDFs."), + ); + let error = read_skill_resource(&workspace, "pdf", "../secret.txt") + .expect_err("traversal should fail"); + assert!(error.contains("traversal")); + let _ = std::fs::remove_dir_all(workspace); + } + + #[test] + fn read_skill_resource_reads_text_resource() { + let workspace = temp_workspace(); + write_skill( + &workspace, + "pdf", + &skill_doc("PDFs", "Render PDFs.", "Use this for PDFs."), + ); + let root = workspace.join(".catdesk").join("skills").join("pdf"); + std::fs::create_dir_all(root.join("templates")).expect("create templates"); + std::fs::write(root.join("templates/basic.txt"), "template body").expect("write template"); + let resource = + read_skill_resource(&workspace, "pdf", "templates/basic.txt").expect("read resource"); + assert_eq!(resource.content, "template body"); + let _ = std::fs::remove_dir_all(workspace); + } +} diff --git a/src/state.rs b/src/state.rs index 79ab801..436f132 100644 --- a/src/state.rs +++ b/src/state.rs @@ -68,6 +68,124 @@ impl UsageTotals { } } +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum DirectToolsConfig { + Enabled(bool), + Names(Vec), +} + +fn default_external_mcp_tool_prefix() -> String { + "server".to_string() +} + +fn default_external_mcp_idle_timeout_minutes() -> u64 { + 10 +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExternalMcpSettings { + #[serde(default = "default_external_mcp_tool_prefix")] + pub tool_prefix: String, + #[serde(default)] + pub direct_tools: bool, + #[serde(default = "default_external_mcp_idle_timeout_minutes")] + pub idle_timeout: u64, +} + +impl Default for ExternalMcpSettings { + fn default() -> Self { + Self { + tool_prefix: default_external_mcp_tool_prefix(), + direct_tools: false, + idle_timeout: default_external_mcp_idle_timeout_minutes(), + } + } +} + +fn default_external_mcp_lifecycle() -> String { + "lazy".to_string() +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExternalMcpServer { + #[serde(default)] + pub unprefixed_tools: bool, + #[serde(default)] + pub command: Option, + #[serde(default)] + pub args: Vec, + #[serde(default)] + pub env: HashMap, + #[serde(default)] + pub cwd: Option, + #[serde(default)] + pub url: Option, + #[serde(default)] + pub headers: HashMap, + #[serde(default = "default_external_mcp_lifecycle")] + pub lifecycle: String, + #[serde(default)] + pub direct_tools: Option, + #[serde(default)] + pub exclude_tools: Vec, +} + +impl Default for ExternalMcpServer { + fn default() -> Self { + Self { + unprefixed_tools: false, + command: None, + args: Vec::new(), + env: HashMap::new(), + cwd: None, + url: None, + headers: HashMap::new(), + lifecycle: default_external_mcp_lifecycle(), + direct_tools: None, + exclude_tools: Vec::new(), + } + } +} + +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct ExternalMcpTuiStatus { + pub configured_server_count: usize, + pub connected_server_count: usize, + pub failed_server_count: usize, + pub tool_count: usize, + pub browser_gateway_enabled: bool, +} + +impl ExternalMcpTuiStatus { + pub fn render_summary(&self) -> String { + let browser = if self.browser_gateway_enabled { + ", browser gateway" + } else { + "" + }; + format!( + "{} configured, {} connected, {} failed, {} tools{}", + self.configured_server_count, + self.connected_server_count, + self.failed_server_count, + self.tool_count, + browser + ) + } +} + +#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExternalMcpConfig { + #[serde(default)] + pub settings: ExternalMcpSettings, + #[serde(default)] + pub mcp_servers: HashMap, +} + #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum AgentsPathMode { @@ -128,6 +246,8 @@ pub struct AppConfig { #[serde(default)] pub show_detail_mode: ShowDetailMode, #[serde(default)] + pub mcp: ExternalMcpConfig, + #[serde(default)] pub partner_binagotchy_seed: Option, #[serde(default)] pub set_catdesk_as_co_author: bool, @@ -145,6 +265,7 @@ impl Default for AppConfig { agents_path_mode: AgentsPathMode::Default, token_stats_layout: TokenStatsLayout::Right, show_detail_mode: ShowDetailMode::Expanded, + mcp: ExternalMcpConfig::default(), partner_binagotchy_seed: None, set_catdesk_as_co_author: false, theme: theme::DEFAULT_THEME_ID.to_string(), @@ -506,6 +627,7 @@ pub struct AppState { pub remote_connected: bool, pub last_remote_activity_ms: Option, pub devtools_running: bool, + pub external_mcp_status: ExternalMcpTuiStatus, pub port: u16, pub workspace_root: String, pub mascot_seed: u64, @@ -888,6 +1010,7 @@ impl AppState { remote_connected: false, last_remote_activity_ms: None, devtools_running: false, + external_mcp_status: ExternalMcpTuiStatus::default(), port, mascot_seed, partner_binagotchy_seed, @@ -1279,6 +1402,21 @@ toolCallCount = 7 let _ = std::fs::remove_dir(workspace); } + #[test] + fn external_mcp_tui_status_summary_formats_counts() { + let status = ExternalMcpTuiStatus { + configured_server_count: 3, + connected_server_count: 2, + failed_server_count: 1, + tool_count: 12, + browser_gateway_enabled: true, + }; + assert_eq!( + status.render_summary(), + "3 configured, 2 connected, 1 failed, 12 tools, browser gateway" + ); + } + #[test] fn app_config_round_trips_agents_path_mode() { let unique = SystemTime::now() @@ -1348,6 +1486,48 @@ toolCallCount = 7 let _ = std::fs::remove_dir(workspace); } + #[test] + fn app_config_round_trips_external_mcp_servers() { + let unique = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap_or_default() + .as_nanos(); + let workspace = std::env::temp_dir().join(format!("catdesk-config-external-mcp-{unique}")); + std::fs::create_dir_all(&workspace).expect("create temp config dir"); + let config_path = workspace.join(APP_CONFIG_FILE_NAME); + + let mut servers = HashMap::new(); + servers.insert( + "serena".to_string(), + ExternalMcpServer { + command: Some("serena-mcp-server".to_string()), + args: vec!["--project".to_string(), ".".to_string()], + lifecycle: "lazy".to_string(), + ..ExternalMcpServer::default() + }, + ); + let config = AppConfig { + mcp: ExternalMcpConfig { + mcp_servers: servers, + ..ExternalMcpConfig::default() + }, + ..AppConfig::default() + }; + config.save_to_path(&config_path).expect("save config"); + + let saved = AppConfig::load_from_path(&config_path).expect("load config"); + let server = saved + .mcp + .mcp_servers + .get("serena") + .expect("missing serena server"); + assert_eq!(server.command.as_deref(), Some("serena-mcp-server")); + assert_eq!(server.args, vec!["--project".to_string(), ".".to_string()]); + + let _ = std::fs::remove_file(config_path); + let _ = std::fs::remove_dir(workspace); + } + #[test] fn app_state_loads_partner_binagotchy_seed() { let unique = SystemTime::now()