diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 2cf8ce0c..5ebd0dfa 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -79,6 +79,7 @@ def evaluation_test( aggregation_method: AggregationMethod = "mean", passed_threshold: EvaluationThreshold | float | EvaluationThresholdDict | None = None, num_runs: int = 1, + filtered_row_ids: Sequence[str] | None = None, max_dataset_rows: int | None = None, mcp_config_path: str | None = None, max_concurrent_rollouts: int = 8, @@ -146,6 +147,7 @@ def evaluation_test( Success rate must be above success, and if set, standard error must be below standard_error. Success rate +/- one standard_error is equivalent to 68% confidence interval. num_runs: Number of times to repeat the rollout and evaluations. + filtered_row_ids: List of row_ids to filter for the evaluation. If provided, only the rows with the given row_ids will be evaluated. max_dataset_rows: Limit dataset to the first N rows. mcp_config_path: Path to MCP config file that follows MCPMultiClientConfiguration schema max_concurrent_rollouts: Maximum number of concurrent rollouts to run in parallel. @@ -286,6 +288,9 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo else: raise ValueError("No input dataset, input messages, or input rows provided") + if filtered_row_ids is not None: + data = [row for row in data if row.input_metadata.row_id in filtered_row_ids] + """ data_loaders handles preprocess_fn internally so we want to specially handle data_loaders here so we don't double diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index d6e563a4..7201e688 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -155,7 +155,13 @@ def _get_status() -> Dict[str, Any]: except Exception: # transient errors; continue polling pass + await asyncio.sleep(poll_interval) + else: + # Loop completed without breaking, which means we timed out + row.rollout_status = Status.rollout_error( + f"Rollout {row.execution_metadata.rollout_id} timed out after {timeout_seconds} seconds" + ) # Update duration, regardless of termination row.execution_metadata.duration_seconds = time.perf_counter() - start_time diff --git a/eval_protocol/types/remote_rollout_processor.py b/eval_protocol/types/remote_rollout_processor.py index 21a821e0..817d1c3f 100644 --- a/eval_protocol/types/remote_rollout_processor.py +++ b/eval_protocol/types/remote_rollout_processor.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field -from eval_protocol.models import Message +from eval_protocol.models import Message, Status class RolloutMetadata(BaseModel): @@ -40,6 +40,12 @@ class StatusResponse(BaseModel): terminated: bool info: Optional[Dict[str, Any]] = None + status: Optional[Status] = None + """ + Optional status indicator for the rollout to be used by eval-protocol. This + is useful to distinguish between successful and failed rollouts. + """ + def create_langfuse_config_tags(init_request: InitRequest) -> List[str]: """Create Langfuse tags from InitRequest metadata."""