Skip to content

Commit da7fc9d

Browse files
authored
take out output dataloader (#396)
* take out output dataloader * update * test * fix * fix test
1 parent cee95a9 commit da7fc9d

File tree

8 files changed

+151
-171
lines changed

8 files changed

+151
-171
lines changed

eval_protocol/pytest/github_action_rollout_processor.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import asyncio
22
import os
33
import time
4-
from typing import Any, Callable, Dict, List, Optional
4+
from typing import Any, Dict, List, Optional
55
import json
66
import requests
77
from datetime import datetime, timezone, timedelta
88
from eval_protocol.models import EvaluationRow, Status
99
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
10-
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig
1110

1211
from .rollout_processor import RolloutProcessor
1312
from .types import RolloutProcessorConfig
@@ -21,7 +20,7 @@ class GithubActionRolloutProcessor(RolloutProcessor):
2120
Expected GitHub Actions workflow:
2221
- Workflow dispatch with inputs: completion_params, metadata (JSON), model_base_url, api_key
2322
- Workflow makes API calls that get traced (e.g., via Fireworks tracing proxy)
24-
- Traces are fetched later via output_data_loader using rollout_id tags
23+
- Traces are fetched later via Fireworks tracing proxy using rollout_id tags
2524
2625
NOTE: GHA has a rate limit of 5000 requests per hour.
2726
"""
@@ -38,7 +37,6 @@ def __init__(
3837
timeout_seconds: float = 1800.0,
3938
max_find_workflow_retries: int = 5,
4039
github_token: Optional[str] = None,
41-
output_data_loader: Optional[Callable[[DataLoaderConfig], DynamicDataLoader]] = None,
4240
):
4341
self.owner = owner
4442
self.repo = repo
@@ -52,7 +50,6 @@ def __init__(
5250
self.timeout_seconds = timeout_seconds
5351
self.max_find_workflow_retries = max_find_workflow_retries
5452
self.github_token = github_token
55-
self._output_data_loader = output_data_loader or default_fireworks_output_data_loader
5653

5754
def _headers(self) -> Dict[str, str]:
5855
headers = {"Accept": "application/vnd.github+json"}
@@ -200,7 +197,7 @@ def _get_run() -> Dict[str, Any]:
200197
row.execution_metadata.rollout_duration_seconds = time.perf_counter() - start_time
201198

202199
def _update_with_trace() -> None:
203-
return update_row_with_remote_trace(row, self._output_data_loader, self.model_base_url)
200+
return update_row_with_remote_trace(row, default_fireworks_output_data_loader, self.model_base_url)
204201

205202
await asyncio.to_thread(_update_with_trace)
206203

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import time
3-
from typing import Any, Dict, List, Optional, Callable
3+
from typing import Any, Dict, List, Optional
44

55
import requests
66

@@ -26,8 +26,7 @@ class RemoteRolloutProcessor(RolloutProcessor):
2626
"""
2727
Rollout processor that triggers a remote HTTP server to perform the rollout.
2828
29-
By default, fetches traces from the Fireworks tracing proxy using rollout_id tags.
30-
You can provide a custom output_data_loader for different tracing backends.
29+
Fetches traces from the Fireworks tracing proxy using rollout_id tags.
3130
3231
See https://evalprotocol.io/tutorial/remote-rollout-processor for documentation.
3332
"""
@@ -39,7 +38,6 @@ def __init__(
3938
model_base_url: str = "https://tracing.fireworks.ai",
4039
poll_interval: float = 1.0,
4140
timeout_seconds: float = 120.0,
42-
output_data_loader: Optional[Callable[[DataLoaderConfig], DynamicDataLoader]] = None,
4341
):
4442
# Prefer constructor-provided configuration. These can be overridden via
4543
# config.kwargs at call time for backward compatibility.
@@ -52,7 +50,6 @@ def __init__(
5250
self._model_base_url = _ep_model_base_url
5351
self._poll_interval = poll_interval
5452
self._timeout_seconds = timeout_seconds
55-
self._output_data_loader = output_data_loader or default_fireworks_output_data_loader
5653
self._tracing_adapter = FireworksTracingAdapter(base_url=self._model_base_url)
5754

5855
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
@@ -188,7 +185,7 @@ def _get_status() -> Dict[str, Any]:
188185
row.execution_metadata.rollout_duration_seconds = time.perf_counter() - start_time
189186

190187
def _update_with_trace() -> None:
191-
return update_row_with_remote_trace(row, self._output_data_loader, model_base_url)
188+
return update_row_with_remote_trace(row, default_fireworks_output_data_loader, model_base_url)
192189

193190
await asyncio.to_thread(_update_with_trace) # Update row with remote trace in-place
194191
return row

tests/github_actions/test_github_actions_rollout.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,36 +12,33 @@
1212
from eval_protocol.models import EvaluationRow, InputMetadata
1313
from eval_protocol.pytest import evaluation_test
1414
from eval_protocol.pytest.github_action_rollout_processor import GithubActionRolloutProcessor
15+
import eval_protocol.pytest.github_action_rollout_processor as github_action_rollout_processor_module
1516
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig
16-
from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter
17-
from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation
17+
1818

1919
ROLLOUT_IDS = set()
2020

2121

2222
@pytest.fixture(autouse=True)
23-
def check_rollout_coverage():
24-
"""Ensure we processed all expected rollout_ids"""
23+
def check_rollout_coverage(monkeypatch):
24+
"""
25+
Ensure we attempted to fetch remote traces for each rollout.
26+
27+
This wraps the built-in default_fireworks_output_data_loader (without making it configurable)
28+
and tracks rollout_ids passed through its DataLoaderConfig.
29+
"""
2530
global ROLLOUT_IDS
2631
ROLLOUT_IDS.clear()
27-
yield
28-
29-
assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}"
3032

33+
original_loader = github_action_rollout_processor_module.default_fireworks_output_data_loader
3134

32-
def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]:
33-
global ROLLOUT_IDS # Track all rollout_ids we've seen
34-
ROLLOUT_IDS.add(config.rollout_id)
35+
def wrapped_loader(config: DataLoaderConfig) -> DynamicDataLoader:
36+
ROLLOUT_IDS.add(config.rollout_id)
37+
return original_loader(config)
3538

36-
base_url = config.model_base_url or "https://tracing.fireworks.ai"
37-
adapter = FireworksTracingAdapter(base_url=base_url)
38-
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5)
39-
40-
41-
def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
42-
return DynamicDataLoader(
43-
generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation
44-
)
39+
monkeypatch.setattr(github_action_rollout_processor_module, "default_fireworks_output_data_loader", wrapped_loader)
40+
yield
41+
assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}"
4542

4643

4744
def rows() -> List[EvaluationRow]:
@@ -68,14 +65,11 @@ def rows() -> List[EvaluationRow]:
6865
ref=os.getenv("GITHUB_REF", "main"),
6966
poll_interval=3.0, # For multi-turn, you'll likely want higher poll interval
7067
timeout_seconds=300,
71-
output_data_loader=fireworks_output_data_loader,
7268
),
7369
)
7470
async def test_github_actions_rollout(row: EvaluationRow) -> EvaluationRow:
7571
"""Test GitHub Actions rollout with worker-controlled dataset."""
76-
# Track rollout IDs for coverage check
77-
global ROLLOUT_IDS
78-
ROLLOUT_IDS.add(row.execution_metadata.rollout_id)
72+
assert row.execution_metadata.rollout_id is not None
7973

8074
# This dataset is built into github_actions/rollout_worker.py
8175
if row.messages[0].content == "What is the capital of France?":

tests/remote_server/remote_server.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,23 @@ def _worker():
3737
if not model:
3838
raise ValueError("model is required in completion_params")
3939

40+
# Convert Eval Protocol Message objects into OpenAI-compatible dicts,
41+
# excluding any None fields (Fireworks rejects extra keys even when null).
42+
messages_payload = []
43+
for m in req.messages:
44+
if hasattr(m, "dump_mdoel_for_chat_completion_request"):
45+
md = m.dump_mdoel_for_chat_completion_request() # type: ignore[attr-defined]
46+
elif hasattr(m, "model_dump"):
47+
md = m.model_dump(exclude_none=True) # type: ignore[call-arg]
48+
elif isinstance(m, dict):
49+
md = {k: v for k, v in m.items() if v is not None}
50+
else:
51+
md = {"role": getattr(m, "role", None), "content": getattr(m, "content", None)}
52+
md = {k: v for k, v in md.items() if v is not None}
53+
messages_payload.append(md)
54+
4055
# Spread all completion_params (model, temperature, max_tokens, etc.)
41-
completion_kwargs = {"messages": req.messages, **req.completion_params}
56+
completion_kwargs = {"messages": messages_payload, **req.completion_params}
4257

4358
if req.tools:
4459
completion_kwargs["tools"] = req.tools

tests/remote_server/test_remote_fireworks.py

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# AUTO SERVER STARTUP: Server is automatically started and stopped by the test
22

3-
import os
43
import subprocess
54
import socket
65
import time
@@ -13,13 +12,35 @@
1312
from eval_protocol.models import EvaluationRow, Message, EvaluateResult
1413
from eval_protocol.pytest import evaluation_test
1514
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
16-
from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter
17-
from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation
15+
import eval_protocol.pytest.remote_rollout_processor as remote_rollout_processor_module
1816
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig
1917

18+
2019
ROLLOUT_IDS = set()
2120

2221

22+
@pytest.fixture(autouse=True)
23+
def check_rollout_coverage(monkeypatch):
24+
"""
25+
Ensure we attempted to fetch remote traces for each rollout.
26+
27+
This wraps the built-in default_fireworks_output_data_loader (without making it configurable)
28+
and tracks rollout_ids passed through its DataLoaderConfig.
29+
"""
30+
global ROLLOUT_IDS
31+
ROLLOUT_IDS.clear()
32+
33+
original_loader = remote_rollout_processor_module.default_fireworks_output_data_loader
34+
35+
def wrapped_loader(config: DataLoaderConfig) -> DynamicDataLoader:
36+
ROLLOUT_IDS.add(config.rollout_id)
37+
return original_loader(config)
38+
39+
monkeypatch.setattr(remote_rollout_processor_module, "default_fireworks_output_data_loader", wrapped_loader)
40+
yield
41+
assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}"
42+
43+
2344
def find_available_port() -> int:
2445
"""Find an available port on localhost"""
2546
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
@@ -68,31 +89,6 @@ def setup_remote_server():
6889
process.wait()
6990

7091

71-
@pytest.fixture(autouse=True)
72-
def check_rollout_coverage():
73-
"""Ensure we processed all expected rollout_ids"""
74-
global ROLLOUT_IDS
75-
ROLLOUT_IDS.clear()
76-
yield
77-
78-
assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}"
79-
80-
81-
def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]:
82-
global ROLLOUT_IDS # Track all rollout_ids we've seen
83-
ROLLOUT_IDS.add(config.rollout_id)
84-
85-
base_url = config.model_base_url or "https://tracing.fireworks.ai"
86-
adapter = FireworksTracingAdapter(base_url=base_url)
87-
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=7)
88-
89-
90-
def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
91-
return DynamicDataLoader(
92-
generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation
93-
)
94-
95-
9692
def rows() -> List[EvaluationRow]:
9793
"""Generate local rows with rich input_metadata to verify it survives remote traces."""
9894
base_dataset_info = {
@@ -118,7 +114,6 @@ def rows() -> List[EvaluationRow]:
118114
rollout_processor=RemoteRolloutProcessor(
119115
remote_base_url=f"http://127.0.0.1:{SERVER_PORT}",
120116
timeout_seconds=180,
121-
output_data_loader=fireworks_output_data_loader,
122117
),
123118
)
124119
async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> EvaluationRow:
@@ -133,9 +128,6 @@ async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> Evaluat
133128
assert row.messages[0].content == "What is the capital of France?", "Row should have correct message content"
134129
assert len(row.messages) > 1, "Row should have a response. If this fails, we fellback to the original row."
135130

136-
assert row.execution_metadata.rollout_id in ROLLOUT_IDS, (
137-
f"Row rollout_id {row.execution_metadata.rollout_id} should be in tracked rollout_ids: {ROLLOUT_IDS}"
138-
)
139131
assert row.input_metadata.completion_params["model"] == "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"
140132
assert row.input_metadata.completion_params["temperature"] == 0.5, "Row should have temperature at top level"
141133

tests/remote_server/test_remote_fireworks_propagate_status.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212
from eval_protocol.models import EvaluationRow, Message, Status, EvaluateResult
1313
from eval_protocol.pytest import evaluation_test
1414
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
15-
from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter
16-
from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation
17-
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig
1815

1916

2017
def find_available_port() -> int:
@@ -67,18 +64,6 @@ def setup_remote_server():
6764
process.wait()
6865

6966

70-
def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]:
71-
base_url = config.model_base_url or "https://tracing.fireworks.ai"
72-
adapter = FireworksTracingAdapter(base_url=base_url)
73-
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=7)
74-
75-
76-
def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
77-
return DynamicDataLoader(
78-
generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation
79-
)
80-
81-
8267
def rows() -> List[EvaluationRow]:
8368
row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")])
8469
return [row]
@@ -92,7 +77,6 @@ def rows() -> List[EvaluationRow]:
9277
rollout_processor=RemoteRolloutProcessor(
9378
remote_base_url=f"http://127.0.0.1:{SERVER_PORT}",
9479
timeout_seconds=120,
95-
output_data_loader=fireworks_output_data_loader,
9680
),
9781
)
9882
async def test_remote_rollout_and_fetch_fireworks_propagate_status(row: EvaluationRow) -> EvaluationRow:

0 commit comments

Comments
 (0)