Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def evaluation_test(
aggregation_method: AggregationMethod = "mean",
passed_threshold: EvaluationThreshold | float | EvaluationThresholdDict | None = None,
num_runs: int = 1,
filtered_row_ids: Sequence[str] | None = None,
max_dataset_rows: int | None = None,
mcp_config_path: str | None = None,
max_concurrent_rollouts: int = 8,
Expand Down Expand Up @@ -146,6 +147,7 @@ def evaluation_test(
Success rate must be above success, and if set, standard error must be below standard_error.
Success rate +/- one standard_error is equivalent to 68% confidence interval.
num_runs: Number of times to repeat the rollout and evaluations.
filtered_row_ids: List of row_ids to filter for the evaluation. If provided, only the rows with the given row_ids will be evaluated.
max_dataset_rows: Limit dataset to the first N rows.
mcp_config_path: Path to MCP config file that follows MCPMultiClientConfiguration schema
max_concurrent_rollouts: Maximum number of concurrent rollouts to run in parallel.
Expand Down Expand Up @@ -286,6 +288,9 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
else:
raise ValueError("No input dataset, input messages, or input rows provided")

if filtered_row_ids is not None:
data = [row for row in data if row.input_metadata.row_id in filtered_row_ids]

"""
data_loaders handles preprocess_fn internally so we want
to specially handle data_loaders here so we don't double
Expand Down
6 changes: 6 additions & 0 deletions eval_protocol/pytest/remote_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,13 @@ def _get_status() -> Dict[str, Any]:
except Exception:
# transient errors; continue polling
pass

await asyncio.sleep(poll_interval)
else:
# Loop completed without breaking, which means we timed out
row.rollout_status = Status.rollout_error(
f"Rollout {row.execution_metadata.rollout_id} timed out after {timeout_seconds} seconds"
)

# Update duration, regardless of termination
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
Expand Down
8 changes: 7 additions & 1 deletion eval_protocol/types/remote_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from eval_protocol.models import Message
from eval_protocol.models import Message, Status


class RolloutMetadata(BaseModel):
Expand Down Expand Up @@ -40,6 +40,12 @@ class StatusResponse(BaseModel):
terminated: bool
info: Optional[Dict[str, Any]] = None

status: Optional[Status] = None
"""
Optional status indicator for the rollout to be used by eval-protocol. This
is useful to distinguish between successful and failed rollouts.
"""


def create_langfuse_config_tags(init_request: InitRequest) -> List[str]:
"""Create Langfuse tags from InitRequest metadata."""
Expand Down
Loading