diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index a5a00d1d..5031f6fe 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -1,4 +1,5 @@ import asyncio +import base64 import time from typing import Any, Dict, List, Optional, Callable @@ -25,7 +26,9 @@ logger = logging.getLogger(__name__) -def _build_fireworks_tracing_url(base_url: str, metadata: RolloutMetadata) -> str: +def _build_fireworks_tracing_url( + base_url: str, metadata: RolloutMetadata, completion_params_base_url: Optional[str] = None +) -> str: """Build a Fireworks tracing URL by appending rollout metadata to the base URL path, allowing the Fireworks tracing proxy to automatically tag traces. @@ -35,8 +38,9 @@ def _build_fireworks_tracing_url(base_url: str, metadata: RolloutMetadata) -> st base_url: Fireworks tracing proxy URL (we expect this to be https://tracing.fireworks.ai or https://tracing.fireworks.ai/project_id/{project_id}) metadata: Rollout metadata containing IDs to embed in the URL + completion_params_base_url: Optional LLM base URL to encode and append to the final URL """ - return ( + url = ( f"{base_url}/rollout_id/{metadata.rollout_id}" f"/invocation_id/{metadata.invocation_id}" f"/experiment_id/{metadata.experiment_id}" @@ -44,6 +48,14 @@ def _build_fireworks_tracing_url(base_url: str, metadata: RolloutMetadata) -> st f"/row_id/{metadata.row_id}" ) + if ( + completion_params_base_url + ): # The final URL is both tracing.fireworks.ai and the actual LLM base URL we want to use + encoded_base_url = base64.urlsafe_b64encode(completion_params_base_url.encode()).decode() + url = f"{url}/encoded_base_url/{encoded_base_url}" + + return url + def _default_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: """Default output data loader that fetches traces from Fireworks tracing proxy. @@ -164,6 +176,13 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow: "Model must be provided in row.input_metadata.completion_params or config.completion_params" ) + # Extract base_url from completion_params if provided. If we're using tracing.fireworks.ai, this base_url gets encoded and passed to LiteLLM inside the proxy. + completion_params_base_url: Optional[str] = None + if row.input_metadata and row.input_metadata.completion_params: + completion_params_base_url = row.input_metadata.completion_params.get("base_url") + if completion_params_base_url is None and config.completion_params: + completion_params_base_url = config.completion_params.get("base_url") + # Strip non-OpenAI fields from messages before sending to remote allowed_message_fields = {"role", "content", "tool_calls", "tool_call_id", "name"} clean_messages = [] @@ -192,7 +211,7 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow: model_base_url.startswith("https://tracing.fireworks.ai") or model_base_url.startswith("http://localhost") ): - final_model_base_url = _build_fireworks_tracing_url(model_base_url, meta) + final_model_base_url = _build_fireworks_tracing_url(model_base_url, meta, completion_params_base_url) init_payload: InitRequest = InitRequest( model=model,