Skip to content

Commit 9046804

Browse files
committed
auto no prefix needed
1 parent cf4fc4e commit 9046804

File tree

2 files changed

+43
-22
lines changed

2 files changed

+43
-22
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,12 @@
2020
EvaluationRow,
2121
EvaluationThreshold,
2222
EvaluationThresholdDict,
23-
EvaluateResult,
2423
Status,
2524
EPParameters,
2625
)
2726
from eval_protocol.pytest.dual_mode_wrapper import create_dual_mode_wrapper
2827
from eval_protocol.pytest.evaluation_test_postprocess import postprocess
29-
from eval_protocol.pytest.execution import execute_pytest, execute_pytest_with_exception_handling
28+
from eval_protocol.pytest.execution import execute_pytest_with_exception_handling
3029
from eval_protocol.pytest.priority_scheduler import execute_priority_rollouts
3130
from eval_protocol.pytest.generate_parameter_combinations import (
3231
ParameterizedTestKwargs,
@@ -56,6 +55,7 @@
5655
AggregationMethod,
5756
add_cost_metrics,
5857
log_eval_status_and_rows,
58+
normalize_fireworks_model,
5959
parse_ep_completion_params,
6060
parse_ep_completion_params_overwrite,
6161
parse_ep_max_concurrent_rollouts,
@@ -205,6 +205,7 @@ def evaluation_test(
205205
max_dataset_rows = parse_ep_max_rows(max_dataset_rows)
206206
completion_params = parse_ep_completion_params(completion_params)
207207
completion_params = parse_ep_completion_params_overwrite(completion_params)
208+
completion_params = [normalize_fireworks_model(cp) for cp in completion_params]
208209
original_completion_params = completion_params
209210
passed_threshold = parse_ep_passed_threshold(passed_threshold)
210211
data_loaders = parse_ep_dataloaders(data_loaders)
@@ -409,21 +410,22 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
409410

410411
rollout_processor.setup()
411412

412-
use_priority_scheduler = (
413-
(
414-
os.environ.get("EP_USE_PRIORITY_SCHEDULER", "0") == "1"
415-
and not isinstance(rollout_processor, MCPGymRolloutProcessor)
416-
)
417-
)
413+
use_priority_scheduler = os.environ.get(
414+
"EP_USE_PRIORITY_SCHEDULER", "0"
415+
) == "1" and not isinstance(rollout_processor, MCPGymRolloutProcessor)
418416

419417
if use_priority_scheduler:
420418
microbatch_output_size = os.environ.get("EP_MICRO_BATCH_OUTPUT_SIZE", None)
421419
output_dir = os.environ.get("EP_OUTPUT_DIR", None)
422420
if microbatch_output_size and output_dir:
423-
output_buffer = MicroBatchDataBuffer(num_runs=num_runs, batch_size=int(microbatch_output_size), output_path_template=os.path.join(output_dir, "buffer_{index}.jsonl"))
421+
output_buffer = MicroBatchDataBuffer(
422+
num_runs=num_runs,
423+
batch_size=int(microbatch_output_size),
424+
output_path_template=os.path.join(output_dir, "buffer_{index}.jsonl"),
425+
)
424426
else:
425427
output_buffer = None
426-
428+
427429
try:
428430
priority_results = await execute_priority_rollouts(
429431
dataset=data,
@@ -441,12 +443,12 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
441443
finally:
442444
if output_buffer:
443445
await output_buffer.close()
444-
446+
445447
for res in priority_results:
446448
run_idx = (res.execution_metadata.extra or {}).get("run_index", 0)
447449
if run_idx < len(all_results):
448450
all_results[run_idx].append(res)
449-
451+
450452
processed_rows_in_run.append(res)
451453

452454
postprocess(
@@ -462,6 +464,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
462464
)
463465

464466
else:
467+
465468
async def execute_run(run_idx: int, config: RolloutProcessorConfig):
466469
nonlocal all_results
467470

@@ -506,9 +509,7 @@ async def _execute_pointwise_eval_with_semaphore(
506509
raise ValueError(
507510
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
508511
)
509-
result.execution_metadata.eval_duration_seconds = (
510-
time.perf_counter() - start_time
511-
)
512+
result.execution_metadata.eval_duration_seconds = time.perf_counter() - start_time
512513
return result
513514

514515
async def _execute_groupwise_eval_with_semaphore(
@@ -519,7 +520,9 @@ async def _execute_groupwise_eval_with_semaphore(
519520
evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {}
520521
primary_rollout_id = rows[0].execution_metadata.rollout_id if rows else None
521522
group_rollout_ids = [
522-
r.execution_metadata.rollout_id for r in rows if r.execution_metadata.rollout_id
523+
r.execution_metadata.rollout_id
524+
for r in rows
525+
if r.execution_metadata.rollout_id
523526
]
524527
async with rollout_logging_context(
525528
primary_rollout_id or "",
@@ -596,7 +599,9 @@ async def _collect_result(config, lst):
596599
row_groups[row.input_metadata.row_id].append(row)
597600
tasks = []
598601
for _, rows in row_groups.items():
599-
tasks.append(asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows)))
602+
tasks.append(
603+
asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows))
604+
)
600605
results = []
601606
for task in tasks:
602607
res = await task
@@ -692,9 +697,9 @@ async def _collect_result(config, lst):
692697
# For other processors, create all tasks at once and run in parallel
693698
# Concurrency is now controlled by the shared semaphore in each rollout processor
694699
await run_tasks_with_run_progress(execute_run, num_runs, config)
695-
700+
696701
experiment_duration_seconds = time.perf_counter() - experiment_start_time
697-
702+
698703
# for groupwise mode, the result contains eval output from multiple completion_params, we need to differentiate them
699704
# rollout_id is used to differentiate the result from different completion_params
700705
if mode == "groupwise":
@@ -730,15 +735,12 @@ async def _collect_result(config, lst):
730735
experiment_duration_seconds,
731736
)
732737

733-
734-
735738
if not all(r.evaluation_result is not None for run_results in all_results for r in run_results):
736739
raise AssertionError(
737740
"Some EvaluationRow instances are missing evaluation_result. "
738741
"Your @evaluation_test function must set `row.evaluation_result`"
739742
)
740743

741-
742744
except AssertionError:
743745
_log_eval_error(
744746
Status.eval_finished(),

eval_protocol/pytest/evaluation_test_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,3 +619,22 @@ def build_rollout_processor_config(
619619
server_script_path=None,
620620
kwargs=rollout_processor_kwargs,
621621
)
622+
623+
624+
def normalize_fireworks_model(completion_params: CompletionParams | None) -> CompletionParams | None:
625+
"""Fireworks model names like 'accounts/<org>/models/<model>' need the fireworks_ai/
626+
prefix when routing through LiteLLM. This function adds the prefix if missing.
627+
"""
628+
if completion_params is None:
629+
return None
630+
631+
model = completion_params.get("model")
632+
if (
633+
model
634+
and isinstance(model, str)
635+
and not model.startswith("fireworks_ai/")
636+
and re.match(r"^accounts/[^/]+/models/.+", model)
637+
):
638+
completion_params = completion_params.copy()
639+
completion_params["model"] = f"fireworks_ai/{model}"
640+
return completion_params

0 commit comments

Comments
 (0)