Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions src/tools/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -923,10 +923,12 @@ mod tests {
assert!(bg_output.unwrap().contains("background"));
}

#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_handler_spawn_doesnt_block() {
// This test verifies that spawning handlers doesn't block - multiple tools
// can have their handlers running concurrently
// can have their handlers running concurrently.
// Uses multi_thread runtime because shell handlers use spawn_blocking,
// which needs worker threads to poll JoinHandle completion promptly.
let mut registry = ToolRegistry::empty();
registry.register(std::sync::Arc::new(ShellTool::new()));
let mut executor = ToolExecutor::new(registry);
Expand All @@ -937,15 +939,15 @@ mod tests {
agent_id: 0,
call_id: "slow1".to_string(),
name: "mcp_shell".to_string(),
params: serde_json::json!({ "command": "sleep 0.1 && echo slow1_done" }),
params: serde_json::json!({ "command": "sleep 0.5 && echo slow1_done" }),
decision: ToolDecision::Approve,
background: true,
},
ToolCall {
agent_id: 0,
call_id: "slow2".to_string(),
name: "mcp_shell".to_string(),
params: serde_json::json!({ "command": "sleep 0.1 && echo slow2_done" }),
params: serde_json::json!({ "command": "sleep 0.5 && echo slow2_done" }),
decision: ToolDecision::Approve,
background: true,
},
Expand All @@ -954,17 +956,16 @@ mod tests {
let start = std::time::Instant::now();
let events = collect_events(&mut executor).await;
let elapsed = start.elapsed();

// If running concurrently, should take ~0.1s. If sequential, ~0.2s
// Allow some margin for test flakiness
assert!(elapsed.as_millis() < 180, "Tools should run concurrently, took {:?}", elapsed);


// If running concurrently, should take ~0.5s. If sequential, ~1.0s+
assert!(elapsed.as_millis() < 800, "Tools should run concurrently, took {:?}", elapsed);

// Both should complete
let completed: HashSet<_> = events.iter().filter_map(|e| {
if let ToolEvent::BackgroundCompleted { call_id, .. } = e { Some(call_id.clone()) } else { None }
}).collect();
assert_eq!(completed.len(), 2);

// Both outputs should be present
assert!(executor.get_background_output("slow1").unwrap().contains("slow1_done"));
assert!(executor.get_background_output("slow2").unwrap().contains("slow2_done"));
Expand Down
16 changes: 11 additions & 5 deletions src/tools/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,10 @@ impl EffectHandler for WriteFile {
// Shell handler
// =============================================================================

/// Execute a shell command
/// Execute a shell command.
///
/// Runs on tokio's blocking thread pool via `spawn_blocking` so that shell
/// I/O doesn't occupy async worker threads.
pub struct Shell {
pub command: String,
pub working_dir: Option<String>,
Expand All @@ -223,11 +226,14 @@ pub struct Shell {
#[async_trait::async_trait]
impl EffectHandler for Shell {
async fn call(self: Box<Self>) -> Step {
match io::execute_shell(&self.command, self.working_dir.as_deref(), self.timeout_secs).await
match tokio::task::spawn_blocking(move || {
io::execute_shell(&self.command, self.working_dir.as_deref(), self.timeout_secs)
})
.await
{
Ok(result) if result.success => Step::Output(result.output),
Ok(result) => Step::Output(result.output), // Still output, but includes exit code
Err(e) => Step::Error(e),
Ok(Ok(result)) => Step::Output(result.output),
Ok(Err(e)) => Step::Error(e),
Err(e) => Step::Error(format!("Shell task failed: {}", e)),
}
}
}
Expand Down
126 changes: 67 additions & 59 deletions src/tools/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,36 @@
use std::fs;
use std::os::unix::process::CommandExt;
use std::path::Path;
use std::process::Stdio;
use std::process::{Command, Stdio};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::Command;

/// Wrapper that kills the entire process group on drop.
///
/// `tokio::process::Child` does NOT kill on drop — it orphans the process.
/// We spawn bash with `setpgid(0, 0)` so it becomes a process group leader.
/// Guard that kills an entire process group on drop.
///
/// Spawned bash processes use `setpgid(0, 0)` to become process group leaders.
/// On drop, we send SIGKILL to the negative PID (the entire group), killing
/// bash and all its children (sleep, find, etc.).
struct KillOnDrop(tokio::process::Child);
/// bash and all its children. Call `disarm()` after the process exits normally
/// to prevent killing an already-exited group.
struct ProcessGroupGuard(Option<u32>);

impl Drop for KillOnDrop {
impl ProcessGroupGuard {
fn new(pid: u32) -> Self {
Self(Some(pid))
}

/// Disarm the guard so it won't kill the process group on drop.
fn disarm(&mut self) {
self.0 = None;
}
}

impl Drop for ProcessGroupGuard {
fn drop(&mut self) {
if let Some(pid) = self.0.id() {
// Kill the entire process group (negative PID = group kill).
// This is sync and safe to call in Drop.
if let Some(pid) = self.0 {
unsafe {
let pgid = -(pid as i32);
libc_kill(pgid, 9); // SIGKILL = 9
}
} else {
// Process already exited or no PID available, try direct kill.
let _ = self.0.start_kill();
}
}
}
Expand Down Expand Up @@ -143,8 +148,12 @@ pub fn read_file(
Ok(output)
}

/// Execute a shell command
pub async fn execute_shell(
/// Execute a shell command (blocking).
///
/// Runs entirely on the calling thread with no async runtime involvement.
/// Uses `std::process::Command` and synchronous I/O so this can be called
/// from `spawn_blocking` to keep shell work off the tokio worker threads.
pub fn execute_shell(
command: &str,
working_dir: Option<&str>,
timeout_secs: u64,
Expand All @@ -153,7 +162,7 @@ pub async fn execute_shell(
cmd.arg("-c").arg(command);
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::piped());
// Spawn in its own process group so KillOnDrop can kill all children.
// Spawn in its own process group so the guard can kill all children.
unsafe {
cmd.pre_exec(|| {
if setpgid(0, 0) != 0 {
Expand All @@ -174,50 +183,49 @@ pub async fn execute_shell(
cmd.current_dir(dir);
}

let mut child = KillOnDrop(cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?);

let stdout = child.0.stdout.take();
let stderr = child.0.stderr.take();
let child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?;
let pid = child.id();
let mut guard = ProcessGroupGuard::new(pid);

// Spawn a thread that kills the process group after the timeout.
let timed_out = Arc::new(AtomicBool::new(false));
let timed_out_flag = timed_out.clone();
let (cancel_tx, cancel_rx) = std::sync::mpsc::channel::<()>();

std::thread::spawn(move || {
if cancel_rx
.recv_timeout(std::time::Duration::from_secs(timeout_secs))
.is_err()
{
timed_out_flag.store(true, Ordering::SeqCst);
unsafe {
libc_kill(-(pid as i32), 9);
}
}
});

let mut collected = String::new();
// wait_with_output reads stdout and stderr in parallel, avoiding the
// deadlock that can occur when reading them sequentially with piped I/O.
let result = child
.wait_with_output()
.map_err(|e| format!("Wait failed: {}", e))?;

if let Some(stdout) = stdout {
let mut reader = BufReader::new(stdout).lines();
while let Ok(Some(line)) = reader.next_line().await {
collected.push_str(&line);
collected.push('\n');
}
}
// Process exited — cancel the timeout thread and disarm the guard.
let _ = cancel_tx.send(());
guard.disarm();

let mut stderr_output = String::new();
if let Some(stderr) = stderr {
let mut reader = BufReader::new(stderr).lines();
while let Ok(Some(line)) = reader.next_line().await {
stderr_output.push_str(&line);
stderr_output.push('\n');
}
if timed_out.load(Ordering::SeqCst) {
return Err(format!(
"Command timed out after {} seconds",
timeout_secs
));
}

let status = match tokio::time::timeout(
std::time::Duration::from_secs(timeout_secs),
child.0.wait(),
)
.await
{
Ok(Ok(status)) => status,
Ok(Err(e)) => return Err(format!("Wait failed: {}", e)),
Err(_) => {
// KillOnDrop will handle cleanup, but we can be explicit here too.
let _ = child.0.start_kill();
return Err(format!(
"Command timed out after {} seconds",
timeout_secs
));
}
};
let exit_code = result.status.code().unwrap_or(-1);
let collected = String::from_utf8_lossy(&result.stdout);
let stderr_output = String::from_utf8_lossy(&result.stderr);

let exit_code = status.code().unwrap_or(-1);
let mut output = collected;
let mut output = collected.into_owned();

if !stderr_output.is_empty() {
if !output.is_empty() {
Expand Down Expand Up @@ -253,7 +261,7 @@ pub async fn execute_shell(
Ok(ShellResult {
output,
exit_code,
success: status.success(),
success: result.status.success(),
})
}

Expand Down