11# AUTO SERVER STARTUP: Server is automatically started and stopped by the test
22
3- import os
43import subprocess
54import socket
65import time
1312from eval_protocol .models import EvaluationRow , Message , EvaluateResult
1413from eval_protocol .pytest import evaluation_test
1514from eval_protocol .pytest .remote_rollout_processor import RemoteRolloutProcessor
16- from eval_protocol .adapters .fireworks_tracing import FireworksTracingAdapter
17- from eval_protocol .utils .evaluation_row_utils import filter_longest_conversation
15+ import eval_protocol .pytest .remote_rollout_processor as remote_rollout_processor_module
1816from eval_protocol .types .remote_rollout_processor import DataLoaderConfig
1917
18+
2019ROLLOUT_IDS = set ()
2120
2221
22+ @pytest .fixture (autouse = True )
23+ def check_rollout_coverage (monkeypatch ):
24+ """
25+ Ensure we attempted to fetch remote traces for each rollout.
26+
27+ This wraps the built-in default_fireworks_output_data_loader (without making it configurable)
28+ and tracks rollout_ids passed through its DataLoaderConfig.
29+ """
30+ global ROLLOUT_IDS
31+ ROLLOUT_IDS .clear ()
32+
33+ original_loader = remote_rollout_processor_module .default_fireworks_output_data_loader
34+
35+ def wrapped_loader (config : DataLoaderConfig ) -> DynamicDataLoader :
36+ ROLLOUT_IDS .add (config .rollout_id )
37+ return original_loader (config )
38+
39+ monkeypatch .setattr (remote_rollout_processor_module , "default_fireworks_output_data_loader" , wrapped_loader )
40+ yield
41+ assert len (ROLLOUT_IDS ) == 3 , f"Expected to see 3 rollout_ids, but only saw { ROLLOUT_IDS } "
42+
43+
2344def find_available_port () -> int :
2445 """Find an available port on localhost"""
2546 with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
@@ -68,31 +89,6 @@ def setup_remote_server():
6889 process .wait ()
6990
7091
71- @pytest .fixture (autouse = True )
72- def check_rollout_coverage ():
73- """Ensure we processed all expected rollout_ids"""
74- global ROLLOUT_IDS
75- ROLLOUT_IDS .clear ()
76- yield
77-
78- assert len (ROLLOUT_IDS ) == 3 , f"Expected to see 3 rollout_ids, but only saw { ROLLOUT_IDS } "
79-
80-
81- def fetch_fireworks_traces (config : DataLoaderConfig ) -> List [EvaluationRow ]:
82- global ROLLOUT_IDS # Track all rollout_ids we've seen
83- ROLLOUT_IDS .add (config .rollout_id )
84-
85- base_url = config .model_base_url or "https://tracing.fireworks.ai"
86- adapter = FireworksTracingAdapter (base_url = base_url )
87- return adapter .get_evaluation_rows (tags = [f"rollout_id:{ config .rollout_id } " ], max_retries = 7 )
88-
89-
90- def fireworks_output_data_loader (config : DataLoaderConfig ) -> DynamicDataLoader :
91- return DynamicDataLoader (
92- generators = [lambda : fetch_fireworks_traces (config )], preprocess_fn = filter_longest_conversation
93- )
94-
95-
9692def rows () -> List [EvaluationRow ]:
9793 """Generate local rows with rich input_metadata to verify it survives remote traces."""
9894 base_dataset_info = {
@@ -118,7 +114,6 @@ def rows() -> List[EvaluationRow]:
118114 rollout_processor = RemoteRolloutProcessor (
119115 remote_base_url = f"http://127.0.0.1:{ SERVER_PORT } " ,
120116 timeout_seconds = 180 ,
121- output_data_loader = fireworks_output_data_loader ,
122117 ),
123118)
124119async def test_remote_rollout_and_fetch_fireworks (row : EvaluationRow ) -> EvaluationRow :
@@ -133,9 +128,6 @@ async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> Evaluat
133128 assert row .messages [0 ].content == "What is the capital of France?" , "Row should have correct message content"
134129 assert len (row .messages ) > 1 , "Row should have a response. If this fails, we fellback to the original row."
135130
136- assert row .execution_metadata .rollout_id in ROLLOUT_IDS , (
137- f"Row rollout_id { row .execution_metadata .rollout_id } should be in tracked rollout_ids: { ROLLOUT_IDS } "
138- )
139131 assert row .input_metadata .completion_params ["model" ] == "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"
140132 assert row .input_metadata .completion_params ["temperature" ] == 0.5 , "Row should have temperature at top level"
141133
0 commit comments