Skip to content

Commit d330ee0

Browse files
author
Dylan Huang
committed
fix test_pytest_propagate_error
1 parent 67b40d7 commit d330ee0

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

tests/pytest/test_pytest_propagate_error.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
1-
from typing import Set
1+
from typing_extensions import override
22
from eval_protocol.models import EvaluationRow, Message
33
from eval_protocol.pytest.default_agent_rollout_processor import AgentRolloutProcessor
4-
from eval_protocol.dataset_logger import DatasetLogger
4+
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
55

66

77
class TrackingLogger(DatasetLogger):
88
"""Custom logger that ensures that the final row is in an error state."""
99

1010
def __init__(self, rollouts: dict[str, EvaluationRow]):
11-
self.rollouts = rollouts
11+
self.rollouts: dict[str, EvaluationRow] = rollouts
1212

13+
@override
1314
def log(self, row: EvaluationRow):
15+
if row.execution_metadata.rollout_id is None:
16+
raise ValueError("Rollout ID is None")
1417
self.rollouts[row.execution_metadata.rollout_id] = row
1518

16-
def read(self):
19+
@override
20+
def read(self, row_id: str | None = None) -> list[EvaluationRow]:
1721
return []
1822

1923

@@ -56,11 +60,16 @@ def eval_fn(row: EvaluationRow) -> EvaluationRow:
5660

5761
# Manually invoke all parameter combinations within a single test
5862
for params in completion_params_list:
59-
await eval_fn(input_messages=input_messages, completion_params=params)
63+
await eval_fn(input_messages=input_messages[0], completion_params=params) # pyright: ignore[reportCallIssue]
6064

6165
# assert that the status of eval_metadata.status is "error"
6266
assert len(rollouts) == 5
63-
assert all(row.eval_metadata.status.is_error() for row in rollouts.values())
67+
for row in rollouts.values():
68+
if row.eval_metadata is None:
69+
raise ValueError("Row has no eval_metadata")
70+
if row.eval_metadata.status is None:
71+
raise ValueError("Eval metadata has no status")
72+
assert row.eval_metadata.status.is_error()
6473

6574
# make sure the error message includes details of the error
6675
assert all("HTTPStatusError" in row.rollout_status.message for row in rollouts.values())

0 commit comments

Comments
 (0)