Skip to content

Commit d20f83e

Browse files
author
Dylan Huang
committed
fix
1 parent 099bd48 commit d20f83e

File tree

5 files changed

+23
-16
lines changed

5 files changed

+23
-16
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def evaluation_test(
6363
*,
6464
completion_params: Sequence[CompletionParams | None] | None = None,
6565
input_messages: Sequence[InputMessagesParam | None] | None = None,
66-
input_dataset: list[DatasetPathParam] | None = None,
66+
input_dataset: Sequence[DatasetPathParam] | None = None,
6767
input_rows: Sequence[list[EvaluationRow]] | None = None,
6868
dataset_adapter: Callable[[list[dict[str, Any]]], Dataset] = default_dataset_adapter, # pyright: ignore[reportExplicitAny]
6969
rollout_processor: RolloutProcessor | None = None,

eval_protocol/pytest/generate_parameter_combinations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class ParameterizedTestKwargs(TypedDict):
4545

4646

4747
def generate_parameter_combinations(
48-
input_dataset: list[DatasetPathParam] | None,
48+
input_dataset: Sequence[DatasetPathParam] | None,
4949
completion_params: Sequence[CompletionParams | None],
5050
input_messages: Sequence[InputMessagesParam | None] | None,
5151
input_rows: Sequence[list[EvaluationRow] | None] | None,
@@ -73,7 +73,7 @@ def generate_parameter_combinations(
7373
datasets: Sequence[list[DatasetPathParam] | None] = [None]
7474
if input_dataset is not None:
7575
if combine_datasets:
76-
datasets = [input_dataset]
76+
datasets = [list(input_dataset)]
7777
else:
7878
# Fan out: one dataset path per parameterization
7979
datasets = [[p] for p in input_dataset]

eval_protocol/pytest/parameterize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class PytestParametrizeArgs(TypedDict):
1616

1717
def pytest_parametrize(
1818
combinations: list[CombinationTuple],
19-
input_dataset: list[DatasetPathParam] | None,
19+
input_dataset: Sequence[DatasetPathParam] | None,
2020
completion_params: Sequence[CompletionParams | None] | None,
2121
input_messages: Sequence[InputMessagesParam | None] | None,
2222
input_rows: Sequence[list[EvaluationRow]] | None,

tests/pytest/test_pydantic_agent.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
import pytest
32

43
from eval_protocol.models import EvaluationRow, Message
@@ -12,7 +11,7 @@
1211

1312
@pytest.mark.asyncio
1413
@evaluation_test(
15-
input_messages=[Message(role="user", content="Hello, how are you?")],
14+
input_messages=[[Message(role="user", content="Hello, how are you?")]],
1615
completion_params=[
1716
{"model": "accounts/fireworks/models/gpt-oss-120b", "provider": "fireworks"},
1817
],

tests/pytest/test_pytest_assertion_error_no_new_rollouts.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import List, Set
21
import asyncio
2+
from typing import Any
3+
from typing_extensions import override
34

45
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
56
from eval_protocol.models import EvaluationRow
@@ -11,14 +12,17 @@
1112
class TrackingRolloutProcessor(RolloutProcessor):
1213
"""Custom rollout processor that tracks which rollout IDs are generated during rollout phase."""
1314

14-
def __init__(self, shared_rollout_ids: Set[str]):
15-
self.shared_rollout_ids = shared_rollout_ids
15+
def __init__(self, shared_rollout_ids: set[str]):
16+
self.shared_rollout_ids: set[str] = shared_rollout_ids
1617

17-
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
18+
@override
19+
def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) -> list[asyncio.Task[EvaluationRow]]:
1820
"""Process rows and track rollout IDs generated during rollout phase."""
1921

2022
async def process_row(row: EvaluationRow) -> EvaluationRow:
2123
# Track this rollout ID as being generated during rollout phase
24+
if row.execution_metadata.rollout_id is None:
25+
raise ValueError("Rollout ID is None")
2226
self.shared_rollout_ids.add(row.execution_metadata.rollout_id)
2327
return row
2428

@@ -30,13 +34,17 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
3034
class TrackingLogger(DatasetLogger):
3135
"""Custom logger that tracks all rollout IDs that are logged."""
3236

33-
def __init__(self, shared_rollout_ids: Set[str]):
34-
self.shared_rollout_ids = shared_rollout_ids
37+
def __init__(self, shared_rollout_ids: set[str]):
38+
self.shared_rollout_ids: set[str] = shared_rollout_ids
3539

40+
@override
3641
def log(self, row: EvaluationRow):
42+
if row.execution_metadata.rollout_id is None:
43+
raise ValueError("Rollout ID is None")
3744
self.shared_rollout_ids.add(row.execution_metadata.rollout_id)
3845

39-
def read(self):
46+
@override
47+
def read(self, row_id: str | None = None) -> list[EvaluationRow]:
4048
return []
4149

4250

@@ -48,7 +56,7 @@ async def test_assertion_error_no_new_rollouts():
4856
from eval_protocol.pytest.evaluation_test import evaluation_test
4957

5058
# Create shared set to track rollout IDs generated during rollout phase
51-
shared_rollout_ids: Set[str] = set()
59+
shared_rollout_ids: set[str] = set()
5260

5361
# Create custom processor and logger for tracking with shared set
5462
rollout_processor = TrackingRolloutProcessor(shared_rollout_ids)
@@ -57,7 +65,7 @@ async def test_assertion_error_no_new_rollouts():
5765
input_dataset: list[str] = [
5866
"tests/pytest/data/markdown_dataset.jsonl",
5967
]
60-
completion_params: list[dict] = [{"temperature": 0.0, "model": "dummy/local-model"}]
68+
completion_params: list[dict[str, Any]] = [{"temperature": 0.0, "model": "dummy/local-model"}] # pyright: ignore[reportExplicitAny]
6169

6270
@evaluation_test(
6371
input_dataset=input_dataset,
@@ -81,7 +89,7 @@ def eval_fn(row: EvaluationRow) -> EvaluationRow:
8189
# This should fail due to threshold not being met
8290
for ds_path in input_dataset:
8391
for completion_param in completion_params:
84-
await eval_fn(dataset_path=ds_path, completion_params=completion_param)
92+
await eval_fn(dataset_path=[ds_path], completion_params=completion_param) # pyright: ignore[reportCallIssue]
8593
except AssertionError:
8694
# Expected - the threshold check should fail
8795
pass

0 commit comments

Comments
 (0)