diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index 13ca5422..58c67b6d 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -11,10 +11,52 @@ from .types import RolloutProcessorConfig +def _attach_metadata_to_model_base_url(model_base_url: Optional[str], metadata: RolloutMetadata) -> Optional[str]: + """ + Attach rollout metadata as path segments to the model_base_url. + + Args: + model_base_url: The base URL for the model API + metadata: The rollout metadata containing IDs to attach + + Returns: + The model_base_url with path segments attached, or None if model_base_url is None + """ + if model_base_url is None: + return None + + # Parse the URL to extract components + from urllib.parse import urlparse, urlunparse + + parsed = urlparse(model_base_url) + + # Build the path with metadata segments + # Format: /rollout_id/{rollout_id}/invocation_id/{invocation_id}/experiment_id/{experiment_id}/run_id/{run_id}/row_id/{row_id} + metadata_path = f"/rollout_id/{metadata.rollout_id}/invocation_id/{metadata.invocation_id}/experiment_id/{metadata.experiment_id}/run_id/{metadata.run_id}/row_id/{metadata.row_id}" + + # Append metadata path to existing path, ensuring proper path joining + base_path = parsed.path.rstrip("/") + new_path = f"{base_path}{metadata_path}" + + # Rebuild the URL with the new path + new_parsed = parsed._replace(path=new_path) + return urlunparse(new_parsed) + + class RemoteRolloutProcessor(RolloutProcessor): """ Rollout processor that triggers a remote HTTP server to perform the rollout. + The processor automatically attaches rollout metadata (rollout_id, invocation_id, + experiment_id, run_id, row_id) as path segments to the model_base_url when + provided. This passes along rollout context to the remote server for use in + LLM API calls. + + Example: + If model_base_url is "https://api.openai.com/v1" and rollout_id is "abc123", + the enhanced URL will be: + "https://api.openai.com/v1/rollout_id/abc123/invocation_id/def456/experiment_id/ghi789/run_id/jkl012/row_id/mno345" + See https://evalprotocol.io/tutorial/remote-rollout-processor for documentation. """ @@ -22,13 +64,25 @@ def __init__( self, *, remote_base_url: Optional[str] = None, + model_base_url: Optional[str] = None, poll_interval: float = 1.0, timeout_seconds: float = 120.0, output_data_loader: Callable[[str], DynamicDataLoader], ): - # Prefer constructor-provided configuration. These can be overridden via - # config.kwargs at call time for backward compatibility. + """ + Initialize the remote rollout processor. + + Args: + remote_base_url: Base URL of the remote rollout server (required) + model_base_url: Base URL for LLM API calls. Will be enhanced with rollout + metadata as query parameters to pass along rollout context to the remote server. + poll_interval: Interval in seconds between status polls + timeout_seconds: Maximum time to wait for rollout completion + output_data_loader: Function to load rollout results by rollout_id + """ + # Store configuration parameters self._remote_base_url = remote_base_url + self._model_base_url = model_base_url self._poll_interval = poll_interval self._timeout_seconds = timeout_seconds self._output_data_loader = output_data_loader @@ -36,20 +90,8 @@ def __init__( def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: tasks: List[asyncio.Task[EvaluationRow]] = [] - # Start with constructor values - remote_base_url: Optional[str] = self._remote_base_url - poll_interval: float = self._poll_interval - timeout_seconds: float = self._timeout_seconds - - # Backward compatibility: allow overrides via config.kwargs - if config.kwargs: - if remote_base_url is None: - remote_base_url = config.kwargs.get("remote_base_url", remote_base_url) - poll_interval = float(config.kwargs.get("poll_interval", poll_interval)) - timeout_seconds = float(config.kwargs.get("timeout_seconds", timeout_seconds)) - - if not remote_base_url: - raise ValueError("remote_base_url is required in RolloutProcessorConfig.kwargs for RemoteRolloutProcessor") + if not self._remote_base_url: + raise ValueError("remote_base_url is required for RemoteRolloutProcessor") async def _process_row(row: EvaluationRow) -> EvaluationRow: start_time = time.perf_counter() @@ -107,27 +149,31 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow: if row.execution_metadata.rollout_id is None: raise ValueError("Rollout ID is required in RemoteRolloutProcessor") + # Attach rollout metadata to model_base_url as query parameters + # This passes along rollout context to the remote server for use in LLM calls + enhanced_model_base_url = _attach_metadata_to_model_base_url(self._model_base_url, meta) + init_payload: InitRequest = InitRequest( model=model, messages=clean_messages, tools=row.tools, metadata=meta, - model_base_url=config.kwargs.get("model_base_url", None), + model_base_url=enhanced_model_base_url, ) # Fire-and-poll def _post_init() -> None: - url = f"{remote_base_url}/init" + url = f"{self._remote_base_url}/init" r = requests.post(url, json=init_payload.model_dump(), timeout=30) r.raise_for_status() await asyncio.to_thread(_post_init) terminated = False - deadline = time.time() + timeout_seconds + deadline = time.time() + self._timeout_seconds def _get_status() -> Dict[str, Any]: - url = f"{remote_base_url}/status" + url = f"{self._remote_base_url}/status" r = requests.get(url, params={"rollout_id": row.execution_metadata.rollout_id}, timeout=15) r.raise_for_status() return r.json() @@ -141,7 +187,7 @@ def _get_status() -> Dict[str, Any]: except Exception: # transient errors; continue polling pass - await asyncio.sleep(poll_interval) + await asyncio.sleep(self._poll_interval) # Update duration, regardless of termination row.execution_metadata.duration_seconds = time.perf_counter() - start_time diff --git a/eval_protocol/types/remote_rollout_processor.py b/eval_protocol/types/remote_rollout_processor.py index 89967729..946fb1a4 100644 --- a/eval_protocol/types/remote_rollout_processor.py +++ b/eval_protocol/types/remote_rollout_processor.py @@ -28,7 +28,16 @@ class InitRequest(BaseModel): """ A Base URL that the remote server can use to make LLM calls. This is useful to configure on the eval-protocol side for flexibility in - development/traning. + development/training. + + The RemoteRolloutProcessor automatically enhances this URL by attaching + rollout metadata as query parameters (rollout_id, invocation_id, experiment_id, + run_id, row_id) before sending it to the remote server. This passes along + rollout context to the remote server for use in LLM API calls. + + Example: + If model_base_url is "https://api.openai.com/v1", it will be enhanced to: + "https://api.openai.com/v1/rollout_id/abc123/invocation_id/def456/experiment_id/ghi789/run_id/jkl012/row_id/mno345/chat/completions" """ metadata: RolloutMetadata diff --git a/tests/remote_server/remote_server.py b/tests/remote_server/remote_server.py index ea831f51..6c4cb21a 100644 --- a/tests/remote_server/remote_server.py +++ b/tests/remote_server/remote_server.py @@ -39,7 +39,16 @@ def _worker(): if req.tools: completion_kwargs["tools"] = req.tools - completion = openai.chat.completions.create(**completion_kwargs) + # Use the provided model_base_url if available + if req.model_base_url: + print(f"Using custom model_base_url: {req.model_base_url}") + # Create a new Langfuse OpenAI client with the custom base URL + # The URL already contains the metadata as path segments, so we can use it directly + custom_openai = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=req.model_base_url) + completion = custom_openai.chat.completions.create(**completion_kwargs) + else: + print("Using default OpenAI base URL") + completion = openai.chat.completions.create(**completion_kwargs) except Exception as e: # Best-effort; mark as done even on error to unblock polling diff --git a/tests/remote_server/test_remote_langfuse.py b/tests/remote_server/test_remote_langfuse.py index 78cde359..e657498c 100644 --- a/tests/remote_server/test_remote_langfuse.py +++ b/tests/remote_server/test_remote_langfuse.py @@ -64,6 +64,7 @@ def rows() -> List[EvaluationRow]: rollout_processor=RemoteRolloutProcessor( remote_base_url="http://127.0.0.1:3000", timeout_seconds=30, + model_base_url="https://api.openai.com/v1", output_data_loader=langfuse_output_data_loader, ), ) diff --git a/tests/remote_server/typescript-server/package.json b/tests/remote_server/typescript-server/package.json index 7e64fee5..a6353dce 100644 --- a/tests/remote_server/typescript-server/package.json +++ b/tests/remote_server/typescript-server/package.json @@ -22,7 +22,7 @@ "@opentelemetry/sdk-node": "^0.205.0", "cors": "^2.8.5", "dotenv": "^17.2.2", - "eval-protocol": "^0.1.2", + "eval-protocol": "^0.1.3", "express": "^5.1.0", "helmet": "^7.1.0", "openai": "^5.23.0" diff --git a/tests/remote_server/typescript-server/pnpm-lock.yaml b/tests/remote_server/typescript-server/pnpm-lock.yaml index a105871d..9a34485b 100644 --- a/tests/remote_server/typescript-server/pnpm-lock.yaml +++ b/tests/remote_server/typescript-server/pnpm-lock.yaml @@ -27,8 +27,8 @@ importers: specifier: ^17.2.2 version: 17.2.2 eval-protocol: - specifier: ^0.1.2 - version: 0.1.2(typescript@5.9.2) + specifier: ^0.1.3 + version: 0.1.3(typescript@5.9.2) express: specifier: ^5.1.0 version: 5.1.0 @@ -608,8 +608,8 @@ packages: resolution: {integrity: sha512-aIL5Fx7mawVa300al2BnEE4iNvo1qETxLrPI/o05L7z6go7fCw1J6EQmbK4FmJ2AS7kgVF/KEZWufBfdClMcPg==} engines: {node: '>= 0.6'} - eval-protocol@0.1.2: - resolution: {integrity: sha512-YmEjRUy/MnYPudZpsCRzbQrBD3ZAKlK+jb+E5RklkKz7eDTLvhGY63Ynn5OoKcNW0+o9j9eV7SSHRVye6Sjbaw==} + eval-protocol@0.1.3: + resolution: {integrity: sha512-Mq+4c9cAJSC+ScO3xqko9WgLsZS9BG+p49wokgL6t/VUOS0o65RCOVZHelOKxcHNo3nlpUwBBA60kPtg72RJzw==} peerDependencies: typescript: ^5 @@ -1472,7 +1472,7 @@ snapshots: etag@1.8.1: {} - eval-protocol@0.1.2(typescript@5.9.2): + eval-protocol@0.1.3(typescript@5.9.2): dependencies: typescript: 5.9.2 zod: 4.1.11 diff --git a/tests/remote_server/typescript-server/server.ts b/tests/remote_server/typescript-server/server.ts index 95362cd0..893e341b 100644 --- a/tests/remote_server/typescript-server/server.ts +++ b/tests/remote_server/typescript-server/server.ts @@ -147,8 +147,15 @@ async function simulateRolloutExecution( const openai = new OpenAI({ apiKey: process.env["OPENAI_API_KEY"], + baseURL: initRequest.model_base_url || undefined, }); + if (initRequest.model_base_url) { + console.log(`Using custom model_base_url: ${initRequest.model_base_url}`); + } else { + console.log("Using default OpenAI base URL"); + } + const tracedOpenAI = observeOpenAI(openai, { tags: createLangfuseConfigTags(initRequest), });