Skip to content

Commit 5e7a5fa

Browse files
author
Dylan Huang
committed
Merge branch 'main' into dhuang/dxe-478-implement-evaluator-versions
2 parents 532e071 + 35db8e2 commit 5e7a5fa

19 files changed

+409
-139
lines changed

eval_protocol/evaluation.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(
3434

3535
@staticmethod
3636
def _parse_ignore_file(ignore_path: str) -> List[str]:
37-
"""Parse .gitignore or .dockerignore and return patterns."""
37+
"""Parse .gitignore and return patterns."""
3838
patterns = []
3939
if not os.path.exists(ignore_path):
4040
return patterns
@@ -129,8 +129,7 @@ def _create_tar_gz_with_ignores(output_path: str, source_dir: str) -> int:
129129

130130
source_path = Path(source_dir)
131131
gitignore_patterns = Evaluator._parse_ignore_file(str(source_path / ".gitignore"))
132-
dockerignore_patterns = Evaluator._parse_ignore_file(str(source_path / ".dockerignore"))
133-
all_ignore_patterns = gitignore_patterns + dockerignore_patterns
132+
all_ignore_patterns = gitignore_patterns
134133

135134
logger.info(f"Creating tar.gz with {len(all_ignore_patterns)} ignore patterns")
136135

eval_protocol/pytest/evaluation_test.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,12 @@
2020
EvaluationRow,
2121
EvaluationThreshold,
2222
EvaluationThresholdDict,
23-
EvaluateResult,
2423
Status,
2524
EPParameters,
2625
)
2726
from eval_protocol.pytest.dual_mode_wrapper import create_dual_mode_wrapper
2827
from eval_protocol.pytest.evaluation_test_postprocess import postprocess
29-
from eval_protocol.pytest.execution import execute_pytest, execute_pytest_with_exception_handling
28+
from eval_protocol.pytest.execution import execute_pytest_with_exception_handling
3029
from eval_protocol.pytest.priority_scheduler import execute_priority_rollouts
3130
from eval_protocol.pytest.generate_parameter_combinations import (
3231
ParameterizedTestKwargs,
@@ -56,6 +55,7 @@
5655
AggregationMethod,
5756
add_cost_metrics,
5857
log_eval_status_and_rows,
58+
normalize_fireworks_model,
5959
parse_ep_completion_params,
6060
parse_ep_completion_params_overwrite,
6161
parse_ep_max_concurrent_rollouts,
@@ -93,8 +93,8 @@ def evaluation_test(
9393
filtered_row_ids: Sequence[str] | None = None,
9494
max_dataset_rows: int | None = None,
9595
mcp_config_path: str | None = None,
96-
max_concurrent_rollouts: int = 8,
97-
max_concurrent_evaluations: int = 64,
96+
max_concurrent_rollouts: int = 96,
97+
max_concurrent_evaluations: int = 96,
9898
server_script_path: str | None = None,
9999
steps: int = 30,
100100
mode: EvaluationTestMode = "pointwise",
@@ -205,6 +205,7 @@ def evaluation_test(
205205
max_dataset_rows = parse_ep_max_rows(max_dataset_rows)
206206
completion_params = parse_ep_completion_params(completion_params)
207207
completion_params = parse_ep_completion_params_overwrite(completion_params)
208+
completion_params = [normalize_fireworks_model(cp) for cp in completion_params]
208209
original_completion_params = completion_params
209210
passed_threshold = parse_ep_passed_threshold(passed_threshold)
210211
data_loaders = parse_ep_dataloaders(data_loaders)
@@ -365,6 +366,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
365366
row.input_metadata.row_id = generate_id(seed=0, index=index)
366367

367368
completion_params = kwargs["completion_params"] if "completion_params" in kwargs else None
369+
completion_params = normalize_fireworks_model(completion_params)
368370
# Create eval metadata with test function info and current commit hash
369371
eval_metadata = EvalMetadata(
370372
name=test_func.__name__,
@@ -409,21 +411,22 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
409411

410412
rollout_processor.setup()
411413

412-
use_priority_scheduler = (
413-
(
414-
os.environ.get("EP_USE_PRIORITY_SCHEDULER", "0") == "1"
415-
and not isinstance(rollout_processor, MCPGymRolloutProcessor)
416-
)
417-
)
414+
use_priority_scheduler = os.environ.get(
415+
"EP_USE_PRIORITY_SCHEDULER", "0"
416+
) == "1" and not isinstance(rollout_processor, MCPGymRolloutProcessor)
418417

419418
if use_priority_scheduler:
420419
microbatch_output_size = os.environ.get("EP_MICRO_BATCH_OUTPUT_SIZE", None)
421420
output_dir = os.environ.get("EP_OUTPUT_DIR", None)
422421
if microbatch_output_size and output_dir:
423-
output_buffer = MicroBatchDataBuffer(num_runs=num_runs, batch_size=int(microbatch_output_size), output_path_template=os.path.join(output_dir, "buffer_{index}.jsonl"))
422+
output_buffer = MicroBatchDataBuffer(
423+
num_runs=num_runs,
424+
batch_size=int(microbatch_output_size),
425+
output_path_template=os.path.join(output_dir, "buffer_{index}.jsonl"),
426+
)
424427
else:
425428
output_buffer = None
426-
429+
427430
try:
428431
priority_results = await execute_priority_rollouts(
429432
dataset=data,
@@ -441,12 +444,12 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
441444
finally:
442445
if output_buffer:
443446
await output_buffer.close()
444-
447+
445448
for res in priority_results:
446449
run_idx = (res.execution_metadata.extra or {}).get("run_index", 0)
447450
if run_idx < len(all_results):
448451
all_results[run_idx].append(res)
449-
452+
450453
processed_rows_in_run.append(res)
451454

452455
postprocess(
@@ -462,6 +465,7 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
462465
)
463466

464467
else:
468+
465469
async def execute_run(run_idx: int, config: RolloutProcessorConfig):
466470
nonlocal all_results
467471

@@ -506,9 +510,7 @@ async def _execute_pointwise_eval_with_semaphore(
506510
raise ValueError(
507511
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
508512
)
509-
result.execution_metadata.eval_duration_seconds = (
510-
time.perf_counter() - start_time
511-
)
513+
result.execution_metadata.eval_duration_seconds = time.perf_counter() - start_time
512514
return result
513515

514516
async def _execute_groupwise_eval_with_semaphore(
@@ -519,7 +521,9 @@ async def _execute_groupwise_eval_with_semaphore(
519521
evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {}
520522
primary_rollout_id = rows[0].execution_metadata.rollout_id if rows else None
521523
group_rollout_ids = [
522-
r.execution_metadata.rollout_id for r in rows if r.execution_metadata.rollout_id
524+
r.execution_metadata.rollout_id
525+
for r in rows
526+
if r.execution_metadata.rollout_id
523527
]
524528
async with rollout_logging_context(
525529
primary_rollout_id or "",
@@ -596,7 +600,9 @@ async def _collect_result(config, lst):
596600
row_groups[row.input_metadata.row_id].append(row)
597601
tasks = []
598602
for _, rows in row_groups.items():
599-
tasks.append(asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows)))
603+
tasks.append(
604+
asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows))
605+
)
600606
results = []
601607
for task in tasks:
602608
res = await task
@@ -692,9 +698,9 @@ async def _collect_result(config, lst):
692698
# For other processors, create all tasks at once and run in parallel
693699
# Concurrency is now controlled by the shared semaphore in each rollout processor
694700
await run_tasks_with_run_progress(execute_run, num_runs, config)
695-
701+
696702
experiment_duration_seconds = time.perf_counter() - experiment_start_time
697-
703+
698704
# for groupwise mode, the result contains eval output from multiple completion_params, we need to differentiate them
699705
# rollout_id is used to differentiate the result from different completion_params
700706
if mode == "groupwise":
@@ -730,15 +736,12 @@ async def _collect_result(config, lst):
730736
experiment_duration_seconds,
731737
)
732738

733-
734-
735739
if not all(r.evaluation_result is not None for run_results in all_results for r in run_results):
736740
raise AssertionError(
737741
"Some EvaluationRow instances are missing evaluation_result. "
738742
"Your @evaluation_test function must set `row.evaluation_result`"
739743
)
740744

741-
742745
except AssertionError:
743746
_log_eval_error(
744747
Status.eval_finished(),

eval_protocol/pytest/evaluation_test_utils.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ async def execute_row_with_backoff_retry(row: EvaluationRow) -> EvaluationRow:
371371
retry_config = replace(config, kwargs={**(config.kwargs or {}), "start_server": False})
372372
retry_tasks = rollout_processor([row], retry_config)
373373
result = await retry_tasks[0]
374-
374+
375375
# Apply post-processing quality checks if configured
376376
# This must be inside the retry function so ResponseQualityError can trigger retries
377377
if config.post_processor is not None:
@@ -380,7 +380,7 @@ async def execute_row_with_backoff_retry(row: EvaluationRow) -> EvaluationRow:
380380
except ResponseQualityError as quality_error:
381381
# Re-raise ResponseQualityError to trigger retry logic
382382
raise quality_error
383-
383+
384384
return result
385385

386386
async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: EvaluationRow) -> EvaluationRow:
@@ -464,6 +464,7 @@ async def execute_row_with_backoff_and_log(
464464
yield result
465465

466466
finally:
467+
await rollout_processor.acleanup()
467468
rollout_processor.cleanup()
468469

469470

@@ -618,3 +619,22 @@ def build_rollout_processor_config(
618619
server_script_path=None,
619620
kwargs=rollout_processor_kwargs,
620621
)
622+
623+
624+
def normalize_fireworks_model(completion_params: CompletionParams | None) -> CompletionParams | None:
625+
"""Fireworks model names like 'accounts/<org>/models/<model>' need the fireworks_ai/
626+
prefix when routing through LiteLLM. This function adds the prefix if missing.
627+
"""
628+
if completion_params is None:
629+
return None
630+
631+
model = completion_params.get("model")
632+
if (
633+
model
634+
and isinstance(model, str)
635+
and not model.startswith("fireworks_ai/")
636+
and re.match(r"^accounts/[^/]+/models/.+", model)
637+
):
638+
completion_params = completion_params.copy()
639+
completion_params["model"] = f"fireworks_ai/{model}"
640+
return completion_params

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 57 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
import asyncio
22
import time
3-
from typing import Any, Dict, List, Optional
3+
from typing import List, Optional
44

5-
import requests
5+
import aiohttp
66

77
from eval_protocol.models import EvaluationRow, Status
8-
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
9-
from eval_protocol.types.remote_rollout_processor import (
10-
DataLoaderConfig,
11-
)
128
from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter
139
from eval_protocol.exceptions import exception_for_status_code
1410

@@ -51,6 +47,12 @@ def __init__(
5147
self._poll_interval = poll_interval
5248
self._timeout_seconds = timeout_seconds
5349
self._tracing_adapter = FireworksTracingAdapter(base_url=self._model_base_url)
50+
self._session: Optional[aiohttp.ClientSession] = None
51+
52+
def _get_or_create_session(self) -> aiohttp.ClientSession:
53+
if self._session is None or self._session.closed:
54+
self._session = aiohttp.ClientSession()
55+
return self._session
5456

5557
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
5658
tasks: List[asyncio.Task[EvaluationRow]] = []
@@ -88,48 +90,26 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
8890
init_payload = build_init_request(row, config, model_base_url)
8991

9092
# Fire-and-poll
91-
def _post_init() -> None:
92-
url = f"{remote_base_url}/init"
93-
try:
94-
r = requests.post(url, json=init_payload.model_dump(), timeout=300)
95-
r.raise_for_status()
96-
except requests.exceptions.Timeout:
97-
raise TimeoutError(
98-
f"The /init endpoint tried {url} with {init_payload.model_dump()} but timed out after 300 seconds."
99-
)
100-
101-
await asyncio.to_thread(_post_init)
93+
init_url = f"{remote_base_url}/init"
94+
95+
timeout_init = aiohttp.ClientTimeout(total=300)
96+
97+
try:
98+
session = self._get_or_create_session()
99+
async with session.post(init_url, json=init_payload.model_dump(), timeout=timeout_init) as resp:
100+
if resp.status >= 400:
101+
body = await resp.text()
102+
raise RuntimeError(f"Remote /init failed (HTTP {resp.status}): {body}")
103+
resp.raise_for_status()
104+
await resp.read() # Drain the response body and release the connection back to the pool
105+
except asyncio.TimeoutError:
106+
raise TimeoutError(
107+
f"The /init endpoint tried {init_url} with {init_payload.model_dump()} but timed out after 300 seconds."
108+
)
102109

103-
terminated = False
104110
deadline = time.time() + timeout_seconds
105111

106-
def _get_status() -> Dict[str, Any]:
107-
url = f"{remote_base_url}/status"
108-
r = requests.get(url, params={"rollout_id": row.execution_metadata.rollout_id}, timeout=15)
109-
r.raise_for_status()
110-
return r.json()
111-
112-
continue_polling_status = True
113112
while time.time() < deadline:
114-
try:
115-
if continue_polling_status:
116-
status = await asyncio.to_thread(_get_status)
117-
terminated = bool(status.get("terminated", False))
118-
if terminated:
119-
break
120-
except requests.exceptions.HTTPError as e:
121-
if e.response is not None and e.response.status_code == 404:
122-
# 404 means server doesn't implement /status endpoint, stop polling
123-
logger.debug(
124-
f"Server doesn't implement /status endpoint (404), stopping status polling for rollout {row.execution_metadata.rollout_id}"
125-
)
126-
continue_polling_status = False
127-
else:
128-
raise
129-
except Exception:
130-
# For all other exceptions, raise them
131-
raise
132-
133113
# Search Fireworks tracing logs for completion (run in thread to avoid blocking event loop)
134114
completed_logs = await asyncio.to_thread(
135115
self._tracing_adapter.search_logs, tags=[f"rollout_id:{row.execution_metadata.rollout_id}"]
@@ -142,9 +122,20 @@ def _get_status() -> Dict[str, Any]:
142122
status_logs.append(log)
143123

144124
if status_logs:
125+
if len(status_logs) > 1:
126+
logger.warning(
127+
"Found %s status logs for rollout %s; expected at most 1. Using the first one: %s",
128+
len(status_logs),
129+
row.execution_metadata.rollout_id,
130+
status_logs[0],
131+
)
145132
# Use the first log with status information
146133
status_log = status_logs[0]
147134
status_dict = status_log.get("status")
135+
raw_extras = status_log.get("extras") or {}
136+
status_extras = {
137+
k: v for k, v in raw_extras.items() if k not in ("logger_name", "level", "timestamp")
138+
}
148139

149140
logger.info(
150141
f"Found status log for rollout {row.execution_metadata.rollout_id}: {status_log.get('message', '')}"
@@ -169,6 +160,11 @@ def _get_status() -> Dict[str, Any]:
169160
details=status_details,
170161
)
171162

163+
if row.execution_metadata.extra:
164+
row.execution_metadata.extra.update(status_extras)
165+
else:
166+
row.execution_metadata.extra = status_extras
167+
172168
logger.info("Stopping polling for rollout %s", row.execution_metadata.rollout_id)
173169
break
174170

@@ -200,5 +196,21 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
200196
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
201197
return tasks
202198

199+
async def acleanup(self) -> None:
200+
"""Async cleanup - preferred when you can await."""
201+
if self._session and not self._session.closed:
202+
await self._session.close()
203+
203204
def cleanup(self) -> None:
204-
return None
205+
"""Sync cleanup - best-effort, schedules close if event loop is running."""
206+
if self._session and not self._session.closed:
207+
try:
208+
loop = asyncio.get_running_loop()
209+
loop.create_task(self._session.close())
210+
except RuntimeError:
211+
# No running event loop - can't safely close the session.
212+
# The session will be garbage collected eventually, but warn about it.
213+
logger.warning(
214+
"RemoteRolloutProcessor.cleanup() called outside of async context. "
215+
"Session may not be properly closed. Use `await processor.acleanup()` when possible."
216+
)

eval_protocol/pytest/rollout_processor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) ->
1919
"""Process evaluation rows and return async tasks. Must be implemented by subclasses."""
2020
pass
2121

22+
async def acleanup(self) -> None:
23+
"""Async cleanup - preferred when you can await."""
24+
pass
25+
2226
def cleanup(self) -> None:
2327
"""Cleanup resources. Override in subclasses if cleanup is needed."""
2428
pass

0 commit comments

Comments
 (0)