Skip to content
64 changes: 51 additions & 13 deletions eval_protocol/adapters/fireworks_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations
import logging
import requests
import time
from datetime import datetime
from typing import Any, Dict, List, Optional, Protocol

Expand Down Expand Up @@ -280,8 +281,9 @@ def get_evaluation_rows(
from_timestamp: Optional[datetime] = None,
to_timestamp: Optional[datetime] = None,
include_tool_calls: bool = True,
sleep_between_gets: float = 2.5,
max_retries: int = 3,
backend_sleep_between_gets: float = 0.1,
backend_max_retries: int = 3,
proxy_max_retries: int = 3,
span_name: Optional[str] = None,
converter: Optional[TraceDictConverter] = None,
) -> List[EvaluationRow]:
Expand All @@ -303,8 +305,9 @@ def get_evaluation_rows(
from_timestamp: Explicit start time (ISO format)
to_timestamp: Explicit end time (ISO format)
include_tool_calls: Whether to include tool calling traces
sleep_between_gets: Sleep time between trace.get() calls (handled by proxy)
max_retries: Maximum retries for rate limit errors (handled by proxy)
backend_sleep_between_gets: Sleep time between backend trace fetches (passed to proxy)
backend_max_retries: Maximum retries for backend operations (passed to proxy)
proxy_max_retries: Maximum retries when proxy returns 404 (client-side retries with exponential backoff)
span_name: If provided, extract messages from generations within this named span
converter: Optional custom converter implementing TraceDictConverter protocol.
If provided, this will be used instead of the default conversion logic.
Expand Down Expand Up @@ -336,25 +339,60 @@ def get_evaluation_rows(
"hours_back": hours_back,
"from_timestamp": from_timestamp.isoformat() if from_timestamp else None,
"to_timestamp": to_timestamp.isoformat() if to_timestamp else None,
"sleep_between_gets": sleep_between_gets,
"max_retries": max_retries,
"sleep_between_gets": backend_sleep_between_gets,
"max_retries": backend_max_retries,
}

# Remove None values
params = {k: v for k, v in params.items() if v is not None}

# Make request to proxy
# Make request to proxy with retry logic
if self.project_id:
url = f"{self.base_url}/v1/project_id/{self.project_id}/traces"
else:
url = f"{self.base_url}/v1/traces"

try:
response = requests.get(url, params=params, timeout=self.timeout)
response.raise_for_status()
result = response.json()
except requests.exceptions.RequestException as e:
logger.error("Failed to fetch traces from proxy: %s", e)
# Retry loop for handling backend indexing delays (proxy returns 404)
result = None
for attempt in range(proxy_max_retries):
try:
response = requests.get(url, params=params, timeout=self.timeout)
response.raise_for_status()
result = response.json()
break # Success, exit retry loop
except requests.exceptions.HTTPError as e:
error_msg = str(e)
should_retry = False

# Try to extract detail message from response
if e.response is not None:
try:
error_detail = e.response.json().get("detail", "")
error_msg = error_detail or e.response.text

# Retry on 404 if it's due to incomplete/missing traces (backend still indexing)
if e.response.status_code == 404 and (
"Incomplete traces" in error_detail or "No traces found" in error_detail
):
should_retry = True
except Exception:
error_msg = e.response.text

if should_retry and attempt < proxy_max_retries - 1:
sleep_time = 2 ** (attempt + 1)
logger.warning(error_msg)
time.sleep(sleep_time)
else:
# Final retry or non-retryable error
logger.error("Failed to fetch traces from proxy: %s", error_msg)
return eval_rows
except requests.exceptions.RequestException as e:
# Non-HTTP errors (network issues, timeouts, etc.)
logger.error("Failed to fetch traces from proxy: %s", str(e))
return eval_rows

if result is None:
logger.error("Failed to fetch traces after %d retries", proxy_max_retries)
return eval_rows

# Extract traces from response
Expand Down
7 changes: 5 additions & 2 deletions eval_protocol/pytest/remote_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _default_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
def fetch_traces() -> List[EvaluationRow]:
base_url = config.model_base_url or "https://tracing.fireworks.ai"
adapter = FireworksTracingAdapter(base_url=base_url)
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5)
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], proxy_max_retries=5)

return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=filter_longest_conversation)

Expand Down Expand Up @@ -188,7 +188,10 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
raise ValueError("Rollout ID is required in RemoteRolloutProcessor")

final_model_base_url = model_base_url
if model_base_url and model_base_url.startswith("https://tracing.fireworks.ai"):
if model_base_url and (
model_base_url.startswith("https://tracing.fireworks.ai")
or model_base_url.startswith("http://localhost")
):
final_model_base_url = _build_fireworks_tracing_url(model_base_url, meta)

init_payload: InitRequest = InitRequest(
Expand Down
104 changes: 104 additions & 0 deletions tests/remote_server/remote_server_multi_turn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import os
import random
import threading

import uvicorn
from fastapi import FastAPI
from openai import OpenAI
import logging

from eval_protocol import Status, InitRequest, ElasticsearchDirectHttpHandler, RolloutIdFilter


app = FastAPI()

# attach handler to root logger
handler = ElasticsearchDirectHttpHandler()
logging.getLogger().addHandler(handler)


@app.post("/init")
def init(req: InitRequest):
if req.elastic_search_config:
handler.configure(req.elastic_search_config)

# attach rollout_id filter to logger
logger = logging.getLogger(f"{__name__}.{req.metadata.rollout_id}")
logger.addFilter(RolloutIdFilter(req.metadata.rollout_id))

# Kick off worker thread that does a multi-turn chat (6 turns total)
def _worker():
try:
if not req.messages:
raise ValueError("messages is required")

client = OpenAI(base_url=req.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY"))

# Build up conversation over 6 turns (3 user messages + 3 assistant responses)
# Convert Message objects to dicts for OpenAI API
conversation_history = [{"role": m.role, "content": m.content} for m in req.messages]

follow_up_questions = [
"Tell me more about that.",
"What else can you share about this topic?",
]

# First completion (turns 1-2: initial user message + assistant response)
logger.info(f"Turn 1-2: Sending initial completion request to model {req.model}")
completion = client.chat.completions.create(
model=req.model,
messages=conversation_history, # type: ignore
)
assistant_message = completion.choices[0].message
assistant_content = assistant_message.content or ""
conversation_history.append({"role": "assistant", "content": assistant_content})
logger.info(f"Turn 2 response: {assistant_content[:100]}...")

# Second completion (turns 3-4: follow-up user message + assistant response)
conversation_history.append({"role": "user", "content": follow_up_questions[0]})
logger.info(f"Turn 3: User asks: {follow_up_questions[0]}")
completion = client.chat.completions.create(
model=req.model,
messages=conversation_history, # type: ignore
)
assistant_message = completion.choices[0].message
assistant_content = assistant_message.content or ""
conversation_history.append({"role": "assistant", "content": assistant_content})
logger.info(f"Turn 4 response: {assistant_content[:100]}...")

# Third completion (turns 5-6: another follow-up user message + assistant response)
conversation_history.append({"role": "user", "content": follow_up_questions[1]})
logger.info(f"Turn 5: User asks: {follow_up_questions[1]}")
completion = client.chat.completions.create(
model=req.model,
messages=conversation_history, # type: ignore
)
assistant_message = completion.choices[0].message
assistant_content = assistant_message.content or ""
conversation_history.append({"role": "assistant", "content": assistant_content})
logger.info(f"Turn 6 response: {assistant_content[:100]}...")

logger.info(f"Completed 6-turn conversation with {len(conversation_history)} messages total")

except Exception as e:
# Best-effort; mark as done even on error to unblock polling
print(f"❌ Error in rollout {req.metadata.rollout_id}: {e}")
pass
finally:
logger.info(
f"Rollout {req.metadata.rollout_id} completed",
extra={"status": Status.rollout_finished()},
)

t = threading.Thread(target=_worker, daemon=True)
t.start()


def main():
host = os.getenv("REMOTE_SERVER_HOST", "127.0.0.1")
port = int(os.getenv("REMOTE_SERVER_PORT", "3000"))
uvicorn.run(app, host=host, port=port)


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions tests/remote_server/test_remote_fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]:

base_url = config.model_base_url or "https://tracing.fireworks.ai"
adapter = FireworksTracingAdapter(base_url=base_url)
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5)
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], proxy_max_retries=5)


def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
Expand All @@ -65,7 +65,7 @@ def rows() -> List[EvaluationRow]:
),
rollout_processor=RemoteRolloutProcessor(
remote_base_url="http://127.0.0.1:3000",
timeout_seconds=30,
timeout_seconds=180,
output_data_loader=fireworks_output_data_loader,
),
)
Expand Down
Loading