Skip to content

Commit 52c0f20

Browse files
author
Dylan Huang
authored
Refactor Agent Rollout Error Handling: Propagate Status, Adopt AIP-193, and Enhance UI Feedback (#114)
* test reproduces error properly * propagate error from rollout status * migrate eval_metadata.status to AIP-193 * ass rolloutstatussection * log the rollout status at the end * show rollout status * vite build * Update EvaluationRow model to set default for pid and simplify TestFunction type definition * fix tests * fix test * fix types for input_rows * fix * include details of error in message * Enhance EvaluationWatcher to update rollout_status and refactor status checking logic * fix tests + add more * fix test
1 parent b813ce7 commit 52c0f20

21 files changed

+555
-126
lines changed

eval_protocol/mcp/mcp_multi_client.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,7 @@ async def connect_to_servers(self):
6969
return
7070

7171
for server_name, server_config in self.config.mcpServers.items():
72-
try:
73-
await self._connect_to_server(server_name, server_config)
74-
except Exception as e:
75-
print(f"Failed to connect to server '{server_name}': {e}")
72+
await self._connect_to_server(server_name, server_config)
7673

7774
async def _connect_to_server(
7875
self, server_name: str, server_config: Union[MCPConfigurationServerStdio, MCPConfigurationServerUrl]

eval_protocol/models.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,29 @@ class Code(int, Enum):
107107
DATA_LOSS = 15
108108
UNAUTHENTICATED = 16
109109

110-
# Custom codes for rollout states (using higher numbers to avoid conflicts)
110+
# Custom codes for EP (using higher numbers to avoid conflicts)
111111
FINISHED = 100
112+
RUNNING = 101
112113

113114
@classmethod
114115
def rollout_running(cls) -> "Status":
115116
"""Create a status indicating the rollout is running."""
116-
return cls(code=cls.Code.OK, message="Rollout is running", details=[])
117+
return cls(code=cls.Code.RUNNING, message="Rollout is running", details=[])
118+
119+
@classmethod
120+
def eval_running(cls) -> "Status":
121+
"""Create a status indicating the evaluation is running."""
122+
return cls(code=cls.Code.RUNNING, message="Evaluation is running", details=[])
123+
124+
@classmethod
125+
def eval_finished(cls) -> "Status":
126+
"""Create a status indicating the evaluation finished."""
127+
return cls(code=cls.Code.FINISHED, message="Evaluation finished", details=[])
128+
129+
@classmethod
130+
def aborted(cls, message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status":
131+
"""Create a status indicating the evaluation was aborted."""
132+
return cls(code=cls.Code.ABORTED, message=message, details=details or [])
117133

118134
@classmethod
119135
def rollout_finished(
@@ -127,7 +143,12 @@ def rollout_finished(
127143
details.append(ErrorInfo.termination_reason(termination_reason).to_aip193_format())
128144
if extra_info:
129145
details.append(ErrorInfo.extra_info(extra_info).to_aip193_format())
130-
return cls(code=cls.Code.FINISHED, message="Rollout finished", details=details)
146+
return cls.finished("Rollout finished", details)
147+
148+
@classmethod
149+
def finished(cls, message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status":
150+
"""Create a status indicating the rollout finished."""
151+
return cls(code=cls.Code.FINISHED, message=message, details=details or [])
131152

132153
@classmethod
133154
def rollout_error(cls, error_message: str, extra_info: Optional[Dict[str, Any]] = None) -> "Status":
@@ -140,11 +161,11 @@ def rollout_error(cls, error_message: str, extra_info: Optional[Dict[str, Any]]
140161
@classmethod
141162
def error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status":
142163
"""Create a status indicating the rollout failed with an error."""
143-
return cls(code=cls.Code.INTERNAL, message=error_message, details=details)
164+
return cls(code=cls.Code.INTERNAL, message=error_message, details=details or [])
144165

145166
def is_running(self) -> bool:
146167
"""Check if the status indicates the rollout is running."""
147-
return self.code == self.Code.OK and self.message == "Rollout is running"
168+
return self.code == self.Code.RUNNING
148169

149170
def is_finished(self) -> bool:
150171
"""Check if the status indicates the rollout finished successfully."""
@@ -436,9 +457,7 @@ class EvalMetadata(BaseModel):
436457
default_factory=get_pep440_version,
437458
description="Version of the evaluation. Should be populated with a PEP 440 version string.",
438459
)
439-
status: Optional[Literal["running", "finished", "error", "stopped"]] = Field(
440-
None, description="Status of the evaluation"
441-
)
460+
status: Optional[Status] = Field(None, description="Status of the evaluation")
442461
num_runs: int = Field(..., description="Number of times the evaluation was repeated")
443462
aggregation_method: str = Field(..., description="Method used to aggregate scores across runs")
444463
passed_threshold: Optional[EvaluationThreshold] = Field(
@@ -527,7 +546,7 @@ class EvaluationRow(BaseModel):
527546
)
528547

529548
pid: Optional[int] = Field(
530-
None,
549+
default=None,
531550
description="The PID of the process that created the row. This is used by the evaluation watcher to detect stopped evaluations.",
532551
)
533552

eval_protocol/pytest/evaluation_test.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@
2121
from eval_protocol.human_id import generate_id, num_combinations
2222
from eval_protocol.models import (
2323
CompletionParams,
24+
ErrorInfo,
2425
EvalMetadata,
2526
EvaluationRow,
2627
EvaluationThreshold,
2728
InputMetadata,
2829
Message,
30+
Status,
2931
)
3032
from eval_protocol.pytest.default_dataset_adapter import default_dataset_adapter
3133
from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor
@@ -36,7 +38,6 @@
3638
EvaluationInputParam,
3739
EvaluationTestMode,
3840
InputMessagesParam,
39-
InputRowsParam,
4041
ModelParam,
4142
RolloutProcessorConfig,
4243
RolloutProcessorInputParam,
@@ -58,6 +59,7 @@
5859
)
5960
from eval_protocol.pytest.exception_config import ExceptionHandlerConfig
6061
from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci
62+
from eval_protocol.types.types import TerminationReason
6163

6264
from ..common_utils import load_jsonl
6365

@@ -240,7 +242,7 @@ def evaluation_test( # noqa: C901
240242
completion_params: List[CompletionParams],
241243
input_messages: Optional[List[InputMessagesParam]] = None,
242244
input_dataset: Optional[List[DatasetPathParam]] = None,
243-
input_rows: Optional[List[InputRowsParam]] = None,
245+
input_rows: Optional[List[EvaluationRow]] = None,
244246
dataset_adapter: Callable[[List[Dict[str, Any]]], Dataset] = default_dataset_adapter,
245247
rollout_processor: RolloutProcessor = NoOpRolloutProcessor(),
246248
evaluation_test_kwargs: Optional[List[EvaluationInputParam]] = None,
@@ -425,7 +427,7 @@ async def execute_with_params(
425427
if mode == "groupwise":
426428
combinations = generate_parameter_combinations(
427429
input_dataset,
428-
None,
430+
completion_params,
429431
input_messages,
430432
input_rows,
431433
evaluation_test_kwargs,
@@ -489,9 +491,7 @@ async def wrapper_body(**kwargs):
489491

490492
experiment_id = generate_id()
491493

492-
def _log_eval_error(
493-
status: Literal["finished", "error"], rows: Optional[List[EvaluationRow]] | None, passed: bool
494-
) -> None:
494+
def _log_eval_error(status: Status, rows: Optional[List[EvaluationRow]] | None, passed: bool) -> None:
495495
log_eval_status_and_rows(eval_metadata, rows, status, passed, active_logger)
496496

497497
try:
@@ -562,7 +562,7 @@ def _log_eval_error(
562562
eval_metadata = EvalMetadata(
563563
name=test_func.__name__,
564564
description=test_func.__doc__,
565-
status="running",
565+
status=Status.eval_running(),
566566
num_runs=num_runs,
567567
aggregation_method=aggregation_method,
568568
passed_threshold=threshold,
@@ -732,7 +732,12 @@ async def _collect_result(config, lst):
732732

733733
for r in results:
734734
if r.eval_metadata is not None:
735-
r.eval_metadata.status = "finished"
735+
if r.rollout_status.is_error():
736+
r.eval_metadata.status = Status.error(
737+
r.rollout_status.message, r.rollout_status.details
738+
)
739+
else:
740+
r.eval_metadata.status = Status.eval_finished()
736741
active_logger.log(r)
737742

738743
# for groupwise mode, the result contains eval otuput from multiple completion_params, we need to differentiate them
@@ -770,14 +775,16 @@ async def _collect_result(config, lst):
770775

771776
except AssertionError:
772777
_log_eval_error(
773-
"finished",
778+
Status.eval_finished(),
774779
processed_rows_in_run if "processed_rows_in_run" in locals() else None,
775780
passed=False,
776781
)
777782
raise
778-
except Exception:
783+
except Exception as e:
779784
_log_eval_error(
780-
"error", processed_rows_in_run if "processed_rows_in_run" in locals() else None, passed=False
785+
Status.error(str(e)),
786+
processed_rows_in_run if "processed_rows_in_run" in locals() else None,
787+
passed=False,
781788
)
782789
raise
783790

eval_protocol/pytest/types.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
ModelParam = str # gpt-4o, gpt-4o-mini, accounts/fireworks/models/llama-3.1-8b-instruct
1616
DatasetPathParam = str
1717
InputMessagesParam = List[Message]
18-
InputRowsParam = List[EvaluationRow]
1918
EvaluationInputParam = Dict[str, Any]
2019
RolloutProcessorInputParam = Dict[str, Any]
2120

@@ -31,7 +30,7 @@
3130
"""
3231
Test function types
3332
"""
34-
TestFunction = Callable[..., Dataset]
33+
TestFunction = Callable
3534

3635
"""
3736
Rollout processor types

eval_protocol/pytest/utils.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
DatasetPathParam,
1414
EvaluationInputParam,
1515
InputMessagesParam,
16-
InputRowsParam,
1716
RolloutProcessorConfig,
1817
)
1918
from eval_protocol.pytest.exception_config import ExceptionHandlerConfig, get_default_exception_handler_config
@@ -115,7 +114,7 @@ async def wrapper(**kwargs):
115114
def log_eval_status_and_rows(
116115
eval_metadata: Optional[EvalMetadata],
117116
rows: Optional[List[EvaluationRow]] | None,
118-
status: Literal["finished", "error"],
117+
status: Status,
119118
passed: bool,
120119
logger: DatasetLogger,
121120
) -> None:
@@ -185,7 +184,7 @@ def generate_parameter_combinations(
185184
input_dataset: Optional[List[DatasetPathParam]],
186185
completion_params: List[CompletionParams],
187186
input_messages: Optional[List[InputMessagesParam]],
188-
input_rows: Optional[List[InputRowsParam]],
187+
input_rows: Optional[List[EvaluationRow]],
189188
evaluation_test_kwargs: Optional[List[EvaluationInputParam]],
190189
max_dataset_rows: Optional[int],
191190
combine_datasets: bool,
@@ -341,12 +340,20 @@ async def execute_row_with_backoff(task: asyncio.Task, row: EvaluationRow) -> Ev
341340
else:
342341
# Non-retryable exception - fail immediately
343342
logging.error(f"❌ Rollout failed (non-retryable error encountered): {repr(e)}")
344-
row.rollout_status = Status.rollout_error(str(e))
343+
row.rollout_status = Status.rollout_error(repr(e))
345344
return row
346345

346+
async def execute_row_with_backoff_and_log(task: asyncio.Task, row: EvaluationRow) -> EvaluationRow:
347+
"""Execute a single row task with backoff retry and logging."""
348+
result = await execute_row_with_backoff(task, row)
349+
# Log the row after execution completes (success or failure)
350+
config.logger.log(result)
351+
return result
352+
347353
# Process all tasks concurrently with backoff retry
348354
retry_tasks = [
349-
asyncio.create_task(execute_row_with_backoff(task, fresh_dataset[i])) for i, task in enumerate(base_tasks)
355+
asyncio.create_task(execute_row_with_backoff_and_log(task, fresh_dataset[i]))
356+
for i, task in enumerate(base_tasks)
350357
]
351358

352359
# Yield results as they complete

eval_protocol/utils/logs_server.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from eval_protocol.dataset_logger import default_logger
1616
from eval_protocol.dataset_logger.dataset_logger import LOG_EVENT_TYPE
1717
from eval_protocol.event_bus import event_bus
18+
from eval_protocol.models import Status
1819
from eval_protocol.utils.vite_server import ViteServer
1920

2021
if TYPE_CHECKING:
@@ -178,8 +179,17 @@ def _check_running_evaluations(self):
178179
for row in logs:
179180
if self._should_update_status(row):
180181
logger.info(f"Updating status to 'stopped' for row {row.input_metadata.row_id} (PID {row.pid})")
181-
if row.eval_metadata is not None:
182-
row.eval_metadata.status = "stopped"
182+
183+
# Update eval_metadata.status if it's running
184+
if row.eval_metadata and row.eval_metadata.status and row.eval_metadata.status.is_running():
185+
row.eval_metadata.status = Status.aborted(
186+
f"Evaluation aborted since process {row.pid} stopped"
187+
)
188+
189+
# Update rollout_status if it's running
190+
if row.rollout_status and row.rollout_status.is_running():
191+
row.rollout_status = Status.aborted(f"Rollout aborted since process {row.pid} stopped")
192+
183193
updated_rows.append(row)
184194

185195
# Log all updated rows
@@ -193,11 +203,18 @@ def _check_running_evaluations(self):
193203

194204
def _should_update_status(self, row: "EvaluationRow") -> bool:
195205
"""Check if a row's status should be updated to 'stopped'."""
196-
# Check if the row has running status and a PID
197-
if row.eval_metadata and row.eval_metadata.status == "running" and row.pid is not None:
206+
# Check if any status field should be updated
207+
return self._should_update_status_field(
208+
row.eval_metadata.status if row.eval_metadata else None, row.pid
209+
) or self._should_update_status_field(row.rollout_status, row.pid)
210+
211+
def _should_update_status_field(self, status: Optional["Status"], pid: Optional[int]) -> bool:
212+
"""Check if a specific status field should be updated to 'stopped'."""
213+
# Check if the status is running and there's a PID
214+
if status and status.is_running() and pid is not None:
198215
# Check if the process is still running
199216
try:
200-
process = psutil.Process(row.pid)
217+
process = psutil.Process(pid)
201218
# Check if process is still running
202219
if not process.is_running():
203220
return True
@@ -206,10 +223,10 @@ def _should_update_status(self, row: "EvaluationRow") -> bool:
206223
return True
207224
except psutil.AccessDenied:
208225
# Can't access process info, assume it's stopped
209-
logger.warning(f"Access denied to process {row.pid}, assuming stopped")
226+
logger.warning(f"Access denied to process {pid}, assuming stopped")
210227
return True
211228
except Exception as e:
212-
logger.error(f"Error checking process {row.pid}: {e}")
229+
logger.error(f"Error checking process {pid}: {e}")
213230
# On error, assume process is still running to be safe
214231
return False
215232

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"mcpServers": {
3+
"docs.fireworks.ai": {
4+
"url": "https://docs.fireworks.ai/mcp-non-existent"
5+
}
6+
}
7+
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from typing import Set
2+
from eval_protocol.models import EvaluationRow, Message
3+
from eval_protocol.pytest.default_agent_rollout_processor import AgentRolloutProcessor
4+
from eval_protocol.dataset_logger import DatasetLogger
5+
6+
7+
class TrackingLogger(DatasetLogger):
8+
"""Custom logger that ensures that the final row is in an error state."""
9+
10+
def __init__(self, rollouts: dict[str, EvaluationRow]):
11+
self.rollouts = rollouts
12+
13+
def log(self, row: EvaluationRow):
14+
self.rollouts[row.execution_metadata.rollout_id] = row
15+
16+
def read(self):
17+
return []
18+
19+
20+
async def test_pytest_propagate_error():
21+
"""
22+
Properly propagate errors from rollout processing to eval_metadata.status.
23+
To test this, we use a broken MCP configuration that should fail during the
24+
rollout processing. Then the final eval_metadata.status should be an error.
25+
This way the UI can properly render an error state for the rollout and a
26+
developer can identify and investigate the error.
27+
"""
28+
from eval_protocol.pytest.evaluation_test import evaluation_test
29+
30+
input_messages = [
31+
[
32+
Message(
33+
role="system",
34+
content="You are a helpful assistant that can answer questions about Fireworks.",
35+
),
36+
]
37+
]
38+
completion_params_list = [
39+
{"model": "dummy/local-model"},
40+
]
41+
42+
rollouts: dict[str, EvaluationRow] = {}
43+
logger = TrackingLogger(rollouts)
44+
45+
@evaluation_test(
46+
input_messages=input_messages,
47+
completion_params=completion_params_list,
48+
rollout_processor=AgentRolloutProcessor(),
49+
mode="pointwise",
50+
num_runs=5,
51+
mcp_config_path="tests/pytest/mcp_configurations/docs_mcp_config_broken.json",
52+
logger=logger,
53+
)
54+
def eval_fn(row: EvaluationRow) -> EvaluationRow:
55+
return row
56+
57+
# Manually invoke all parameter combinations within a single test
58+
for params in completion_params_list:
59+
await eval_fn(input_messages=input_messages, completion_params=params)
60+
61+
# assert that the status of eval_metadata.status is "error"
62+
assert len(rollouts) == 5
63+
assert all(row.eval_metadata.status.is_error() for row in rollouts.values())
64+
65+
# make sure the error message includes details of the error
66+
assert all("HTTPStatusError" in row.rollout_status.message for row in rollouts.values())
67+
assert all("405 Method Not Allowed" in row.rollout_status.message for row in rollouts.values())
68+
assert all("https://docs.fireworks.ai/mcp-non-existent" in row.rollout_status.message for row in rollouts.values())

0 commit comments

Comments
 (0)