Skip to content

Commit 7943cdf

Browse files
committed
test fix
1 parent 65abb77 commit 7943cdf

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -335,11 +335,14 @@ def evaluation_test( # noqa: C901
335335

336336
active_logger: DatasetLogger = logger if logger else default_logger
337337

338-
# Apply override from pytest flags if present
338+
# Optional global overrides via environment for ad-hoc experimentation
339+
# EP_INPUT_PARAMS_JSON can contain a JSON object that will be deep-merged
340+
# into input_params (e.g., '{"temperature":0,"extra_body":{"reasoning":{"effort":"low"}}}').
339341
num_runs = parse_ep_num_runs(num_runs)
340342
max_concurrent_rollouts = parse_ep_max_concurrent_rollouts(max_concurrent_rollouts)
341343
max_dataset_rows = parse_ep_max_rows(max_dataset_rows)
342344
completion_params = parse_ep_completion_params(completion_params)
345+
original_completion_params = completion_params
343346

344347
def decorator(
345348
test_func: TestFunction,
@@ -646,7 +649,7 @@ async def _execute_eval_with_semaphore(**inner_kwargs):
646649
row_groups = defaultdict(list) # key: row_id, value: list of rollout_result
647650
tasks: List[asyncio.Task[List[EvaluationRow]]] = []
648651
# completion_groups = []
649-
for idx, cp in enumerate(completion_params):
652+
for idx, cp in enumerate(original_completion_params):
650653
config = RolloutProcessorConfig(
651654
completion_params=cp,
652655
mcp_config_path=mcp_config_path or "",
@@ -728,7 +731,9 @@ async def _collect_result(config, lst):
728731
# for groupwise mode, the result contains eval otuput from multiple completion_params, we need to differentiate them
729732
# rollout_id is used to differentiate the result from different completion_params
730733
if mode == "groupwise":
731-
results_by_group = [[[] for _ in range(num_runs)] for _ in range(len(completion_params))]
734+
results_by_group = [
735+
[[] for _ in range(num_runs)] for _ in range(len(original_completion_params))
736+
]
732737
for i_run, result in enumerate(all_results):
733738
for r in result:
734739
completion_param_idx = int(r.execution_metadata.rollout_id.split("_")[1])
@@ -740,7 +745,7 @@ async def _collect_result(config, lst):
740745
threshold,
741746
active_logger,
742747
mode,
743-
completion_params[rollout_id],
748+
original_completion_params[rollout_id],
744749
test_func.__name__,
745750
num_runs,
746751
)

0 commit comments

Comments
 (0)