Skip to content

Commit 62867fe

Browse files
author
Dylan Huang
committed
Enhance RemoteRolloutProcessor to attach rollout metadata to model_base_url
- Added a new function to attach rollout metadata as query parameters to the model_base_url. - Updated the RemoteRolloutProcessor to utilize the enhanced model_base_url when making API calls. - Modified InitRequest documentation to reflect the automatic enhancement of model_base_url with rollout metadata. - Updated tests to include the new model_base_url functionality.
1 parent 626a125 commit 62867fe

File tree

7 files changed

+105
-29
lines changed

7 files changed

+105
-29
lines changed

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 72 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,45 +11,92 @@
1111
from .types import RolloutProcessorConfig
1212

1313

14+
def _attach_metadata_to_model_base_url(model_base_url: Optional[str], metadata: RolloutMetadata) -> Optional[str]:
15+
"""
16+
Attach rollout metadata as query parameters to the model_base_url.
17+
18+
Args:
19+
model_base_url: The base URL for the model API
20+
metadata: The rollout metadata containing IDs to attach
21+
22+
Returns:
23+
The model_base_url with query parameters attached, or None if model_base_url is None
24+
"""
25+
if model_base_url is None:
26+
return None
27+
28+
# Parse existing query parameters
29+
from urllib.parse import urlparse, parse_qs, urlencode, urlunparse
30+
31+
parsed = urlparse(model_base_url)
32+
query_params = parse_qs(parsed.query)
33+
34+
# Add rollout metadata as query parameters
35+
query_params.update(
36+
{
37+
"rollout_id": [metadata.rollout_id],
38+
"invocation_id": [metadata.invocation_id],
39+
"experiment_id": [metadata.experiment_id],
40+
"run_id": [metadata.run_id],
41+
"row_id": [metadata.row_id],
42+
}
43+
)
44+
45+
# Rebuild the URL with new query parameters
46+
new_query = urlencode(query_params, doseq=True)
47+
new_parsed = parsed._replace(query=new_query)
48+
return urlunparse(new_parsed)
49+
50+
1451
class RemoteRolloutProcessor(RolloutProcessor):
1552
"""
1653
Rollout processor that triggers a remote HTTP server to perform the rollout.
1754
55+
The processor automatically attaches rollout metadata (rollout_id, invocation_id,
56+
experiment_id, run_id, row_id) as query parameters to the model_base_url when
57+
provided. This passes along rollout context to the remote server for use in
58+
LLM API calls.
59+
60+
Example:
61+
If model_base_url is "https://api.openai.com/v1" and rollout_id is "abc123",
62+
the enhanced URL will be:
63+
"https://api.openai.com/v1?rollout_id=abc123&invocation_id=def456&..."
64+
1865
See https://evalprotocol.io/tutorial/remote-rollout-processor for documentation.
1966
"""
2067

2168
def __init__(
2269
self,
2370
*,
2471
remote_base_url: Optional[str] = None,
72+
model_base_url: Optional[str] = None,
2573
poll_interval: float = 1.0,
2674
timeout_seconds: float = 120.0,
2775
output_data_loader: Callable[[str], DynamicDataLoader],
2876
):
29-
# Prefer constructor-provided configuration. These can be overridden via
30-
# config.kwargs at call time for backward compatibility.
77+
"""
78+
Initialize the remote rollout processor.
79+
80+
Args:
81+
remote_base_url: Base URL of the remote rollout server (required)
82+
model_base_url: Base URL for LLM API calls. Will be enhanced with rollout
83+
metadata as query parameters to pass along rollout context to the remote server.
84+
poll_interval: Interval in seconds between status polls
85+
timeout_seconds: Maximum time to wait for rollout completion
86+
output_data_loader: Function to load rollout results by rollout_id
87+
"""
88+
# Store configuration parameters
3189
self._remote_base_url = remote_base_url
90+
self._model_base_url = model_base_url
3291
self._poll_interval = poll_interval
3392
self._timeout_seconds = timeout_seconds
3493
self._output_data_loader = output_data_loader
3594

3695
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
3796
tasks: List[asyncio.Task[EvaluationRow]] = []
3897

39-
# Start with constructor values
40-
remote_base_url: Optional[str] = self._remote_base_url
41-
poll_interval: float = self._poll_interval
42-
timeout_seconds: float = self._timeout_seconds
43-
44-
# Backward compatibility: allow overrides via config.kwargs
45-
if config.kwargs:
46-
if remote_base_url is None:
47-
remote_base_url = config.kwargs.get("remote_base_url", remote_base_url)
48-
poll_interval = float(config.kwargs.get("poll_interval", poll_interval))
49-
timeout_seconds = float(config.kwargs.get("timeout_seconds", timeout_seconds))
50-
51-
if not remote_base_url:
52-
raise ValueError("remote_base_url is required in RolloutProcessorConfig.kwargs for RemoteRolloutProcessor")
98+
if not self._remote_base_url:
99+
raise ValueError("remote_base_url is required for RemoteRolloutProcessor")
53100

54101
async def _process_row(row: EvaluationRow) -> EvaluationRow:
55102
start_time = time.perf_counter()
@@ -107,27 +154,31 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
107154
if row.execution_metadata.rollout_id is None:
108155
raise ValueError("Rollout ID is required in RemoteRolloutProcessor")
109156

157+
# Attach rollout metadata to model_base_url as query parameters
158+
# This passes along rollout context to the remote server for use in LLM calls
159+
enhanced_model_base_url = _attach_metadata_to_model_base_url(self._model_base_url, meta)
160+
110161
init_payload: InitRequest = InitRequest(
111162
model=model,
112163
messages=clean_messages,
113164
tools=row.tools,
114165
metadata=meta,
115-
model_base_url=config.kwargs.get("model_base_url", None),
166+
model_base_url=enhanced_model_base_url,
116167
)
117168

118169
# Fire-and-poll
119170
def _post_init() -> None:
120-
url = f"{remote_base_url}/init"
171+
url = f"{self._remote_base_url}/init"
121172
r = requests.post(url, json=init_payload.model_dump(), timeout=30)
122173
r.raise_for_status()
123174

124175
await asyncio.to_thread(_post_init)
125176

126177
terminated = False
127-
deadline = time.time() + timeout_seconds
178+
deadline = time.time() + self._timeout_seconds
128179

129180
def _get_status() -> Dict[str, Any]:
130-
url = f"{remote_base_url}/status"
181+
url = f"{self._remote_base_url}/status"
131182
r = requests.get(url, params={"rollout_id": row.execution_metadata.rollout_id}, timeout=15)
132183
r.raise_for_status()
133184
return r.json()
@@ -141,7 +192,7 @@ def _get_status() -> Dict[str, Any]:
141192
except Exception:
142193
# transient errors; continue polling
143194
pass
144-
await asyncio.sleep(poll_interval)
195+
await asyncio.sleep(self._poll_interval)
145196

146197
# Update duration, regardless of termination
147198
row.execution_metadata.duration_seconds = time.perf_counter() - start_time

eval_protocol/types/remote_rollout_processor.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,16 @@ class InitRequest(BaseModel):
2828
"""
2929
A Base URL that the remote server can use to make LLM calls. This is useful
3030
to configure on the eval-protocol side for flexibility in
31-
development/traning.
31+
development/training.
32+
33+
The RemoteRolloutProcessor automatically enhances this URL by attaching
34+
rollout metadata as query parameters (rollout_id, invocation_id, experiment_id,
35+
run_id, row_id) before sending it to the remote server. This passes along
36+
rollout context to the remote server for use in LLM API calls.
37+
38+
Example:
39+
If model_base_url is "https://api.openai.com/v1", it will be enhanced to:
40+
"https://api.openai.com/v1?rollout_id=abc123&invocation_id=def456&experiment_id=ghi789&run_id=jkl012&row_id=mno345"
3241
"""
3342

3443
metadata: RolloutMetadata

tests/remote_server/remote_server.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,15 @@ def _worker():
3939
if req.tools:
4040
completion_kwargs["tools"] = req.tools
4141

42-
completion = openai.chat.completions.create(**completion_kwargs)
42+
# Use the provided model_base_url if available
43+
if req.model_base_url:
44+
print(f"Using custom model_base_url: {req.model_base_url}")
45+
# Create a new Langfuse OpenAI client with the custom base URL
46+
custom_openai = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=req.model_base_url)
47+
completion = custom_openai.chat.completions.create(**completion_kwargs)
48+
else:
49+
print("Using default OpenAI base URL")
50+
completion = openai.chat.completions.create(**completion_kwargs)
4351

4452
except Exception as e:
4553
# Best-effort; mark as done even on error to unblock polling

tests/remote_server/test_remote_langfuse.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def rows() -> List[EvaluationRow]:
6464
rollout_processor=RemoteRolloutProcessor(
6565
remote_base_url="http://127.0.0.1:3000",
6666
timeout_seconds=30,
67+
model_base_url="https://api.openai.com/v1",
6768
output_data_loader=langfuse_output_data_loader,
6869
),
6970
)

tests/remote_server/typescript-server/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"@opentelemetry/sdk-node": "^0.205.0",
2323
"cors": "^2.8.5",
2424
"dotenv": "^17.2.2",
25-
"eval-protocol": "^0.1.2",
25+
"eval-protocol": "^0.1.3",
2626
"express": "^5.1.0",
2727
"helmet": "^7.1.0",
2828
"openai": "^5.23.0"

tests/remote_server/typescript-server/pnpm-lock.yaml

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/remote_server/typescript-server/server.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,15 @@ async function simulateRolloutExecution(
147147

148148
const openai = new OpenAI({
149149
apiKey: process.env["OPENAI_API_KEY"],
150+
baseURL: initRequest.model_base_url || undefined,
150151
});
151152

153+
if (initRequest.model_base_url) {
154+
console.log(`Using custom model_base_url: ${initRequest.model_base_url}`);
155+
} else {
156+
console.log("Using default OpenAI base URL");
157+
}
158+
152159
const tracedOpenAI = observeOpenAI(openai, {
153160
tags: createLangfuseConfigTags(initRequest),
154161
});

0 commit comments

Comments
 (0)