Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 9 additions & 21 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
parse_ep_max_rows,
parse_ep_max_concurrent_rollouts,
parse_ep_num_runs,
parse_ep_completion_params,
rollout_processor_with_retry,
sanitize_filename,
)
Expand Down Expand Up @@ -334,10 +335,14 @@ def evaluation_test( # noqa: C901

active_logger: DatasetLogger = logger if logger else default_logger

# Apply override from pytest flags if present
# Optional global overrides via environment for ad-hoc experimentation
# EP_INPUT_PARAMS_JSON can contain a JSON object that will be deep-merged
# into input_params (e.g., '{"temperature":0,"extra_body":{"reasoning":{"effort":"low"}}}').
num_runs = parse_ep_num_runs(num_runs)
max_concurrent_rollouts = parse_ep_max_concurrent_rollouts(max_concurrent_rollouts)
max_dataset_rows = parse_ep_max_rows(max_dataset_rows)
completion_params = parse_ep_completion_params(completion_params)
original_completion_params = completion_params

def decorator(
test_func: TestFunction,
Expand Down Expand Up @@ -420,9 +425,6 @@ async def execute_with_params(
else:
return test_func(**kwargs)

# preserve the original completion_params list for groupwise mode
original_completion_params_list = completion_params

# Calculate all possible combinations of parameters
if mode == "groupwise":
combinations = generate_parameter_combinations(
Expand Down Expand Up @@ -544,20 +546,6 @@ def _log_eval_error(status: Status, rows: Optional[List[EvaluationRow]] | None,
"No model provided. Please provide a model in the completion parameters object."
)

# Optional global overrides via environment for ad-hoc experimentation
# EP_INPUT_PARAMS_JSON can contain a JSON object that will be deep-merged
# into input_params (e.g., '{"temperature":0,"extra_body":{"reasoning":{"effort":"low"}}}').
try:
import json as _json

_env_override = os.getenv("EP_INPUT_PARAMS_JSON")
if _env_override:
override_obj = _json.loads(_env_override)
if isinstance(override_obj, dict):
completion_params = deep_update_dict(dict(completion_params), override_obj)
except Exception:
pass

# Create eval metadata with test function info and current commit hash
eval_metadata = EvalMetadata(
name=test_func.__name__,
Expand Down Expand Up @@ -661,7 +649,7 @@ async def _execute_eval_with_semaphore(**inner_kwargs):
row_groups = defaultdict(list) # key: row_id, value: list of rollout_result
tasks: List[asyncio.Task[List[EvaluationRow]]] = []
# completion_groups = []
for idx, cp in enumerate(original_completion_params_list):
for idx, cp in enumerate(original_completion_params):
config = RolloutProcessorConfig(
completion_params=cp,
mcp_config_path=mcp_config_path or "",
Expand Down Expand Up @@ -744,7 +732,7 @@ async def _collect_result(config, lst):
# rollout_id is used to differentiate the result from different completion_params
if mode == "groupwise":
results_by_group = [
[[] for _ in range(num_runs)] for _ in range(len(original_completion_params_list))
[[] for _ in range(num_runs)] for _ in range(len(original_completion_params))
]
for i_run, result in enumerate(all_results):
for r in result:
Expand All @@ -757,7 +745,7 @@ async def _collect_result(config, lst):
threshold,
active_logger,
mode,
original_completion_params_list[rollout_id],
original_completion_params[rollout_id],
test_func.__name__,
num_runs,
)
Expand Down
19 changes: 19 additions & 0 deletions eval_protocol/pytest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,25 @@ def parse_ep_max_concurrent_rollouts(default_value: int) -> int:
return int(raw) if raw is not None else default_value


def parse_ep_completion_params(completion_params: List[CompletionParams]) -> List[CompletionParams]:
"""Apply EP_INPUT_PARAMS_JSON overrides to completion_params.

Reads the environment variable set by plugin.py and applies deep merge to each completion param.
"""
try:
import json as _json

_env_override = os.getenv("EP_INPUT_PARAMS_JSON")
if _env_override:
override_obj = _json.loads(_env_override)
if isinstance(override_obj, dict):
# Apply override to each completion_params item
return [deep_update_dict(dict(cp), override_obj) for cp in completion_params]
except Exception:
pass
return completion_params


def deep_update_dict(base: dict, override: dict) -> dict:
"""Recursively update nested dictionaries in-place and return base."""
for key, value in override.items():
Expand Down
Loading