Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions eval_protocol/mcp/mcp_multi_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,7 @@ async def connect_to_servers(self):
return

for server_name, server_config in self.config.mcpServers.items():
try:
await self._connect_to_server(server_name, server_config)
except Exception as e:
print(f"Failed to connect to server '{server_name}': {e}")
await self._connect_to_server(server_name, server_config)

async def _connect_to_server(
self, server_name: str, server_config: Union[MCPConfigurationServerStdio, MCPConfigurationServerUrl]
Expand Down
37 changes: 28 additions & 9 deletions eval_protocol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,29 @@ class Code(int, Enum):
DATA_LOSS = 15
UNAUTHENTICATED = 16

# Custom codes for rollout states (using higher numbers to avoid conflicts)
# Custom codes for EP (using higher numbers to avoid conflicts)
FINISHED = 100
RUNNING = 101

@classmethod
def rollout_running(cls) -> "Status":
"""Create a status indicating the rollout is running."""
return cls(code=cls.Code.OK, message="Rollout is running", details=[])
return cls(code=cls.Code.RUNNING, message="Rollout is running", details=[])

@classmethod
def eval_running(cls) -> "Status":
"""Create a status indicating the evaluation is running."""
return cls(code=cls.Code.RUNNING, message="Evaluation is running", details=[])

@classmethod
def eval_finished(cls) -> "Status":
"""Create a status indicating the evaluation finished."""
return cls(code=cls.Code.FINISHED, message="Evaluation finished", details=[])

@classmethod
def aborted(cls, message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status":
"""Create a status indicating the evaluation was aborted."""
return cls(code=cls.Code.ABORTED, message=message, details=details or [])

@classmethod
def rollout_finished(
Expand All @@ -127,7 +143,12 @@ def rollout_finished(
details.append(ErrorInfo.termination_reason(termination_reason).to_aip193_format())
if extra_info:
details.append(ErrorInfo.extra_info(extra_info).to_aip193_format())
return cls(code=cls.Code.FINISHED, message="Rollout finished", details=details)
return cls.finished("Rollout finished", details)

@classmethod
def finished(cls, message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status":
"""Create a status indicating the rollout finished."""
return cls(code=cls.Code.FINISHED, message=message, details=details or [])

@classmethod
def rollout_error(cls, error_message: str, extra_info: Optional[Dict[str, Any]] = None) -> "Status":
Expand All @@ -140,11 +161,11 @@ def rollout_error(cls, error_message: str, extra_info: Optional[Dict[str, Any]]
@classmethod
def error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status":
"""Create a status indicating the rollout failed with an error."""
return cls(code=cls.Code.INTERNAL, message=error_message, details=details)
return cls(code=cls.Code.INTERNAL, message=error_message, details=details or [])

def is_running(self) -> bool:
"""Check if the status indicates the rollout is running."""
return self.code == self.Code.OK and self.message == "Rollout is running"
return self.code == self.Code.RUNNING

def is_finished(self) -> bool:
"""Check if the status indicates the rollout finished successfully."""
Expand Down Expand Up @@ -436,9 +457,7 @@ class EvalMetadata(BaseModel):
default_factory=get_pep440_version,
description="Version of the evaluation. Should be populated with a PEP 440 version string.",
)
status: Optional[Literal["running", "finished", "error", "stopped"]] = Field(
None, description="Status of the evaluation"
)
status: Optional[Status] = Field(None, description="Status of the evaluation")
num_runs: int = Field(..., description="Number of times the evaluation was repeated")
aggregation_method: str = Field(..., description="Method used to aggregate scores across runs")
passed_threshold: Optional[EvaluationThreshold] = Field(
Expand Down Expand Up @@ -527,7 +546,7 @@ class EvaluationRow(BaseModel):
)

pid: Optional[int] = Field(
None,
default=None,
description="The PID of the process that created the row. This is used by the evaluation watcher to detect stopped evaluations.",
)

Expand Down
29 changes: 18 additions & 11 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
from eval_protocol.human_id import generate_id, num_combinations
from eval_protocol.models import (
CompletionParams,
ErrorInfo,
EvalMetadata,
EvaluationRow,
EvaluationThreshold,
InputMetadata,
Message,
Status,
)
from eval_protocol.pytest.default_dataset_adapter import default_dataset_adapter
from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor
Expand All @@ -36,7 +38,6 @@
EvaluationInputParam,
EvaluationTestMode,
InputMessagesParam,
InputRowsParam,
ModelParam,
RolloutProcessorConfig,
RolloutProcessorInputParam,
Expand All @@ -58,6 +59,7 @@
)
from eval_protocol.pytest.exception_config import ExceptionHandlerConfig
from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci
from eval_protocol.types.types import TerminationReason

from ..common_utils import load_jsonl

Expand Down Expand Up @@ -240,7 +242,7 @@ def evaluation_test( # noqa: C901
completion_params: List[CompletionParams],
input_messages: Optional[List[InputMessagesParam]] = None,
input_dataset: Optional[List[DatasetPathParam]] = None,
input_rows: Optional[List[InputRowsParam]] = None,
input_rows: Optional[List[EvaluationRow]] = None,
dataset_adapter: Callable[[List[Dict[str, Any]]], Dataset] = default_dataset_adapter,
rollout_processor: RolloutProcessor = NoOpRolloutProcessor(),
evaluation_test_kwargs: Optional[List[EvaluationInputParam]] = None,
Expand Down Expand Up @@ -425,7 +427,7 @@ async def execute_with_params(
if mode == "groupwise":
combinations = generate_parameter_combinations(
input_dataset,
None,
completion_params,
input_messages,
input_rows,
evaluation_test_kwargs,
Expand Down Expand Up @@ -489,9 +491,7 @@ async def wrapper_body(**kwargs):

experiment_id = generate_id()

def _log_eval_error(
status: Literal["finished", "error"], rows: Optional[List[EvaluationRow]] | None, passed: bool
) -> None:
def _log_eval_error(status: Status, rows: Optional[List[EvaluationRow]] | None, passed: bool) -> None:
log_eval_status_and_rows(eval_metadata, rows, status, passed, active_logger)

try:
Expand Down Expand Up @@ -562,7 +562,7 @@ def _log_eval_error(
eval_metadata = EvalMetadata(
name=test_func.__name__,
description=test_func.__doc__,
status="running",
status=Status.eval_running(),
num_runs=num_runs,
aggregation_method=aggregation_method,
passed_threshold=threshold,
Expand Down Expand Up @@ -732,7 +732,12 @@ async def _collect_result(config, lst):

for r in results:
if r.eval_metadata is not None:
r.eval_metadata.status = "finished"
if r.rollout_status.is_error():
r.eval_metadata.status = Status.error(
r.rollout_status.message, r.rollout_status.details
)
else:
r.eval_metadata.status = Status.eval_finished()
active_logger.log(r)

# for groupwise mode, the result contains eval otuput from multiple completion_params, we need to differentiate them
Expand Down Expand Up @@ -770,14 +775,16 @@ async def _collect_result(config, lst):

except AssertionError:
_log_eval_error(
"finished",
Status.eval_finished(),
processed_rows_in_run if "processed_rows_in_run" in locals() else None,
passed=False,
)
raise
except Exception:
except Exception as e:
_log_eval_error(
"error", processed_rows_in_run if "processed_rows_in_run" in locals() else None, passed=False
Status.error(str(e)),
processed_rows_in_run if "processed_rows_in_run" in locals() else None,
passed=False,
)
raise

Expand Down
3 changes: 1 addition & 2 deletions eval_protocol/pytest/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
ModelParam = str # gpt-4o, gpt-4o-mini, accounts/fireworks/models/llama-3.1-8b-instruct
DatasetPathParam = str
InputMessagesParam = List[Message]
InputRowsParam = List[EvaluationRow]
EvaluationInputParam = Dict[str, Any]
RolloutProcessorInputParam = Dict[str, Any]

Expand All @@ -31,7 +30,7 @@
"""
Test function types
"""
TestFunction = Callable[..., Dataset]
TestFunction = Callable

"""
Rollout processor types
Expand Down
17 changes: 12 additions & 5 deletions eval_protocol/pytest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
DatasetPathParam,
EvaluationInputParam,
InputMessagesParam,
InputRowsParam,
RolloutProcessorConfig,
)
from eval_protocol.pytest.exception_config import ExceptionHandlerConfig, get_default_exception_handler_config
Expand Down Expand Up @@ -115,7 +114,7 @@ async def wrapper(**kwargs):
def log_eval_status_and_rows(
eval_metadata: Optional[EvalMetadata],
rows: Optional[List[EvaluationRow]] | None,
status: Literal["finished", "error"],
status: Status,
passed: bool,
logger: DatasetLogger,
) -> None:
Expand Down Expand Up @@ -185,7 +184,7 @@ def generate_parameter_combinations(
input_dataset: Optional[List[DatasetPathParam]],
completion_params: List[CompletionParams],
input_messages: Optional[List[InputMessagesParam]],
input_rows: Optional[List[InputRowsParam]],
input_rows: Optional[List[EvaluationRow]],
evaluation_test_kwargs: Optional[List[EvaluationInputParam]],
max_dataset_rows: Optional[int],
combine_datasets: bool,
Expand Down Expand Up @@ -341,12 +340,20 @@ async def execute_row_with_backoff(task: asyncio.Task, row: EvaluationRow) -> Ev
else:
# Non-retryable exception - fail immediately
logging.error(f"❌ Rollout failed (non-retryable error encountered): {repr(e)}")
row.rollout_status = Status.rollout_error(str(e))
row.rollout_status = Status.rollout_error(repr(e))
return row

async def execute_row_with_backoff_and_log(task: asyncio.Task, row: EvaluationRow) -> EvaluationRow:
"""Execute a single row task with backoff retry and logging."""
result = await execute_row_with_backoff(task, row)
# Log the row after execution completes (success or failure)
config.logger.log(result)
return result

# Process all tasks concurrently with backoff retry
retry_tasks = [
asyncio.create_task(execute_row_with_backoff(task, fresh_dataset[i])) for i, task in enumerate(base_tasks)
asyncio.create_task(execute_row_with_backoff_and_log(task, fresh_dataset[i]))
for i, task in enumerate(base_tasks)
]

# Yield results as they complete
Expand Down
31 changes: 24 additions & 7 deletions eval_protocol/utils/logs_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from eval_protocol.dataset_logger import default_logger
from eval_protocol.dataset_logger.dataset_logger import LOG_EVENT_TYPE
from eval_protocol.event_bus import event_bus
from eval_protocol.models import Status
from eval_protocol.utils.vite_server import ViteServer

if TYPE_CHECKING:
Expand Down Expand Up @@ -178,8 +179,17 @@ def _check_running_evaluations(self):
for row in logs:
if self._should_update_status(row):
logger.info(f"Updating status to 'stopped' for row {row.input_metadata.row_id} (PID {row.pid})")
if row.eval_metadata is not None:
row.eval_metadata.status = "stopped"

# Update eval_metadata.status if it's running
if row.eval_metadata and row.eval_metadata.status and row.eval_metadata.status.is_running():
row.eval_metadata.status = Status.aborted(
f"Evaluation aborted since process {row.pid} stopped"
)

# Update rollout_status if it's running
if row.rollout_status and row.rollout_status.is_running():
row.rollout_status = Status.aborted(f"Rollout aborted since process {row.pid} stopped")

updated_rows.append(row)

# Log all updated rows
Expand All @@ -193,11 +203,18 @@ def _check_running_evaluations(self):

def _should_update_status(self, row: "EvaluationRow") -> bool:
"""Check if a row's status should be updated to 'stopped'."""
# Check if the row has running status and a PID
if row.eval_metadata and row.eval_metadata.status == "running" and row.pid is not None:
# Check if any status field should be updated
return self._should_update_status_field(
row.eval_metadata.status if row.eval_metadata else None, row.pid
) or self._should_update_status_field(row.rollout_status, row.pid)

def _should_update_status_field(self, status: Optional["Status"], pid: Optional[int]) -> bool:
"""Check if a specific status field should be updated to 'stopped'."""
# Check if the status is running and there's a PID
if status and status.is_running() and pid is not None:
# Check if the process is still running
try:
process = psutil.Process(row.pid)
process = psutil.Process(pid)
# Check if process is still running
if not process.is_running():
return True
Expand All @@ -206,10 +223,10 @@ def _should_update_status(self, row: "EvaluationRow") -> bool:
return True
except psutil.AccessDenied:
# Can't access process info, assume it's stopped
logger.warning(f"Access denied to process {row.pid}, assuming stopped")
logger.warning(f"Access denied to process {pid}, assuming stopped")
return True
except Exception as e:
logger.error(f"Error checking process {row.pid}: {e}")
logger.error(f"Error checking process {pid}: {e}")
# On error, assume process is still running to be safe
return False

Expand Down
7 changes: 7 additions & 0 deletions tests/pytest/mcp_configurations/docs_mcp_config_broken.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"mcpServers": {
"docs.fireworks.ai": {
"url": "https://docs.fireworks.ai/mcp-non-existent"
}
}
}
68 changes: 68 additions & 0 deletions tests/pytest/test_pytest_propagate_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Set
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.default_agent_rollout_processor import AgentRolloutProcessor
from eval_protocol.dataset_logger import DatasetLogger


class TrackingLogger(DatasetLogger):
"""Custom logger that ensures that the final row is in an error state."""

def __init__(self, rollouts: dict[str, EvaluationRow]):
self.rollouts = rollouts

def log(self, row: EvaluationRow):
self.rollouts[row.execution_metadata.rollout_id] = row

def read(self):
return []


async def test_pytest_propagate_error():
"""
Properly propagate errors from rollout processing to eval_metadata.status.
To test this, we use a broken MCP configuration that should fail during the
rollout processing. Then the final eval_metadata.status should be an error.
This way the UI can properly render an error state for the rollout and a
developer can identify and investigate the error.
"""
from eval_protocol.pytest.evaluation_test import evaluation_test

input_messages = [
[
Message(
role="system",
content="You are a helpful assistant that can answer questions about Fireworks.",
),
]
]
completion_params_list = [
{"model": "dummy/local-model"},
]

rollouts: dict[str, EvaluationRow] = {}
logger = TrackingLogger(rollouts)

@evaluation_test(
input_messages=input_messages,
completion_params=completion_params_list,
rollout_processor=AgentRolloutProcessor(),
mode="pointwise",
num_runs=5,
mcp_config_path="tests/pytest/mcp_configurations/docs_mcp_config_broken.json",
logger=logger,
)
def eval_fn(row: EvaluationRow) -> EvaluationRow:
return row

# Manually invoke all parameter combinations within a single test
for params in completion_params_list:
await eval_fn(input_messages=input_messages, completion_params=params)

# assert that the status of eval_metadata.status is "error"
assert len(rollouts) == 5
assert all(row.eval_metadata.status.is_error() for row in rollouts.values())

# make sure the error message includes details of the error
assert all("HTTPStatusError" in row.rollout_status.message for row in rollouts.values())
assert all("405 Method Not Allowed" in row.rollout_status.message for row in rollouts.values())
assert all("https://docs.fireworks.ai/mcp-non-existent" in row.rollout_status.message for row in rollouts.values())
Loading
Loading