diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 2cacc09f..2dd4dbc0 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -54,6 +54,7 @@ parse_ep_max_rows, parse_ep_max_concurrent_rollouts, parse_ep_num_runs, + parse_ep_completion_params, rollout_processor_with_retry, sanitize_filename, ) @@ -334,10 +335,14 @@ def evaluation_test( # noqa: C901 active_logger: DatasetLogger = logger if logger else default_logger - # Apply override from pytest flags if present + # Optional global overrides via environment for ad-hoc experimentation + # EP_INPUT_PARAMS_JSON can contain a JSON object that will be deep-merged + # into input_params (e.g., '{"temperature":0,"extra_body":{"reasoning":{"effort":"low"}}}'). num_runs = parse_ep_num_runs(num_runs) max_concurrent_rollouts = parse_ep_max_concurrent_rollouts(max_concurrent_rollouts) max_dataset_rows = parse_ep_max_rows(max_dataset_rows) + completion_params = parse_ep_completion_params(completion_params) + original_completion_params = completion_params def decorator( test_func: TestFunction, @@ -420,9 +425,6 @@ async def execute_with_params( else: return test_func(**kwargs) - # preserve the original completion_params list for groupwise mode - original_completion_params_list = completion_params - # Calculate all possible combinations of parameters if mode == "groupwise": combinations = generate_parameter_combinations( @@ -544,20 +546,6 @@ def _log_eval_error(status: Status, rows: Optional[List[EvaluationRow]] | None, "No model provided. Please provide a model in the completion parameters object." ) - # Optional global overrides via environment for ad-hoc experimentation - # EP_INPUT_PARAMS_JSON can contain a JSON object that will be deep-merged - # into input_params (e.g., '{"temperature":0,"extra_body":{"reasoning":{"effort":"low"}}}'). - try: - import json as _json - - _env_override = os.getenv("EP_INPUT_PARAMS_JSON") - if _env_override: - override_obj = _json.loads(_env_override) - if isinstance(override_obj, dict): - completion_params = deep_update_dict(dict(completion_params), override_obj) - except Exception: - pass - # Create eval metadata with test function info and current commit hash eval_metadata = EvalMetadata( name=test_func.__name__, @@ -661,7 +649,7 @@ async def _execute_eval_with_semaphore(**inner_kwargs): row_groups = defaultdict(list) # key: row_id, value: list of rollout_result tasks: List[asyncio.Task[List[EvaluationRow]]] = [] # completion_groups = [] - for idx, cp in enumerate(original_completion_params_list): + for idx, cp in enumerate(original_completion_params): config = RolloutProcessorConfig( completion_params=cp, mcp_config_path=mcp_config_path or "", @@ -744,7 +732,7 @@ async def _collect_result(config, lst): # rollout_id is used to differentiate the result from different completion_params if mode == "groupwise": results_by_group = [ - [[] for _ in range(num_runs)] for _ in range(len(original_completion_params_list)) + [[] for _ in range(num_runs)] for _ in range(len(original_completion_params)) ] for i_run, result in enumerate(all_results): for r in result: @@ -757,7 +745,7 @@ async def _collect_result(config, lst): threshold, active_logger, mode, - original_completion_params_list[rollout_id], + original_completion_params[rollout_id], test_func.__name__, num_runs, ) diff --git a/eval_protocol/pytest/utils.py b/eval_protocol/pytest/utils.py index c24fbdc9..bad097a7 100644 --- a/eval_protocol/pytest/utils.py +++ b/eval_protocol/pytest/utils.py @@ -170,6 +170,25 @@ def parse_ep_max_concurrent_rollouts(default_value: int) -> int: return int(raw) if raw is not None else default_value +def parse_ep_completion_params(completion_params: List[CompletionParams]) -> List[CompletionParams]: + """Apply EP_INPUT_PARAMS_JSON overrides to completion_params. + + Reads the environment variable set by plugin.py and applies deep merge to each completion param. + """ + try: + import json as _json + + _env_override = os.getenv("EP_INPUT_PARAMS_JSON") + if _env_override: + override_obj = _json.loads(_env_override) + if isinstance(override_obj, dict): + # Apply override to each completion_params item + return [deep_update_dict(dict(cp), override_obj) for cp in completion_params] + except Exception: + pass + return completion_params + + def deep_update_dict(base: dict, override: dict) -> dict: """Recursively update nested dictionaries in-place and return base.""" for key, value in override.items():