Skip to content

Commit ed22857

Browse files
authored
auto no prefix needed (#404)
* auto no prefix needed * update * update test
1 parent 795072e commit ed22857

File tree

4 files changed

+28
-4
lines changed

4 files changed

+28
-4
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,12 @@
2020
EvaluationRow,
2121
EvaluationThreshold,
2222
EvaluationThresholdDict,
23-
EvaluateResult,
2423
Status,
2524
EPParameters,
2625
)
2726
from eval_protocol.pytest.dual_mode_wrapper import create_dual_mode_wrapper
2827
from eval_protocol.pytest.evaluation_test_postprocess import postprocess
29-
from eval_protocol.pytest.execution import execute_pytest, execute_pytest_with_exception_handling
28+
from eval_protocol.pytest.execution import execute_pytest_with_exception_handling
3029
from eval_protocol.pytest.priority_scheduler import execute_priority_rollouts
3130
from eval_protocol.pytest.generate_parameter_combinations import (
3231
ParameterizedTestKwargs,
@@ -56,6 +55,7 @@
5655
AggregationMethod,
5756
add_cost_metrics,
5857
log_eval_status_and_rows,
58+
normalize_fireworks_model,
5959
parse_ep_completion_params,
6060
parse_ep_completion_params_overwrite,
6161
parse_ep_max_concurrent_rollouts,
@@ -205,6 +205,7 @@ def evaluation_test(
205205
max_dataset_rows = parse_ep_max_rows(max_dataset_rows)
206206
completion_params = parse_ep_completion_params(completion_params)
207207
completion_params = parse_ep_completion_params_overwrite(completion_params)
208+
completion_params = [normalize_fireworks_model(cp) for cp in completion_params]
208209
original_completion_params = completion_params
209210
passed_threshold = parse_ep_passed_threshold(passed_threshold)
210211
data_loaders = parse_ep_dataloaders(data_loaders)
@@ -365,6 +366,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
365366
row.input_metadata.row_id = generate_id(seed=0, index=index)
366367

367368
completion_params = kwargs["completion_params"] if "completion_params" in kwargs else None
369+
completion_params = normalize_fireworks_model(completion_params)
368370
# Create eval metadata with test function info and current commit hash
369371
eval_metadata = EvalMetadata(
370372
name=test_func.__name__,

eval_protocol/pytest/evaluation_test_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,3 +619,22 @@ def build_rollout_processor_config(
619619
server_script_path=None,
620620
kwargs=rollout_processor_kwargs,
621621
)
622+
623+
624+
def normalize_fireworks_model(completion_params: CompletionParams | None) -> CompletionParams | None:
625+
"""Fireworks model names like 'accounts/<org>/models/<model>' need the fireworks_ai/
626+
prefix when routing through LiteLLM. This function adds the prefix if missing.
627+
"""
628+
if completion_params is None:
629+
return None
630+
631+
model = completion_params.get("model")
632+
if (
633+
model
634+
and isinstance(model, str)
635+
and not model.startswith("fireworks_ai/")
636+
and re.match(r"^accounts/[^/]+/models/.+", model)
637+
):
638+
completion_params = completion_params.copy()
639+
completion_params["model"] = f"fireworks_ai/{model}"
640+
return completion_params

tests/pytest/test_pydantic_agent.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010

1111

1212
def agent_factory(config: RolloutProcessorConfig) -> Agent:
13-
model = OpenAIChatModel(config.completion_params["model"], provider="fireworks")
13+
model_name = config.completion_params["model"]
14+
if model_name.startswith("fireworks_ai/"):
15+
model_name = model_name[len("fireworks_ai/") :]
16+
model = OpenAIChatModel(model_name, provider="fireworks")
1417
return Agent(model=model)
1518

1619

tests/remote_server/test_remote_fireworks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def rows() -> List[EvaluationRow]:
105105

106106
@pytest.mark.parametrize(
107107
"completion_params",
108-
[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", "temperature": 0.5}],
108+
[{"model": "accounts/fireworks/models/gpt-oss-120b", "temperature": 0.5}],
109109
)
110110
@evaluation_test(
111111
data_loaders=DynamicDataLoader(

0 commit comments

Comments
 (0)