From 823d230ed69bf8a1480c06112eb1fe66efcbe40d Mon Sep 17 00:00:00 2001 From: mohit Date: Sun, 14 Jun 2026 16:03:29 +0530 Subject: [PATCH] fix(forge): resolve abort race condition in agentic loop MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The orchestrator's in-memory state overwrites the abort flag set by the API when save_state() is called after tool execution. This made the abort button silently fail — users couldn't cancel expensive runs. Added _sync_abort_flag() that reads the abort flag from disk and merges it into in-memory state before every save_state() call. This ensures the API's abort request is preserved through the orchestrator's state saves. Added 6 tests covering the fix: - _sync_abort_flag preserves/respects abort flag - Abort during tool execution is preserved - Abort after iteration stops loop correctly - Abort flag survives multiple save_state calls All 109 tests pass. --- app/services/forge/agent_runner.py | 25 ++++ tests/test_abort_race_condition.py | 204 +++++++++++++++++++++++++++++ 2 files changed, 229 insertions(+) create mode 100644 tests/test_abort_race_condition.py diff --git a/app/services/forge/agent_runner.py b/app/services/forge/agent_runner.py index b6b8920..104bf21 100644 --- a/app/services/forge/agent_runner.py +++ b/app/services/forge/agent_runner.py @@ -207,6 +207,19 @@ def _check_abort(run: ForgeRun, repo_root: Path) -> None: raise _AbortRequested() +def _sync_abort_flag(state: AgenticRunState, run: ForgeRun, repo_root: Path) -> None: + """Merge abort flag from disk into in-memory state before save_state(). + + Race condition fix: The API can set abort_requested=True on disk while + the orchestrator is running. The in-memory state doesn't know about this + change, so save_state() would overwrite it with False. This function + reads the disk state and preserves the abort flag if it was set. + """ + disk_state = load_state(run, repo_root) + if disk_state and disk_state.abort_requested: + state.abort_requested = True + + def _run_single_turn( *, client: anthropic.Anthropic, @@ -347,6 +360,7 @@ def run_agentic_loop( state.iteration = iteration state.last_message = f"Iteration {iteration}/{config.max_iterations} — calling agent." + _sync_abort_flag(state, run, repo_root) save_state(state, run, repo_root) response = _run_single_turn( @@ -378,6 +392,7 @@ def run_agentic_loop( if response.stop_reason == "refusal": state.status = AgenticRunStatus.ERRORED state.error = "Agent refused to continue (stop_reason=refusal)." + _sync_abort_flag(state, run, repo_root) save_state(state, run, repo_root) return state @@ -393,6 +408,7 @@ def run_agentic_loop( # If the model ended its turn without using tools, it's done. if response.stop_reason == "end_turn": state.last_message = "Agent ended its turn without further tool calls." + _sync_abort_flag(state, run, repo_root) save_state(state, run, repo_root) break @@ -406,15 +422,18 @@ def run_agentic_loop( if not tool_results: # No tool_use blocks but stop_reason wasn't end_turn — unusual, stop. state.last_message = "Agent stopped without using any tools." + _sync_abort_flag(state, run, repo_root) save_state(state, run, repo_root) break messages.append({"role": "user", "content": tool_results}) + _sync_abort_flag(state, run, repo_root) save_state(state, run, repo_root) if give_up_payload is not None: state.status = AgenticRunStatus.REJECTED state.error = f"Agent gave up: {give_up_payload.get('reason', '(unknown)')}" + _sync_abort_flag(state, run, repo_root) save_state(state, run, repo_root) return state @@ -422,6 +441,7 @@ def run_agentic_loop( if state.last_verify_passed and state.last_benchmark_passed: if _try_promote(run, repo_root, state): state.status = AgenticRunStatus.SUCCEEDED + _sync_abort_flag(state, run, repo_root) save_state(state, run, repo_root) update_run_status(run, ForgeRunStatus.PROMOTED, repo_root) return state @@ -432,6 +452,7 @@ def run_agentic_loop( f"Cost cap reached (${state.cost_usd:.2f} >= ${state.cost_cap_usd:.2f}) " f"after iteration {iteration}." ) + _sync_abort_flag(state, run, repo_root) save_state(state, run, repo_root) return state @@ -442,22 +463,26 @@ def run_agentic_loop( f"Exhausted {config.max_iterations} iterations without producing a " f"verified, faster kernel." ) + _sync_abort_flag(state, run, repo_root) save_state(state, run, repo_root) return state except _AbortRequested: state.status = AgenticRunStatus.ABORTED state.last_message = "Aborted by user." + _sync_abort_flag(state, run, repo_root) save_state(state, run, repo_root) return state except anthropic.APIStatusError as e: state.status = AgenticRunStatus.ERRORED state.error = f"Anthropic API error ({e.status_code}): {e.message}" + _sync_abort_flag(state, run, repo_root) save_state(state, run, repo_root) return state except Exception as e: state.status = AgenticRunStatus.ERRORED state.error = f"{type(e).__name__}: {e}" + _sync_abort_flag(state, run, repo_root) save_state(state, run, repo_root) return state diff --git a/tests/test_abort_race_condition.py b/tests/test_abort_race_condition.py new file mode 100644 index 0000000..4f90a87 --- /dev/null +++ b/tests/test_abort_race_condition.py @@ -0,0 +1,204 @@ +"""Tests for the abort race condition fix in agent_runner.py.""" + +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from app.services.forge.agent_runner import ( + AgenticRunConfig, + _sync_abort_flag, + run_agentic_loop, +) +from app.services.forge.agent_state import ( + AgenticRunState, + AgenticRunStatus, + init_state, + load_state, + request_abort, + save_state, +) +from app.services.forge.models import ( + ForgeRun, + ForgeRunStatus, + KernelLanguage, + KernelOp, + KernelTaskSpec, +) + + +pytestmark = pytest.mark.forge + + +def _make_task() -> KernelTaskSpec: + return KernelTaskSpec( + op=KernelOp.RMSNORM, + language=KernelLanguage.TRITON, + target_gpu="RTX 4070", + dtype="fp16", + shape={"batch": 16, "hidden_size": 4096}, + ) + + +def _make_run(tmp_path: Path) -> ForgeRun: + run_id = "20260507T000000_rmsnorm_rtx-4070" + artifact_dir = tmp_path / "forge_runs" / run_id + artifact_dir.mkdir(parents=True) + (artifact_dir / "skill_bundle.md").write_text("# Test skill bundle\n") + + run = ForgeRun( + run_id=run_id, + status=ForgeRunStatus.CANDIDATE_READY, + task=_make_task(), + skill_ids=["inference.write-triton-rmsnorm-kernel"], + artifact_dir=str(Path("forge_runs") / run_id), + ) + (artifact_dir / "run.json").write_text(run.model_dump_json(indent=2)) + return run + + +def _make_config(**overrides) -> AgenticRunConfig: + defaults = {"max_iterations": 3, "cost_cap_usd": 10.0} + defaults.update(overrides) + return AgenticRunConfig(**defaults) + + +def _text_block(text: str): + block = SimpleNamespace(type="text", text=text) + block.model_dump = lambda: {"type": "text", "text": text} + return block + + +def _tool_use_block(name: str, tool_input: dict, block_id: str = "toolu_1"): + block = SimpleNamespace(type="tool_use", name=name, input=tool_input, id=block_id) + block.model_dump = lambda: {"type": "tool_use", "name": name, "input": tool_input, "id": block_id} + return block + + +def _mock_response(stop_reason: str, content: list, usage=None): + if usage is None: + usage = SimpleNamespace( + input_tokens=0, + output_tokens=0, + cache_creation_input_tokens=0, + cache_read_input_tokens=0, + ) + return SimpleNamespace(stop_reason=stop_reason, content=content, usage=usage) + + +# ---- _sync_abort_flag tests ---------------------------------------------- + + +def test_sync_abort_flag_preserves_abort_when_set_on_disk(tmp_path: Path): + run = _make_run(tmp_path) + state = init_state(run, tmp_path) + state.abort_requested = False + save_state(state, run, tmp_path) + + request_abort(run, tmp_path) + + _sync_abort_flag(state, run, tmp_path) + + assert state.abort_requested is True + + +def test_sync_abort_flag_does_nothing_when_not_set_on_disk(tmp_path: Path): + run = _make_run(tmp_path) + state = init_state(run, tmp_path) + state.abort_requested = False + save_state(state, run, tmp_path) + + _sync_abort_flag(state, run, tmp_path) + + assert state.abort_requested is False + + +def test_sync_abort_flag_preserves_existing_abort_in_memory(tmp_path: Path): + run = _make_run(tmp_path) + state = init_state(run, tmp_path) + state.abort_requested = True + save_state(state, run, tmp_path) + + _sync_abort_flag(state, run, tmp_path) + + assert state.abort_requested is True + + +# ---- Race condition scenario tests --------------------------------------- + + +def test_abort_during_tool_execution_is_preserved(tmp_path: Path, monkeypatch): + monkeypatch.setenv("ANTHROPIC_API_KEY", "fake-key") + run = _make_run(tmp_path) + config = _make_config(max_iterations=3) + + call_count = [0] + + def side_effect(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + request_abort(run, tmp_path) + return _mock_response("tool_use", [_tool_use_block("list_candidate_files", {})]) + return _mock_response("end_turn", [_text_block("Done")]) + + with patch("anthropic.Anthropic") as MockAnthropic: + mock_client = MagicMock() + mock_client.messages.create.side_effect = side_effect + MockAnthropic.return_value = mock_client + + state = run_agentic_loop(run, tmp_path, config) + + assert state.status == AgenticRunStatus.ABORTED + + +def test_abort_after_first_iteration_stops_loop(tmp_path: Path, monkeypatch): + monkeypatch.setenv("ANTHROPIC_API_KEY", "fake-key") + run = _make_run(tmp_path) + config = _make_config(max_iterations=5) + + call_count = [0] + + def side_effect(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + return _mock_response("tool_use", [_tool_use_block("list_candidate_files", {})]) + elif call_count[0] == 2: + request_abort(run, tmp_path) + return _mock_response("tool_use", [_tool_use_block("list_candidate_files", {})]) + return _mock_response("end_turn", [_text_block("Done")]) + + with patch("anthropic.Anthropic") as MockAnthropic: + mock_client = MagicMock() + mock_client.messages.create.side_effect = side_effect + MockAnthropic.return_value = mock_client + + state = run_agentic_loop(run, tmp_path, config) + + assert state.status == AgenticRunStatus.ABORTED + assert state.iteration == 2 + + +def test_abort_flag_survives_multiple_save_state_calls(tmp_path: Path, monkeypatch): + monkeypatch.setenv("ANTHROPIC_API_KEY", "fake-key") + run = _make_run(tmp_path) + config = _make_config(max_iterations=3) + + def side_effect(*args, **kwargs): + request_abort(run, tmp_path) + return _mock_response("tool_use", [_tool_use_block("list_candidate_files", {})]) + + with patch("anthropic.Anthropic") as MockAnthropic: + mock_client = MagicMock() + mock_client.messages.create.side_effect = side_effect + MockAnthropic.return_value = mock_client + + state = run_agentic_loop(run, tmp_path, config) + + assert state.status == AgenticRunStatus.ABORTED + + disk_state = load_state(run, tmp_path) + assert disk_state.abort_requested is True + assert disk_state.status == AgenticRunStatus.ABORTED