Skip to content

Commit b6eac99

Browse files
committed
update
1 parent dda36ea commit b6eac99

9 files changed

+25
-8
lines changed

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
254254

255255
# Normalize Fireworks model names for LiteLLM routing
256256
completion_params = normalize_fireworks_model_for_litellm(row.input_metadata.completion_params) or {}
257+
row.input_metadata.completion_params = completion_params
257258
agent = Agent(
258259
model=completion_params["model"],
259260
row=row,

eval_protocol/pytest/default_klavis_sandbox_rollout_processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
110110
)
111111
# Normalize Fireworks model names for LiteLLM routing
112112
completion_params = normalize_fireworks_model_for_litellm(row.input_metadata.completion_params) or {}
113+
row.input_metadata.completion_params = completion_params
113114
agent = Agent(
114115
model=completion_params["model"],
115116
row=row,

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,9 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
283283

284284
# Normalize Fireworks model names for LiteLLM routing
285285
completion_params = normalize_fireworks_model_for_litellm(config.completion_params) or {}
286+
# Update all rows with normalized completion_params
287+
for row in rows:
288+
row.input_metadata.completion_params = completion_params
286289
model_id = str(completion_params.get("model") or "gpt-4o-mini")
287290
temperature = completion_params.get("temperature", 0.0)
288291
max_tokens = completion_params.get("max_tokens", 4096)

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
8080

8181
# Normalize Fireworks model names for LiteLLM routing
8282
completion_params = normalize_fireworks_model_for_litellm(config.completion_params) or {}
83+
row.input_metadata.completion_params = completion_params
8384
request_params = {"messages": messages_payload, **completion_params}
8485
# Ensure caching is disabled only for this request (review feedback)
8586
request_params["cache"] = {"no-cache": True}

eval_protocol/pytest/github_action_rollout_processor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .rollout_processor import RolloutProcessor
1212
from .types import RolloutProcessorConfig
1313
from .tracing_utils import default_fireworks_output_data_loader, build_init_request, update_row_with_remote_trace
14+
from .utils import normalize_fireworks_model_for_litellm
1415

1516

1617
class GithubActionRolloutProcessor(RolloutProcessor):
@@ -80,8 +81,16 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
8081
if row.input_metadata.row_id is None:
8182
raise ValueError("Row ID is required in GithubActionRolloutProcessor")
8283

84+
# Normalize Fireworks model names for LiteLLM routing
85+
config.completion_params = (
86+
normalize_fireworks_model_for_litellm(config.completion_params) or config.completion_params
87+
)
88+
8389
init_request = build_init_request(row, config, self.model_base_url)
8490

91+
# Update row with normalized completion_params for downstream access
92+
row.input_metadata.completion_params = init_request.completion_params
93+
8594
def _dispatch_workflow():
8695
url = f"https://api.github.com/repos/{self.owner}/{self.repo}/actions/workflows/{self.workflow_id}/dispatches"
8796

eval_protocol/pytest/openenv_rollout_processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
180180
try:
181181
# Normalize Fireworks model names for LiteLLM routing
182182
completion_params = normalize_fireworks_model_for_litellm(config.completion_params) or {}
183+
row.input_metadata.completion_params = completion_params
183184
# Get model config
184185
raw_model = completion_params.get("model", "gpt-4o-mini")
185186
model = raw_model

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .rollout_processor import RolloutProcessor
1212
from .types import RolloutProcessorConfig
1313
from .tracing_utils import default_fireworks_output_data_loader, build_init_request, update_row_with_remote_trace
14+
from .utils import normalize_fireworks_model_for_litellm
1415
import logging
1516

1617
import os
@@ -87,8 +88,16 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
8788
if row.input_metadata.row_id is None:
8889
raise ValueError("Row ID is required in RemoteRolloutProcessor")
8990

91+
# Normalize Fireworks model names for LiteLLM routing
92+
config.completion_params = (
93+
normalize_fireworks_model_for_litellm(config.completion_params) or config.completion_params
94+
)
95+
9096
init_payload = build_init_request(row, config, model_base_url)
9197

98+
# Update row with normalized completion_params for downstream access
99+
row.input_metadata.completion_params = init_payload.completion_params
100+
92101
# Fire-and-poll
93102
init_url = f"{remote_base_url}/init"
94103

eval_protocol/pytest/tracing_utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation
1313
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig, RolloutMetadata, InitRequest
1414
from eval_protocol.pytest.types import RolloutProcessorConfig
15-
from eval_protocol.pytest.utils import normalize_fireworks_model_for_litellm
1615

1716

1817
def default_fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
@@ -96,12 +95,6 @@ def build_init_request(
9695
if isinstance(row_cp, dict):
9796
completion_params_dict.update(row_cp)
9897

99-
# Normalize Fireworks model names for LiteLLM routing
100-
completion_params_dict = normalize_fireworks_model_for_litellm(completion_params_dict) or {}
101-
102-
# Update row's completion_params with normalized value
103-
row.input_metadata.completion_params = completion_params_dict
104-
10598
# Validate model is present
10699
if not completion_params_dict.get("model"):
107100
raise ValueError("Model must be provided in completion_params")

eval_protocol/pytest/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Utility functions for model name handling."""
22

33
import re
4-
from typing import Optional
54

65
from eval_protocol.models import CompletionParams
76

0 commit comments

Comments
 (0)