Skip to content

Commit 3318f22

Browse files
authored
Checkpointing + Error Retry for Rollout Processor (#80)
* Finished Error Handling * Address comments * Changing the rollout processors * cleaning up mcp gym * remove import * Update * failing test * fixing flaky test * update comments
1 parent d81b1f4 commit 3318f22

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+831
-489
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ jobs:
9292
--ignore=tests/pytest/test_frozen_lake.py \
9393
--ignore=tests/pytest/test_lunar_lander.py \
9494
--ignore=tests/pytest/test_tau_bench_airline.py \
95+
--ignore=tests/pytest/test_apps_coding.py \
9596
--ignore=tests/test_tau_bench_airline_smoke.py \
9697
--cov=eval_protocol --cov-append --cov-report=xml --cov-report=term-missing -v --durations=10
9798

eval_protocol/benchmarks/suites/aime25.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from eval_protocol.benchmarks.registry import export_benchmark
44
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult
55
from eval_protocol.pytest.default_single_turn_rollout_process import (
6-
default_single_turn_rollout_processor,
6+
SingleTurnRolloutProcessor,
77
)
88
from eval_protocol.pytest.evaluation_test import evaluation_test
99

@@ -72,7 +72,7 @@ def aime2025_dataset_adapter(rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
7272
"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b",
7373
}
7474
],
75-
rollout_processor=default_single_turn_rollout_processor,
75+
rollout_processor=SingleTurnRolloutProcessor(),
7676
aggregation_method="mean",
7777
passed_threshold=None,
7878
num_runs=8,

eval_protocol/benchmarks/suites/gpqa.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import csv
23
import io
34
import re
@@ -8,9 +9,11 @@
89
from eval_protocol.benchmarks.registry import export_benchmark
910
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult
1011
from eval_protocol.pytest.default_single_turn_rollout_process import (
11-
default_single_turn_rollout_processor,
12+
SingleTurnRolloutProcessor,
1213
)
1314
from eval_protocol.pytest.evaluation_test import evaluation_test
15+
from eval_protocol.pytest.rollout_processor import RolloutProcessor
16+
from eval_protocol.pytest.types import RolloutProcessorConfig
1417

1518
SYSTEM_PROMPT = (
1619
"You are a helpful assistant. Read the question and options carefully. "
@@ -60,19 +63,31 @@ def _strip_gt_messages(msgs: List[Message]) -> List[Message]:
6063
return [m for m in msgs if not (m.role == "system" and (m.content or "").startswith("__GT__:"))]
6164

6265

63-
async def gpqa_strip_gt_rollout_processor(rows: List[EvaluationRow], config) -> List[EvaluationRow]:
64-
"""Preprocess rows to set ground_truth and remove __GT__ messages, then delegate to default processor."""
65-
processed: List[EvaluationRow] = []
66-
for r in rows:
67-
gt_tokens = [m.content for m in r.messages if m.role == "system" and (m.content or "").startswith("__GT__:")]
68-
if gt_tokens:
69-
gt_val = gt_tokens[-1].split(":", 1)[1].strip()
70-
r.ground_truth = gt_val
71-
r.messages = [
72-
m for m in r.messages if not (m.role == "system" and (m.content or "").startswith("__GT__:"))
66+
class GPQAStripGTRolloutProcessor(RolloutProcessor):
67+
"""Preprocess rows to set ground_truth and remove __GT__ messages, then delegate to SingleTurnRolloutProcessor."""
68+
69+
def __init__(self):
70+
super().__init__()
71+
self.single_turn_processor = SingleTurnRolloutProcessor()
72+
73+
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
74+
"""Preprocess rows and delegate to SingleTurnRolloutProcessor."""
75+
processed: List[EvaluationRow] = []
76+
77+
for r in rows:
78+
gt_tokens = [
79+
m.content for m in r.messages if m.role == "system" and (m.content or "").startswith("__GT__:")
7380
]
74-
processed.append(r)
75-
return await default_single_turn_rollout_processor(processed, config)
81+
if gt_tokens:
82+
gt_val = gt_tokens[-1].split(":", 1)[1].strip()
83+
r.ground_truth = gt_val
84+
r.messages = [
85+
m for m in r.messages if not (m.role == "system" and (m.content or "").startswith("__GT__:"))
86+
]
87+
processed.append(r)
88+
89+
# Delegate to SingleTurnRolloutProcessor
90+
return self.single_turn_processor(processed, config)
7691

7792

7893
@export_benchmark("gpqa")
@@ -81,7 +96,7 @@ async def gpqa_strip_gt_rollout_processor(rows: List[EvaluationRow], config) ->
8196
completion_params=[
8297
{"extra_body": {"reasoning_effort": "low"}, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}
8398
],
84-
rollout_processor=gpqa_strip_gt_rollout_processor,
99+
rollout_processor=GPQAStripGTRolloutProcessor(),
85100
aggregation_method="mean",
86101
passed_threshold=None,
87102
num_runs=8,

eval_protocol/benchmarks/suites/livebench_data_analysis.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from eval_protocol.benchmarks.registry import export_benchmark, register_composite_benchmark
66
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult
77
from eval_protocol.pytest.default_single_turn_rollout_process import (
8-
default_single_turn_rollout_processor,
8+
SingleTurnRolloutProcessor,
99
)
1010
from eval_protocol.pytest.evaluation_test import evaluation_test
1111

@@ -375,7 +375,7 @@ def _extract_gt(row: EvaluationRow) -> Dict[str, Any]:
375375
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
376376
input_messages=[[m for m in r.messages] for r in _CTA_ROWS],
377377
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
378-
rollout_processor=default_single_turn_rollout_processor,
378+
rollout_processor=SingleTurnRolloutProcessor(),
379379
aggregation_method="mean",
380380
passed_threshold=None,
381381
num_runs=4,
@@ -418,7 +418,7 @@ def livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow:
418418
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
419419
input_messages=[[m for m in r.messages] for r in _TABLEJOIN_ROWS],
420420
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
421-
rollout_processor=default_single_turn_rollout_processor,
421+
rollout_processor=SingleTurnRolloutProcessor(),
422422
aggregation_method="mean",
423423
passed_threshold=None,
424424
num_runs=4,
@@ -462,7 +462,7 @@ def livebench_tablejoin_pointwise(row: EvaluationRow) -> EvaluationRow:
462462
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}],
463463
input_messages=[[m for m in r.messages] for r in _TABLEREFORMAT_ROWS],
464464
rollout_processor_kwargs=[{"extra_body": {"reasoning_effort": "low"}}],
465-
rollout_processor=default_single_turn_rollout_processor,
465+
rollout_processor=SingleTurnRolloutProcessor(),
466466
aggregation_method="mean",
467467
passed_threshold=None,
468468
num_runs=4,

eval_protocol/benchmarks/suites/tau_bench_retail.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from eval_protocol.benchmarks.registry import export_benchmark
1414
from eval_protocol.models import EvaluateResult, EvaluationRow, InputMetadata, Message
1515
from eval_protocol.pytest import evaluation_test
16-
from eval_protocol.pytest.default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor
16+
from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
1717
from vendor.tau2.data_model.message import (
1818
AssistantMessage,
1919
SystemMessage,
@@ -73,7 +73,7 @@ def tau_bench_retail_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evalu
7373
"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b",
7474
}
7575
],
76-
rollout_processor=default_mcp_gym_rollout_processor,
76+
rollout_processor=MCPGymRolloutProcessor(),
7777
rollout_processor_kwargs={"domain": "retail"},
7878
num_runs=8,
7979
mode="pointwise",

eval_protocol/mcp/execution/manager.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ class ExecutionManager:
3535
Manage rollout for MCP environments.
3636
"""
3737

38-
async def execute_rollouts(
38+
def execute_rollouts(
3939
self,
4040
envs: "GeneralMCPVectorEnv",
4141
policy: Union["LLMBasePolicy", Callable],
4242
steps: int = 512,
4343
openai_format_log_file: Optional[str] = None,
4444
max_concurrent_rollouts: int = 8,
4545
evaluation_rows: Optional[List[EvaluationRow]] = None,
46-
) -> AsyncIterator[EvaluationRow]:
46+
) -> List[asyncio.Task[EvaluationRow]]:
4747
"""
4848
Execute general rollouts using tool calling interface with automatic record/playback.
4949
@@ -66,7 +66,7 @@ async def execute_rollouts(
6666
- Set and file exists: Playback mode (uses recorded data)
6767
6868
Returns:
69-
AsyncIterator of EvaluationRow objects with unified evaluation data format
69+
List of asyncio.Task objects for external handling
7070
"""
7171
start_time = time.time()
7272

@@ -138,7 +138,7 @@ async def _execute_with_semaphore(idx):
138138
if trajectory.terminated:
139139
if trajectory.termination_reason == TerminationReason.ERROR:
140140
evaluation_row.rollout_status.status = "error"
141-
evaluation_row.rollout_status.error_message = trajectory.control_plane_summary.get(
141+
evaluation_row.rollout_status.termination_reason = trajectory.control_plane_summary.get(
142142
"error_message", None
143143
)
144144
else:
@@ -151,18 +151,7 @@ async def _execute_with_semaphore(idx):
151151

152152
# Create all tasks
153153
tasks = [asyncio.create_task(_execute_with_semaphore(i)) for i in range(envs.n)]
154-
155-
# Yield results as they complete (note that they're not necessarily in original order)
156-
try:
157-
for task in asyncio.as_completed(tasks):
158-
try:
159-
yield await task
160-
except Exception:
161-
logger.exception("Error processing rollout")
162-
finally:
163-
for t in tasks:
164-
t.cancel()
165-
await asyncio.gather(*tasks, return_exceptions=True)
154+
return tasks
166155

167156
async def _execute_rollout(
168157
self,

eval_protocol/mcp_env.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def make(
236236
return mcp_envs
237237

238238

239-
async def rollout(
239+
def rollout(
240240
envs: GeneralMCPVectorEnv,
241241
policy: Union[FireworksPolicy, LLMBasePolicy, Callable],
242242
*,
@@ -246,7 +246,7 @@ async def rollout(
246246
steps: int = 512,
247247
openai_format_log_file: Optional[str] = None,
248248
max_concurrent_rollouts: int = 8,
249-
) -> AsyncIterator[EvaluationRow]:
249+
) -> List[asyncio.Task[EvaluationRow]]:
250250
"""
251251
Execute general rollouts using tool calling interface with automatic record/playback.
252252
@@ -274,14 +274,14 @@ async def rollout(
274274
- Set and file exists: Playback mode (uses recorded data)
275275
276276
Returns:
277-
List of EvaluationRow objects
277+
List of asyncio.Task objects for external handling
278278
279279
Example:
280280
# Live mode
281-
evaluation_rows = await ep.rollout(envs, policy)
281+
tasks = ep.rollout(envs, policy)
282282
283283
# Create environments automatically
284-
trajectories = await ep.rollout(
284+
tasks = ep.rollout(
285285
"http://localhost:8000/mcp/",
286286
policy,
287287
evaluation_rows=my_evaluation_rows,
@@ -290,26 +290,26 @@ async def rollout(
290290
291291
# Recording mode
292292
os.environ["EP_PLAYBACK_FILE"] = "record.jsonl"
293-
evaluation_rows = await ep.rollout(envs, policy, openai_format_log_file="sft_data.jsonl")
293+
tasks = ep.rollout(envs, policy, openai_format_log_file="sft_data.jsonl")
294294
295295
# Playback mode (after recording file exists)
296-
evaluation_rows = await ep.rollout(envs, policy)
296+
tasks = ep.rollout(envs, policy)
297297
"""
298298
# Automatically create environments if a base URL is provided
299299
if isinstance(envs, str):
300300
if evaluation_rows is None and dataset is None:
301301
raise ValueError("Either 'evaluation_rows' or 'dataset' must be provided when envs is a URL")
302302

303303
auto_model_id = model_id or getattr(policy, "model_id", "unknown")
304-
envs = await make(envs, evaluation_rows=evaluation_rows, dataset=dataset, model_id=auto_model_id)
304+
envs = make(envs, evaluation_rows=evaluation_rows, dataset=dataset, model_id=auto_model_id)
305305

306306
# Use the new ExecutionManager for execution
307307
execution_manager = ExecutionManager()
308308

309-
async for evaluation_row in execution_manager.execute_rollouts(
309+
tasks = execution_manager.execute_rollouts(
310310
envs, policy, steps, openai_format_log_file, max_concurrent_rollouts, evaluation_rows
311-
):
312-
yield evaluation_row
311+
)
312+
return tasks
313313

314314

315315
async def test_mcp(base_url: str, seeds: List[int]) -> Dict[str, Any]:
@@ -336,7 +336,7 @@ async def test_mcp(base_url: str, seeds: List[int]) -> Dict[str, Any]:
336336
policy = FireworksPolicy("test-model")
337337

338338
# Run short rollout
339-
evaluation_rows = await rollout(envs, policy=policy, steps=10)
339+
evaluation_rows = rollout(envs, policy=policy, steps=10)
340340

341341
if evaluation_rows and len(evaluation_rows[0].messages) > 1:
342342
results["successful"] += 1

eval_protocol/pytest/__init__.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
1-
from .default_agent_rollout_processor import default_agent_rollout_processor
1+
from .default_agent_rollout_processor import AgentRolloutProcessor
22
from .default_dataset_adapter import default_dataset_adapter
3-
from .default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor
4-
from .default_no_op_rollout_process import default_no_op_rollout_processor
5-
from .default_single_turn_rollout_process import default_single_turn_rollout_processor
3+
from .default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
4+
from .default_no_op_rollout_processor import NoOpRolloutProcessor
5+
from .default_single_turn_rollout_process import SingleTurnRolloutProcessor
66
from .evaluation_test import evaluation_test
7-
from .types import RolloutProcessor, RolloutProcessorConfig
7+
from .rollout_processor import RolloutProcessor
8+
from .types import RolloutProcessorConfig
89

910
__all__ = [
10-
"default_agent_rollout_processor",
11-
"default_mcp_gym_rollout_processor",
12-
"default_no_op_rollout_processor",
13-
"default_single_turn_rollout_processor",
14-
"default_dataset_adapter",
11+
"AgentRolloutProcessor",
12+
"MCPGymRolloutProcessor",
1513
"RolloutProcessor",
14+
"SingleTurnRolloutProcessor",
15+
"NoOpRolloutProcessor",
16+
"default_dataset_adapter",
1617
"RolloutProcessorConfig",
1718
"evaluation_test",
1819
]

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 30 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from eval_protocol.mcp.execution.policy import LiteLLMPolicy
1414
from eval_protocol.mcp.mcp_multi_client import MCPMultiClient
1515
from eval_protocol.models import EvaluationRow, Message
16+
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1617
from eval_protocol.pytest.types import Dataset, RolloutProcessorConfig
1718

1819
logger = logging.getLogger(__name__)
@@ -122,46 +123,36 @@ def _format_tool_message_content(
122123
return [ChatCompletionContentPartTextParam(text=c.text, type="text") for c in content]
123124

124125

125-
async def default_agent_rollout_processor(
126-
rows: List[EvaluationRow], config: RolloutProcessorConfig
127-
) -> AsyncIterator[EvaluationRow]:
128-
"""Process agent rollouts with bounded concurrency and yield as they complete."""
126+
class AgentRolloutProcessor(RolloutProcessor):
127+
"""Agent rollout processor for tool-calling agents."""
129128

130-
max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
131-
semaphore = asyncio.Semaphore(max_concurrent)
129+
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
130+
"""Create agent rollout tasks and return them for external handling."""
132131

133-
async def process_row(row: EvaluationRow) -> EvaluationRow:
134-
"""Process a single row with agent rollout."""
135-
agent = Agent(
136-
model=config.completion_params["model"], row=row, config_path=config.mcp_config_path, logger=config.logger
137-
)
138-
try:
139-
await agent.setup()
140-
await agent.call_agent()
141-
return agent.evaluation_row
142-
finally:
143-
if agent.mcp_client:
144-
await agent.mcp_client.cleanup()
145-
146-
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
147-
async with semaphore:
148-
try:
149-
return await process_row(r)
150-
except Exception as e:
151-
logger.exception(f"Error processing row {r.input_metadata.row_id}: {e}")
152-
return r
153-
154-
# Create all tasks
155-
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
132+
max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
133+
semaphore = asyncio.Semaphore(max_concurrent)
156134

157-
# Yield results as they complete (note that they're not necessarily in original order)
158-
try:
159-
for task in asyncio.as_completed(tasks):
135+
async def process_row(row: EvaluationRow) -> EvaluationRow:
136+
"""Process a single row with agent rollout."""
137+
agent = Agent(
138+
model=config.completion_params["model"],
139+
row=row,
140+
config_path=config.mcp_config_path,
141+
logger=config.logger,
142+
)
160143
try:
161-
yield await task
162-
except Exception:
163-
logger.exception("Error processing row")
164-
finally:
165-
for t in tasks:
166-
t.cancel()
167-
await asyncio.gather(*tasks, return_exceptions=True)
144+
await agent.setup()
145+
await agent.call_agent()
146+
return agent.evaluation_row
147+
finally:
148+
if agent.mcp_client:
149+
await agent.mcp_client.cleanup()
150+
151+
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
152+
async with semaphore:
153+
result = await process_row(r)
154+
return result
155+
156+
# Create and return tasks for external handling
157+
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
158+
return tasks

0 commit comments

Comments
 (0)