|
1 | | -from typing import Set |
| 1 | +from typing_extensions import override |
2 | 2 | from eval_protocol.models import EvaluationRow, Message |
3 | 3 | from eval_protocol.pytest.default_agent_rollout_processor import AgentRolloutProcessor |
4 | | -from eval_protocol.dataset_logger import DatasetLogger |
| 4 | +from eval_protocol.dataset_logger.dataset_logger import DatasetLogger |
5 | 5 |
|
6 | 6 |
|
7 | 7 | class TrackingLogger(DatasetLogger): |
8 | 8 | """Custom logger that ensures that the final row is in an error state.""" |
9 | 9 |
|
10 | 10 | def __init__(self, rollouts: dict[str, EvaluationRow]): |
11 | | - self.rollouts = rollouts |
| 11 | + self.rollouts: dict[str, EvaluationRow] = rollouts |
12 | 12 |
|
| 13 | + @override |
13 | 14 | def log(self, row: EvaluationRow): |
| 15 | + if row.execution_metadata.rollout_id is None: |
| 16 | + raise ValueError("Rollout ID is None") |
14 | 17 | self.rollouts[row.execution_metadata.rollout_id] = row |
15 | 18 |
|
16 | | - def read(self): |
| 19 | + @override |
| 20 | + def read(self, row_id: str | None = None) -> list[EvaluationRow]: |
17 | 21 | return [] |
18 | 22 |
|
19 | 23 |
|
@@ -56,11 +60,16 @@ def eval_fn(row: EvaluationRow) -> EvaluationRow: |
56 | 60 |
|
57 | 61 | # Manually invoke all parameter combinations within a single test |
58 | 62 | for params in completion_params_list: |
59 | | - await eval_fn(input_messages=input_messages, completion_params=params) |
| 63 | + await eval_fn(input_messages=input_messages[0], completion_params=params) # pyright: ignore[reportCallIssue] |
60 | 64 |
|
61 | 65 | # assert that the status of eval_metadata.status is "error" |
62 | 66 | assert len(rollouts) == 5 |
63 | | - assert all(row.eval_metadata.status.is_error() for row in rollouts.values()) |
| 67 | + for row in rollouts.values(): |
| 68 | + if row.eval_metadata is None: |
| 69 | + raise ValueError("Row has no eval_metadata") |
| 70 | + if row.eval_metadata.status is None: |
| 71 | + raise ValueError("Eval metadata has no status") |
| 72 | + assert row.eval_metadata.status.is_error() |
64 | 73 |
|
65 | 74 | # make sure the error message includes details of the error |
66 | 75 | assert all("HTTPStatusError" in row.rollout_status.message for row in rollouts.values()) |
|
0 commit comments