-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathremote_rollout_processor.py
More file actions
224 lines (181 loc) · 9.67 KB
/
remote_rollout_processor.py
File metadata and controls
224 lines (181 loc) · 9.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import asyncio
import time
from typing import Any, Dict, List, Optional, Callable
import requests
from eval_protocol.models import EvaluationRow, Status
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
from eval_protocol.types.remote_rollout_processor import InitRequest, RolloutMetadata
from .rollout_processor import RolloutProcessor
from .types import RolloutProcessorConfig
def _attach_metadata_to_model_base_url(model_base_url: Optional[str], metadata: RolloutMetadata) -> Optional[str]:
"""
Attach rollout metadata as path segments to the model_base_url.
Args:
model_base_url: The base URL for the model API
metadata: The rollout metadata containing IDs to attach
Returns:
The model_base_url with path segments attached, or None if model_base_url is None
"""
if model_base_url is None:
return None
# Parse the URL to extract components
from urllib.parse import urlparse, urlunparse
parsed = urlparse(model_base_url)
# Build the path with metadata segments
# Format: /rollout_id/{rollout_id}/invocation_id/{invocation_id}/experiment_id/{experiment_id}/run_id/{run_id}/row_id/{row_id}
metadata_path = f"/rollout_id/{metadata.rollout_id}/invocation_id/{metadata.invocation_id}/experiment_id/{metadata.experiment_id}/run_id/{metadata.run_id}/row_id/{metadata.row_id}"
# Append metadata path to existing path, ensuring proper path joining
base_path = parsed.path.rstrip("/")
new_path = f"{base_path}{metadata_path}"
# Rebuild the URL with the new path
new_parsed = parsed._replace(path=new_path)
return urlunparse(new_parsed)
class RemoteRolloutProcessor(RolloutProcessor):
"""
Rollout processor that triggers a remote HTTP server to perform the rollout.
The processor automatically attaches rollout metadata (rollout_id, invocation_id,
experiment_id, run_id, row_id) as path segments to the model_base_url when
provided. This passes along rollout context to the remote server for use in
LLM API calls.
Example:
If model_base_url is "https://api.openai.com/v1" and rollout_id is "abc123",
the enhanced URL will be:
"https://api.openai.com/v1/rollout_id/abc123/invocation_id/def456/experiment_id/ghi789/run_id/jkl012/row_id/mno345"
See https://evalprotocol.io/tutorial/remote-rollout-processor for documentation.
"""
def __init__(
self,
*,
remote_base_url: Optional[str] = None,
model_base_url: Optional[str] = None,
poll_interval: float = 1.0,
timeout_seconds: float = 120.0,
output_data_loader: Callable[[str], DynamicDataLoader],
):
"""
Initialize the remote rollout processor.
Args:
remote_base_url: Base URL of the remote rollout server (required)
model_base_url: Base URL for LLM API calls. Will be enhanced with rollout
metadata as query parameters to pass along rollout context to the remote server.
poll_interval: Interval in seconds between status polls
timeout_seconds: Maximum time to wait for rollout completion
output_data_loader: Function to load rollout results by rollout_id
"""
# Store configuration parameters
self._remote_base_url = remote_base_url
self._model_base_url = model_base_url
self._poll_interval = poll_interval
self._timeout_seconds = timeout_seconds
self._output_data_loader = output_data_loader
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
tasks: List[asyncio.Task[EvaluationRow]] = []
if not self._remote_base_url:
raise ValueError("remote_base_url is required for RemoteRolloutProcessor")
async def _process_row(row: EvaluationRow) -> EvaluationRow:
start_time = time.perf_counter()
if row.execution_metadata.invocation_id is None:
raise ValueError("Invocation ID is required in RemoteRolloutProcessor")
if row.execution_metadata.experiment_id is None:
raise ValueError("Experiment ID is required in RemoteRolloutProcessor")
if row.execution_metadata.rollout_id is None:
raise ValueError("Rollout ID is required in RemoteRolloutProcessor")
if row.execution_metadata.run_id is None:
raise ValueError("Run ID is required in RemoteRolloutProcessor")
if row.input_metadata.row_id is None:
raise ValueError("Row ID is required in RemoteRolloutProcessor")
# Build request metadata and payload
meta: RolloutMetadata = RolloutMetadata(
invocation_id=row.execution_metadata.invocation_id,
experiment_id=row.execution_metadata.experiment_id,
rollout_id=row.execution_metadata.rollout_id,
run_id=row.execution_metadata.run_id,
row_id=row.input_metadata.row_id,
)
model: Optional[str] = None
if row.input_metadata and row.input_metadata.completion_params:
model = row.input_metadata.completion_params.get("model")
if model is None and config.completion_params:
model = config.completion_params.get("model")
if model is None:
raise ValueError(
"Model must be provided in row.input_metadata.completion_params or config.completion_params"
)
# Strip non-OpenAI fields from messages before sending to remote
allowed_message_fields = {"role", "content", "tool_calls", "tool_call_id", "name"}
clean_messages = []
for m in row.messages:
md: Dict[str, Any]
if hasattr(m, "model_dump"):
md = m.model_dump() # type: ignore[assignment]
elif isinstance(m, dict):
md = m # type: ignore[assignment]
else:
# Fallback to constructing a dict from Message-like object
md = {
"role": getattr(m, "role", None),
"content": getattr(m, "content", None),
"tool_calls": getattr(m, "tool_calls", None),
"tool_call_id": getattr(m, "tool_call_id", None),
"name": getattr(m, "name", None),
}
clean_messages.append({k: v for k, v in md.items() if k in allowed_message_fields and v is not None})
if row.execution_metadata.rollout_id is None:
raise ValueError("Rollout ID is required in RemoteRolloutProcessor")
# Attach rollout metadata to model_base_url as query parameters
# This passes along rollout context to the remote server for use in LLM calls
enhanced_model_base_url = _attach_metadata_to_model_base_url(self._model_base_url, meta)
init_payload: InitRequest = InitRequest(
model=model,
messages=clean_messages,
tools=row.tools,
metadata=meta,
model_base_url=enhanced_model_base_url,
)
# Fire-and-poll
def _post_init() -> None:
url = f"{self._remote_base_url}/init"
r = requests.post(url, json=init_payload.model_dump(), timeout=30)
r.raise_for_status()
await asyncio.to_thread(_post_init)
terminated = False
deadline = time.time() + self._timeout_seconds
def _get_status() -> Dict[str, Any]:
url = f"{self._remote_base_url}/status"
r = requests.get(url, params={"rollout_id": row.execution_metadata.rollout_id}, timeout=15)
r.raise_for_status()
return r.json()
while time.time() < deadline:
try:
status = await asyncio.to_thread(_get_status)
terminated = bool(status.get("terminated", False))
if terminated:
break
except Exception:
# transient errors; continue polling
pass
await asyncio.sleep(self._poll_interval)
# Update duration, regardless of termination
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
if row.execution_metadata.rollout_id is None:
raise ValueError("Rollout ID is required in RemoteRolloutProcessor")
data_loader = self._output_data_loader(row.execution_metadata.rollout_id)
def _load_data():
return data_loader.load()
results = await asyncio.to_thread(_load_data)
output_rows: List[EvaluationRow] = [row for result in results for row in result.rows]
if len(output_rows) == 0: # Fallback to original row if no Langfuse data found
row.rollout_status = Status(code=Status.Code.NOT_FOUND, message="No Langfuse data found for rollout")
return row
elif len(output_rows) == 1: # Return the Langfuse row
langfuse_row = output_rows[0]
langfuse_row.input_metadata.completion_params = row.input_metadata.completion_params
langfuse_row.eval_metadata = row.eval_metadata
return langfuse_row
else:
raise ValueError("RemoteRolloutProcessor's output_data_loader should return exactly one row.")
for r in rows:
tasks.append(asyncio.create_task(_process_row(r)))
return tasks
def cleanup(self) -> None:
return None