diff --git a/eval_protocol/adapters/__init__.py b/eval_protocol/adapters/__init__.py index d338b6c2..0b9a91a7 100644 --- a/eval_protocol/adapters/__init__.py +++ b/eval_protocol/adapters/__init__.py @@ -6,6 +6,7 @@ Available adapters: - BaseAdapter: Abstract base class for all adapters - LangfuseAdapter: Pull data from Langfuse deployments +- FireworksTracingAdapter: Pull data from Langfuse via Fireworks tracing proxy - HuggingFaceAdapter: Load datasets from HuggingFace Hub - BigQueryAdapter: Query data from Google BigQuery - TRL integration (legacy) @@ -24,6 +25,10 @@ except ImportError: pass +from .fireworks_tracing import FireworksTracingAdapter + +__all__.extend(["FireworksTracingAdapter"]) + try: from .huggingface import ( HuggingFaceAdapter, diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py new file mode 100644 index 00000000..5ce8a436 --- /dev/null +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -0,0 +1,377 @@ +"""Fireworks Tracing adapter for Eval Protocol. + +This adapter uses the Fireworks tracing proxy at tracing.fireworks.ai +to pull data from Langfuse deployments with simplified retry logic handling. +""" + +from __future__ import annotations +import logging +import requests +from datetime import datetime +from typing import Any, Dict, List, Optional, Protocol + +from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message +from .base import BaseAdapter +from .utils import extract_messages_from_data + +logger = logging.getLogger(__name__) + + +class TraceDictConverter(Protocol): + """Protocol for custom trace dictionary-to-EvaluationRow converter functions. + + A converter function should take a trace dictionary along with processing + options and return an EvaluationRow or None to skip the trace. + """ + + def __call__( + self, + trace: Dict[str, Any], + include_tool_calls: bool, + span_name: Optional[str], + ) -> Optional[EvaluationRow]: + """Convert a trace dictionary to an EvaluationRow. + + Args: + trace: The trace dictionary to convert + include_tool_calls: Whether to include tool calling information + span_name: Optional span name to extract messages from + + Returns: + EvaluationRow or None if the trace should be skipped + """ + ... + + +def convert_trace_dict_to_evaluation_row( + trace: Dict[str, Any], include_tool_calls: bool = True, span_name: Optional[str] = None +) -> Optional[EvaluationRow]: + """Convert a trace dictionary (from proxy API) to EvaluationRow format. + + Args: + trace: Trace dictionary from Fireworks proxy API + include_tool_calls: Whether to include tool calling information + span_name: If provided, extract messages from generations within this named span + + Returns: + EvaluationRow or None if conversion fails + """ + try: + # Extract messages from trace input and output + messages = extract_messages_from_trace_dict(trace, include_tool_calls, span_name) + + # Extract tools if available + tools = None + if include_tool_calls and isinstance(trace.get("input"), dict) and "tools" in trace["input"]: + tools = trace["input"]["tools"] + + if not messages: + return None + + execution_metadata = ExecutionMetadata() + row_id = None + + # Extract metadata from tags + tags = trace.get("tags", []) + if tags: + for tag in tags: + if tag.startswith("invocation_id:"): + execution_metadata.invocation_id = tag.split(":", 1)[1] + elif tag.startswith("experiment_id:"): + execution_metadata.experiment_id = tag.split(":", 1)[1] + elif tag.startswith("rollout_id:"): + execution_metadata.rollout_id = tag.split(":", 1)[1] + elif tag.startswith("run_id:"): + execution_metadata.run_id = tag.split(":", 1)[1] + elif tag.startswith("row_id:"): + row_id = tag.split(":", 1)[1] + + if ( + execution_metadata.invocation_id + and execution_metadata.experiment_id + and execution_metadata.rollout_id + and execution_metadata.run_id + and row_id + ): + break # Break early if we've found all the metadata we need + + return EvaluationRow( + messages=messages, + tools=tools, + input_metadata=InputMetadata( + row_id=row_id, + session_data={ + "langfuse_trace_id": trace.get("id"), # Store the trace ID here + }, + ), + execution_metadata=execution_metadata, + ) + + except (AttributeError, ValueError, KeyError) as e: + logger.error("Error converting trace %s: %s", trace.get("id"), e) + return None + + +def extract_messages_from_trace_dict( + trace: Dict[str, Any], include_tool_calls: bool = True, span_name: Optional[str] = None +) -> List[Message]: + """Extract messages from trace dictionary. + + Args: + trace: Trace dictionary from proxy API + include_tool_calls: Whether to include tool calling information + span_name: If provided, extract messages from generations within this named span + + Returns: + List of Message objects + """ + messages = [] + + if span_name: # Look for a generation tied to a span name + try: + # Find the final generation in the named span + gen = get_final_generation_in_span_dict(trace, span_name) + if not gen: + return messages + + # Extract messages from generation input and output + if gen.get("input"): + messages.extend(extract_messages_from_data(gen["input"], include_tool_calls)) + if gen.get("output"): + messages.extend(extract_messages_from_data(gen["output"], include_tool_calls)) + + return messages + + except Exception as e: + logger.error("Failed to extract messages from span '%s' in trace %s: %s", span_name, trace.get("id"), e) + return messages + + else: + try: + # Extract messages from trace input and output + if trace.get("input"): + messages.extend(extract_messages_from_data(trace["input"], include_tool_calls)) + if trace.get("output"): + messages.extend(extract_messages_from_data(trace["output"], include_tool_calls)) + except (AttributeError, ValueError, KeyError) as e: + logger.warning("Error processing trace %s: %s", trace.get("id"), e) + + # Fallback: use the last GENERATION observation which typically contains full chat history + if not messages: + try: + all_observations = trace.get("observations", []) + gens = [obs for obs in all_observations if obs.get("type") == "GENERATION"] + if gens: + gens.sort(key=lambda x: x.get("start_time", "")) + last_gen = gens[-1] + if last_gen.get("input"): + messages.extend(extract_messages_from_data(last_gen["input"], include_tool_calls)) + if last_gen.get("output"): + messages.extend(extract_messages_from_data(last_gen["output"], include_tool_calls)) + except Exception as e: + logger.warning("Failed to extract from last generation for trace %s: %s", trace.get("id"), e) + + return messages + + +def get_final_generation_in_span_dict(trace: Dict[str, Any], span_name: str) -> Optional[Dict[str, Any]]: + """Get the final generation within a named span from trace dictionary. + + Args: + trace: Trace dictionary + span_name: Name of the span to search for + + Returns: + The final generation dictionary, or None if not found + """ + # Get all observations from the trace + all_observations = trace.get("observations", []) + + # Find a span with the given name that has generation children + parent_span = None + for obs in all_observations: + if obs.get("name") == span_name and obs.get("type") == "SPAN": + # Check if this span has generation children + has_generations = any( + child.get("type") == "GENERATION" and child.get("parent_observation_id") == obs.get("id") + for child in all_observations + ) + if has_generations: + parent_span = obs + break + + if not parent_span: + logger.warning("No span named '%s' found in trace %s", span_name, trace.get("id")) + return None + + # Find all generations within this span + generations = [] + for obs in all_observations: + if obs.get("type") == "GENERATION" and obs.get("parent_observation_id") == parent_span.get("id"): + generations.append(obs) + + if not generations: + logger.warning("No generations found in span '%s' in trace %s", span_name, trace.get("id")) + return None + + # Sort generations by start time for chronological order + generations.sort(key=lambda x: x.get("start_time", "")) + + # Return the final generation (contains full message history) + return generations[-1] + + +class FireworksTracingAdapter(BaseAdapter): + """Adapter to pull data from Langfuse via Fireworks tracing proxy. + + This adapter uses the Fireworks tracing proxy API which handles retry logic + and rate limiting internally, simplifying the client-side implementation. + + Examples: + Basic usage (default project): + >>> adapter = FireworksTracingAdapter() + >>> rows = list(adapter.get_evaluation_rows(tags=["rollout_id:xyz"], limit=10)) + + With explicit project ID: + >>> adapter = FireworksTracingAdapter( + ... project_id="your_project_id", + ... base_url="https://tracing.fireworks.ai" + ... ) + >>> rows = list(adapter.get_evaluation_rows(tags=["production"], limit=10)) + + Filter by specific criteria: + >>> rows = list(adapter.get_evaluation_rows( + ... tags=["production"], + ... limit=50, + ... hours_back=24 + ... )) + """ + + def __init__( + self, + project_id: Optional[str] = None, + base_url: str = "https://tracing.fireworks.ai", + timeout: int = 300, + ): + """Initialize the Fireworks Tracing adapter. + + Args: + project_id: Optional project ID. If not provided, uses the default project configured on the server. + base_url: The base URL of the tracing proxy (default: https://tracing.fireworks.ai) + timeout: Request timeout in seconds (default: 300) + """ + self.project_id = project_id + self.base_url = base_url.rstrip("/") + self.timeout = timeout + + def get_evaluation_rows( + self, + tags: List[str], + limit: int = 100, + sample_size: Optional[int] = None, + user_id: Optional[str] = None, + session_id: Optional[str] = None, + name: Optional[str] = None, + environment: Optional[str] = None, + version: Optional[str] = None, + release: Optional[str] = None, + fields: Optional[str] = None, + hours_back: Optional[int] = None, + from_timestamp: Optional[datetime] = None, + to_timestamp: Optional[datetime] = None, + include_tool_calls: bool = True, + sleep_between_gets: float = 2.5, + max_retries: int = 3, + span_name: Optional[str] = None, + converter: Optional[TraceDictConverter] = None, + ) -> List[EvaluationRow]: + """Pull traces from Langfuse via proxy and convert to EvaluationRow format. + + Args: + tags: REQUIRED - Filter by specific tags (prevents fetching all traces). + Must provide at least one tag (e.g., ['rollout_id:xyz'], ['production']) + limit: Max number of trace summaries to collect via pagination + sample_size: Optional number of traces to randomly sample (if None, process all) + user_id: Filter by user ID + session_id: Filter by session ID + name: Filter by trace name + environment: Filter by environment (e.g., production, staging, development) + version: Filter by trace version + release: Filter by trace release + fields: Comma-separated list of fields to include + hours_back: Filter traces from this many hours ago + from_timestamp: Explicit start time (ISO format) + to_timestamp: Explicit end time (ISO format) + include_tool_calls: Whether to include tool calling traces + sleep_between_gets: Sleep time between trace.get() calls (handled by proxy) + max_retries: Maximum retries for rate limit errors (handled by proxy) + span_name: If provided, extract messages from generations within this named span + converter: Optional custom converter implementing TraceDictConverter protocol. + If provided, this will be used instead of the default conversion logic. + + Returns: + List[EvaluationRow]: Converted evaluation rows + + Raises: + ValueError: If tags list is empty + """ + # Validate that tags are provided (security requirement) + if not tags or len(tags) == 0: + raise ValueError("At least one tag is required to fetch traces (security: prevents fetching all traces)") + + eval_rows = [] + + # Build query parameters for GET request + params = { + "limit": limit, + "sample_size": sample_size, + "tags": tags, + "user_id": user_id, + "session_id": session_id, + "name": name, + "environment": environment, + "version": version, + "release": release, + "fields": fields, + "hours_back": hours_back, + "from_timestamp": from_timestamp.isoformat() if from_timestamp else None, + "to_timestamp": to_timestamp.isoformat() if to_timestamp else None, + "sleep_between_gets": sleep_between_gets, + "max_retries": max_retries, + } + + # Remove None values + params = {k: v for k, v in params.items() if v is not None} + + # Make request to proxy + if self.project_id: + url = f"{self.base_url}/v1/project_id/{self.project_id}/traces" + else: + url = f"{self.base_url}/v1/traces" + + try: + response = requests.get(url, params=params, timeout=self.timeout) + response.raise_for_status() + result = response.json() + except requests.exceptions.RequestException as e: + logger.error("Failed to fetch traces from proxy: %s", e) + return eval_rows + + # Extract traces from response + traces = result.get("traces", []) + + # Convert each trace to EvaluationRow + for trace in traces: + try: + if converter: + eval_row = converter(trace, include_tool_calls, span_name) + else: + eval_row = convert_trace_dict_to_evaluation_row(trace, include_tool_calls, span_name) + if eval_row: + eval_rows.append(eval_row) + except (AttributeError, ValueError, KeyError) as e: + logger.warning("Failed to convert trace %s: %s", trace.get("id"), e) + continue + + logger.info("Successfully converted %d traces to evaluation rows", len(eval_rows)) + return eval_rows diff --git a/eval_protocol/adapters/langfuse.py b/eval_protocol/adapters/langfuse.py index 42eeee6f..0174922f 100644 --- a/eval_protocol/adapters/langfuse.py +++ b/eval_protocol/adapters/langfuse.py @@ -355,7 +355,7 @@ def get_evaluation_rows( # If no results, possible due to indexing delay--remote rollout processor just finished pushing rows to Langfuse if traces and traces.meta and traces.meta.total_items == 0 and page == 1: - raise Exception("Empty results - indexing delay") + raise Exception("Empty results") break except Exception as e: diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index 89adb59e..1d4b6553 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -7,7 +7,14 @@ from eval_protocol.log_utils.elasticsearch_client import ElasticsearchClient from eval_protocol.models import EvaluationRow, Status from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader -from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig, InitRequest, RolloutMetadata +from eval_protocol.types.remote_rollout_processor import ( + DataLoaderConfig, + ElasticsearchConfig, + InitRequest, + RolloutMetadata, +) +from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter +from eval_protocol.quickstart.utils import filter_longest_conversation from .rollout_processor import RolloutProcessor from .types import RolloutProcessorConfig from .elasticsearch_setup import ElasticsearchSetup @@ -18,10 +25,51 @@ logger = logging.getLogger(__name__) +def _build_fireworks_tracing_url(base_url: str, metadata: RolloutMetadata) -> str: + """Build a Fireworks tracing URL by appending rollout metadata to the base URL path, + allowing the Fireworks tracing proxy to automatically tag traces. + + Format: {base_url}/rollout_id/{id}/invocation_id/{id}/experiment_id/{id}/run_id/{id}/row_id/{id} + + Args: + base_url: Fireworks tracing proxy URL (we expect this to be https://tracing.fireworks.ai or + https://tracing.fireworks.ai/project_id/{project_id}) + metadata: Rollout metadata containing IDs to embed in the URL + """ + return ( + f"{base_url}/rollout_id/{metadata.rollout_id}" + f"/invocation_id/{metadata.invocation_id}" + f"/experiment_id/{metadata.experiment_id}" + f"/run_id/{metadata.run_id}" + f"/row_id/{metadata.row_id}" + ) + + +def _default_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: + """Default output data loader that fetches traces from Fireworks tracing proxy. + + Args: + config: Configuration containing rollout_id and optional model_base_url + + Returns: + DynamicDataLoader configured to fetch and process traces + """ + + def fetch_traces() -> List[EvaluationRow]: + base_url = config.model_base_url or "https://tracing.fireworks.ai" + adapter = FireworksTracingAdapter(base_url=base_url) + return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5) + + return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=filter_longest_conversation) + + class RemoteRolloutProcessor(RolloutProcessor): """ Rollout processor that triggers a remote HTTP server to perform the rollout. + By default, fetches traces from the Fireworks tracing proxy using rollout_id tags. + You can provide a custom output_data_loader for different tracing backends. + See https://evalprotocol.io/tutorial/remote-rollout-processor for documentation. """ @@ -29,10 +77,10 @@ def __init__( self, *, remote_base_url: Optional[str] = None, - model_base_url: Optional[str] = None, + model_base_url: str = "https://tracing.fireworks.ai", poll_interval: float = 1.0, timeout_seconds: float = 120.0, - output_data_loader: Callable[[str], DynamicDataLoader], + output_data_loader: Optional[Callable[[DataLoaderConfig], DynamicDataLoader]] = None, disable_elastic_search: bool = False, elastic_search_config: Optional[ElasticsearchConfig] = None, ): @@ -42,12 +90,12 @@ def __init__( self._model_base_url = model_base_url if os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL"): self._remote_base_url = os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL") - self._model_base_url = model_base_url - if os.getenv("EP_MODEL_BASE_URL"): - self._model_base_url = os.getenv("EP_MODEL_BASE_URL") + _ep_model_base_url = os.getenv("EP_MODEL_BASE_URL") + if _ep_model_base_url: + self._model_base_url = _ep_model_base_url self._poll_interval = poll_interval self._timeout_seconds = timeout_seconds - self._output_data_loader = output_data_loader + self._output_data_loader = output_data_loader or _default_output_data_loader self._disable_elastic_search = disable_elastic_search self._elastic_search_config = elastic_search_config @@ -69,7 +117,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> # Start with constructor values remote_base_url: Optional[str] = self._remote_base_url - model_base_url: Optional[str] = self._model_base_url + model_base_url: str = self._model_base_url poll_interval: float = self._poll_interval timeout_seconds: float = self._timeout_seconds @@ -140,14 +188,8 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow: raise ValueError("Rollout ID is required in RemoteRolloutProcessor") final_model_base_url = model_base_url - if model_base_url and model_base_url.startswith("https://tracing.fireworks.ai/project_id/"): - final_model_base_url = ( - f"{model_base_url}/rollout_id/{meta.rollout_id}" - f"/invocation_id/{meta.invocation_id}" - f"/experiment_id/{meta.experiment_id}" - f"/run_id/{meta.run_id}" - f"/row_id/{meta.row_id}" - ) + if model_base_url and model_base_url.startswith("https://tracing.fireworks.ai"): + final_model_base_url = _build_fireworks_tracing_url(model_base_url, meta) init_payload: InitRequest = InitRequest( model=model, @@ -245,7 +287,10 @@ def _get_status() -> Dict[str, Any]: if row.execution_metadata.rollout_id is None: raise ValueError("Rollout ID is required in RemoteRolloutProcessor") - data_loader = self._output_data_loader(row.execution_metadata.rollout_id) + loader_config = DataLoaderConfig( + rollout_id=row.execution_metadata.rollout_id, model_base_url=model_base_url + ) + data_loader = self._output_data_loader(loader_config) def _load_data(): return data_loader.load() @@ -254,24 +299,24 @@ def _load_data(): output_rows: List[EvaluationRow] = [row for result in results for row in result.rows] - if len(output_rows) == 0: # Fallback to original row if no Langfuse data found - row.rollout_status = Status(code=Status.Code.NOT_FOUND, message="No Langfuse data found for rollout") + if len(output_rows) == 0: # Fallback to original row if no Remote data found + row.rollout_status = Status(code=Status.Code.NOT_FOUND, message="No remote data found for rollout") return row - elif len(output_rows) == 1: # Return the Langfuse row - langfuse_row = output_rows[0] + elif len(output_rows) == 1: # Return the remote row + remote_row = output_rows[0] - # if the langfuse_row has the same number of messages as the original row, + # if the remote_row has the same number of messages as the original row, # something went wrong - if len(langfuse_row.messages) == len(row.messages): + if len(remote_row.messages) == len(row.messages): row.rollout_status = Status.rollout_error( "Rollout finished with the same number of messages as the original row" ) return row - row.messages = langfuse_row.messages - row.tools = langfuse_row.tools - row.input_metadata.session_data = langfuse_row.input_metadata.session_data - row.execution_metadata = langfuse_row.execution_metadata + row.messages = remote_row.messages + row.tools = remote_row.tools + row.input_metadata.session_data = remote_row.input_metadata.session_data + row.execution_metadata = remote_row.execution_metadata return row else: raise ValueError("RemoteRolloutProcessor's output_data_loader should return exactly one row.") diff --git a/eval_protocol/types/remote_rollout_processor.py b/eval_protocol/types/remote_rollout_processor.py index 67c3158a..a972d2b5 100644 --- a/eval_protocol/types/remote_rollout_processor.py +++ b/eval_protocol/types/remote_rollout_processor.py @@ -34,6 +34,13 @@ class RolloutMetadata(BaseModel): row_id: str +class DataLoaderConfig(BaseModel): + """Configuration passed to output_data_loader functions.""" + + rollout_id: str + model_base_url: Optional[str] = None + + class InitRequest(BaseModel): """Request model for POST /init endpoint.""" diff --git a/tests/remote_server/quickstart.py b/tests/remote_server/quickstart.py new file mode 100644 index 00000000..1d204dd2 --- /dev/null +++ b/tests/remote_server/quickstart.py @@ -0,0 +1,53 @@ +# MANUAL SERVER STARTUP REQUIRED: +# +# For Python server testing, start: +# python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000) +# +# For TypeScript server testing, start: +# cd tests/remote_server/typescript-server +# npm install +# npm start +# +# The TypeScript server should be running on http://127.0.0.1:3000 +# You only need to start one of the servers! + +import os +from typing import List + +import pytest + +from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest import evaluation_test +from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor + + +def rows() -> List[EvaluationRow]: + row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")]) + return [row, row, row] + + +@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") +@pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}]) +@evaluation_test( + data_loaders=DynamicDataLoader( + generators=[rows], + ), + rollout_processor=RemoteRolloutProcessor( + remote_base_url="http://127.0.0.1:3000", + timeout_seconds=30, + ), +) +async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> EvaluationRow: + """ + End-to-end test: + - REQUIRES MANUAL SERVER STARTUP: python -m tests.remote_server.remote_server + - trigger remote rollout via RemoteRolloutProcessor (calls init/status) + - fetch traces from Langfuse via Fireworks tracing proxy (uses default FireworksTracingAdapter) + - FAIL if no traces found or rollout_id missing + """ + assert row.messages[0].content == "What is the capital of France?", "Row should have correct message content" + assert len(row.messages) > 1, "Row should have a response. If this fails, we fellback to the original row." + assert row.execution_metadata.rollout_id, "Row should have a rollout_id from the remote rollout" + + return row diff --git a/tests/remote_server/test_remote_fireworks.py b/tests/remote_server/test_remote_fireworks.py new file mode 100644 index 00000000..f647fe61 --- /dev/null +++ b/tests/remote_server/test_remote_fireworks.py @@ -0,0 +1,86 @@ +# MANUAL SERVER STARTUP REQUIRED: +# +# For Python server testing, start: +# python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000) +# +# For TypeScript server testing, start: +# cd tests/remote_server/typescript-server +# npm install +# npm start +# +# The TypeScript server should be running on http://127.0.0.1:3000 +# You only need to start one of the servers! + +import os +from typing import List + +import pytest + +from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest import evaluation_test +from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor +from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter +from eval_protocol.quickstart.utils import filter_longest_conversation +from eval_protocol.types.remote_rollout_processor import DataLoaderConfig + +ROLLOUT_IDS = set() + + +@pytest.fixture(autouse=True) +def check_rollout_coverage(): + """Ensure we processed all expected rollout_ids""" + global ROLLOUT_IDS + ROLLOUT_IDS.clear() + yield + + assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}" + + +def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]: + global ROLLOUT_IDS # Track all rollout_ids we've seen + ROLLOUT_IDS.add(config.rollout_id) + + base_url = config.model_base_url or "https://tracing.fireworks.ai" + adapter = FireworksTracingAdapter(base_url=base_url) + return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5) + + +def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: + return DynamicDataLoader( + generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation + ) + + +def rows() -> List[EvaluationRow]: + row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")]) + return [row, row, row] + + +@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") +@pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}]) +@evaluation_test( + data_loaders=DynamicDataLoader( + generators=[rows], + ), + rollout_processor=RemoteRolloutProcessor( + remote_base_url="http://127.0.0.1:3000", + timeout_seconds=30, + output_data_loader=fireworks_output_data_loader, + ), +) +async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> EvaluationRow: + """ + End-to-end test: + - REQUIRES MANUAL SERVER STARTUP: python -m tests.remote_server.remote_server + - trigger remote rollout via RemoteRolloutProcessor (calls init/status) + - fetch traces from Langfuse via Fireworks tracing proxy filtered by metadata via output_data_loader; FAIL if none found + """ + assert row.messages[0].content == "What is the capital of France?", "Row should have correct message content" + assert len(row.messages) > 1, "Row should have a response. If this fails, we fellback to the original row." + + assert row.execution_metadata.rollout_id in ROLLOUT_IDS, ( + f"Row rollout_id {row.execution_metadata.rollout_id} should be in tracked rollout_ids: {ROLLOUT_IDS}" + ) + + return row diff --git a/tests/remote_server/test_remote_langfuse.py b/tests/remote_server/test_remote_langfuse.py index bf92d19a..35828570 100644 --- a/tests/remote_server/test_remote_langfuse.py +++ b/tests/remote_server/test_remote_langfuse.py @@ -1,10 +1,10 @@ # MANUAL SERVER STARTUP REQUIRED: # # For Python server testing, start: -# python -m tests.remote_server.remote_server (runs on http://127.0.0.1:7077) +# python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000) # # For TypeScript server testing, start: -# cd /Users/derekxu/Documents/code/python-sdk/tests/remote_server/typescript-server +# cd tests/remote_server/typescript-server # npm install # npm start # @@ -22,6 +22,7 @@ from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor from eval_protocol.adapters.langfuse import create_langfuse_adapter from eval_protocol.quickstart.utils import filter_longest_conversation +from eval_protocol.types.remote_rollout_processor import DataLoaderConfig ROLLOUT_IDS = set() @@ -36,17 +37,17 @@ def check_rollout_coverage(): assert len(ROLLOUT_IDS) == 3, f"Expected to see {ROLLOUT_IDS} rollout_ids, but only saw {ROLLOUT_IDS}" -def fetch_langfuse_traces(rollout_id: str) -> List[EvaluationRow]: +def fetch_langfuse_traces(config: DataLoaderConfig) -> List[EvaluationRow]: global ROLLOUT_IDS # Track all rollout_ids we've seen - ROLLOUT_IDS.add(rollout_id) + ROLLOUT_IDS.add(config.rollout_id) adapter = create_langfuse_adapter() - return adapter.get_evaluation_rows(tags=[f"rollout_id:{rollout_id}"], max_retries=5) + return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5) -def langfuse_output_data_loader(rollout_id: str) -> DynamicDataLoader: +def langfuse_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: return DynamicDataLoader( - generators=[lambda: fetch_langfuse_traces(rollout_id)], preprocess_fn=filter_longest_conversation + generators=[lambda: fetch_langfuse_traces(config)], preprocess_fn=filter_longest_conversation )