From 341a0333403d99c15ee7ec7246461d9efa829cc6 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Sun, 5 Oct 2025 02:47:09 -0700 Subject: [PATCH 1/9] Fireworks Tracing --- eval_protocol/adapters/__init__.py | 8 + eval_protocol/adapters/fireworks_tracing.py | 394 ++++++++++++++++++ eval_protocol/adapters/langfuse.py | 2 +- .../pytest/remote_rollout_processor.py | 27 +- tests/remote_server/test_remote_fireworks.py | 84 ++++ tests/remote_server/test_remote_langfuse.py | 2 +- 6 files changed, 501 insertions(+), 16 deletions(-) create mode 100644 eval_protocol/adapters/fireworks_tracing.py create mode 100644 tests/remote_server/test_remote_fireworks.py diff --git a/eval_protocol/adapters/__init__.py b/eval_protocol/adapters/__init__.py index d338b6c2..39b312d8 100644 --- a/eval_protocol/adapters/__init__.py +++ b/eval_protocol/adapters/__init__.py @@ -6,6 +6,7 @@ Available adapters: - BaseAdapter: Abstract base class for all adapters - LangfuseAdapter: Pull data from Langfuse deployments +- FireworksTracingAdapter: Pull data from Langfuse via Fireworks tracing proxy - HuggingFaceAdapter: Load datasets from HuggingFace Hub - BigQueryAdapter: Query data from Google BigQuery - TRL integration (legacy) @@ -24,6 +25,13 @@ except ImportError: pass +try: + from .fireworks_tracing import FireworksTracingAdapter, create_fireworks_tracing_adapter + + __all__.extend(["FireworksTracingAdapter", "create_fireworks_tracing_adapter"]) +except ImportError: + pass + try: from .huggingface import ( HuggingFaceAdapter, diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py new file mode 100644 index 00000000..f3155023 --- /dev/null +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -0,0 +1,394 @@ +"""Fireworks Tracing adapter for Eval Protocol. + +This adapter uses the Fireworks tracing proxy at tracing.fireworks.ai +to pull data from Langfuse deployments with simplified retry logic handling. +""" + +from __future__ import annotations +import logging +import requests +from datetime import datetime +from typing import Any, Dict, List, Optional, Protocol + +from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message +from .base import BaseAdapter +from .utils import extract_messages_from_data + +logger = logging.getLogger(__name__) + + +class TraceDictConverter(Protocol): + """Protocol for custom trace dictionary-to-EvaluationRow converter functions. + + A converter function should take a trace dictionary along with processing + options and return an EvaluationRow or None to skip the trace. + """ + + def __call__( + self, + trace: Dict[str, Any], + include_tool_calls: bool, + span_name: Optional[str], + ) -> Optional[EvaluationRow]: + """Convert a trace dictionary to an EvaluationRow. + + Args: + trace: The trace dictionary to convert + include_tool_calls: Whether to include tool calling information + span_name: Optional span name to extract messages from + + Returns: + EvaluationRow or None if the trace should be skipped + """ + ... + + +def convert_trace_dict_to_evaluation_row( + trace: Dict[str, Any], include_tool_calls: bool = True, span_name: Optional[str] = None +) -> Optional[EvaluationRow]: + """Convert a trace dictionary (from proxy API) to EvaluationRow format. + + Args: + trace: Trace dictionary from Fireworks proxy API + include_tool_calls: Whether to include tool calling information + span_name: If provided, extract messages from generations within this named span + + Returns: + EvaluationRow or None if conversion fails + """ + try: + # Extract messages from trace input and output + messages = extract_messages_from_trace_dict(trace, include_tool_calls, span_name) + + # Extract tools if available + tools = None + if include_tool_calls and isinstance(trace.get("input"), dict) and "tools" in trace["input"]: + tools = trace["input"]["tools"] + + if not messages: + return None + + execution_metadata = ExecutionMetadata() + row_id = None + + # Extract metadata from tags + tags = trace.get("tags", []) + if tags: + for tag in tags: + if tag.startswith("invocation_id:"): + execution_metadata.invocation_id = tag.split(":", 1)[1] + elif tag.startswith("experiment_id:"): + execution_metadata.experiment_id = tag.split(":", 1)[1] + elif tag.startswith("rollout_id:"): + execution_metadata.rollout_id = tag.split(":", 1)[1] + elif tag.startswith("run_id:"): + execution_metadata.run_id = tag.split(":", 1)[1] + elif tag.startswith("row_id:"): + row_id = tag.split(":", 1)[1] + + if ( + execution_metadata.invocation_id + and execution_metadata.experiment_id + and execution_metadata.rollout_id + and execution_metadata.run_id + and row_id + ): + break # Break early if we've found all the metadata we need + + return EvaluationRow( + messages=messages, + tools=tools, + input_metadata=InputMetadata( + row_id=row_id, + session_data={ + "langfuse_trace_id": trace.get("id"), # Store the trace ID here + }, + ), + execution_metadata=execution_metadata, + ) + + except (AttributeError, ValueError, KeyError) as e: + logger.error("Error converting trace %s: %s", trace.get("id"), e) + return None + + +def extract_messages_from_trace_dict( + trace: Dict[str, Any], include_tool_calls: bool = True, span_name: Optional[str] = None +) -> List[Message]: + """Extract messages from trace dictionary. + + Args: + trace: Trace dictionary from proxy API + include_tool_calls: Whether to include tool calling information + span_name: If provided, extract messages from generations within this named span + + Returns: + List of Message objects + """ + messages = [] + + if span_name: # Look for a generation tied to a span name + try: + # Find the final generation in the named span + gen = get_final_generation_in_span_dict(trace, span_name) + if not gen: + return messages + + # Extract messages from generation input and output + if gen.get("input"): + messages.extend(extract_messages_from_data(gen["input"], include_tool_calls)) + if gen.get("output"): + messages.extend(extract_messages_from_data(gen["output"], include_tool_calls)) + + return messages + + except Exception as e: + logger.error("Failed to extract messages from span '%s' in trace %s: %s", span_name, trace.get("id"), e) + return messages + + else: + try: + # Extract messages from trace input and output + if trace.get("input"): + messages.extend(extract_messages_from_data(trace["input"], include_tool_calls)) + if trace.get("output"): + messages.extend(extract_messages_from_data(trace["output"], include_tool_calls)) + except (AttributeError, ValueError, KeyError) as e: + logger.warning("Error processing trace %s: %s", trace.get("id"), e) + + # Fallback: use the last GENERATION observation which typically contains full chat history + if not messages: + try: + all_observations = trace.get("observations", []) + gens = [obs for obs in all_observations if obs.get("type") == "GENERATION"] + if gens: + gens.sort(key=lambda x: x.get("start_time", "")) + last_gen = gens[-1] + if last_gen.get("input"): + messages.extend(extract_messages_from_data(last_gen["input"], include_tool_calls)) + if last_gen.get("output"): + messages.extend(extract_messages_from_data(last_gen["output"], include_tool_calls)) + except Exception as e: + logger.warning("Failed to extract from last generation for trace %s: %s", trace.get("id"), e) + + return messages + + +def get_final_generation_in_span_dict(trace: Dict[str, Any], span_name: str) -> Optional[Dict[str, Any]]: + """Get the final generation within a named span from trace dictionary. + + Args: + trace: Trace dictionary + span_name: Name of the span to search for + + Returns: + The final generation dictionary, or None if not found + """ + # Get all observations from the trace + all_observations = trace.get("observations", []) + + # Find a span with the given name that has generation children + parent_span = None + for obs in all_observations: + if obs.get("name") == span_name and obs.get("type") == "SPAN": + # Check if this span has generation children + has_generations = any( + child.get("type") == "GENERATION" and child.get("parent_observation_id") == obs.get("id") + for child in all_observations + ) + if has_generations: + parent_span = obs + break + + if not parent_span: + logger.warning("No span named '%s' found in trace %s", span_name, trace.get("id")) + return None + + # Find all generations within this span + generations = [] + for obs in all_observations: + if obs.get("type") == "GENERATION" and obs.get("parent_observation_id") == parent_span.get("id"): + generations.append(obs) + + if not generations: + logger.warning("No generations found in span '%s' in trace %s", span_name, trace.get("id")) + return None + + # Sort generations by start time for chronological order + generations.sort(key=lambda x: x.get("start_time", "")) + + # Return the final generation (contains full message history) + return generations[-1] + + +class FireworksTracingAdapter(BaseAdapter): + """Adapter to pull data from Langfuse via Fireworks tracing proxy. + + This adapter uses the Fireworks tracing proxy API which handles retry logic + and rate limiting internally, simplifying the client-side implementation. + + Examples: + Basic usage (default project): + >>> adapter = FireworksTracingAdapter() + >>> rows = list(adapter.get_evaluation_rows(tags=["rollout_id:xyz"], limit=10)) + + With explicit project ID: + >>> adapter = FireworksTracingAdapter( + ... project_id="your_project_id", + ... base_url="https://tracing.fireworks.ai" + ... ) + >>> rows = list(adapter.get_evaluation_rows(tags=["production"], limit=10)) + + Filter by specific criteria: + >>> rows = list(adapter.get_evaluation_rows( + ... tags=["production"], + ... limit=50, + ... hours_back=24 + ... )) + """ + + def __init__( + self, + project_id: Optional[str] = None, + base_url: str = "https://tracing.fireworks.ai", + timeout: int = 300, + ): + """Initialize the Fireworks Tracing adapter. + + Args: + project_id: Optional project ID. If not provided, uses the default project configured on the server. + base_url: The base URL of the tracing proxy (default: https://tracing.fireworks.ai) + timeout: Request timeout in seconds (default: 300) + """ + self.project_id = project_id + self.base_url = base_url.rstrip("/") + self.timeout = timeout + + def get_evaluation_rows( + self, + tags: List[str], + limit: int = 100, + sample_size: Optional[int] = None, + user_id: Optional[str] = None, + session_id: Optional[str] = None, + name: Optional[str] = None, + environment: Optional[str] = None, + version: Optional[str] = None, + release: Optional[str] = None, + fields: Optional[str] = None, + hours_back: Optional[int] = None, + from_timestamp: Optional[datetime] = None, + to_timestamp: Optional[datetime] = None, + include_tool_calls: bool = True, + sleep_between_gets: float = 2.5, + max_retries: int = 3, + span_name: Optional[str] = None, + converter: Optional[TraceDictConverter] = None, + ) -> List[EvaluationRow]: + """Pull traces from Langfuse via proxy and convert to EvaluationRow format. + + Args: + tags: REQUIRED - Filter by specific tags (prevents fetching all traces). + Must provide at least one tag (e.g., ['rollout_id:xyz'], ['production']) + limit: Max number of trace summaries to collect via pagination + sample_size: Optional number of traces to randomly sample (if None, process all) + user_id: Filter by user ID + session_id: Filter by session ID + name: Filter by trace name + environment: Filter by environment (e.g., production, staging, development) + version: Filter by trace version + release: Filter by trace release + fields: Comma-separated list of fields to include + hours_back: Filter traces from this many hours ago + from_timestamp: Explicit start time (ISO format) + to_timestamp: Explicit end time (ISO format) + include_tool_calls: Whether to include tool calling traces + sleep_between_gets: Sleep time between trace.get() calls (handled by proxy) + max_retries: Maximum retries for rate limit errors (handled by proxy) + span_name: If provided, extract messages from generations within this named span + converter: Optional custom converter implementing TraceDictConverter protocol. + If provided, this will be used instead of the default conversion logic. + + Returns: + List[EvaluationRow]: Converted evaluation rows + + Raises: + ValueError: If tags list is empty + """ + # Validate that tags are provided (security requirement) + if not tags or len(tags) == 0: + raise ValueError("At least one tag is required to fetch traces (security: prevents fetching all traces)") + + eval_rows = [] + + # Build request payload + payload = { + "limit": limit, + "sample_size": sample_size, + "tags": tags, + "user_id": user_id, + "session_id": session_id, + "name": name, + "environment": environment, + "version": version, + "release": release, + "fields": fields, + "hours_back": hours_back, + "from_timestamp": from_timestamp.isoformat() if from_timestamp else None, + "to_timestamp": to_timestamp.isoformat() if to_timestamp else None, + "include_tool_calls": include_tool_calls, + "sleep_between_gets": sleep_between_gets, + "max_retries": max_retries, + "span_name": span_name, + } + + # Remove None values + payload = {k: v for k, v in payload.items() if v is not None} + + # Make request to proxy + if self.project_id: + url = f"{self.base_url}/v1/project_id/{self.project_id}/langfuse/traces" + else: + url = f"{self.base_url}/v1/langfuse/traces" + + try: + response = requests.post(url, json=payload, timeout=self.timeout) + response.raise_for_status() + result = response.json() + except requests.exceptions.RequestException as e: + logger.error("Failed to fetch traces from proxy: %s", e) + return eval_rows + + # Extract traces from response + traces = result.get("traces", []) + + # Convert each trace to EvaluationRow + for trace in traces: + try: + if converter: + eval_row = converter(trace, include_tool_calls, span_name) + else: + eval_row = convert_trace_dict_to_evaluation_row(trace, include_tool_calls, span_name) + if eval_row: + eval_rows.append(eval_row) + except (AttributeError, ValueError, KeyError) as e: + logger.warning("Failed to convert trace %s: %s", trace.get("id"), e) + continue + + logger.info("Successfully converted %d traces to evaluation rows", len(eval_rows)) + return eval_rows + + +def create_fireworks_tracing_adapter( + project_id: Optional[str] = None, base_url: str = "https://tracing.fireworks.ai" +) -> FireworksTracingAdapter: + """Factory function to create a Fireworks Tracing adapter. + + Args: + project_id: Optional project ID. If not provided, uses the default project configured on the server. + base_url: The base URL of the tracing proxy + + Returns: + FireworksTracingAdapter instance + """ + return FireworksTracingAdapter(project_id=project_id, base_url=base_url) diff --git a/eval_protocol/adapters/langfuse.py b/eval_protocol/adapters/langfuse.py index 42eeee6f..0174922f 100644 --- a/eval_protocol/adapters/langfuse.py +++ b/eval_protocol/adapters/langfuse.py @@ -355,7 +355,7 @@ def get_evaluation_rows( # If no results, possible due to indexing delay--remote rollout processor just finished pushing rows to Langfuse if traces and traces.meta and traces.meta.total_items == 0 and page == 1: - raise Exception("Empty results - indexing delay") + raise Exception("Empty results") break except Exception as e: diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index 1c96affa..c88ccc56 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -29,7 +29,7 @@ def __init__( self, *, remote_base_url: Optional[str] = None, - model_base_url: Optional[str] = None, + model_base_url: str = "https://tracing.fireworks.ai", poll_interval: float = 1.0, timeout_seconds: float = 120.0, output_data_loader: Callable[[str], DynamicDataLoader], @@ -42,7 +42,6 @@ def __init__( self._model_base_url = model_base_url if os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL"): self._remote_base_url = os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL") - self._model_base_url = model_base_url self._poll_interval = poll_interval self._timeout_seconds = timeout_seconds self._output_data_loader = output_data_loader @@ -67,7 +66,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> # Start with constructor values remote_base_url: Optional[str] = self._remote_base_url - model_base_url: Optional[str] = self._model_base_url + model_base_url: str = self._model_base_url poll_interval: float = self._poll_interval timeout_seconds: float = self._timeout_seconds @@ -138,7 +137,7 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow: raise ValueError("Rollout ID is required in RemoteRolloutProcessor") final_model_base_url = model_base_url - if model_base_url and model_base_url.startswith("https://tracing.fireworks.ai/project_id/"): + if model_base_url and model_base_url.startswith("https://tracing.fireworks.ai"): final_model_base_url = ( f"{model_base_url}/rollout_id/{meta.rollout_id}" f"/invocation_id/{meta.invocation_id}" @@ -252,24 +251,24 @@ def _load_data(): output_rows: List[EvaluationRow] = [row for result in results for row in result.rows] - if len(output_rows) == 0: # Fallback to original row if no Langfuse data found - row.rollout_status = Status(code=Status.Code.NOT_FOUND, message="No Langfuse data found for rollout") + if len(output_rows) == 0: # Fallback to original row if no Remote data found + row.rollout_status = Status(code=Status.Code.NOT_FOUND, message="No remote data found for rollout") return row - elif len(output_rows) == 1: # Return the Langfuse row - langfuse_row = output_rows[0] + elif len(output_rows) == 1: # Return the remote row + remote_row = output_rows[0] - # if the langfuse_row has the same number of messages as the original row, + # if the remote_row has the same number of messages as the original row, # something went wrong - if len(langfuse_row.messages) == len(row.messages): + if len(remote_row.messages) == len(row.messages): row.rollout_status = Status.rollout_error( "Rollout finished with the same number of messages as the original row" ) return row - row.messages = langfuse_row.messages - row.tools = langfuse_row.tools - row.input_metadata.session_data = langfuse_row.input_metadata.session_data - row.execution_metadata = langfuse_row.execution_metadata + row.messages = remote_row.messages + row.tools = remote_row.tools + row.input_metadata.session_data = remote_row.input_metadata.session_data + row.execution_metadata = remote_row.execution_metadata return row else: raise ValueError("RemoteRolloutProcessor's output_data_loader should return exactly one row.") diff --git a/tests/remote_server/test_remote_fireworks.py b/tests/remote_server/test_remote_fireworks.py new file mode 100644 index 00000000..9471c956 --- /dev/null +++ b/tests/remote_server/test_remote_fireworks.py @@ -0,0 +1,84 @@ +# MANUAL SERVER STARTUP REQUIRED: +# +# For Python server testing, start: +# python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000) +# +# For TypeScript server testing, start: +# cd /Users/derekxu/Documents/code/python-sdk/tests/remote_server/typescript-server +# npm install +# npm start +# +# The TypeScript server should be running on http://127.0.0.1:3000 +# You only need to start one of the servers! + +import os +from typing import List + +import pytest + +from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest import evaluation_test +from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor +from eval_protocol.adapters.fireworks_tracing import create_fireworks_tracing_adapter +from eval_protocol.quickstart.utils import filter_longest_conversation + +ROLLOUT_IDS = set() + + +@pytest.fixture(autouse=True) +def check_rollout_coverage(): + """Ensure we processed all expected rollout_ids""" + global ROLLOUT_IDS + ROLLOUT_IDS.clear() + yield + + assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}" + + +def fetch_fireworks_traces(rollout_id: str) -> List[EvaluationRow]: + global ROLLOUT_IDS # Track all rollout_ids we've seen + ROLLOUT_IDS.add(rollout_id) + + adapter = create_fireworks_tracing_adapter() + return adapter.get_evaluation_rows(tags=[f"rollout_id:{rollout_id}"], max_retries=5) + + +def fireworks_output_data_loader(rollout_id: str) -> DynamicDataLoader: + return DynamicDataLoader( + generators=[lambda: fetch_fireworks_traces(rollout_id)], preprocess_fn=filter_longest_conversation + ) + + +def rows() -> List[EvaluationRow]: + row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")]) + return [row, row, row] + + +@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") +@pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}]) +@evaluation_test( + data_loaders=DynamicDataLoader( + generators=[rows], + ), + rollout_processor=RemoteRolloutProcessor( + remote_base_url="http://127.0.0.1:3000", + timeout_seconds=30, + output_data_loader=fireworks_output_data_loader, + ), +) +async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> EvaluationRow: + """ + End-to-end test: + - REQUIRES MANUAL SERVER STARTUP: python -m tests.remote_server.remote_server + - trigger remote rollout via RemoteRolloutProcessor (calls init/status) + - fetch traces from Langfuse via Fireworks tracing proxy filtered by metadata via output_data_loader; FAIL if none found + """ + assert row.messages[0].content == "What is the capital of France?", "Row should have correct message content" + assert len(row.messages) > 1, "Row should have a response. If this fails, we fellback to the original row." + + assert row.execution_metadata.rollout_id in ROLLOUT_IDS, ( + f"Row rollout_id {row.execution_metadata.rollout_id} should be in tracked rollout_ids: {ROLLOUT_IDS}" + ) + + return row diff --git a/tests/remote_server/test_remote_langfuse.py b/tests/remote_server/test_remote_langfuse.py index bf92d19a..1eb4f3e5 100644 --- a/tests/remote_server/test_remote_langfuse.py +++ b/tests/remote_server/test_remote_langfuse.py @@ -1,7 +1,7 @@ # MANUAL SERVER STARTUP REQUIRED: # # For Python server testing, start: -# python -m tests.remote_server.remote_server (runs on http://127.0.0.1:7077) +# python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000) # # For TypeScript server testing, start: # cd /Users/derekxu/Documents/code/python-sdk/tests/remote_server/typescript-server From fd204bb19f93dde1fb0d9f2588b089b2934833f2 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Sun, 5 Oct 2025 02:48:48 -0700 Subject: [PATCH 2/9] update path --- tests/remote_server/test_remote_fireworks.py | 2 +- tests/remote_server/test_remote_langfuse.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/remote_server/test_remote_fireworks.py b/tests/remote_server/test_remote_fireworks.py index 9471c956..21fef2e0 100644 --- a/tests/remote_server/test_remote_fireworks.py +++ b/tests/remote_server/test_remote_fireworks.py @@ -4,7 +4,7 @@ # python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000) # # For TypeScript server testing, start: -# cd /Users/derekxu/Documents/code/python-sdk/tests/remote_server/typescript-server +# cd tests/remote_server/typescript-server # npm install # npm start # diff --git a/tests/remote_server/test_remote_langfuse.py b/tests/remote_server/test_remote_langfuse.py index 1eb4f3e5..753518f0 100644 --- a/tests/remote_server/test_remote_langfuse.py +++ b/tests/remote_server/test_remote_langfuse.py @@ -4,7 +4,7 @@ # python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000) # # For TypeScript server testing, start: -# cd /Users/derekxu/Documents/code/python-sdk/tests/remote_server/typescript-server +# cd tests/remote_server/typescript-server # npm install # npm start # From 91f23782469000007d6feb8b013bdd09dd8b9e1f Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Mon, 6 Oct 2025 10:39:28 -0700 Subject: [PATCH 3/9] various changes --- eval_protocol/adapters/__init__.py | 7 +-- eval_protocol/adapters/fireworks_tracing.py | 4 +- .../pytest/remote_rollout_processor.py | 28 ++++++++-- tests/remote_server/quickstart.py | 53 +++++++++++++++++++ tests/remote_server/test_remote_fireworks.py | 8 +-- 5 files changed, 86 insertions(+), 14 deletions(-) create mode 100644 tests/remote_server/quickstart.py diff --git a/eval_protocol/adapters/__init__.py b/eval_protocol/adapters/__init__.py index 39b312d8..d664b425 100644 --- a/eval_protocol/adapters/__init__.py +++ b/eval_protocol/adapters/__init__.py @@ -25,12 +25,9 @@ except ImportError: pass -try: - from .fireworks_tracing import FireworksTracingAdapter, create_fireworks_tracing_adapter +from .fireworks_tracing import FireworksTracingAdapter, create_fireworks_tracing_adapter - __all__.extend(["FireworksTracingAdapter", "create_fireworks_tracing_adapter"]) -except ImportError: - pass +__all__.extend(["FireworksTracingAdapter", "create_fireworks_tracing_adapter"]) try: from .huggingface import ( diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index f3155023..108e85a5 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -347,9 +347,9 @@ def get_evaluation_rows( # Make request to proxy if self.project_id: - url = f"{self.base_url}/v1/project_id/{self.project_id}/langfuse/traces" + url = f"{self.base_url}/v1/project_id/{self.project_id}/traces" else: - url = f"{self.base_url}/v1/langfuse/traces" + url = f"{self.base_url}/v1/traces" try: response = requests.post(url, json=payload, timeout=self.timeout) diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index c88ccc56..d879cd17 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -8,6 +8,8 @@ from eval_protocol.models import EvaluationRow, Status from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig, InitRequest, RolloutMetadata +from eval_protocol.adapters.fireworks_tracing import create_fireworks_tracing_adapter +from eval_protocol.quickstart.utils import filter_longest_conversation from .rollout_processor import RolloutProcessor from .types import RolloutProcessorConfig from .elasticsearch_setup import ElasticsearchSetup @@ -18,10 +20,30 @@ logger = logging.getLogger(__name__) +def _default_output_data_loader(rollout_id: str, base_url: str) -> DynamicDataLoader: + """Default output data loader that fetches traces from Fireworks tracing proxy. + + Args: + rollout_id: The rollout ID to filter traces by + + Returns: + DynamicDataLoader configured to fetch and process traces + """ + + def fetch_traces() -> List[EvaluationRow]: + adapter = create_fireworks_tracing_adapter(base_url=base_url) + return adapter.get_evaluation_rows(tags=[f"rollout_id:{rollout_id}"], max_retries=5) + + return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=filter_longest_conversation) + + class RemoteRolloutProcessor(RolloutProcessor): """ Rollout processor that triggers a remote HTTP server to perform the rollout. + By default, fetches traces from the Fireworks tracing proxy using rollout_id tags. + You can provide a custom output_data_loader for different tracing backends. + See https://evalprotocol.io/tutorial/remote-rollout-processor for documentation. """ @@ -32,7 +54,7 @@ def __init__( model_base_url: str = "https://tracing.fireworks.ai", poll_interval: float = 1.0, timeout_seconds: float = 120.0, - output_data_loader: Callable[[str], DynamicDataLoader], + output_data_loader: Optional[Callable[[str, str], DynamicDataLoader]] = None, disable_elastic_search: bool = False, elastic_search_config: Optional[ElasticsearchConfig] = None, ): @@ -44,7 +66,7 @@ def __init__( self._remote_base_url = os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL") self._poll_interval = poll_interval self._timeout_seconds = timeout_seconds - self._output_data_loader = output_data_loader + self._output_data_loader = output_data_loader or _default_output_data_loader self._disable_elastic_search = disable_elastic_search self._elastic_search_config = elastic_search_config @@ -242,7 +264,7 @@ def _get_status() -> Dict[str, Any]: if row.execution_metadata.rollout_id is None: raise ValueError("Rollout ID is required in RemoteRolloutProcessor") - data_loader = self._output_data_loader(row.execution_metadata.rollout_id) + data_loader = self._output_data_loader(row.execution_metadata.rollout_id, model_base_url) def _load_data(): return data_loader.load() diff --git a/tests/remote_server/quickstart.py b/tests/remote_server/quickstart.py new file mode 100644 index 00000000..1d204dd2 --- /dev/null +++ b/tests/remote_server/quickstart.py @@ -0,0 +1,53 @@ +# MANUAL SERVER STARTUP REQUIRED: +# +# For Python server testing, start: +# python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000) +# +# For TypeScript server testing, start: +# cd tests/remote_server/typescript-server +# npm install +# npm start +# +# The TypeScript server should be running on http://127.0.0.1:3000 +# You only need to start one of the servers! + +import os +from typing import List + +import pytest + +from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest import evaluation_test +from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor + + +def rows() -> List[EvaluationRow]: + row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")]) + return [row, row, row] + + +@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") +@pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}]) +@evaluation_test( + data_loaders=DynamicDataLoader( + generators=[rows], + ), + rollout_processor=RemoteRolloutProcessor( + remote_base_url="http://127.0.0.1:3000", + timeout_seconds=30, + ), +) +async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> EvaluationRow: + """ + End-to-end test: + - REQUIRES MANUAL SERVER STARTUP: python -m tests.remote_server.remote_server + - trigger remote rollout via RemoteRolloutProcessor (calls init/status) + - fetch traces from Langfuse via Fireworks tracing proxy (uses default FireworksTracingAdapter) + - FAIL if no traces found or rollout_id missing + """ + assert row.messages[0].content == "What is the capital of France?", "Row should have correct message content" + assert len(row.messages) > 1, "Row should have a response. If this fails, we fellback to the original row." + assert row.execution_metadata.rollout_id, "Row should have a rollout_id from the remote rollout" + + return row diff --git a/tests/remote_server/test_remote_fireworks.py b/tests/remote_server/test_remote_fireworks.py index 21fef2e0..a9505481 100644 --- a/tests/remote_server/test_remote_fireworks.py +++ b/tests/remote_server/test_remote_fireworks.py @@ -36,17 +36,17 @@ def check_rollout_coverage(): assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}" -def fetch_fireworks_traces(rollout_id: str) -> List[EvaluationRow]: +def fetch_fireworks_traces(rollout_id: str, base_url: str) -> List[EvaluationRow]: global ROLLOUT_IDS # Track all rollout_ids we've seen ROLLOUT_IDS.add(rollout_id) - adapter = create_fireworks_tracing_adapter() + adapter = create_fireworks_tracing_adapter(base_url=base_url) return adapter.get_evaluation_rows(tags=[f"rollout_id:{rollout_id}"], max_retries=5) -def fireworks_output_data_loader(rollout_id: str) -> DynamicDataLoader: +def fireworks_output_data_loader(rollout_id: str, base_url: str) -> DynamicDataLoader: return DynamicDataLoader( - generators=[lambda: fetch_fireworks_traces(rollout_id)], preprocess_fn=filter_longest_conversation + generators=[lambda: fetch_fireworks_traces(rollout_id, base_url)], preprocess_fn=filter_longest_conversation ) From 260f7216234d2f9c51c087d80385e1616871dc75 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Mon, 6 Oct 2025 13:06:32 -0700 Subject: [PATCH 4/9] add dataloaderconfig --- .../pytest/remote_rollout_processor.py | 22 +++++++++++++------ .../types/remote_rollout_processor.py | 7 ++++++ tests/remote_server/test_remote_fireworks.py | 12 +++++----- tests/remote_server/test_remote_langfuse.py | 11 +++++----- 4 files changed, 35 insertions(+), 17 deletions(-) diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index 7a116143..ceb9e7d4 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -7,7 +7,12 @@ from eval_protocol.log_utils.elasticsearch_client import ElasticsearchClient from eval_protocol.models import EvaluationRow, Status from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader -from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig, InitRequest, RolloutMetadata +from eval_protocol.types.remote_rollout_processor import ( + DataLoaderConfig, + ElasticsearchConfig, + InitRequest, + RolloutMetadata, +) from eval_protocol.adapters.fireworks_tracing import create_fireworks_tracing_adapter from eval_protocol.quickstart.utils import filter_longest_conversation from .rollout_processor import RolloutProcessor @@ -20,19 +25,20 @@ logger = logging.getLogger(__name__) -def _default_output_data_loader(rollout_id: str, base_url: str) -> DynamicDataLoader: +def _default_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: """Default output data loader that fetches traces from Fireworks tracing proxy. Args: - rollout_id: The rollout ID to filter traces by + config: Configuration containing rollout_id and optional model_base_url Returns: DynamicDataLoader configured to fetch and process traces """ def fetch_traces() -> List[EvaluationRow]: + base_url = config.model_base_url or "https://tracing.fireworks.ai" adapter = create_fireworks_tracing_adapter(base_url=base_url) - return adapter.get_evaluation_rows(tags=[f"rollout_id:{rollout_id}"], max_retries=5) + return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5) return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=filter_longest_conversation) @@ -54,7 +60,7 @@ def __init__( model_base_url: str = "https://tracing.fireworks.ai", poll_interval: float = 1.0, timeout_seconds: float = 120.0, - output_data_loader: Optional[Callable[[str, str], DynamicDataLoader]] = None, + output_data_loader: Optional[Callable[[DataLoaderConfig], DynamicDataLoader]] = None, disable_elastic_search: bool = False, elastic_search_config: Optional[ElasticsearchConfig] = None, ): @@ -64,7 +70,6 @@ def __init__( self._model_base_url = model_base_url if os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL"): self._remote_base_url = os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL") - self._model_base_url = model_base_url _ep_model_base_url = os.getenv("EP_MODEL_BASE_URL") if _ep_model_base_url: self._model_base_url = _ep_model_base_url @@ -268,7 +273,10 @@ def _get_status() -> Dict[str, Any]: if row.execution_metadata.rollout_id is None: raise ValueError("Rollout ID is required in RemoteRolloutProcessor") - data_loader = self._output_data_loader(row.execution_metadata.rollout_id, model_base_url) + loader_config = DataLoaderConfig( + rollout_id=row.execution_metadata.rollout_id, model_base_url=model_base_url + ) + data_loader = self._output_data_loader(loader_config) def _load_data(): return data_loader.load() diff --git a/eval_protocol/types/remote_rollout_processor.py b/eval_protocol/types/remote_rollout_processor.py index 67c3158a..a972d2b5 100644 --- a/eval_protocol/types/remote_rollout_processor.py +++ b/eval_protocol/types/remote_rollout_processor.py @@ -34,6 +34,13 @@ class RolloutMetadata(BaseModel): row_id: str +class DataLoaderConfig(BaseModel): + """Configuration passed to output_data_loader functions.""" + + rollout_id: str + model_base_url: Optional[str] = None + + class InitRequest(BaseModel): """Request model for POST /init endpoint.""" diff --git a/tests/remote_server/test_remote_fireworks.py b/tests/remote_server/test_remote_fireworks.py index a9505481..7dc65cf7 100644 --- a/tests/remote_server/test_remote_fireworks.py +++ b/tests/remote_server/test_remote_fireworks.py @@ -22,6 +22,7 @@ from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor from eval_protocol.adapters.fireworks_tracing import create_fireworks_tracing_adapter from eval_protocol.quickstart.utils import filter_longest_conversation +from eval_protocol.types.remote_rollout_processor import DataLoaderConfig ROLLOUT_IDS = set() @@ -36,17 +37,18 @@ def check_rollout_coverage(): assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}" -def fetch_fireworks_traces(rollout_id: str, base_url: str) -> List[EvaluationRow]: +def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]: global ROLLOUT_IDS # Track all rollout_ids we've seen - ROLLOUT_IDS.add(rollout_id) + ROLLOUT_IDS.add(config.rollout_id) + base_url = config.model_base_url or "https://tracing.fireworks.ai" adapter = create_fireworks_tracing_adapter(base_url=base_url) - return adapter.get_evaluation_rows(tags=[f"rollout_id:{rollout_id}"], max_retries=5) + return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5) -def fireworks_output_data_loader(rollout_id: str, base_url: str) -> DynamicDataLoader: +def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: return DynamicDataLoader( - generators=[lambda: fetch_fireworks_traces(rollout_id, base_url)], preprocess_fn=filter_longest_conversation + generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation ) diff --git a/tests/remote_server/test_remote_langfuse.py b/tests/remote_server/test_remote_langfuse.py index 753518f0..35828570 100644 --- a/tests/remote_server/test_remote_langfuse.py +++ b/tests/remote_server/test_remote_langfuse.py @@ -22,6 +22,7 @@ from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor from eval_protocol.adapters.langfuse import create_langfuse_adapter from eval_protocol.quickstart.utils import filter_longest_conversation +from eval_protocol.types.remote_rollout_processor import DataLoaderConfig ROLLOUT_IDS = set() @@ -36,17 +37,17 @@ def check_rollout_coverage(): assert len(ROLLOUT_IDS) == 3, f"Expected to see {ROLLOUT_IDS} rollout_ids, but only saw {ROLLOUT_IDS}" -def fetch_langfuse_traces(rollout_id: str) -> List[EvaluationRow]: +def fetch_langfuse_traces(config: DataLoaderConfig) -> List[EvaluationRow]: global ROLLOUT_IDS # Track all rollout_ids we've seen - ROLLOUT_IDS.add(rollout_id) + ROLLOUT_IDS.add(config.rollout_id) adapter = create_langfuse_adapter() - return adapter.get_evaluation_rows(tags=[f"rollout_id:{rollout_id}"], max_retries=5) + return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5) -def langfuse_output_data_loader(rollout_id: str) -> DynamicDataLoader: +def langfuse_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: return DynamicDataLoader( - generators=[lambda: fetch_langfuse_traces(rollout_id)], preprocess_fn=filter_longest_conversation + generators=[lambda: fetch_langfuse_traces(config)], preprocess_fn=filter_longest_conversation ) From 03d3b0cb6642ece42b9c4d833779255badf32b06 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Mon, 6 Oct 2025 14:48:46 -0700 Subject: [PATCH 5/9] use get --- eval_protocol/adapters/fireworks_tracing.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index 108e85a5..ba06c532 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -321,8 +321,8 @@ def get_evaluation_rows( eval_rows = [] - # Build request payload - payload = { + # Build query parameters for GET request + params = { "limit": limit, "sample_size": sample_size, "tags": tags, @@ -336,14 +336,12 @@ def get_evaluation_rows( "hours_back": hours_back, "from_timestamp": from_timestamp.isoformat() if from_timestamp else None, "to_timestamp": to_timestamp.isoformat() if to_timestamp else None, - "include_tool_calls": include_tool_calls, "sleep_between_gets": sleep_between_gets, "max_retries": max_retries, - "span_name": span_name, } # Remove None values - payload = {k: v for k, v in payload.items() if v is not None} + params = {k: v for k, v in params.items() if v is not None} # Make request to proxy if self.project_id: @@ -352,7 +350,7 @@ def get_evaluation_rows( url = f"{self.base_url}/v1/traces" try: - response = requests.post(url, json=payload, timeout=self.timeout) + response = requests.get(url, params=params, timeout=self.timeout) response.raise_for_status() result = response.json() except requests.exceptions.RequestException as e: From aa85ed525e0ec0491e95a0aae9a822edadefe416 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Tue, 7 Oct 2025 01:38:22 -0700 Subject: [PATCH 6/9] validated using remote_server_multi_turn.py --- eval_protocol/adapters/fireworks_tracing.py | 64 ++++++++--- .../pytest/remote_rollout_processor.py | 7 +- .../remote_server/remote_server_multi_turn.py | 104 ++++++++++++++++++ tests/remote_server/test_remote_fireworks.py | 4 +- 4 files changed, 162 insertions(+), 17 deletions(-) create mode 100644 tests/remote_server/remote_server_multi_turn.py diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index ba06c532..1efc0579 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -7,6 +7,7 @@ from __future__ import annotations import logging import requests +import time from datetime import datetime from typing import Any, Dict, List, Optional, Protocol @@ -280,8 +281,9 @@ def get_evaluation_rows( from_timestamp: Optional[datetime] = None, to_timestamp: Optional[datetime] = None, include_tool_calls: bool = True, - sleep_between_gets: float = 2.5, - max_retries: int = 3, + backend_sleep_between_gets: float = 0.1, + backend_max_retries: int = 3, + proxy_max_retries: int = 3, span_name: Optional[str] = None, converter: Optional[TraceDictConverter] = None, ) -> List[EvaluationRow]: @@ -303,8 +305,9 @@ def get_evaluation_rows( from_timestamp: Explicit start time (ISO format) to_timestamp: Explicit end time (ISO format) include_tool_calls: Whether to include tool calling traces - sleep_between_gets: Sleep time between trace.get() calls (handled by proxy) - max_retries: Maximum retries for rate limit errors (handled by proxy) + backend_sleep_between_gets: Sleep time between backend trace fetches (passed to proxy) + backend_max_retries: Maximum retries for backend operations (passed to proxy) + proxy_max_retries: Maximum retries when proxy returns 404 (client-side retries with exponential backoff) span_name: If provided, extract messages from generations within this named span converter: Optional custom converter implementing TraceDictConverter protocol. If provided, this will be used instead of the default conversion logic. @@ -336,25 +339,60 @@ def get_evaluation_rows( "hours_back": hours_back, "from_timestamp": from_timestamp.isoformat() if from_timestamp else None, "to_timestamp": to_timestamp.isoformat() if to_timestamp else None, - "sleep_between_gets": sleep_between_gets, - "max_retries": max_retries, + "sleep_between_gets": backend_sleep_between_gets, + "max_retries": backend_max_retries, } # Remove None values params = {k: v for k, v in params.items() if v is not None} - # Make request to proxy + # Make request to proxy with retry logic if self.project_id: url = f"{self.base_url}/v1/project_id/{self.project_id}/traces" else: url = f"{self.base_url}/v1/traces" - try: - response = requests.get(url, params=params, timeout=self.timeout) - response.raise_for_status() - result = response.json() - except requests.exceptions.RequestException as e: - logger.error("Failed to fetch traces from proxy: %s", e) + # Retry loop for handling backend indexing delays (proxy returns 404) + result = None + for attempt in range(proxy_max_retries): + try: + response = requests.get(url, params=params, timeout=self.timeout) + response.raise_for_status() + result = response.json() + break # Success, exit retry loop + except requests.exceptions.HTTPError as e: + error_msg = str(e) + should_retry = False + + # Try to extract detail message from response + if e.response is not None: + try: + error_detail = e.response.json().get("detail", "") + error_msg = error_detail or e.response.text + + # Retry on 404 if it's due to incomplete/missing traces (backend still indexing) + if e.response.status_code == 404 and ( + "Incomplete traces" in error_detail or "No traces found" in error_detail + ): + should_retry = True + except Exception: + error_msg = e.response.text + + if should_retry and attempt < proxy_max_retries - 1: + sleep_time = 2 ** (attempt + 1) + logger.warning(error_msg) + time.sleep(sleep_time) + else: + # Final retry or non-retryable error + logger.error("Failed to fetch traces from proxy: %s", error_msg) + return eval_rows + except requests.exceptions.RequestException as e: + # Non-HTTP errors (network issues, timeouts, etc.) + logger.error("Failed to fetch traces from proxy: %s", str(e)) + return eval_rows + + if result is None: + logger.error("Failed to fetch traces after %d retries", proxy_max_retries) return eval_rows # Extract traces from response diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index ceb9e7d4..bcdb44be 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -38,7 +38,7 @@ def _default_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: def fetch_traces() -> List[EvaluationRow]: base_url = config.model_base_url or "https://tracing.fireworks.ai" adapter = create_fireworks_tracing_adapter(base_url=base_url) - return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5) + return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], proxy_max_retries=5) return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=filter_longest_conversation) @@ -168,7 +168,10 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow: raise ValueError("Rollout ID is required in RemoteRolloutProcessor") final_model_base_url = model_base_url - if model_base_url and model_base_url.startswith("https://tracing.fireworks.ai"): + if model_base_url and ( + model_base_url.startswith("https://tracing.fireworks.ai") + or model_base_url.startswith("http://localhost") + ): final_model_base_url = ( f"{model_base_url}/rollout_id/{meta.rollout_id}" f"/invocation_id/{meta.invocation_id}" diff --git a/tests/remote_server/remote_server_multi_turn.py b/tests/remote_server/remote_server_multi_turn.py new file mode 100644 index 00000000..155a0a2a --- /dev/null +++ b/tests/remote_server/remote_server_multi_turn.py @@ -0,0 +1,104 @@ +import os +import random +import threading + +import uvicorn +from fastapi import FastAPI +from openai import OpenAI +import logging + +from eval_protocol import Status, InitRequest, ElasticsearchDirectHttpHandler, RolloutIdFilter + + +app = FastAPI() + +# attach handler to root logger +handler = ElasticsearchDirectHttpHandler() +logging.getLogger().addHandler(handler) + + +@app.post("/init") +def init(req: InitRequest): + if req.elastic_search_config: + handler.configure(req.elastic_search_config) + + # attach rollout_id filter to logger + logger = logging.getLogger(f"{__name__}.{req.metadata.rollout_id}") + logger.addFilter(RolloutIdFilter(req.metadata.rollout_id)) + + # Kick off worker thread that does a multi-turn chat (6 turns total) + def _worker(): + try: + if not req.messages: + raise ValueError("messages is required") + + client = OpenAI(base_url=req.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY")) + + # Build up conversation over 6 turns (3 user messages + 3 assistant responses) + # Convert Message objects to dicts for OpenAI API + conversation_history = [{"role": m.role, "content": m.content} for m in req.messages] + + follow_up_questions = [ + "Tell me more about that.", + "What else can you share about this topic?", + ] + + # First completion (turns 1-2: initial user message + assistant response) + logger.info(f"Turn 1-2: Sending initial completion request to model {req.model}") + completion = client.chat.completions.create( + model=req.model, + messages=conversation_history, # type: ignore + ) + assistant_message = completion.choices[0].message + assistant_content = assistant_message.content or "" + conversation_history.append({"role": "assistant", "content": assistant_content}) + logger.info(f"Turn 2 response: {assistant_content[:100]}...") + + # Second completion (turns 3-4: follow-up user message + assistant response) + conversation_history.append({"role": "user", "content": follow_up_questions[0]}) + logger.info(f"Turn 3: User asks: {follow_up_questions[0]}") + completion = client.chat.completions.create( + model=req.model, + messages=conversation_history, # type: ignore + ) + assistant_message = completion.choices[0].message + assistant_content = assistant_message.content or "" + conversation_history.append({"role": "assistant", "content": assistant_content}) + logger.info(f"Turn 4 response: {assistant_content[:100]}...") + + # Third completion (turns 5-6: another follow-up user message + assistant response) + conversation_history.append({"role": "user", "content": follow_up_questions[1]}) + logger.info(f"Turn 5: User asks: {follow_up_questions[1]}") + completion = client.chat.completions.create( + model=req.model, + messages=conversation_history, # type: ignore + ) + assistant_message = completion.choices[0].message + assistant_content = assistant_message.content or "" + conversation_history.append({"role": "assistant", "content": assistant_content}) + logger.info(f"Turn 6 response: {assistant_content[:100]}...") + + logger.info(f"Completed 6-turn conversation with {len(conversation_history)} messages total") + + except Exception as e: + # Best-effort; mark as done even on error to unblock polling + print(f"❌ Error in rollout {req.metadata.rollout_id}: {e}") + pass + finally: + logger.info( + f"Rollout {req.metadata.rollout_id} completed", + extra={"status": Status.rollout_finished()}, + ) + + t = threading.Thread(target=_worker, daemon=True) + t.start() + + +def main(): + host = os.getenv("REMOTE_SERVER_HOST", "127.0.0.1") + port = int(os.getenv("REMOTE_SERVER_PORT", "3000")) + uvicorn.run(app, host=host, port=port) + + +if __name__ == "__main__": + main() diff --git a/tests/remote_server/test_remote_fireworks.py b/tests/remote_server/test_remote_fireworks.py index 7dc65cf7..ae6bf985 100644 --- a/tests/remote_server/test_remote_fireworks.py +++ b/tests/remote_server/test_remote_fireworks.py @@ -43,7 +43,7 @@ def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]: base_url = config.model_base_url or "https://tracing.fireworks.ai" adapter = create_fireworks_tracing_adapter(base_url=base_url) - return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5) + return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], proxy_max_retries=5) def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: @@ -65,7 +65,7 @@ def rows() -> List[EvaluationRow]: ), rollout_processor=RemoteRolloutProcessor( remote_base_url="http://127.0.0.1:3000", - timeout_seconds=30, + timeout_seconds=180, output_data_loader=fireworks_output_data_loader, ), ) From c8efd62c8d46f4c51c8e6e445cce1892e6866de1 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Sun, 5 Oct 2025 02:47:09 -0700 Subject: [PATCH 7/9] Fireworks Tracing --- eval_protocol/adapters/__init__.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/eval_protocol/adapters/__init__.py b/eval_protocol/adapters/__init__.py index 0b9a91a7..09d51175 100644 --- a/eval_protocol/adapters/__init__.py +++ b/eval_protocol/adapters/__init__.py @@ -29,6 +29,13 @@ __all__.extend(["FireworksTracingAdapter"]) +try: + from .fireworks_tracing import FireworksTracingAdapter, create_fireworks_tracing_adapter + + __all__.extend(["FireworksTracingAdapter", "create_fireworks_tracing_adapter"]) +except ImportError: + pass + try: from .huggingface import ( HuggingFaceAdapter, From 62f3c5b62a13381c72bf654bbab5e09eb3a46c91 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Mon, 6 Oct 2025 10:39:28 -0700 Subject: [PATCH 8/9] various changes --- eval_protocol/adapters/__init__.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/eval_protocol/adapters/__init__.py b/eval_protocol/adapters/__init__.py index 09d51175..0b9a91a7 100644 --- a/eval_protocol/adapters/__init__.py +++ b/eval_protocol/adapters/__init__.py @@ -29,13 +29,6 @@ __all__.extend(["FireworksTracingAdapter"]) -try: - from .fireworks_tracing import FireworksTracingAdapter, create_fireworks_tracing_adapter - - __all__.extend(["FireworksTracingAdapter", "create_fireworks_tracing_adapter"]) -except ImportError: - pass - try: from .huggingface import ( HuggingFaceAdapter, From 645c2358414199215027e193a21ce9fbbc131690 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Tue, 7 Oct 2025 01:38:22 -0700 Subject: [PATCH 9/9] validated using remote_server_multi_turn.py --- eval_protocol/adapters/fireworks_tracing.py | 64 ++++++++--- .../remote_server/remote_server_multi_turn.py | 104 ++++++++++++++++++ tests/remote_server/test_remote_fireworks.py | 2 +- 3 files changed, 156 insertions(+), 14 deletions(-) create mode 100644 tests/remote_server/remote_server_multi_turn.py diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index 5ce8a436..b43df2b5 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -7,6 +7,7 @@ from __future__ import annotations import logging import requests +import time from datetime import datetime from typing import Any, Dict, List, Optional, Protocol @@ -280,8 +281,9 @@ def get_evaluation_rows( from_timestamp: Optional[datetime] = None, to_timestamp: Optional[datetime] = None, include_tool_calls: bool = True, - sleep_between_gets: float = 2.5, - max_retries: int = 3, + backend_sleep_between_gets: float = 0.1, + backend_max_retries: int = 3, + proxy_max_retries: int = 3, span_name: Optional[str] = None, converter: Optional[TraceDictConverter] = None, ) -> List[EvaluationRow]: @@ -303,8 +305,9 @@ def get_evaluation_rows( from_timestamp: Explicit start time (ISO format) to_timestamp: Explicit end time (ISO format) include_tool_calls: Whether to include tool calling traces - sleep_between_gets: Sleep time between trace.get() calls (handled by proxy) - max_retries: Maximum retries for rate limit errors (handled by proxy) + backend_sleep_between_gets: Sleep time between backend trace fetches (passed to proxy) + backend_max_retries: Maximum retries for backend operations (passed to proxy) + proxy_max_retries: Maximum retries when proxy returns 404 (client-side retries with exponential backoff) span_name: If provided, extract messages from generations within this named span converter: Optional custom converter implementing TraceDictConverter protocol. If provided, this will be used instead of the default conversion logic. @@ -336,25 +339,60 @@ def get_evaluation_rows( "hours_back": hours_back, "from_timestamp": from_timestamp.isoformat() if from_timestamp else None, "to_timestamp": to_timestamp.isoformat() if to_timestamp else None, - "sleep_between_gets": sleep_between_gets, - "max_retries": max_retries, + "sleep_between_gets": backend_sleep_between_gets, + "max_retries": backend_max_retries, } # Remove None values params = {k: v for k, v in params.items() if v is not None} - # Make request to proxy + # Make request to proxy with retry logic if self.project_id: url = f"{self.base_url}/v1/project_id/{self.project_id}/traces" else: url = f"{self.base_url}/v1/traces" - try: - response = requests.get(url, params=params, timeout=self.timeout) - response.raise_for_status() - result = response.json() - except requests.exceptions.RequestException as e: - logger.error("Failed to fetch traces from proxy: %s", e) + # Retry loop for handling backend indexing delays (proxy returns 404) + result = None + for attempt in range(proxy_max_retries): + try: + response = requests.get(url, params=params, timeout=self.timeout) + response.raise_for_status() + result = response.json() + break # Success, exit retry loop + except requests.exceptions.HTTPError as e: + error_msg = str(e) + should_retry = False + + # Try to extract detail message from response + if e.response is not None: + try: + error_detail = e.response.json().get("detail", "") + error_msg = error_detail or e.response.text + + # Retry on 404 if it's due to incomplete/missing traces (backend still indexing) + if e.response.status_code == 404 and ( + "Incomplete traces" in error_detail or "No traces found" in error_detail + ): + should_retry = True + except Exception: + error_msg = e.response.text + + if should_retry and attempt < proxy_max_retries - 1: + sleep_time = 2 ** (attempt + 1) + logger.warning(error_msg) + time.sleep(sleep_time) + else: + # Final retry or non-retryable error + logger.error("Failed to fetch traces from proxy: %s", error_msg) + return eval_rows + except requests.exceptions.RequestException as e: + # Non-HTTP errors (network issues, timeouts, etc.) + logger.error("Failed to fetch traces from proxy: %s", str(e)) + return eval_rows + + if result is None: + logger.error("Failed to fetch traces after %d retries", proxy_max_retries) return eval_rows # Extract traces from response diff --git a/tests/remote_server/remote_server_multi_turn.py b/tests/remote_server/remote_server_multi_turn.py new file mode 100644 index 00000000..155a0a2a --- /dev/null +++ b/tests/remote_server/remote_server_multi_turn.py @@ -0,0 +1,104 @@ +import os +import random +import threading + +import uvicorn +from fastapi import FastAPI +from openai import OpenAI +import logging + +from eval_protocol import Status, InitRequest, ElasticsearchDirectHttpHandler, RolloutIdFilter + + +app = FastAPI() + +# attach handler to root logger +handler = ElasticsearchDirectHttpHandler() +logging.getLogger().addHandler(handler) + + +@app.post("/init") +def init(req: InitRequest): + if req.elastic_search_config: + handler.configure(req.elastic_search_config) + + # attach rollout_id filter to logger + logger = logging.getLogger(f"{__name__}.{req.metadata.rollout_id}") + logger.addFilter(RolloutIdFilter(req.metadata.rollout_id)) + + # Kick off worker thread that does a multi-turn chat (6 turns total) + def _worker(): + try: + if not req.messages: + raise ValueError("messages is required") + + client = OpenAI(base_url=req.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY")) + + # Build up conversation over 6 turns (3 user messages + 3 assistant responses) + # Convert Message objects to dicts for OpenAI API + conversation_history = [{"role": m.role, "content": m.content} for m in req.messages] + + follow_up_questions = [ + "Tell me more about that.", + "What else can you share about this topic?", + ] + + # First completion (turns 1-2: initial user message + assistant response) + logger.info(f"Turn 1-2: Sending initial completion request to model {req.model}") + completion = client.chat.completions.create( + model=req.model, + messages=conversation_history, # type: ignore + ) + assistant_message = completion.choices[0].message + assistant_content = assistant_message.content or "" + conversation_history.append({"role": "assistant", "content": assistant_content}) + logger.info(f"Turn 2 response: {assistant_content[:100]}...") + + # Second completion (turns 3-4: follow-up user message + assistant response) + conversation_history.append({"role": "user", "content": follow_up_questions[0]}) + logger.info(f"Turn 3: User asks: {follow_up_questions[0]}") + completion = client.chat.completions.create( + model=req.model, + messages=conversation_history, # type: ignore + ) + assistant_message = completion.choices[0].message + assistant_content = assistant_message.content or "" + conversation_history.append({"role": "assistant", "content": assistant_content}) + logger.info(f"Turn 4 response: {assistant_content[:100]}...") + + # Third completion (turns 5-6: another follow-up user message + assistant response) + conversation_history.append({"role": "user", "content": follow_up_questions[1]}) + logger.info(f"Turn 5: User asks: {follow_up_questions[1]}") + completion = client.chat.completions.create( + model=req.model, + messages=conversation_history, # type: ignore + ) + assistant_message = completion.choices[0].message + assistant_content = assistant_message.content or "" + conversation_history.append({"role": "assistant", "content": assistant_content}) + logger.info(f"Turn 6 response: {assistant_content[:100]}...") + + logger.info(f"Completed 6-turn conversation with {len(conversation_history)} messages total") + + except Exception as e: + # Best-effort; mark as done even on error to unblock polling + print(f"❌ Error in rollout {req.metadata.rollout_id}: {e}") + pass + finally: + logger.info( + f"Rollout {req.metadata.rollout_id} completed", + extra={"status": Status.rollout_finished()}, + ) + + t = threading.Thread(target=_worker, daemon=True) + t.start() + + +def main(): + host = os.getenv("REMOTE_SERVER_HOST", "127.0.0.1") + port = int(os.getenv("REMOTE_SERVER_PORT", "3000")) + uvicorn.run(app, host=host, port=port) + + +if __name__ == "__main__": + main() diff --git a/tests/remote_server/test_remote_fireworks.py b/tests/remote_server/test_remote_fireworks.py index f647fe61..3050b1f5 100644 --- a/tests/remote_server/test_remote_fireworks.py +++ b/tests/remote_server/test_remote_fireworks.py @@ -65,7 +65,7 @@ def rows() -> List[EvaluationRow]: ), rollout_processor=RemoteRolloutProcessor( remote_base_url="http://127.0.0.1:3000", - timeout_seconds=30, + timeout_seconds=180, output_data_loader=fireworks_output_data_loader, ), )