Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions eval_protocol/pytest/remote_rollout_processor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import base64
import time
from typing import Any, Dict, List, Optional, Callable

Expand All @@ -25,7 +26,9 @@
logger = logging.getLogger(__name__)


def _build_fireworks_tracing_url(base_url: str, metadata: RolloutMetadata) -> str:
def _build_fireworks_tracing_url(
base_url: str, metadata: RolloutMetadata, completion_params_base_url: Optional[str] = None
) -> str:
"""Build a Fireworks tracing URL by appending rollout metadata to the base URL path,
allowing the Fireworks tracing proxy to automatically tag traces.

Expand All @@ -35,15 +38,24 @@ def _build_fireworks_tracing_url(base_url: str, metadata: RolloutMetadata) -> st
base_url: Fireworks tracing proxy URL (we expect this to be https://tracing.fireworks.ai or
https://tracing.fireworks.ai/project_id/{project_id})
metadata: Rollout metadata containing IDs to embed in the URL
completion_params_base_url: Optional LLM base URL to encode and append to the final URL
"""
return (
url = (
f"{base_url}/rollout_id/{metadata.rollout_id}"
f"/invocation_id/{metadata.invocation_id}"
f"/experiment_id/{metadata.experiment_id}"
f"/run_id/{metadata.run_id}"
f"/row_id/{metadata.row_id}"
)

if (
completion_params_base_url
): # The final URL is both tracing.fireworks.ai and the actual LLM base URL we want to use
encoded_base_url = base64.urlsafe_b64encode(completion_params_base_url.encode()).decode()
url = f"{url}/encoded_base_url/{encoded_base_url}"

return url


def _default_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
"""Default output data loader that fetches traces from Fireworks tracing proxy.
Expand Down Expand Up @@ -164,6 +176,13 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
"Model must be provided in row.input_metadata.completion_params or config.completion_params"
)

# Extract base_url from completion_params if provided. If we're using tracing.fireworks.ai, this base_url gets encoded and passed to LiteLLM inside the proxy.
completion_params_base_url: Optional[str] = None
if row.input_metadata and row.input_metadata.completion_params:
completion_params_base_url = row.input_metadata.completion_params.get("base_url")
if completion_params_base_url is None and config.completion_params:
completion_params_base_url = config.completion_params.get("base_url")

# Strip non-OpenAI fields from messages before sending to remote
allowed_message_fields = {"role", "content", "tool_calls", "tool_call_id", "name"}
clean_messages = []
Expand Down Expand Up @@ -192,7 +211,7 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
model_base_url.startswith("https://tracing.fireworks.ai")
or model_base_url.startswith("http://localhost")
):
final_model_base_url = _build_fireworks_tracing_url(model_base_url, meta)
final_model_base_url = _build_fireworks_tracing_url(model_base_url, meta, completion_params_base_url)

init_payload: InitRequest = InitRequest(
model=model,
Expand Down
Loading