2323from eval_protocol .models import EvaluationRow , Message , Status
2424from eval_protocol .pytest import evaluation_test
2525from eval_protocol .pytest .remote_rollout_processor import RemoteRolloutProcessor
26+ from eval_protocol .adapters .fireworks_tracing import FireworksTracingAdapter
27+ from eval_protocol .utils .evaluation_row_utils import filter_longest_conversation
28+ from eval_protocol .types .remote_rollout_processor import DataLoaderConfig
2629
2730
2831def find_available_port () -> int :
@@ -75,6 +78,18 @@ def setup_remote_server():
7578 process .wait ()
7679
7780
81+ def fetch_fireworks_traces (config : DataLoaderConfig ) -> List [EvaluationRow ]:
82+ base_url = config .model_base_url or "https://tracing.fireworks.ai"
83+ adapter = FireworksTracingAdapter (base_url = base_url )
84+ return adapter .get_evaluation_rows (tags = [f"rollout_id:{ config .rollout_id } " ], max_retries = 7 )
85+
86+
87+ def fireworks_output_data_loader (config : DataLoaderConfig ) -> DynamicDataLoader :
88+ return DynamicDataLoader (
89+ generators = [lambda : fetch_fireworks_traces (config )], preprocess_fn = filter_longest_conversation
90+ )
91+
92+
7893def rows () -> List [EvaluationRow ]:
7994 row = EvaluationRow (messages = [Message (role = "user" , content = "What is the capital of France?" )])
8095 return [row ]
@@ -88,6 +103,7 @@ def rows() -> List[EvaluationRow]:
88103 rollout_processor = RemoteRolloutProcessor (
89104 remote_base_url = f"http://127.0.0.1:{ SERVER_PORT } " ,
90105 timeout_seconds = 120 ,
106+ output_data_loader = fireworks_output_data_loader ,
91107 ),
92108)
93109async def test_remote_rollout_and_fetch_fireworks_propagate_status (row : EvaluationRow ) -> EvaluationRow :
0 commit comments