Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 67 additions & 21 deletions eval_protocol/pytest/remote_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,45 +11,87 @@
from .types import RolloutProcessorConfig


def _attach_metadata_to_model_base_url(model_base_url: Optional[str], metadata: RolloutMetadata) -> Optional[str]:
"""
Attach rollout metadata as path segments to the model_base_url.

Args:
model_base_url: The base URL for the model API
metadata: The rollout metadata containing IDs to attach

Returns:
The model_base_url with path segments attached, or None if model_base_url is None
"""
if model_base_url is None:
return None

# Parse the URL to extract components
from urllib.parse import urlparse, urlunparse

parsed = urlparse(model_base_url)

# Build the path with metadata segments
# Format: /rollout_id/{rollout_id}/invocation_id/{invocation_id}/experiment_id/{experiment_id}/run_id/{run_id}/row_id/{row_id}
metadata_path = f"/rollout_id/{metadata.rollout_id}/invocation_id/{metadata.invocation_id}/experiment_id/{metadata.experiment_id}/run_id/{metadata.run_id}/row_id/{metadata.row_id}"

# Append metadata path to existing path, ensuring proper path joining
base_path = parsed.path.rstrip("/")
new_path = f"{base_path}{metadata_path}"

# Rebuild the URL with the new path
new_parsed = parsed._replace(path=new_path)
return urlunparse(new_parsed)


class RemoteRolloutProcessor(RolloutProcessor):
"""
Rollout processor that triggers a remote HTTP server to perform the rollout.

The processor automatically attaches rollout metadata (rollout_id, invocation_id,
experiment_id, run_id, row_id) as path segments to the model_base_url when
provided. This passes along rollout context to the remote server for use in
LLM API calls.

Example:
If model_base_url is "https://api.openai.com/v1" and rollout_id is "abc123",
the enhanced URL will be:
"https://api.openai.com/v1/rollout_id/abc123/invocation_id/def456/experiment_id/ghi789/run_id/jkl012/row_id/mno345"

See https://evalprotocol.io/tutorial/remote-rollout-processor for documentation.
"""

def __init__(
self,
*,
remote_base_url: Optional[str] = None,
model_base_url: Optional[str] = None,
poll_interval: float = 1.0,
timeout_seconds: float = 120.0,
output_data_loader: Callable[[str], DynamicDataLoader],
):
# Prefer constructor-provided configuration. These can be overridden via
# config.kwargs at call time for backward compatibility.
"""
Initialize the remote rollout processor.

Args:
remote_base_url: Base URL of the remote rollout server (required)
model_base_url: Base URL for LLM API calls. Will be enhanced with rollout
metadata as query parameters to pass along rollout context to the remote server.
poll_interval: Interval in seconds between status polls
timeout_seconds: Maximum time to wait for rollout completion
output_data_loader: Function to load rollout results by rollout_id
"""
# Store configuration parameters
self._remote_base_url = remote_base_url
self._model_base_url = model_base_url
self._poll_interval = poll_interval
self._timeout_seconds = timeout_seconds
self._output_data_loader = output_data_loader

def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
tasks: List[asyncio.Task[EvaluationRow]] = []

# Start with constructor values
remote_base_url: Optional[str] = self._remote_base_url
poll_interval: float = self._poll_interval
timeout_seconds: float = self._timeout_seconds

# Backward compatibility: allow overrides via config.kwargs
if config.kwargs:
if remote_base_url is None:
remote_base_url = config.kwargs.get("remote_base_url", remote_base_url)
poll_interval = float(config.kwargs.get("poll_interval", poll_interval))
timeout_seconds = float(config.kwargs.get("timeout_seconds", timeout_seconds))

if not remote_base_url:
raise ValueError("remote_base_url is required in RolloutProcessorConfig.kwargs for RemoteRolloutProcessor")
if not self._remote_base_url:
raise ValueError("remote_base_url is required for RemoteRolloutProcessor")

async def _process_row(row: EvaluationRow) -> EvaluationRow:
start_time = time.perf_counter()
Expand Down Expand Up @@ -107,27 +149,31 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
if row.execution_metadata.rollout_id is None:
raise ValueError("Rollout ID is required in RemoteRolloutProcessor")

# Attach rollout metadata to model_base_url as query parameters
# This passes along rollout context to the remote server for use in LLM calls
enhanced_model_base_url = _attach_metadata_to_model_base_url(self._model_base_url, meta)

init_payload: InitRequest = InitRequest(
model=model,
messages=clean_messages,
tools=row.tools,
metadata=meta,
model_base_url=config.kwargs.get("model_base_url", None),
model_base_url=enhanced_model_base_url,
)

# Fire-and-poll
def _post_init() -> None:
url = f"{remote_base_url}/init"
url = f"{self._remote_base_url}/init"
r = requests.post(url, json=init_payload.model_dump(), timeout=30)
r.raise_for_status()

await asyncio.to_thread(_post_init)

terminated = False
deadline = time.time() + timeout_seconds
deadline = time.time() + self._timeout_seconds

def _get_status() -> Dict[str, Any]:
url = f"{remote_base_url}/status"
url = f"{self._remote_base_url}/status"
r = requests.get(url, params={"rollout_id": row.execution_metadata.rollout_id}, timeout=15)
r.raise_for_status()
return r.json()
Expand All @@ -141,7 +187,7 @@ def _get_status() -> Dict[str, Any]:
except Exception:
# transient errors; continue polling
pass
await asyncio.sleep(poll_interval)
await asyncio.sleep(self._poll_interval)

# Update duration, regardless of termination
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
Expand Down
11 changes: 10 additions & 1 deletion eval_protocol/types/remote_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,16 @@ class InitRequest(BaseModel):
"""
A Base URL that the remote server can use to make LLM calls. This is useful
to configure on the eval-protocol side for flexibility in
development/traning.
development/training.

The RemoteRolloutProcessor automatically enhances this URL by attaching
rollout metadata as query parameters (rollout_id, invocation_id, experiment_id,
run_id, row_id) before sending it to the remote server. This passes along
rollout context to the remote server for use in LLM API calls.

Example:
If model_base_url is "https://api.openai.com/v1", it will be enhanced to:
"https://api.openai.com/v1/rollout_id/abc123/invocation_id/def456/experiment_id/ghi789/run_id/jkl012/row_id/mno345/chat/completions"
"""

metadata: RolloutMetadata
Expand Down
11 changes: 10 additions & 1 deletion tests/remote_server/remote_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,16 @@ def _worker():
if req.tools:
completion_kwargs["tools"] = req.tools

completion = openai.chat.completions.create(**completion_kwargs)
# Use the provided model_base_url if available
if req.model_base_url:
print(f"Using custom model_base_url: {req.model_base_url}")
# Create a new Langfuse OpenAI client with the custom base URL
# The URL already contains the metadata as path segments, so we can use it directly
custom_openai = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=req.model_base_url)
completion = custom_openai.chat.completions.create(**completion_kwargs)
else:
print("Using default OpenAI base URL")
completion = openai.chat.completions.create(**completion_kwargs)

except Exception as e:
# Best-effort; mark as done even on error to unblock polling
Expand Down
1 change: 1 addition & 0 deletions tests/remote_server/test_remote_langfuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def rows() -> List[EvaluationRow]:
rollout_processor=RemoteRolloutProcessor(
remote_base_url="http://127.0.0.1:3000",
timeout_seconds=30,
model_base_url="https://api.openai.com/v1",
output_data_loader=langfuse_output_data_loader,
),
)
Expand Down
2 changes: 1 addition & 1 deletion tests/remote_server/typescript-server/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"@opentelemetry/sdk-node": "^0.205.0",
"cors": "^2.8.5",
"dotenv": "^17.2.2",
"eval-protocol": "^0.1.2",
"eval-protocol": "^0.1.3",
"express": "^5.1.0",
"helmet": "^7.1.0",
"openai": "^5.23.0"
Expand Down
10 changes: 5 additions & 5 deletions tests/remote_server/typescript-server/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions tests/remote_server/typescript-server/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,15 @@ async function simulateRolloutExecution(

const openai = new OpenAI({
apiKey: process.env["OPENAI_API_KEY"],
baseURL: initRequest.model_base_url || undefined,
});

if (initRequest.model_base_url) {
console.log(`Using custom model_base_url: ${initRequest.model_base_url}`);
} else {
console.log("Using default OpenAI base URL");
}

const tracedOpenAI = observeOpenAI(openai, {
tags: createLangfuseConfigTags(initRequest),
});
Expand Down
Loading