Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions app/services/forge/agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,19 @@ def _check_abort(run: ForgeRun, repo_root: Path) -> None:
raise _AbortRequested()


def _sync_abort_flag(state: AgenticRunState, run: ForgeRun, repo_root: Path) -> None:
"""Merge abort flag from disk into in-memory state before save_state().

Race condition fix: The API can set abort_requested=True on disk while
the orchestrator is running. The in-memory state doesn't know about this
change, so save_state() would overwrite it with False. This function
reads the disk state and preserves the abort flag if it was set.
"""
disk_state = load_state(run, repo_root)
if disk_state and disk_state.abort_requested:
state.abort_requested = True


def _run_single_turn(
*,
client: anthropic.Anthropic,
Expand Down Expand Up @@ -347,6 +360,7 @@ def run_agentic_loop(

state.iteration = iteration
state.last_message = f"Iteration {iteration}/{config.max_iterations} — calling agent."
_sync_abort_flag(state, run, repo_root)
save_state(state, run, repo_root)

response = _run_single_turn(
Expand Down Expand Up @@ -378,6 +392,7 @@ def run_agentic_loop(
if response.stop_reason == "refusal":
state.status = AgenticRunStatus.ERRORED
state.error = "Agent refused to continue (stop_reason=refusal)."
_sync_abort_flag(state, run, repo_root)
save_state(state, run, repo_root)
return state

Expand All @@ -393,6 +408,7 @@ def run_agentic_loop(
# If the model ended its turn without using tools, it's done.
if response.stop_reason == "end_turn":
state.last_message = "Agent ended its turn without further tool calls."
_sync_abort_flag(state, run, repo_root)
save_state(state, run, repo_root)
break

Expand All @@ -406,22 +422,26 @@ def run_agentic_loop(
if not tool_results:
# No tool_use blocks but stop_reason wasn't end_turn — unusual, stop.
state.last_message = "Agent stopped without using any tools."
_sync_abort_flag(state, run, repo_root)
save_state(state, run, repo_root)
break

messages.append({"role": "user", "content": tool_results})
_sync_abort_flag(state, run, repo_root)
save_state(state, run, repo_root)

if give_up_payload is not None:
state.status = AgenticRunStatus.REJECTED
state.error = f"Agent gave up: {give_up_payload.get('reason', '(unknown)')}"
_sync_abort_flag(state, run, repo_root)
save_state(state, run, repo_root)
return state

# Try promotion if both gates passed in this iteration
if state.last_verify_passed and state.last_benchmark_passed:
if _try_promote(run, repo_root, state):
state.status = AgenticRunStatus.SUCCEEDED
_sync_abort_flag(state, run, repo_root)
save_state(state, run, repo_root)
update_run_status(run, ForgeRunStatus.PROMOTED, repo_root)
return state
Expand All @@ -432,6 +452,7 @@ def run_agentic_loop(
f"Cost cap reached (${state.cost_usd:.2f} >= ${state.cost_cap_usd:.2f}) "
f"after iteration {iteration}."
)
_sync_abort_flag(state, run, repo_root)
save_state(state, run, repo_root)
return state

Expand All @@ -442,22 +463,26 @@ def run_agentic_loop(
f"Exhausted {config.max_iterations} iterations without producing a "
f"verified, faster kernel."
)
_sync_abort_flag(state, run, repo_root)
save_state(state, run, repo_root)
return state

except _AbortRequested:
state.status = AgenticRunStatus.ABORTED
state.last_message = "Aborted by user."
_sync_abort_flag(state, run, repo_root)
save_state(state, run, repo_root)
return state
except anthropic.APIStatusError as e:
state.status = AgenticRunStatus.ERRORED
state.error = f"Anthropic API error ({e.status_code}): {e.message}"
_sync_abort_flag(state, run, repo_root)
save_state(state, run, repo_root)
return state
except Exception as e:
state.status = AgenticRunStatus.ERRORED
state.error = f"{type(e).__name__}: {e}"
_sync_abort_flag(state, run, repo_root)
save_state(state, run, repo_root)
return state

Expand Down
204 changes: 204 additions & 0 deletions tests/test_abort_race_condition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
"""Tests for the abort race condition fix in agent_runner.py."""

from __future__ import annotations

from pathlib import Path
from types import SimpleNamespace
from unittest.mock import MagicMock, patch

import pytest

from app.services.forge.agent_runner import (
AgenticRunConfig,
_sync_abort_flag,
run_agentic_loop,
)
from app.services.forge.agent_state import (
AgenticRunState,
AgenticRunStatus,
init_state,
load_state,
request_abort,
save_state,
)
from app.services.forge.models import (
ForgeRun,
ForgeRunStatus,
KernelLanguage,
KernelOp,
KernelTaskSpec,
)


pytestmark = pytest.mark.forge


def _make_task() -> KernelTaskSpec:
return KernelTaskSpec(
op=KernelOp.RMSNORM,
language=KernelLanguage.TRITON,
target_gpu="RTX 4070",
dtype="fp16",
shape={"batch": 16, "hidden_size": 4096},
)


def _make_run(tmp_path: Path) -> ForgeRun:
run_id = "20260507T000000_rmsnorm_rtx-4070"
artifact_dir = tmp_path / "forge_runs" / run_id
artifact_dir.mkdir(parents=True)
(artifact_dir / "skill_bundle.md").write_text("# Test skill bundle\n")

run = ForgeRun(
run_id=run_id,
status=ForgeRunStatus.CANDIDATE_READY,
task=_make_task(),
skill_ids=["inference.write-triton-rmsnorm-kernel"],
artifact_dir=str(Path("forge_runs") / run_id),
)
(artifact_dir / "run.json").write_text(run.model_dump_json(indent=2))
return run


def _make_config(**overrides) -> AgenticRunConfig:
defaults = {"max_iterations": 3, "cost_cap_usd": 10.0}
defaults.update(overrides)
return AgenticRunConfig(**defaults)


def _text_block(text: str):
block = SimpleNamespace(type="text", text=text)
block.model_dump = lambda: {"type": "text", "text": text}
return block


def _tool_use_block(name: str, tool_input: dict, block_id: str = "toolu_1"):
block = SimpleNamespace(type="tool_use", name=name, input=tool_input, id=block_id)
block.model_dump = lambda: {"type": "tool_use", "name": name, "input": tool_input, "id": block_id}
return block


def _mock_response(stop_reason: str, content: list, usage=None):
if usage is None:
usage = SimpleNamespace(
input_tokens=0,
output_tokens=0,
cache_creation_input_tokens=0,
cache_read_input_tokens=0,
)
return SimpleNamespace(stop_reason=stop_reason, content=content, usage=usage)


# ---- _sync_abort_flag tests ----------------------------------------------


def test_sync_abort_flag_preserves_abort_when_set_on_disk(tmp_path: Path):
run = _make_run(tmp_path)
state = init_state(run, tmp_path)
state.abort_requested = False
save_state(state, run, tmp_path)

request_abort(run, tmp_path)

_sync_abort_flag(state, run, tmp_path)

assert state.abort_requested is True


def test_sync_abort_flag_does_nothing_when_not_set_on_disk(tmp_path: Path):
run = _make_run(tmp_path)
state = init_state(run, tmp_path)
state.abort_requested = False
save_state(state, run, tmp_path)

_sync_abort_flag(state, run, tmp_path)

assert state.abort_requested is False


def test_sync_abort_flag_preserves_existing_abort_in_memory(tmp_path: Path):
run = _make_run(tmp_path)
state = init_state(run, tmp_path)
state.abort_requested = True
save_state(state, run, tmp_path)

_sync_abort_flag(state, run, tmp_path)

assert state.abort_requested is True


# ---- Race condition scenario tests ---------------------------------------


def test_abort_during_tool_execution_is_preserved(tmp_path: Path, monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "fake-key")
run = _make_run(tmp_path)
config = _make_config(max_iterations=3)

call_count = [0]

def side_effect(*args, **kwargs):
call_count[0] += 1
if call_count[0] == 1:
request_abort(run, tmp_path)
return _mock_response("tool_use", [_tool_use_block("list_candidate_files", {})])
return _mock_response("end_turn", [_text_block("Done")])

with patch("anthropic.Anthropic") as MockAnthropic:
mock_client = MagicMock()
mock_client.messages.create.side_effect = side_effect
MockAnthropic.return_value = mock_client

state = run_agentic_loop(run, tmp_path, config)

assert state.status == AgenticRunStatus.ABORTED


def test_abort_after_first_iteration_stops_loop(tmp_path: Path, monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "fake-key")
run = _make_run(tmp_path)
config = _make_config(max_iterations=5)

call_count = [0]

def side_effect(*args, **kwargs):
call_count[0] += 1
if call_count[0] == 1:
return _mock_response("tool_use", [_tool_use_block("list_candidate_files", {})])
elif call_count[0] == 2:
request_abort(run, tmp_path)
return _mock_response("tool_use", [_tool_use_block("list_candidate_files", {})])
return _mock_response("end_turn", [_text_block("Done")])

with patch("anthropic.Anthropic") as MockAnthropic:
mock_client = MagicMock()
mock_client.messages.create.side_effect = side_effect
MockAnthropic.return_value = mock_client

state = run_agentic_loop(run, tmp_path, config)

assert state.status == AgenticRunStatus.ABORTED
assert state.iteration == 2


def test_abort_flag_survives_multiple_save_state_calls(tmp_path: Path, monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "fake-key")
run = _make_run(tmp_path)
config = _make_config(max_iterations=3)

def side_effect(*args, **kwargs):
request_abort(run, tmp_path)
return _mock_response("tool_use", [_tool_use_block("list_candidate_files", {})])

with patch("anthropic.Anthropic") as MockAnthropic:
mock_client = MagicMock()
mock_client.messages.create.side_effect = side_effect
MockAnthropic.return_value = mock_client

state = run_agentic_loop(run, tmp_path, config)

assert state.status == AgenticRunStatus.ABORTED

disk_state = load_state(run, tmp_path)
assert disk_state.abort_requested is True
assert disk_state.status == AgenticRunStatus.ABORTED