Skip to content

Commit afb347d

Browse files
committed
support overwrite
1 parent b42b3e1 commit afb347d

File tree

4 files changed

+85
-0
lines changed

4 files changed

+85
-0
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,12 @@
5252
add_cost_metrics,
5353
log_eval_status_and_rows,
5454
parse_ep_completion_params,
55+
parse_ep_completion_params_overwrite,
5556
parse_ep_max_concurrent_rollouts,
5657
parse_ep_max_rows,
5758
parse_ep_num_runs,
5859
parse_ep_passed_threshold,
60+
parse_ep_dataloaders,
5961
rollout_processor_with_retry,
6062
run_tasks_with_eval_progress,
6163
run_tasks_with_run_progress,
@@ -187,10 +189,18 @@ def evaluation_test(
187189
max_concurrent_rollouts = parse_ep_max_concurrent_rollouts(max_concurrent_rollouts)
188190
max_dataset_rows = parse_ep_max_rows(max_dataset_rows)
189191
completion_params = parse_ep_completion_params(completion_params)
192+
completion_params = parse_ep_completion_params_overwrite(completion_params)
190193
original_completion_params = completion_params
191194
passed_threshold = parse_ep_passed_threshold(passed_threshold)
195+
data_loaders = parse_ep_dataloaders(data_loaders)
192196
custom_invocation_id = os.environ.get("EP_INVOCATION_ID", None)
193197

198+
# dataloader might be overridden here, to avoid conflict, manually unset other data input params
199+
if data_loaders:
200+
input_dataset = None
201+
input_messages = None
202+
input_rows = None
203+
194204
def decorator(
195205
test_func: TestFunction,
196206
) -> TestFunction:

eval_protocol/pytest/plugin.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pathlib
2020
import sys
2121
from pytest import StashKey
22+
import pytest
2223

2324

2425
def pytest_addoption(parser) -> None:
@@ -56,6 +57,7 @@ def pytest_addoption(parser) -> None:
5657
default=None,
5758
help=("Write a JSON summary artifact at the given path (e.g., ./outputs/aime_low.json)."),
5859
)
60+
# deprecate this later
5961
group.addoption(
6062
"--ep-input-param",
6163
action="append",
@@ -115,6 +117,27 @@ def pytest_addoption(parser) -> None:
115117
"Default: false (experiment JSONs are saved and uploaded by default)."
116118
),
117119
)
120+
group.addoption(
121+
"--ep-jsonl-path",
122+
default=None,
123+
help=("Load input from a jsonl file that is already in EvaluationRow or openai CHAT format")
124+
)
125+
group.addoption(
126+
"--ep-completion-params",
127+
default=[],
128+
action="append",
129+
help=(
130+
"Overwrite completion params with json. Can be used multiple times. "
131+
),
132+
)
133+
group.addoption(
134+
"--ep-remote-rollout-processor-base-url",
135+
default=None,
136+
help=(
137+
"If set, use this base URL for remote rollout processing. "
138+
"Example: http://localhost:8000"
139+
),
140+
)
118141

119142

120143
def _normalize_max_rows(val: Optional[str]) -> Optional[str]:
@@ -243,6 +266,18 @@ def pytest_configure(config) -> None:
243266
if config.getoption("--ep-no-upload"):
244267
os.environ["EP_NO_UPLOAD"] = "1"
245268

269+
if config.getoption("--ep-jsonl-path"):
270+
os.environ["EP_JSONL_PATH"] = config.getoption("--ep-jsonl-path")
271+
272+
if config.getoption("--ep-completion-params"):
273+
# redump to json to make sure they are legit
274+
os.environ["EP_COMPLETION_PARAMS"] = json.dumps([
275+
json.loads(s) for s in config.getoption("--ep-completion-params") or []
276+
])
277+
278+
if config.getoption("--ep-remote-rollout-processor-base-url"):
279+
os.environ["EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL"] = config.getoption("--ep-remote-rollout-processor-base-url")
280+
246281
# Allow ad-hoc overrides of input params via CLI flags
247282
try:
248283
merged: dict = {}

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
99
from .rollout_processor import RolloutProcessor
1010
from .types import RolloutProcessorConfig
11+
import os
1112

1213

1314
class RemoteRolloutProcessor(RolloutProcessor):
@@ -46,6 +47,8 @@ def __init__(
4647
# Prefer constructor-provided configuration. These can be overridden via
4748
# config.kwargs at call time for backward compatibility.
4849
self._remote_base_url = remote_base_url
50+
if os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL"):
51+
self._remote_base_url = os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL")
4952
self._poll_interval = poll_interval
5053
self._timeout_seconds = timeout_seconds
5154
self._output_data_loader = output_data_loader

eval_protocol/pytest/utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
EvaluationThresholdDict,
2020
Status,
2121
)
22+
from eval_protocol.data_loader import DynamicDataLoader
23+
from eval_protocol.data_loader.models import EvaluationDataLoader
2224
from eval_protocol.pytest.rollout_processor import RolloutProcessor
2325
from eval_protocol.pytest.types import (
2426
RolloutProcessorConfig,
@@ -238,6 +240,41 @@ def parse_ep_completion_params(
238240
pass
239241
return completion_params
240242

243+
def parse_ep_completion_params_overwrite(completion_params: Sequence[CompletionParams | None] | None) -> Sequence[CompletionParams | None]:
244+
new_completion_params = os.getenv("EP_COMPLETION_PARAMS")
245+
if new_completion_params:
246+
try:
247+
new_completion_params_list = json.loads(new_completion_params)
248+
if isinstance(new_completion_params_list, list):
249+
return new_completion_params_list
250+
except Exception:
251+
pass
252+
return completion_params or []
253+
254+
def _rows_from_jsonl(path: str) -> list[EvaluationRow]:
255+
rows = []
256+
try:
257+
with open(path, "r", encoding="utf-8") as f:
258+
for line in f:
259+
rows.append(EvaluationRow(**json.loads(line)))
260+
except Exception as e:
261+
print(f"❌ Failed to load rows from JSONL at {path}: {e}")
262+
return []
263+
264+
return rows
265+
266+
def parse_ep_dataloaders(
267+
dataloaders: Sequence[EvaluationDataLoader] | EvaluationDataLoader | None,
268+
) -> Sequence[EvaluationDataLoader] | EvaluationDataLoader | None:
269+
try:
270+
load_from_jsonl_path = os.getenv("EP_JSONL_PATH")
271+
if load_from_jsonl_path:
272+
return DynamicDataLoader(
273+
generators=[lambda path=load_from_jsonl_path: _rows_from_jsonl(path)])
274+
except Exception:
275+
pass
276+
return dataloaders or None
277+
241278

242279
def parse_ep_passed_threshold(
243280
default_value: float | EvaluationThresholdDict | EvaluationThreshold | None,

0 commit comments

Comments
 (0)