diff --git a/src/tools/exec.rs b/src/tools/exec.rs index dd451d1..a92b823 100644 --- a/src/tools/exec.rs +++ b/src/tools/exec.rs @@ -923,10 +923,12 @@ mod tests { assert!(bg_output.unwrap().contains("background")); } - #[tokio::test] + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_handler_spawn_doesnt_block() { // This test verifies that spawning handlers doesn't block - multiple tools - // can have their handlers running concurrently + // can have their handlers running concurrently. + // Uses multi_thread runtime because shell handlers use spawn_blocking, + // which needs worker threads to poll JoinHandle completion promptly. let mut registry = ToolRegistry::empty(); registry.register(std::sync::Arc::new(ShellTool::new())); let mut executor = ToolExecutor::new(registry); @@ -937,7 +939,7 @@ mod tests { agent_id: 0, call_id: "slow1".to_string(), name: "mcp_shell".to_string(), - params: serde_json::json!({ "command": "sleep 0.1 && echo slow1_done" }), + params: serde_json::json!({ "command": "sleep 0.5 && echo slow1_done" }), decision: ToolDecision::Approve, background: true, }, @@ -945,7 +947,7 @@ mod tests { agent_id: 0, call_id: "slow2".to_string(), name: "mcp_shell".to_string(), - params: serde_json::json!({ "command": "sleep 0.1 && echo slow2_done" }), + params: serde_json::json!({ "command": "sleep 0.5 && echo slow2_done" }), decision: ToolDecision::Approve, background: true, }, @@ -954,17 +956,16 @@ mod tests { let start = std::time::Instant::now(); let events = collect_events(&mut executor).await; let elapsed = start.elapsed(); - - // If running concurrently, should take ~0.1s. If sequential, ~0.2s - // Allow some margin for test flakiness - assert!(elapsed.as_millis() < 180, "Tools should run concurrently, took {:?}", elapsed); - + + // If running concurrently, should take ~0.5s. If sequential, ~1.0s+ + assert!(elapsed.as_millis() < 800, "Tools should run concurrently, took {:?}", elapsed); + // Both should complete let completed: HashSet<_> = events.iter().filter_map(|e| { if let ToolEvent::BackgroundCompleted { call_id, .. } = e { Some(call_id.clone()) } else { None } }).collect(); assert_eq!(completed.len(), 2); - + // Both outputs should be present assert!(executor.get_background_output("slow1").unwrap().contains("slow1_done")); assert!(executor.get_background_output("slow2").unwrap().contains("slow2_done")); diff --git a/src/tools/handlers.rs b/src/tools/handlers.rs index 9e043b8..4f6c2d1 100644 --- a/src/tools/handlers.rs +++ b/src/tools/handlers.rs @@ -213,7 +213,10 @@ impl EffectHandler for WriteFile { // Shell handler // ============================================================================= -/// Execute a shell command +/// Execute a shell command. +/// +/// Runs on tokio's blocking thread pool via `spawn_blocking` so that shell +/// I/O doesn't occupy async worker threads. pub struct Shell { pub command: String, pub working_dir: Option, @@ -223,11 +226,14 @@ pub struct Shell { #[async_trait::async_trait] impl EffectHandler for Shell { async fn call(self: Box) -> Step { - match io::execute_shell(&self.command, self.working_dir.as_deref(), self.timeout_secs).await + match tokio::task::spawn_blocking(move || { + io::execute_shell(&self.command, self.working_dir.as_deref(), self.timeout_secs) + }) + .await { - Ok(result) if result.success => Step::Output(result.output), - Ok(result) => Step::Output(result.output), // Still output, but includes exit code - Err(e) => Step::Error(e), + Ok(Ok(result)) => Step::Output(result.output), + Ok(Err(e)) => Step::Error(e), + Err(e) => Step::Error(format!("Shell task failed: {}", e)), } } } diff --git a/src/tools/io.rs b/src/tools/io.rs index 3978977..774a174 100644 --- a/src/tools/io.rs +++ b/src/tools/io.rs @@ -6,31 +6,36 @@ use std::fs; use std::os::unix::process::CommandExt; use std::path::Path; -use std::process::Stdio; +use std::process::{Command, Stdio}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; -use tokio::io::{AsyncBufReadExt, BufReader}; -use tokio::process::Command; - -/// Wrapper that kills the entire process group on drop. -/// -/// `tokio::process::Child` does NOT kill on drop — it orphans the process. -/// We spawn bash with `setpgid(0, 0)` so it becomes a process group leader. +/// Guard that kills an entire process group on drop. +/// +/// Spawned bash processes use `setpgid(0, 0)` to become process group leaders. /// On drop, we send SIGKILL to the negative PID (the entire group), killing -/// bash and all its children (sleep, find, etc.). -struct KillOnDrop(tokio::process::Child); +/// bash and all its children. Call `disarm()` after the process exits normally +/// to prevent killing an already-exited group. +struct ProcessGroupGuard(Option); -impl Drop for KillOnDrop { +impl ProcessGroupGuard { + fn new(pid: u32) -> Self { + Self(Some(pid)) + } + + /// Disarm the guard so it won't kill the process group on drop. + fn disarm(&mut self) { + self.0 = None; + } +} + +impl Drop for ProcessGroupGuard { fn drop(&mut self) { - if let Some(pid) = self.0.id() { - // Kill the entire process group (negative PID = group kill). - // This is sync and safe to call in Drop. + if let Some(pid) = self.0 { unsafe { let pgid = -(pid as i32); libc_kill(pgid, 9); // SIGKILL = 9 } - } else { - // Process already exited or no PID available, try direct kill. - let _ = self.0.start_kill(); } } } @@ -143,8 +148,12 @@ pub fn read_file( Ok(output) } -/// Execute a shell command -pub async fn execute_shell( +/// Execute a shell command (blocking). +/// +/// Runs entirely on the calling thread with no async runtime involvement. +/// Uses `std::process::Command` and synchronous I/O so this can be called +/// from `spawn_blocking` to keep shell work off the tokio worker threads. +pub fn execute_shell( command: &str, working_dir: Option<&str>, timeout_secs: u64, @@ -153,7 +162,7 @@ pub async fn execute_shell( cmd.arg("-c").arg(command); cmd.stdout(Stdio::piped()); cmd.stderr(Stdio::piped()); - // Spawn in its own process group so KillOnDrop can kill all children. + // Spawn in its own process group so the guard can kill all children. unsafe { cmd.pre_exec(|| { if setpgid(0, 0) != 0 { @@ -174,50 +183,49 @@ pub async fn execute_shell( cmd.current_dir(dir); } - let mut child = KillOnDrop(cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?); - - let stdout = child.0.stdout.take(); - let stderr = child.0.stderr.take(); + let child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?; + let pid = child.id(); + let mut guard = ProcessGroupGuard::new(pid); + + // Spawn a thread that kills the process group after the timeout. + let timed_out = Arc::new(AtomicBool::new(false)); + let timed_out_flag = timed_out.clone(); + let (cancel_tx, cancel_rx) = std::sync::mpsc::channel::<()>(); + + std::thread::spawn(move || { + if cancel_rx + .recv_timeout(std::time::Duration::from_secs(timeout_secs)) + .is_err() + { + timed_out_flag.store(true, Ordering::SeqCst); + unsafe { + libc_kill(-(pid as i32), 9); + } + } + }); - let mut collected = String::new(); + // wait_with_output reads stdout and stderr in parallel, avoiding the + // deadlock that can occur when reading them sequentially with piped I/O. + let result = child + .wait_with_output() + .map_err(|e| format!("Wait failed: {}", e))?; - if let Some(stdout) = stdout { - let mut reader = BufReader::new(stdout).lines(); - while let Ok(Some(line)) = reader.next_line().await { - collected.push_str(&line); - collected.push('\n'); - } - } + // Process exited — cancel the timeout thread and disarm the guard. + let _ = cancel_tx.send(()); + guard.disarm(); - let mut stderr_output = String::new(); - if let Some(stderr) = stderr { - let mut reader = BufReader::new(stderr).lines(); - while let Ok(Some(line)) = reader.next_line().await { - stderr_output.push_str(&line); - stderr_output.push('\n'); - } + if timed_out.load(Ordering::SeqCst) { + return Err(format!( + "Command timed out after {} seconds", + timeout_secs + )); } - let status = match tokio::time::timeout( - std::time::Duration::from_secs(timeout_secs), - child.0.wait(), - ) - .await - { - Ok(Ok(status)) => status, - Ok(Err(e)) => return Err(format!("Wait failed: {}", e)), - Err(_) => { - // KillOnDrop will handle cleanup, but we can be explicit here too. - let _ = child.0.start_kill(); - return Err(format!( - "Command timed out after {} seconds", - timeout_secs - )); - } - }; + let exit_code = result.status.code().unwrap_or(-1); + let collected = String::from_utf8_lossy(&result.stdout); + let stderr_output = String::from_utf8_lossy(&result.stderr); - let exit_code = status.code().unwrap_or(-1); - let mut output = collected; + let mut output = collected.into_owned(); if !stderr_output.is_empty() { if !output.is_empty() { @@ -253,7 +261,7 @@ pub async fn execute_shell( Ok(ShellResult { output, exit_code, - success: status.success(), + success: result.status.success(), }) }