Skip to content

Commit 8acdc35

Browse files
author
Dylan Huang
committed
Merge branch 'main' into dhuang/dxe-478-implement-evaluator-versions
2 parents 2076f0a + 8349922 commit 8acdc35

10 files changed

+103
-68
lines changed

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from openai.types import CompletionUsage
2323
from eval_protocol.pytest.rollout_processor import RolloutProcessor
2424
from eval_protocol.pytest.types import Dataset, RolloutProcessorConfig
25+
from eval_protocol.pytest.utils import normalize_fireworks_model_for_litellm
2526
from pydantic import BaseModel
2627
from typing import Optional
2728

@@ -251,8 +252,11 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
251252
"""Process a single row with agent rollout."""
252253
start_time = time.perf_counter()
253254

255+
# Normalize Fireworks model names for LiteLLM routing
256+
completion_params = normalize_fireworks_model_for_litellm(row.input_metadata.completion_params) or {}
257+
row.input_metadata.completion_params = completion_params
254258
agent = Agent(
255-
model=row.input_metadata.completion_params["model"],
259+
model=completion_params["model"],
256260
row=row,
257261
config_path=config.mcp_config_path,
258262
logger=config.logger,

eval_protocol/pytest/default_klavis_sandbox_rollout_processor.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from eval_protocol.models import EvaluationRow
1212
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1313
from eval_protocol.pytest.types import RolloutProcessorConfig
14+
from eval_protocol.pytest.utils import normalize_fireworks_model_for_litellm
1415

1516
from eval_protocol.pytest.default_agent_rollout_processor import Agent
1617
from klavis import Klavis
@@ -30,15 +31,15 @@ def __init__(
3031
self.server_name = server_name
3132
self.initialize_data_factory = initialize_data_factory
3233
self.klavis_client = Klavis(api_key=os.environ.get("KLAVIS_API_KEY"))
33-
34+
3435
def _init_sandbox(self) -> CreateSandboxResponse:
3536
try:
3637
server_name_enum = SandboxMcpServer(self.server_name)
3738
return self.klavis_client.sandbox.create_sandbox(server_name=server_name_enum)
3839
except Exception as e:
3940
logger.error(f"Error creating sandbox: {str(e)}", exc_info=True)
4041
raise
41-
42+
4243
@staticmethod
4344
def create_mcp_config(server_url: str, server_key: str = "main", auth_token: str | None = None) -> str:
4445
"""Create a temporary MCP config file and return its path."""
@@ -47,26 +48,24 @@ def create_mcp_config(server_url: str, server_key: str = "main", auth_token: str
4748
server_key: {
4849
"url": server_url,
4950
"transport": "streamable_http",
50-
**({"authorization": f"Bearer {auth_token}"} if auth_token else {})
51+
**({"authorization": f"Bearer {auth_token}"} if auth_token else {}),
5152
}
5253
}
5354
}
54-
55+
5556
# Create a temp file that persists for the session
5657
fd, path = tempfile.mkstemp(suffix=".json", prefix="mcp_config_")
57-
with os.fdopen(fd, 'w') as f:
58+
with os.fdopen(fd, "w") as f:
5859
json.dump(config, f)
5960
return path
6061

61-
def __call__(
62-
self, rows: List[EvaluationRow], config: RolloutProcessorConfig
63-
) -> List[asyncio.Task[EvaluationRow]]:
62+
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
6463
"""Process evaluation rows with Klavis sandbox lifecycle management"""
6564
semaphore = config.semaphore
6665

6766
async def process_row(row: EvaluationRow) -> EvaluationRow:
6867
"""Process a single row with complete sandbox lifecycle"""
69-
68+
7069
start_time = time.perf_counter()
7170
agent: Agent | None = None
7271
temp_config_path: str | None = None
@@ -88,25 +87,32 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
8887
if row.input_metadata is not None
8988
else None
9089
)
91-
90+
9291
if init_data:
93-
logger.info(f"Initializing {self.server_name} sandbox {sandbox.sandbox_id}")
92+
logger.info(f"Initializing {self.server_name} sandbox {sandbox.sandbox_id}") # pyright: ignore[reportOptionalMemberAccess]
9493
initialize_method = getattr(
95-
self.klavis_client.sandbox, f"initialize_{sandbox.server_name.value}_sandbox"
94+
self.klavis_client.sandbox,
95+
f"initialize_{sandbox.server_name.value}_sandbox", # pyright: ignore[reportOptionalMemberAccess]
9696
)
97-
init_response = initialize_method(sandbox_id=sandbox.sandbox_id, **init_data)
97+
init_response = initialize_method(sandbox_id=sandbox.sandbox_id, **init_data) # pyright: ignore[reportOptionalMemberAccess]
9898
logger.info(f"Initialization response: {init_response}")
99-
99+
100100
# Step 2: Create temporary MCP config with sandbox URL
101101
temp_config_path = self.create_mcp_config(
102-
server_url=sandbox.server_url, server_key=sandbox.server_name.value
102+
server_url=sandbox.server_url, # pyright: ignore[reportOptionalMemberAccess]
103+
server_key=sandbox.server_name.value, # pyright: ignore[reportOptionalMemberAccess]
103104
)
104105
logger.info(f"MCP config created: {temp_config_path}")
105106

106107
# Step 3: Run agent with sandbox MCP server
107-
logger.info(f"Running agent for row {row.execution_metadata.rollout_id} with {self.server_name} sandbox")
108+
logger.info(
109+
f"Running agent for row {row.execution_metadata.rollout_id} with {self.server_name} sandbox"
110+
)
111+
# Normalize Fireworks model names for LiteLLM routing
112+
completion_params = normalize_fireworks_model_for_litellm(row.input_metadata.completion_params) or {}
113+
row.input_metadata.completion_params = completion_params
108114
agent = Agent(
109-
model=row.input_metadata.completion_params["model"],
115+
model=completion_params["model"],
110116
row=row,
111117
config_path=temp_config_path,
112118
logger=config.logger,
@@ -124,16 +130,16 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
124130
logger.info(f"Agent execution completed for row {row.execution_metadata.rollout_id}")
125131

126132
# Step 4: Export sandbox data
127-
dump_method = getattr(self.klavis_client.sandbox, f"dump_{sandbox.server_name.value}_sandbox")
128-
dump_response = dump_method(sandbox_id=sandbox.sandbox_id)
133+
dump_method = getattr(self.klavis_client.sandbox, f"dump_{sandbox.server_name.value}_sandbox") # pyright: ignore[reportOptionalMemberAccess]
134+
dump_response = dump_method(sandbox_id=sandbox.sandbox_id) # pyright: ignore[reportOptionalMemberAccess]
129135
sandbox_data = dump_response.data
130136
logger.info(f"Sandbox data: {sandbox_data}")
131137

132138
# Store sandbox data in row metadata for evaluation
133139
if not row.execution_metadata.extra:
134140
row.execution_metadata.extra = {}
135141
row.execution_metadata.extra["sandbox_data"] = sandbox_data
136-
row.execution_metadata.extra["sandbox_id"] = sandbox.sandbox_id
142+
row.execution_metadata.extra["sandbox_id"] = sandbox.sandbox_id # pyright: ignore[reportOptionalMemberAccess]
137143
row.execution_metadata.extra["server_name"] = self.server_name
138144

139145
except Exception as e:
@@ -149,7 +155,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
149155
await agent.mcp_client.cleanup()
150156
if temp_config_path and os.path.exists(temp_config_path):
151157
os.unlink(temp_config_path)
152-
158+
153159
# Release sandbox
154160
if sandbox and sandbox.sandbox_id:
155161
try:

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from eval_protocol.models import EvaluationRow
1515
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1616
from eval_protocol.pytest.types import RolloutProcessorConfig, ServerMode
17+
from eval_protocol.pytest.utils import normalize_fireworks_model_for_litellm
1718

1819

1920
class MCPServerManager:
@@ -280,17 +281,20 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
280281
"Cannot retry without existing server/environments. Call with start_server=True first."
281282
)
282283

283-
model_id = str((config.completion_params.get("model") if config.completion_params else None) or "gpt-4o-mini")
284-
temperature = config.completion_params.get("temperature", 0.0)
285-
max_tokens = config.completion_params.get("max_tokens", 4096)
284+
# Normalize Fireworks model names for LiteLLM routing
285+
completion_params = normalize_fireworks_model_for_litellm(config.completion_params) or {}
286+
# Update all rows with normalized completion_params
287+
for row in rows:
288+
row.input_metadata.completion_params = completion_params
289+
model_id = str(completion_params.get("model") or "gpt-4o-mini")
290+
temperature = completion_params.get("temperature", 0.0)
291+
max_tokens = completion_params.get("max_tokens", 4096)
286292

287293
# Pass all other completion_params (e.g. stream=True) via kwargs
288294
other_params = {
289-
k: v
290-
for k, v in (config.completion_params or {}).items()
291-
if k not in ["model", "temperature", "max_tokens", "extra_body"]
295+
k: v for k, v in completion_params.items() if k not in ["model", "temperature", "max_tokens", "extra_body"]
292296
}
293-
extra_body = config.completion_params.get("extra_body", {}) or {}
297+
extra_body = completion_params.get("extra_body", {}) or {}
294298

295299
self.policy = ep.LiteLLMPolicy(
296300
model_id=model_id,

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from openai.types import CompletionUsage
1818
from eval_protocol.pytest.rollout_processor import RolloutProcessor
1919
from eval_protocol.pytest.types import RolloutProcessorConfig
20+
from eval_protocol.pytest.utils import normalize_fireworks_model_for_litellm
2021

2122
logger = logging.getLogger(__name__)
2223

@@ -63,7 +64,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
6364
async def process_row(row: EvaluationRow) -> EvaluationRow:
6465
"""Process a single row asynchronously."""
6566
start_time = time.perf_counter()
66-
67+
6768
if len(row.messages) == 0:
6869
raise ValueError("Messages is empty. Please provide a non-empty dataset")
6970

@@ -77,7 +78,10 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
7778
# Use the Message class method that excludes unsupported fields
7879
messages_payload = [message.dump_mdoel_for_chat_completion_request() for message in messages_for_request]
7980

80-
request_params = {"messages": messages_payload, **config.completion_params}
81+
# Normalize Fireworks model names for LiteLLM routing
82+
completion_params = normalize_fireworks_model_for_litellm(config.completion_params) or {}
83+
row.input_metadata.completion_params = completion_params
84+
request_params = {"messages": messages_payload, **completion_params}
8185
# Ensure caching is disabled only for this request (review feedback)
8286
request_params["cache"] = {"no-cache": True}
8387

@@ -87,18 +91,15 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
8791
# Single-level reasoning effort: expect `reasoning_effort` only
8892
effort_val = None
8993

90-
if (
91-
"reasoning_effort" in config.completion_params
92-
and config.completion_params["reasoning_effort"] is not None
93-
):
94-
effort_val = str(config.completion_params["reasoning_effort"]) # flat shape
94+
if "reasoning_effort" in completion_params and completion_params["reasoning_effort"] is not None:
95+
effort_val = str(completion_params["reasoning_effort"]) # flat shape
9596
elif (
96-
isinstance(config.completion_params.get("extra_body"), dict)
97-
and "reasoning_effort" in config.completion_params["extra_body"]
98-
and config.completion_params["extra_body"]["reasoning_effort"] is not None
97+
isinstance(completion_params.get("extra_body"), dict)
98+
and "reasoning_effort" in completion_params["extra_body"]
99+
and completion_params["extra_body"]["reasoning_effort"] is not None
99100
):
100101
# Accept if user passed it directly inside extra_body
101-
effort_val = str(config.completion_params["extra_body"]["reasoning_effort"]) # already in extra_body
102+
effort_val = str(completion_params["extra_body"]["reasoning_effort"]) # already in extra_body
102103

103104
if effort_val:
104105
# Always under extra_body so LiteLLM forwards to provider-specific param set

eval_protocol/pytest/evaluation_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
AggregationMethod,
5656
add_cost_metrics,
5757
log_eval_status_and_rows,
58-
normalize_fireworks_model,
5958
parse_ep_completion_params,
6059
parse_ep_completion_params_overwrite,
6160
parse_ep_max_concurrent_rollouts,
@@ -205,7 +204,6 @@ def evaluation_test(
205204
max_dataset_rows = parse_ep_max_rows(max_dataset_rows)
206205
completion_params = parse_ep_completion_params(completion_params)
207206
completion_params = parse_ep_completion_params_overwrite(completion_params)
208-
completion_params = [normalize_fireworks_model(cp) for cp in completion_params]
209207
original_completion_params = completion_params
210208
passed_threshold = parse_ep_passed_threshold(passed_threshold)
211209
data_loaders = parse_ep_dataloaders(data_loaders)
@@ -366,7 +364,6 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
366364
row.input_metadata.row_id = generate_id(seed=0, index=index)
367365

368366
completion_params = kwargs["completion_params"] if "completion_params" in kwargs else None
369-
completion_params = normalize_fireworks_model(completion_params)
370367
# Create eval metadata with test function info and current commit hash
371368
eval_metadata = EvalMetadata(
372369
name=test_func.__name__,

eval_protocol/pytest/evaluation_test_utils.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -619,22 +619,3 @@ def build_rollout_processor_config(
619619
server_script_path=None,
620620
kwargs=rollout_processor_kwargs,
621621
)
622-
623-
624-
def normalize_fireworks_model(completion_params: CompletionParams | None) -> CompletionParams | None:
625-
"""Fireworks model names like 'accounts/<org>/models/<model>' need the fireworks_ai/
626-
prefix when routing through LiteLLM. This function adds the prefix if missing.
627-
"""
628-
if completion_params is None:
629-
return None
630-
631-
model = completion_params.get("model")
632-
if (
633-
model
634-
and isinstance(model, str)
635-
and not model.startswith("fireworks_ai/")
636-
and re.match(r"^accounts/[^/]+/models/.+", model)
637-
):
638-
completion_params = completion_params.copy()
639-
completion_params["model"] = f"fireworks_ai/{model}"
640-
return completion_params

eval_protocol/pytest/github_action_rollout_processor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .rollout_processor import RolloutProcessor
1212
from .types import RolloutProcessorConfig
1313
from .tracing_utils import default_fireworks_output_data_loader, build_init_request, update_row_with_remote_trace
14+
from .utils import normalize_fireworks_model_for_litellm
1415

1516

1617
class GithubActionRolloutProcessor(RolloutProcessor):
@@ -80,6 +81,12 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
8081
if row.input_metadata.row_id is None:
8182
raise ValueError("Row ID is required in GithubActionRolloutProcessor")
8283

84+
# Normalize Fireworks model names for LiteLLM routing
85+
config.completion_params = (
86+
normalize_fireworks_model_for_litellm(config.completion_params) or config.completion_params
87+
)
88+
row.input_metadata.completion_params = config.completion_params
89+
8390
init_request = build_init_request(row, config, self.model_base_url)
8491

8592
def _dispatch_workflow():

eval_protocol/pytest/openenv_rollout_processor.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from eval_protocol.models import EvaluationRow, Message
2525
from eval_protocol.pytest.rollout_processor import RolloutProcessor
2626
from eval_protocol.pytest.types import RolloutProcessorConfig
27+
from eval_protocol.pytest.utils import normalize_fireworks_model_for_litellm
2728

2829
logger = logging.getLogger(__name__)
2930

@@ -177,15 +178,18 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
177178
logger.debug("[OpenEnvRolloutProcessor] Environment client created successfully")
178179

179180
try:
181+
# Normalize Fireworks model names for LiteLLM routing
182+
completion_params = normalize_fireworks_model_for_litellm(config.completion_params) or {}
183+
row.input_metadata.completion_params = completion_params
180184
# Get model config
181-
raw_model = config.completion_params.get("model", "gpt-4o-mini")
185+
raw_model = completion_params.get("model", "gpt-4o-mini")
182186
model = raw_model
183-
temperature = config.completion_params.get("temperature", 0.0)
184-
max_tokens = config.completion_params.get("max_tokens", 100)
187+
temperature = completion_params.get("temperature", 0.0)
188+
max_tokens = completion_params.get("max_tokens", 100)
185189
# Optional: direct routing or provider overrides (e.g., base_url, api_key, top_p, stop, etc.)
186-
base_url = config.completion_params.get("base_url")
190+
base_url = completion_params.get("base_url")
187191
# Forward any extra completion params to LiteLLMPolicy (they will be sent per-request)
188-
extra_params: Dict[str, Any] = dict(config.completion_params or {})
192+
extra_params: Dict[str, Any] = dict(completion_params)
189193
for _k in ("model", "temperature", "max_tokens", "base_url"):
190194
try:
191195
extra_params.pop(_k, None)
@@ -247,7 +251,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
247251
messages = list(row.messages) # Copy initial messages
248252
# Inject system prompt if provided and not already present
249253
has_system = any(m.role == "system" for m in messages)
250-
system_prompt = config.completion_params.get("system_prompt")
254+
system_prompt = completion_params.get("system_prompt")
251255
if system_prompt and not has_system:
252256
messages.insert(0, Message(role="system", content=system_prompt))
253257
usage = {

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .rollout_processor import RolloutProcessor
1212
from .types import RolloutProcessorConfig
1313
from .tracing_utils import default_fireworks_output_data_loader, build_init_request, update_row_with_remote_trace
14+
from .utils import normalize_fireworks_model_for_litellm
1415
import logging
1516

1617
import os
@@ -87,6 +88,12 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
8788
if row.input_metadata.row_id is None:
8889
raise ValueError("Row ID is required in RemoteRolloutProcessor")
8990

91+
# Normalize Fireworks model names for LiteLLM routing
92+
config.completion_params = (
93+
normalize_fireworks_model_for_litellm(config.completion_params) or config.completion_params
94+
)
95+
row.input_metadata.completion_params = config.completion_params
96+
9097
init_payload = build_init_request(row, config, model_base_url)
9198

9299
# Fire-and-poll

0 commit comments

Comments
 (0)