diff --git a/src/agent/runloop/unified/state.rs b/src/agent/runloop/unified/state.rs index b9ce1879e..aee90dfd5 100644 --- a/src/agent/runloop/unified/state.rs +++ b/src/agent/runloop/unified/state.rs @@ -52,11 +52,11 @@ impl CtrlCState { } pub(crate) fn is_cancel_requested(&self) -> bool { - self.cancel_requested.load(Ordering::Relaxed) + self.cancel_requested.load(Ordering::Acquire) } pub(crate) fn is_exit_requested(&self) -> bool { - self.exit_requested.load(Ordering::Relaxed) + self.exit_requested.load(Ordering::Acquire) } pub(crate) fn disarm_exit(&self) { diff --git a/src/agent/runloop/unified/turn/run_loop.rs b/src/agent/runloop/unified/turn/run_loop.rs index 56fdc3efb..866af0848 100644 --- a/src/agent/runloop/unified/turn/run_loop.rs +++ b/src/agent/runloop/unified/turn/run_loop.rs @@ -1289,12 +1289,10 @@ pub(crate) async fn run_single_agent_loop_unified( let _updated_snapshot = { let mut guard = tools.write().await; guard.retain(|tool| { - !tool - .function + tool.function .as_ref() - .unwrap() - .name - .starts_with("mcp_") + .map(|f| !f.name.starts_with("mcp_")) + .unwrap_or(true) }); guard.extend(new_definitions); guard.clone() @@ -1368,12 +1366,10 @@ pub(crate) async fn run_single_agent_loop_unified( let _updated_snapshot = { let mut guard = tools.write().await; guard.retain(|tool| { - !tool - .function + tool.function .as_ref() - .unwrap() - .name - .starts_with("mcp_") + .map(|f| !f.name.starts_with("mcp_")) + .unwrap_or(true) }); guard.extend(new_definitions); guard.clone() @@ -2448,15 +2444,26 @@ pub(crate) async fn run_single_agent_loop_unified( // This prevents the loop from breaking after tool execution let _ = final_text.take(); for call in &tool_calls { - let name = call - .function - .as_ref() - .expect("Tool call must have function") - .name - .as_str(); - let args_val = call - .parsed_arguments() - .unwrap_or_else(|_| serde_json::json!({})); + let Some(function) = call.function.as_ref() else { + tracing::warn!("Malformed tool call: missing function definition"); + working_history.push(uni::Message::system( + "Skipped malformed tool call: missing function definition".to_string(), + )); + continue; + }; + let name = function.name.as_str(); + let args_val = match call.parsed_arguments() { + Ok(args) => args, + Err(err) => { + tracing::warn!("Failed to parse args for '{}': {}", name, err); + let error_msg = format!( + "Tool '{}' received invalid arguments: {}", + name, err + ); + working_history.push(uni::Message::system(error_msg)); + continue; + } + }; // Normalize args for loop detection: strip pagination params and normalize paths let normalized_args = if let Some(obj) = args_val.as_object() { diff --git a/vtcode-core/src/llm/providers/anthropic.rs b/vtcode-core/src/llm/providers/anthropic.rs index 5a8044d69..e076934e8 100644 --- a/vtcode-core/src/llm/providers/anthropic.rs +++ b/vtcode-core/src/llm/providers/anthropic.rs @@ -1002,6 +1002,14 @@ impl LLMProvider for AnthropicProvider { } async fn generate(&self, request: LLMRequest) -> Result { + // Validate API key before making request + if self.api_key.trim().is_empty() { + return Err(LLMError::Authentication { + message: "Anthropic API key is not configured. Set ANTHROPIC_API_KEY environment variable.".to_string(), + metadata: None, + }); + } + let anthropic_request = self.convert_to_anthropic_format(&request)?; let url = format!("{}/messages", self.base_url); diff --git a/vtcode-core/src/llm/providers/gemini.rs b/vtcode-core/src/llm/providers/gemini.rs index dd3fd9d4f..db86c5e0d 100644 --- a/vtcode-core/src/llm/providers/gemini.rs +++ b/vtcode-core/src/llm/providers/gemini.rs @@ -227,16 +227,25 @@ impl LLMProvider for GeminiProvider { } async fn generate(&self, request: LLMRequest) -> Result { + // Validate API key before making request + if self.api_key.trim().is_empty() { + return Err(LLMError::Authentication { + message: "Gemini API key is not configured. Set GEMINI_API_KEY or GOOGLE_API_KEY environment variable.".to_string(), + metadata: None, + }); + } + let gemini_request = self.convert_to_gemini_request(&request)?; let url = format!( - "{}/models/{}:generateContent?key={}", - self.base_url, request.model, self.api_key + "{}/models/{}:generateContent", + self.base_url, request.model ); let response = self .http_client .post(&url) + .header("x-goog-api-key", self.api_key.as_ref()) .json(&gemini_request) .send() .await @@ -257,16 +266,25 @@ impl LLMProvider for GeminiProvider { } async fn stream(&self, request: LLMRequest) -> Result { + // Validate API key before making request + if self.api_key.trim().is_empty() { + return Err(LLMError::Authentication { + message: "Gemini API key is not configured. Set GEMINI_API_KEY or GOOGLE_API_KEY environment variable.".to_string(), + metadata: None, + }); + } + let gemini_request = self.convert_to_gemini_request(&request)?; let url = format!( - "{}/models/{}:streamGenerateContent?key={}", - self.base_url, request.model, self.api_key + "{}/models/{}:streamGenerateContent", + self.base_url, request.model ); let response = self .http_client .post(&url) + .header("x-goog-api-key", self.api_key.as_ref()) .json(&gemini_request) .send() .await diff --git a/vtcode-tools/src/acp_tool.rs b/vtcode-tools/src/acp_tool.rs index ce154a5d3..06be55e33 100644 --- a/vtcode-tools/src/acp_tool.rs +++ b/vtcode-tools/src/acp_tool.rs @@ -19,6 +19,14 @@ mod shared { const ERR_ARGS_OBJECT: &str = "Arguments must be an object"; const ERR_CLIENT_UNINITIALIZED: &str = "ACP client not initialized"; + /// Maximum allowed length for agent IDs to prevent DoS via oversized strings. + const MAX_AGENT_ID_LEN: usize = 256; + /// Maximum allowed length for action names. + const MAX_ACTION_LEN: usize = 128; + /// Maximum JSON depth for call_args to prevent stack overflow. + const MAX_JSON_DEPTH: usize = 32; + /// Maximum size for call_args payload in bytes. + const MAX_ARGS_SIZE: usize = 1024 * 1024; // 1MB pub fn extract_args_object(args: &Value) -> anyhow::Result<&serde_json::Map> { args.as_object() @@ -54,6 +62,78 @@ mod shared { } Ok(()) } + + /// Validate agent ID format: alphanumeric, hyphens, underscores only, length limit. + pub fn validate_agent_id(agent_id: &str) -> anyhow::Result<()> { + if agent_id.is_empty() { + return Err(anyhow::anyhow!("agent_id cannot be empty")); + } + if agent_id.len() > MAX_AGENT_ID_LEN { + return Err(anyhow::anyhow!( + "agent_id exceeds maximum length of {} characters", + MAX_AGENT_ID_LEN + )); + } + if !agent_id + .chars() + .all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '.') + { + return Err(anyhow::anyhow!( + "agent_id contains invalid characters (allowed: alphanumeric, hyphen, underscore, dot)" + )); + } + Ok(()) + } + + /// Validate action name format. + pub fn validate_action(action: &str) -> anyhow::Result<()> { + if action.is_empty() { + return Err(anyhow::anyhow!("action cannot be empty")); + } + if action.len() > MAX_ACTION_LEN { + return Err(anyhow::anyhow!( + "action exceeds maximum length of {} characters", + MAX_ACTION_LEN + )); + } + if !action + .chars() + .all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '.') + { + return Err(anyhow::anyhow!( + "action contains invalid characters" + )); + } + Ok(()) + } + + /// Validate call_args size and depth. + pub fn validate_call_args(args: &Value) -> anyhow::Result<()> { + let serialized = serde_json::to_string(args) + .map_err(|e| anyhow::anyhow!("Failed to serialize args: {}", e))?; + if serialized.len() > MAX_ARGS_SIZE { + return Err(anyhow::anyhow!( + "call_args exceeds maximum size of {} bytes", + MAX_ARGS_SIZE + )); + } + if json_depth(args) > MAX_JSON_DEPTH { + return Err(anyhow::anyhow!( + "call_args exceeds maximum nesting depth of {}", + MAX_JSON_DEPTH + )); + } + Ok(()) + } + + /// Calculate JSON nesting depth. + fn json_depth(value: &Value) -> usize { + match value { + Value::Array(arr) => 1 + arr.iter().map(json_depth).max().unwrap_or(0), + Value::Object(obj) => 1 + obj.values().map(json_depth).max().unwrap_or(0), + _ => 0, + } + } } /// ACP Inter-Agent Communication Tool @@ -106,6 +186,23 @@ impl Tool for AcpTool { let obj = shared::extract_args_object(args)?; shared::validate_field_exists(obj, "action")?; shared::validate_field_exists(obj, "remote_agent_id")?; + // Validate formats + if let Some(action) = obj.get("action").and_then(|v| v.as_str()) { + shared::validate_action(action)?; + } + if let Some(agent_id) = obj.get("remote_agent_id").and_then(|v| v.as_str()) { + shared::validate_agent_id(agent_id)?; + } + // Validate method if provided + if let Some(method) = obj.get("method").and_then(|v| v.as_str()) { + if method != "sync" && method != "async" { + return Err(anyhow::anyhow!("Invalid method '{}': must be 'sync' or 'async'", method)); + } + } + // Validate call_args if provided + if let Some(call_args) = obj.get("args") { + shared::validate_call_args(call_args)?; + } Ok(()) } @@ -116,7 +213,12 @@ impl Tool for AcpTool { let remote_agent_id = shared::get_required_field(obj, "remote_agent_id", None)?; let method = obj.get("method").and_then(|v| v.as_str()).unwrap_or("sync"); + // Validate inputs before use + shared::validate_action(action)?; + shared::validate_agent_id(remote_agent_id)?; + let call_args = obj.get("args").cloned().unwrap_or(json!({})); + shared::validate_call_args(&call_args)?; let client = self.client.read().await; let client = shared::check_client_initialized(&*client)?; diff --git a/vtcode-tools/src/executor.rs b/vtcode-tools/src/executor.rs index 0a4e6814c..838a2ebe2 100644 --- a/vtcode-tools/src/executor.rs +++ b/vtcode-tools/src/executor.rs @@ -138,7 +138,23 @@ impl CachedToolExecutor { // Execute tool (caller provides actual execution) // This is where your tool registry would call the actual tool - let result = self.execute_tool_internal(tool_name, &*owned_args).await?; + let result = match self.execute_tool_internal(tool_name, &*owned_args).await { + Ok(r) => r, + Err(e) => { + // Invoke error handlers before propagating + if let Err(hook_err) = self.middleware.on_error(&req, &e).await { + eprintln!("[vtcode-tools] Middleware on_error hook failed: {}", hook_err); + } + // Update failed stats + { + let mut stats = self.stats.write().await; + stats.failed_calls += 1; + } + // Record failure in pattern detector + self.record_pattern(tool_name, false, start.elapsed().as_millis() as u64).await; + return Err(e); + } + }; let duration_ms = start.elapsed().as_millis() as u64; diff --git a/vtcode-tools/src/patterns.rs b/vtcode-tools/src/patterns.rs index ae24ae338..caa917491 100644 --- a/vtcode-tools/src/patterns.rs +++ b/vtcode-tools/src/patterns.rs @@ -25,25 +25,40 @@ pub struct DetectedPattern { pub confidence: f64, } +/// Maximum events to retain before eviction (prevents unbounded memory growth). +const MAX_EVENTS_CAPACITY: usize = 1000; + /// Pattern detector using sequence analysis. pub struct PatternDetector { events: Vec, patterns: HashMap, sequence_length: usize, + max_events: usize, } impl PatternDetector { /// Create new detector with sliding window size. pub fn new(sequence_length: usize) -> Self { + Self::with_capacity(sequence_length, MAX_EVENTS_CAPACITY) + } + + /// Create new detector with custom event capacity limit. + pub fn with_capacity(sequence_length: usize, max_events: usize) -> Self { Self { - events: Vec::with_capacity(64), + events: Vec::with_capacity(64.min(max_events)), patterns: HashMap::with_capacity(16), sequence_length, + max_events: max_events.max(sequence_length * 2), } } - /// Add an event to the detector. + /// Add an event to the detector with automatic eviction. pub fn record_event(&mut self, event: ToolEvent) { + // Evict oldest events if at capacity (sliding window) + if self.events.len() >= self.max_events { + let drain_count = self.max_events / 4; // Remove 25% of oldest + self.events.drain(0..drain_count); + } self.events.push(event); self.analyze(); }