1111from .types import RolloutProcessorConfig
1212
1313
14+ def _attach_metadata_to_model_base_url (model_base_url : Optional [str ], metadata : RolloutMetadata ) -> Optional [str ]:
15+ """
16+ Attach rollout metadata as query parameters to the model_base_url.
17+
18+ Args:
19+ model_base_url: The base URL for the model API
20+ metadata: The rollout metadata containing IDs to attach
21+
22+ Returns:
23+ The model_base_url with query parameters attached, or None if model_base_url is None
24+ """
25+ if model_base_url is None :
26+ return None
27+
28+ # Parse existing query parameters
29+ from urllib .parse import urlparse , parse_qs , urlencode , urlunparse
30+
31+ parsed = urlparse (model_base_url )
32+ query_params = parse_qs (parsed .query )
33+
34+ # Add rollout metadata as query parameters
35+ query_params .update (
36+ {
37+ "rollout_id" : [metadata .rollout_id ],
38+ "invocation_id" : [metadata .invocation_id ],
39+ "experiment_id" : [metadata .experiment_id ],
40+ "run_id" : [metadata .run_id ],
41+ "row_id" : [metadata .row_id ],
42+ }
43+ )
44+
45+ # Rebuild the URL with new query parameters
46+ new_query = urlencode (query_params , doseq = True )
47+ new_parsed = parsed ._replace (query = new_query )
48+ return urlunparse (new_parsed )
49+
50+
1451class RemoteRolloutProcessor (RolloutProcessor ):
1552 """
1653 Rollout processor that triggers a remote HTTP server to perform the rollout.
1754
55+ The processor automatically attaches rollout metadata (rollout_id, invocation_id,
56+ experiment_id, run_id, row_id) as query parameters to the model_base_url when
57+ provided. This passes along rollout context to the remote server for use in
58+ LLM API calls.
59+
60+ Example:
61+ If model_base_url is "https://api.openai.com/v1" and rollout_id is "abc123",
62+ the enhanced URL will be:
63+ "https://api.openai.com/v1?rollout_id=abc123&invocation_id=def456&..."
64+
1865 See https://evalprotocol.io/tutorial/remote-rollout-processor for documentation.
1966 """
2067
2168 def __init__ (
2269 self ,
2370 * ,
2471 remote_base_url : Optional [str ] = None ,
72+ model_base_url : Optional [str ] = None ,
2573 poll_interval : float = 1.0 ,
2674 timeout_seconds : float = 120.0 ,
2775 output_data_loader : Callable [[str ], DynamicDataLoader ],
2876 ):
29- # Prefer constructor-provided configuration. These can be overridden via
30- # config.kwargs at call time for backward compatibility.
77+ """
78+ Initialize the remote rollout processor.
79+
80+ Args:
81+ remote_base_url: Base URL of the remote rollout server (required)
82+ model_base_url: Base URL for LLM API calls. Will be enhanced with rollout
83+ metadata as query parameters to pass along rollout context to the remote server.
84+ poll_interval: Interval in seconds between status polls
85+ timeout_seconds: Maximum time to wait for rollout completion
86+ output_data_loader: Function to load rollout results by rollout_id
87+ """
88+ # Store configuration parameters
3189 self ._remote_base_url = remote_base_url
90+ self ._model_base_url = model_base_url
3291 self ._poll_interval = poll_interval
3392 self ._timeout_seconds = timeout_seconds
3493 self ._output_data_loader = output_data_loader
3594
3695 def __call__ (self , rows : List [EvaluationRow ], config : RolloutProcessorConfig ) -> List [asyncio .Task [EvaluationRow ]]:
3796 tasks : List [asyncio .Task [EvaluationRow ]] = []
3897
39- # Start with constructor values
40- remote_base_url : Optional [str ] = self ._remote_base_url
41- poll_interval : float = self ._poll_interval
42- timeout_seconds : float = self ._timeout_seconds
43-
44- # Backward compatibility: allow overrides via config.kwargs
45- if config .kwargs :
46- if remote_base_url is None :
47- remote_base_url = config .kwargs .get ("remote_base_url" , remote_base_url )
48- poll_interval = float (config .kwargs .get ("poll_interval" , poll_interval ))
49- timeout_seconds = float (config .kwargs .get ("timeout_seconds" , timeout_seconds ))
50-
51- if not remote_base_url :
52- raise ValueError ("remote_base_url is required in RolloutProcessorConfig.kwargs for RemoteRolloutProcessor" )
98+ if not self ._remote_base_url :
99+ raise ValueError ("remote_base_url is required for RemoteRolloutProcessor" )
53100
54101 async def _process_row (row : EvaluationRow ) -> EvaluationRow :
55102 start_time = time .perf_counter ()
@@ -107,27 +154,31 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
107154 if row .execution_metadata .rollout_id is None :
108155 raise ValueError ("Rollout ID is required in RemoteRolloutProcessor" )
109156
157+ # Attach rollout metadata to model_base_url as query parameters
158+ # This passes along rollout context to the remote server for use in LLM calls
159+ enhanced_model_base_url = _attach_metadata_to_model_base_url (self ._model_base_url , meta )
160+
110161 init_payload : InitRequest = InitRequest (
111162 model = model ,
112163 messages = clean_messages ,
113164 tools = row .tools ,
114165 metadata = meta ,
115- model_base_url = config . kwargs . get ( "model_base_url" , None ) ,
166+ model_base_url = enhanced_model_base_url ,
116167 )
117168
118169 # Fire-and-poll
119170 def _post_init () -> None :
120- url = f"{ remote_base_url } /init"
171+ url = f"{ self . _remote_base_url } /init"
121172 r = requests .post (url , json = init_payload .model_dump (), timeout = 30 )
122173 r .raise_for_status ()
123174
124175 await asyncio .to_thread (_post_init )
125176
126177 terminated = False
127- deadline = time .time () + timeout_seconds
178+ deadline = time .time () + self . _timeout_seconds
128179
129180 def _get_status () -> Dict [str , Any ]:
130- url = f"{ remote_base_url } /status"
181+ url = f"{ self . _remote_base_url } /status"
131182 r = requests .get (url , params = {"rollout_id" : row .execution_metadata .rollout_id }, timeout = 15 )
132183 r .raise_for_status ()
133184 return r .json ()
@@ -141,7 +192,7 @@ def _get_status() -> Dict[str, Any]:
141192 except Exception :
142193 # transient errors; continue polling
143194 pass
144- await asyncio .sleep (poll_interval )
195+ await asyncio .sleep (self . _poll_interval )
145196
146197 # Update duration, regardless of termination
147198 row .execution_metadata .duration_seconds = time .perf_counter () - start_time
0 commit comments