66
77from eval_protocol .models import EvaluationRow , Status
88from eval_protocol .data_loader .dynamic_data_loader import DynamicDataLoader
9+ from eval_protocol .types .remote_rollout_processor import InitRequest , RolloutMetadata
910from .rollout_processor import RolloutProcessor
1011from .types import RolloutProcessorConfig
1112import os
@@ -15,31 +16,14 @@ class RemoteRolloutProcessor(RolloutProcessor):
1516 """
1617 Rollout processor that triggers a remote HTTP server to perform the rollout.
1718
18- Expected remote API:
19- - POST {remote_base_url}/init
20- Body: {
21- "rollout_id": str,
22- "model": str,
23- "messages": list[dict],
24- "tools": list[dict] | null,
25- "metadata": {
26- "invocation_id": str,
27- "experiment_id": str,
28- "rollout_id": str,
29- "run_id": str | null,
30- "row_id": str | null
31- },
32- }
33- Returns: {"ok": true}
34-
35- - GET {remote_base_url}/status?rollout_id=...
36- Returns: {"terminated": bool, "info": {...}?}
19+ See https://evalprotocol.io/tutorial/remote-rollout-processor for documentation.
3720 """
3821
3922 def __init__ (
4023 self ,
4124 * ,
4225 remote_base_url : Optional [str ] = None ,
26+ model_base_url : Optional [str ] = None ,
4327 poll_interval : float = 1.0 ,
4428 timeout_seconds : float = 120.0 ,
4529 output_data_loader : Callable [[str ], DynamicDataLoader ],
@@ -58,6 +42,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
5842
5943 # Start with constructor values
6044 remote_base_url : Optional [str ] = self ._remote_base_url
45+ model_base_url : Optional [str ] = self ._model_base_url
6146 poll_interval : float = self ._poll_interval
6247 timeout_seconds : float = self ._timeout_seconds
6348
@@ -74,14 +59,25 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
7459 async def _process_row (row : EvaluationRow ) -> EvaluationRow :
7560 start_time = time .perf_counter ()
7661
62+ if row .execution_metadata .invocation_id is None :
63+ raise ValueError ("Invocation ID is required in RemoteRolloutProcessor" )
64+ if row .execution_metadata .experiment_id is None :
65+ raise ValueError ("Experiment ID is required in RemoteRolloutProcessor" )
66+ if row .execution_metadata .rollout_id is None :
67+ raise ValueError ("Rollout ID is required in RemoteRolloutProcessor" )
68+ if row .execution_metadata .run_id is None :
69+ raise ValueError ("Run ID is required in RemoteRolloutProcessor" )
70+ if row .input_metadata .row_id is None :
71+ raise ValueError ("Row ID is required in RemoteRolloutProcessor" )
72+
7773 # Build request metadata and payload
78- meta : Dict [ str , Any ] = {
79- " invocation_id" : row .execution_metadata .invocation_id ,
80- " experiment_id" : row .execution_metadata .experiment_id ,
81- " rollout_id" : row .execution_metadata .rollout_id ,
82- " run_id" : row .execution_metadata .run_id ,
83- " row_id" : row .input_metadata .row_id ,
84- }
74+ meta : RolloutMetadata = RolloutMetadata (
75+ invocation_id = row .execution_metadata .invocation_id ,
76+ experiment_id = row .execution_metadata .experiment_id ,
77+ rollout_id = row .execution_metadata .rollout_id ,
78+ run_id = row .execution_metadata .run_id ,
79+ row_id = row .input_metadata .row_id ,
80+ )
8581
8682 model : Optional [str ] = None
8783 if row .input_metadata and row .input_metadata .completion_params :
@@ -113,19 +109,33 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
113109 }
114110 clean_messages .append ({k : v for k , v in md .items () if k in allowed_message_fields and v is not None })
115111
116- init_payload : Dict [str , Any ] = {
117- "rollout_id" : row .execution_metadata .rollout_id ,
118- "model" : model ,
119- "messages" : clean_messages ,
120- "tools" : row .tools ,
121- "metadata" : meta ,
122- }
112+ if row .execution_metadata .rollout_id is None :
113+ raise ValueError ("Rollout ID is required in RemoteRolloutProcessor" )
114+
115+ init_payload : InitRequest = InitRequest (
116+ model = model ,
117+ messages = clean_messages ,
118+ tools = row .tools ,
119+ metadata = meta ,
120+ model_base_url = model_base_url ,
121+ )
123122
124123 # Fire-and-poll
125124 def _post_init () -> None :
126125 url = f"{ remote_base_url } /init"
127- r = requests .post (url , json = init_payload , timeout = 30 )
128- r .raise_for_status ()
126+ try :
127+ r = requests .post (url , json = init_payload .model_dump (), timeout = 30 )
128+ r .raise_for_status ()
129+ except requests .exceptions .Timeout :
130+ raise TimeoutError (
131+ "The /init endpoint timed out after 30 seconds. "
132+ "CRITICAL: The /init endpoint must return immediately (within 30s) and NOT block on rollout execution. "
133+ "Your remote server should:\n "
134+ "1. Accept the /init request and return a 200 response immediately\n "
135+ "2. Process the actual rollout asynchronously in the background\n "
136+ "3. Use the /status endpoint to report progress\n "
137+ "For Python/Node.js: Start a separate process per rollout to avoid blocking the /init response."
138+ )
129139
130140 await asyncio .to_thread (_post_init )
131141
@@ -147,7 +157,13 @@ def _get_status() -> Dict[str, Any]:
147157 except Exception :
148158 # transient errors; continue polling
149159 pass
160+
150161 await asyncio .sleep (poll_interval )
162+ else :
163+ # Loop completed without breaking, which means we timed out
164+ row .rollout_status = Status .rollout_error (
165+ f"Rollout { row .execution_metadata .rollout_id } timed out after { timeout_seconds } seconds"
166+ )
151167
152168 # Update duration, regardless of termination
153169 row .execution_metadata .duration_seconds = time .perf_counter () - start_time
@@ -170,14 +186,28 @@ def _load_data():
170186 elif len (output_rows ) == 1 : # Return the Langfuse row
171187 langfuse_row = output_rows [0 ]
172188 langfuse_row .input_metadata .completion_params = row .input_metadata .completion_params
189+ # merge dataset_info dicts on input_metadata
190+ if langfuse_row .input_metadata .dataset_info and row .input_metadata .dataset_info :
191+ langfuse_row .input_metadata .dataset_info = {
192+ ** row .input_metadata .dataset_info ,
193+ ** langfuse_row .input_metadata .dataset_info ,
194+ }
195+ elif row .input_metadata .dataset_info :
196+ langfuse_row .input_metadata .dataset_info = row .input_metadata .dataset_info
173197 langfuse_row .eval_metadata = row .eval_metadata
198+ langfuse_row .ground_truth = row .ground_truth
174199 return langfuse_row
175200 else :
176201 raise ValueError ("RemoteRolloutProcessor's output_data_loader should return exactly one row." )
177202
178- for r in rows :
179- tasks .append (asyncio .create_task (_process_row (r )))
203+ semaphore = config .semaphore
204+
205+ async def _sem_wrapper (r : EvaluationRow ) -> EvaluationRow :
206+ async with semaphore :
207+ result = await _process_row (r )
208+ return result
180209
210+ tasks = [asyncio .create_task (_sem_wrapper (row )) for row in rows ]
181211 return tasks
182212
183213 def cleanup (self ) -> None :
0 commit comments