Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,12 @@
add_cost_metrics,
log_eval_status_and_rows,
parse_ep_completion_params,
parse_ep_completion_params_overwrite,
parse_ep_max_concurrent_rollouts,
parse_ep_max_rows,
parse_ep_num_runs,
parse_ep_passed_threshold,
parse_ep_dataloaders,
rollout_processor_with_retry,
run_tasks_with_eval_progress,
run_tasks_with_run_progress,
Expand Down Expand Up @@ -189,10 +191,18 @@ def evaluation_test(
max_concurrent_rollouts = parse_ep_max_concurrent_rollouts(max_concurrent_rollouts)
max_dataset_rows = parse_ep_max_rows(max_dataset_rows)
completion_params = parse_ep_completion_params(completion_params)
completion_params = parse_ep_completion_params_overwrite(completion_params)
original_completion_params = completion_params
passed_threshold = parse_ep_passed_threshold(passed_threshold)
data_loaders = parse_ep_dataloaders(data_loaders)
custom_invocation_id = os.environ.get("EP_INVOCATION_ID", None)

# ignore other data input params when dataloader is provided
if data_loaders:
input_dataset = None
input_messages = None
input_rows = None

def decorator(
test_func: TestFunction,
) -> TestFunction:
Expand Down
27 changes: 27 additions & 0 deletions eval_protocol/pytest/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pathlib
import sys
from pytest import StashKey
import pytest


def pytest_addoption(parser) -> None:
Expand Down Expand Up @@ -56,6 +57,7 @@ def pytest_addoption(parser) -> None:
default=None,
help=("Write a JSON summary artifact at the given path (e.g., ./outputs/aime_low.json)."),
)
# deprecate this later
group.addoption(
"--ep-input-param",
action="append",
Expand Down Expand Up @@ -115,6 +117,22 @@ def pytest_addoption(parser) -> None:
"Default: false (experiment JSONs are saved and uploaded by default)."
),
)
group.addoption(
"--ep-jsonl-path",
default=None,
help=("Load input from a jsonl file that is already in EvaluationRow or openai CHAT format"),
)
group.addoption(
"--ep-completion-params",
default=[],
action="append",
help=("Overwrite completion params with json. Can be used multiple times. "),
)
group.addoption(
"--ep-remote-rollout-processor-base-url",
default=None,
help=("If set, use this base URL for remote rollout processing. Example: http://localhost:8000"),
)


def _normalize_max_rows(val: Optional[str]) -> Optional[str]:
Expand Down Expand Up @@ -243,6 +261,15 @@ def pytest_configure(config) -> None:
if config.getoption("--ep-no-upload"):
os.environ["EP_NO_UPLOAD"] = "1"

if config.getoption("--ep-jsonl-path"):
os.environ["EP_JSONL_PATH"] = config.getoption("--ep-jsonl-path")

if config.getoption("--ep-completion-params"):
# redump to json to make sure they are legit
os.environ["EP_COMPLETION_PARAMS"] = json.dumps(
[json.loads(s) for s in config.getoption("--ep-completion-params") or []]
)

# Allow ad-hoc overrides of input params via CLI flags
try:
merged: dict = {}
Expand Down
4 changes: 3 additions & 1 deletion eval_protocol/pytest/remote_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from eval_protocol.types.remote_rollout_processor import InitRequest, RolloutMetadata
from .rollout_processor import RolloutProcessor
from .types import RolloutProcessorConfig
import os


class RemoteRolloutProcessor(RolloutProcessor):
Expand All @@ -30,7 +31,8 @@
# Prefer constructor-provided configuration. These can be overridden via
# config.kwargs at call time for backward compatibility.
self._remote_base_url = remote_base_url
self._model_base_url = model_base_url
if os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL"):
self._remote_base_url = os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL")
self._poll_interval = poll_interval
self._timeout_seconds = timeout_seconds
self._output_data_loader = output_data_loader
Expand All @@ -40,7 +42,7 @@

# Start with constructor values
remote_base_url: Optional[str] = self._remote_base_url
model_base_url: Optional[str] = self._model_base_url

Check failure on line 45 in eval_protocol/pytest/remote_rollout_processor.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Cannot access attribute "_model_base_url" for class "RemoteRolloutProcessor*"   Attribute "_model_base_url" is unknown (reportAttributeAccessIssue)
poll_interval: float = self._poll_interval
timeout_seconds: float = self._timeout_seconds

Expand Down
41 changes: 41 additions & 0 deletions eval_protocol/pytest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
EvaluationThresholdDict,
Status,
)
from eval_protocol.data_loader import DynamicDataLoader
from eval_protocol.data_loader.models import EvaluationDataLoader
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import (
RolloutProcessorConfig,
Expand Down Expand Up @@ -239,6 +241,45 @@ def parse_ep_completion_params(
return completion_params


def parse_ep_completion_params_overwrite(
completion_params: Sequence[CompletionParams | None] | None,
) -> Sequence[CompletionParams | None]:
new_completion_params = os.getenv("EP_COMPLETION_PARAMS")
if new_completion_params:
try:
new_completion_params_list = json.loads(new_completion_params)
if isinstance(new_completion_params_list, list):
return new_completion_params_list
except Exception:
pass
return completion_params or []


def _rows_from_jsonl(path: str) -> list[EvaluationRow]:
rows = []
try:
with open(path, "r", encoding="utf-8") as f:
for line in f:
rows.append(EvaluationRow(**json.loads(line)))
except Exception as e:
print(f"❌ Failed to load rows from JSONL at {path}: {e}")
return []

return rows


def parse_ep_dataloaders(
dataloaders: Sequence[EvaluationDataLoader] | EvaluationDataLoader | None,
) -> Sequence[EvaluationDataLoader] | EvaluationDataLoader | None:
try:
load_from_jsonl_path = os.getenv("EP_JSONL_PATH")
if load_from_jsonl_path:
return DynamicDataLoader(generators=[lambda path=load_from_jsonl_path: _rows_from_jsonl(path)])
except Exception:
pass
return dataloaders or None


def parse_ep_passed_threshold(
default_value: float | EvaluationThresholdDict | EvaluationThreshold | None,
) -> EvaluationThreshold | None:
Expand Down
40 changes: 40 additions & 0 deletions tests/pytest/test_pytest_env_overwrite.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import atexit
import shutil
import tempfile
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest import evaluation_test
from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor
Expand All @@ -18,3 +21,40 @@ def test_input_messages_in_decorator(row: EvaluationRow) -> EvaluationRow:
assert row.messages[0].content == "What is the capital of France?"
assert row.execution_metadata.invocation_id == "test-invocation-123"
return row


with mock.patch.dict(os.environ, {"EP_COMPLETION_PARAMS": '[{"model": "gpt-40"}]'}):

@evaluation_test(
input_rows=[[EvaluationRow(messages=[Message(role="user", content="What is 5 * 6?")])]],
completion_params=[{"model": "no-op"}], # This should be overridden by the env var
rollout_processor=NoOpRolloutProcessor(),
mode="pointwise",
)
def test_input_messages_in_env(row: EvaluationRow) -> EvaluationRow:
"""Run math evaluation on sample dataset using pytest interface."""
assert row.messages[0].content == "What is 5 * 6?"
assert row.input_metadata.completion_params["model"] == "gpt-40"
return row


_jsonl_tmpdir = tempfile.mkdtemp()
atexit.register(shutil.rmtree, _jsonl_tmpdir, ignore_errors=True)

input_path = os.path.join(_jsonl_tmpdir, "input.jsonl")
with open(input_path, "w") as f:
f.write(
'{"messages": [{"role": "user", "content": "What is 10 / 2?"}], "input_metadata": {"some_key": "some_value"}}\n'
)
print(f"finish prepare input file {input_path}")
with mock.patch.dict(os.environ, {"EP_JSONL_PATH": input_path}):

@evaluation_test(
input_rows=[[EvaluationRow(messages=[Message(role="user", content="This will be ignored")])]],
completion_params=[{"model": "no-op"}],
rollout_processor=NoOpRolloutProcessor(),
mode="pointwise",
)
def test_input_override(row: EvaluationRow) -> EvaluationRow:
assert row.messages[0].content == "What is 10 / 2?"
return row
Loading