From 341a0333403d99c15ee7ec7246461d9efa829cc6 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Sun, 5 Oct 2025 02:47:09 -0700 Subject: [PATCH 01/25] Fireworks Tracing --- eval_protocol/adapters/__init__.py | 8 + eval_protocol/adapters/fireworks_tracing.py | 394 ++++++++++++++++++ eval_protocol/adapters/langfuse.py | 2 +- .../pytest/remote_rollout_processor.py | 27 +- tests/remote_server/test_remote_fireworks.py | 84 ++++ tests/remote_server/test_remote_langfuse.py | 2 +- 6 files changed, 501 insertions(+), 16 deletions(-) create mode 100644 eval_protocol/adapters/fireworks_tracing.py create mode 100644 tests/remote_server/test_remote_fireworks.py diff --git a/eval_protocol/adapters/__init__.py b/eval_protocol/adapters/__init__.py index d338b6c2..39b312d8 100644 --- a/eval_protocol/adapters/__init__.py +++ b/eval_protocol/adapters/__init__.py @@ -6,6 +6,7 @@ Available adapters: - BaseAdapter: Abstract base class for all adapters - LangfuseAdapter: Pull data from Langfuse deployments +- FireworksTracingAdapter: Pull data from Langfuse via Fireworks tracing proxy - HuggingFaceAdapter: Load datasets from HuggingFace Hub - BigQueryAdapter: Query data from Google BigQuery - TRL integration (legacy) @@ -24,6 +25,13 @@ except ImportError: pass +try: + from .fireworks_tracing import FireworksTracingAdapter, create_fireworks_tracing_adapter + + __all__.extend(["FireworksTracingAdapter", "create_fireworks_tracing_adapter"]) +except ImportError: + pass + try: from .huggingface import ( HuggingFaceAdapter, diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py new file mode 100644 index 00000000..f3155023 --- /dev/null +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -0,0 +1,394 @@ +"""Fireworks Tracing adapter for Eval Protocol. + +This adapter uses the Fireworks tracing proxy at tracing.fireworks.ai +to pull data from Langfuse deployments with simplified retry logic handling. +""" + +from __future__ import annotations +import logging +import requests +from datetime import datetime +from typing import Any, Dict, List, Optional, Protocol + +from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message +from .base import BaseAdapter +from .utils import extract_messages_from_data + +logger = logging.getLogger(__name__) + + +class TraceDictConverter(Protocol): + """Protocol for custom trace dictionary-to-EvaluationRow converter functions. + + A converter function should take a trace dictionary along with processing + options and return an EvaluationRow or None to skip the trace. + """ + + def __call__( + self, + trace: Dict[str, Any], + include_tool_calls: bool, + span_name: Optional[str], + ) -> Optional[EvaluationRow]: + """Convert a trace dictionary to an EvaluationRow. + + Args: + trace: The trace dictionary to convert + include_tool_calls: Whether to include tool calling information + span_name: Optional span name to extract messages from + + Returns: + EvaluationRow or None if the trace should be skipped + """ + ... + + +def convert_trace_dict_to_evaluation_row( + trace: Dict[str, Any], include_tool_calls: bool = True, span_name: Optional[str] = None +) -> Optional[EvaluationRow]: + """Convert a trace dictionary (from proxy API) to EvaluationRow format. + + Args: + trace: Trace dictionary from Fireworks proxy API + include_tool_calls: Whether to include tool calling information + span_name: If provided, extract messages from generations within this named span + + Returns: + EvaluationRow or None if conversion fails + """ + try: + # Extract messages from trace input and output + messages = extract_messages_from_trace_dict(trace, include_tool_calls, span_name) + + # Extract tools if available + tools = None + if include_tool_calls and isinstance(trace.get("input"), dict) and "tools" in trace["input"]: + tools = trace["input"]["tools"] + + if not messages: + return None + + execution_metadata = ExecutionMetadata() + row_id = None + + # Extract metadata from tags + tags = trace.get("tags", []) + if tags: + for tag in tags: + if tag.startswith("invocation_id:"): + execution_metadata.invocation_id = tag.split(":", 1)[1] + elif tag.startswith("experiment_id:"): + execution_metadata.experiment_id = tag.split(":", 1)[1] + elif tag.startswith("rollout_id:"): + execution_metadata.rollout_id = tag.split(":", 1)[1] + elif tag.startswith("run_id:"): + execution_metadata.run_id = tag.split(":", 1)[1] + elif tag.startswith("row_id:"): + row_id = tag.split(":", 1)[1] + + if ( + execution_metadata.invocation_id + and execution_metadata.experiment_id + and execution_metadata.rollout_id + and execution_metadata.run_id + and row_id + ): + break # Break early if we've found all the metadata we need + + return EvaluationRow( + messages=messages, + tools=tools, + input_metadata=InputMetadata( + row_id=row_id, + session_data={ + "langfuse_trace_id": trace.get("id"), # Store the trace ID here + }, + ), + execution_metadata=execution_metadata, + ) + + except (AttributeError, ValueError, KeyError) as e: + logger.error("Error converting trace %s: %s", trace.get("id"), e) + return None + + +def extract_messages_from_trace_dict( + trace: Dict[str, Any], include_tool_calls: bool = True, span_name: Optional[str] = None +) -> List[Message]: + """Extract messages from trace dictionary. + + Args: + trace: Trace dictionary from proxy API + include_tool_calls: Whether to include tool calling information + span_name: If provided, extract messages from generations within this named span + + Returns: + List of Message objects + """ + messages = [] + + if span_name: # Look for a generation tied to a span name + try: + # Find the final generation in the named span + gen = get_final_generation_in_span_dict(trace, span_name) + if not gen: + return messages + + # Extract messages from generation input and output + if gen.get("input"): + messages.extend(extract_messages_from_data(gen["input"], include_tool_calls)) + if gen.get("output"): + messages.extend(extract_messages_from_data(gen["output"], include_tool_calls)) + + return messages + + except Exception as e: + logger.error("Failed to extract messages from span '%s' in trace %s: %s", span_name, trace.get("id"), e) + return messages + + else: + try: + # Extract messages from trace input and output + if trace.get("input"): + messages.extend(extract_messages_from_data(trace["input"], include_tool_calls)) + if trace.get("output"): + messages.extend(extract_messages_from_data(trace["output"], include_tool_calls)) + except (AttributeError, ValueError, KeyError) as e: + logger.warning("Error processing trace %s: %s", trace.get("id"), e) + + # Fallback: use the last GENERATION observation which typically contains full chat history + if not messages: + try: + all_observations = trace.get("observations", []) + gens = [obs for obs in all_observations if obs.get("type") == "GENERATION"] + if gens: + gens.sort(key=lambda x: x.get("start_time", "")) + last_gen = gens[-1] + if last_gen.get("input"): + messages.extend(extract_messages_from_data(last_gen["input"], include_tool_calls)) + if last_gen.get("output"): + messages.extend(extract_messages_from_data(last_gen["output"], include_tool_calls)) + except Exception as e: + logger.warning("Failed to extract from last generation for trace %s: %s", trace.get("id"), e) + + return messages + + +def get_final_generation_in_span_dict(trace: Dict[str, Any], span_name: str) -> Optional[Dict[str, Any]]: + """Get the final generation within a named span from trace dictionary. + + Args: + trace: Trace dictionary + span_name: Name of the span to search for + + Returns: + The final generation dictionary, or None if not found + """ + # Get all observations from the trace + all_observations = trace.get("observations", []) + + # Find a span with the given name that has generation children + parent_span = None + for obs in all_observations: + if obs.get("name") == span_name and obs.get("type") == "SPAN": + # Check if this span has generation children + has_generations = any( + child.get("type") == "GENERATION" and child.get("parent_observation_id") == obs.get("id") + for child in all_observations + ) + if has_generations: + parent_span = obs + break + + if not parent_span: + logger.warning("No span named '%s' found in trace %s", span_name, trace.get("id")) + return None + + # Find all generations within this span + generations = [] + for obs in all_observations: + if obs.get("type") == "GENERATION" and obs.get("parent_observation_id") == parent_span.get("id"): + generations.append(obs) + + if not generations: + logger.warning("No generations found in span '%s' in trace %s", span_name, trace.get("id")) + return None + + # Sort generations by start time for chronological order + generations.sort(key=lambda x: x.get("start_time", "")) + + # Return the final generation (contains full message history) + return generations[-1] + + +class FireworksTracingAdapter(BaseAdapter): + """Adapter to pull data from Langfuse via Fireworks tracing proxy. + + This adapter uses the Fireworks tracing proxy API which handles retry logic + and rate limiting internally, simplifying the client-side implementation. + + Examples: + Basic usage (default project): + >>> adapter = FireworksTracingAdapter() + >>> rows = list(adapter.get_evaluation_rows(tags=["rollout_id:xyz"], limit=10)) + + With explicit project ID: + >>> adapter = FireworksTracingAdapter( + ... project_id="your_project_id", + ... base_url="https://tracing.fireworks.ai" + ... ) + >>> rows = list(adapter.get_evaluation_rows(tags=["production"], limit=10)) + + Filter by specific criteria: + >>> rows = list(adapter.get_evaluation_rows( + ... tags=["production"], + ... limit=50, + ... hours_back=24 + ... )) + """ + + def __init__( + self, + project_id: Optional[str] = None, + base_url: str = "https://tracing.fireworks.ai", + timeout: int = 300, + ): + """Initialize the Fireworks Tracing adapter. + + Args: + project_id: Optional project ID. If not provided, uses the default project configured on the server. + base_url: The base URL of the tracing proxy (default: https://tracing.fireworks.ai) + timeout: Request timeout in seconds (default: 300) + """ + self.project_id = project_id + self.base_url = base_url.rstrip("/") + self.timeout = timeout + + def get_evaluation_rows( + self, + tags: List[str], + limit: int = 100, + sample_size: Optional[int] = None, + user_id: Optional[str] = None, + session_id: Optional[str] = None, + name: Optional[str] = None, + environment: Optional[str] = None, + version: Optional[str] = None, + release: Optional[str] = None, + fields: Optional[str] = None, + hours_back: Optional[int] = None, + from_timestamp: Optional[datetime] = None, + to_timestamp: Optional[datetime] = None, + include_tool_calls: bool = True, + sleep_between_gets: float = 2.5, + max_retries: int = 3, + span_name: Optional[str] = None, + converter: Optional[TraceDictConverter] = None, + ) -> List[EvaluationRow]: + """Pull traces from Langfuse via proxy and convert to EvaluationRow format. + + Args: + tags: REQUIRED - Filter by specific tags (prevents fetching all traces). + Must provide at least one tag (e.g., ['rollout_id:xyz'], ['production']) + limit: Max number of trace summaries to collect via pagination + sample_size: Optional number of traces to randomly sample (if None, process all) + user_id: Filter by user ID + session_id: Filter by session ID + name: Filter by trace name + environment: Filter by environment (e.g., production, staging, development) + version: Filter by trace version + release: Filter by trace release + fields: Comma-separated list of fields to include + hours_back: Filter traces from this many hours ago + from_timestamp: Explicit start time (ISO format) + to_timestamp: Explicit end time (ISO format) + include_tool_calls: Whether to include tool calling traces + sleep_between_gets: Sleep time between trace.get() calls (handled by proxy) + max_retries: Maximum retries for rate limit errors (handled by proxy) + span_name: If provided, extract messages from generations within this named span + converter: Optional custom converter implementing TraceDictConverter protocol. + If provided, this will be used instead of the default conversion logic. + + Returns: + List[EvaluationRow]: Converted evaluation rows + + Raises: + ValueError: If tags list is empty + """ + # Validate that tags are provided (security requirement) + if not tags or len(tags) == 0: + raise ValueError("At least one tag is required to fetch traces (security: prevents fetching all traces)") + + eval_rows = [] + + # Build request payload + payload = { + "limit": limit, + "sample_size": sample_size, + "tags": tags, + "user_id": user_id, + "session_id": session_id, + "name": name, + "environment": environment, + "version": version, + "release": release, + "fields": fields, + "hours_back": hours_back, + "from_timestamp": from_timestamp.isoformat() if from_timestamp else None, + "to_timestamp": to_timestamp.isoformat() if to_timestamp else None, + "include_tool_calls": include_tool_calls, + "sleep_between_gets": sleep_between_gets, + "max_retries": max_retries, + "span_name": span_name, + } + + # Remove None values + payload = {k: v for k, v in payload.items() if v is not None} + + # Make request to proxy + if self.project_id: + url = f"{self.base_url}/v1/project_id/{self.project_id}/langfuse/traces" + else: + url = f"{self.base_url}/v1/langfuse/traces" + + try: + response = requests.post(url, json=payload, timeout=self.timeout) + response.raise_for_status() + result = response.json() + except requests.exceptions.RequestException as e: + logger.error("Failed to fetch traces from proxy: %s", e) + return eval_rows + + # Extract traces from response + traces = result.get("traces", []) + + # Convert each trace to EvaluationRow + for trace in traces: + try: + if converter: + eval_row = converter(trace, include_tool_calls, span_name) + else: + eval_row = convert_trace_dict_to_evaluation_row(trace, include_tool_calls, span_name) + if eval_row: + eval_rows.append(eval_row) + except (AttributeError, ValueError, KeyError) as e: + logger.warning("Failed to convert trace %s: %s", trace.get("id"), e) + continue + + logger.info("Successfully converted %d traces to evaluation rows", len(eval_rows)) + return eval_rows + + +def create_fireworks_tracing_adapter( + project_id: Optional[str] = None, base_url: str = "https://tracing.fireworks.ai" +) -> FireworksTracingAdapter: + """Factory function to create a Fireworks Tracing adapter. + + Args: + project_id: Optional project ID. If not provided, uses the default project configured on the server. + base_url: The base URL of the tracing proxy + + Returns: + FireworksTracingAdapter instance + """ + return FireworksTracingAdapter(project_id=project_id, base_url=base_url) diff --git a/eval_protocol/adapters/langfuse.py b/eval_protocol/adapters/langfuse.py index 42eeee6f..0174922f 100644 --- a/eval_protocol/adapters/langfuse.py +++ b/eval_protocol/adapters/langfuse.py @@ -355,7 +355,7 @@ def get_evaluation_rows( # If no results, possible due to indexing delay--remote rollout processor just finished pushing rows to Langfuse if traces and traces.meta and traces.meta.total_items == 0 and page == 1: - raise Exception("Empty results - indexing delay") + raise Exception("Empty results") break except Exception as e: diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index 1c96affa..c88ccc56 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -29,7 +29,7 @@ def __init__( self, *, remote_base_url: Optional[str] = None, - model_base_url: Optional[str] = None, + model_base_url: str = "https://tracing.fireworks.ai", poll_interval: float = 1.0, timeout_seconds: float = 120.0, output_data_loader: Callable[[str], DynamicDataLoader], @@ -42,7 +42,6 @@ def __init__( self._model_base_url = model_base_url if os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL"): self._remote_base_url = os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL") - self._model_base_url = model_base_url self._poll_interval = poll_interval self._timeout_seconds = timeout_seconds self._output_data_loader = output_data_loader @@ -67,7 +66,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> # Start with constructor values remote_base_url: Optional[str] = self._remote_base_url - model_base_url: Optional[str] = self._model_base_url + model_base_url: str = self._model_base_url poll_interval: float = self._poll_interval timeout_seconds: float = self._timeout_seconds @@ -138,7 +137,7 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow: raise ValueError("Rollout ID is required in RemoteRolloutProcessor") final_model_base_url = model_base_url - if model_base_url and model_base_url.startswith("https://tracing.fireworks.ai/project_id/"): + if model_base_url and model_base_url.startswith("https://tracing.fireworks.ai"): final_model_base_url = ( f"{model_base_url}/rollout_id/{meta.rollout_id}" f"/invocation_id/{meta.invocation_id}" @@ -252,24 +251,24 @@ def _load_data(): output_rows: List[EvaluationRow] = [row for result in results for row in result.rows] - if len(output_rows) == 0: # Fallback to original row if no Langfuse data found - row.rollout_status = Status(code=Status.Code.NOT_FOUND, message="No Langfuse data found for rollout") + if len(output_rows) == 0: # Fallback to original row if no Remote data found + row.rollout_status = Status(code=Status.Code.NOT_FOUND, message="No remote data found for rollout") return row - elif len(output_rows) == 1: # Return the Langfuse row - langfuse_row = output_rows[0] + elif len(output_rows) == 1: # Return the remote row + remote_row = output_rows[0] - # if the langfuse_row has the same number of messages as the original row, + # if the remote_row has the same number of messages as the original row, # something went wrong - if len(langfuse_row.messages) == len(row.messages): + if len(remote_row.messages) == len(row.messages): row.rollout_status = Status.rollout_error( "Rollout finished with the same number of messages as the original row" ) return row - row.messages = langfuse_row.messages - row.tools = langfuse_row.tools - row.input_metadata.session_data = langfuse_row.input_metadata.session_data - row.execution_metadata = langfuse_row.execution_metadata + row.messages = remote_row.messages + row.tools = remote_row.tools + row.input_metadata.session_data = remote_row.input_metadata.session_data + row.execution_metadata = remote_row.execution_metadata return row else: raise ValueError("RemoteRolloutProcessor's output_data_loader should return exactly one row.") diff --git a/tests/remote_server/test_remote_fireworks.py b/tests/remote_server/test_remote_fireworks.py new file mode 100644 index 00000000..9471c956 --- /dev/null +++ b/tests/remote_server/test_remote_fireworks.py @@ -0,0 +1,84 @@ +# MANUAL SERVER STARTUP REQUIRED: +# +# For Python server testing, start: +# python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000) +# +# For TypeScript server testing, start: +# cd /Users/derekxu/Documents/code/python-sdk/tests/remote_server/typescript-server +# npm install +# npm start +# +# The TypeScript server should be running on http://127.0.0.1:3000 +# You only need to start one of the servers! + +import os +from typing import List + +import pytest + +from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest import evaluation_test +from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor +from eval_protocol.adapters.fireworks_tracing import create_fireworks_tracing_adapter +from eval_protocol.quickstart.utils import filter_longest_conversation + +ROLLOUT_IDS = set() + + +@pytest.fixture(autouse=True) +def check_rollout_coverage(): + """Ensure we processed all expected rollout_ids""" + global ROLLOUT_IDS + ROLLOUT_IDS.clear() + yield + + assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}" + + +def fetch_fireworks_traces(rollout_id: str) -> List[EvaluationRow]: + global ROLLOUT_IDS # Track all rollout_ids we've seen + ROLLOUT_IDS.add(rollout_id) + + adapter = create_fireworks_tracing_adapter() + return adapter.get_evaluation_rows(tags=[f"rollout_id:{rollout_id}"], max_retries=5) + + +def fireworks_output_data_loader(rollout_id: str) -> DynamicDataLoader: + return DynamicDataLoader( + generators=[lambda: fetch_fireworks_traces(rollout_id)], preprocess_fn=filter_longest_conversation + ) + + +def rows() -> List[EvaluationRow]: + row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")]) + return [row, row, row] + + +@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") +@pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}]) +@evaluation_test( + data_loaders=DynamicDataLoader( + generators=[rows], + ), + rollout_processor=RemoteRolloutProcessor( + remote_base_url="http://127.0.0.1:3000", + timeout_seconds=30, + output_data_loader=fireworks_output_data_loader, + ), +) +async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> EvaluationRow: + """ + End-to-end test: + - REQUIRES MANUAL SERVER STARTUP: python -m tests.remote_server.remote_server + - trigger remote rollout via RemoteRolloutProcessor (calls init/status) + - fetch traces from Langfuse via Fireworks tracing proxy filtered by metadata via output_data_loader; FAIL if none found + """ + assert row.messages[0].content == "What is the capital of France?", "Row should have correct message content" + assert len(row.messages) > 1, "Row should have a response. If this fails, we fellback to the original row." + + assert row.execution_metadata.rollout_id in ROLLOUT_IDS, ( + f"Row rollout_id {row.execution_metadata.rollout_id} should be in tracked rollout_ids: {ROLLOUT_IDS}" + ) + + return row diff --git a/tests/remote_server/test_remote_langfuse.py b/tests/remote_server/test_remote_langfuse.py index bf92d19a..1eb4f3e5 100644 --- a/tests/remote_server/test_remote_langfuse.py +++ b/tests/remote_server/test_remote_langfuse.py @@ -1,7 +1,7 @@ # MANUAL SERVER STARTUP REQUIRED: # # For Python server testing, start: -# python -m tests.remote_server.remote_server (runs on http://127.0.0.1:7077) +# python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000) # # For TypeScript server testing, start: # cd /Users/derekxu/Documents/code/python-sdk/tests/remote_server/typescript-server From fd204bb19f93dde1fb0d9f2588b089b2934833f2 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Sun, 5 Oct 2025 02:48:48 -0700 Subject: [PATCH 02/25] update path --- tests/remote_server/test_remote_fireworks.py | 2 +- tests/remote_server/test_remote_langfuse.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/remote_server/test_remote_fireworks.py b/tests/remote_server/test_remote_fireworks.py index 9471c956..21fef2e0 100644 --- a/tests/remote_server/test_remote_fireworks.py +++ b/tests/remote_server/test_remote_fireworks.py @@ -4,7 +4,7 @@ # python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000) # # For TypeScript server testing, start: -# cd /Users/derekxu/Documents/code/python-sdk/tests/remote_server/typescript-server +# cd tests/remote_server/typescript-server # npm install # npm start # diff --git a/tests/remote_server/test_remote_langfuse.py b/tests/remote_server/test_remote_langfuse.py index 1eb4f3e5..753518f0 100644 --- a/tests/remote_server/test_remote_langfuse.py +++ b/tests/remote_server/test_remote_langfuse.py @@ -4,7 +4,7 @@ # python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000) # # For TypeScript server testing, start: -# cd /Users/derekxu/Documents/code/python-sdk/tests/remote_server/typescript-server +# cd tests/remote_server/typescript-server # npm install # npm start # From 77f9906257655524cf8a8cf5ca52eaf96488daee Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Mon, 6 Oct 2025 09:51:53 -0700 Subject: [PATCH 03/25] add status handling from ECS --- eval_protocol/pytest/remote_rollout_processor.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index 89adb59e..592e726f 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -220,12 +220,19 @@ def _get_status() -> Dict[str, Any]: hits = search_results["hits"]["hits"] if search_results else [] if hits: - # log all statuses found + # log all statuses found and update rollout status from the last hit for hit in hits: document = hit["_source"] logger.info( f"Found log for rollout {row.execution_metadata.rollout_id} with status code {document['status_code']}" ) + # Update rollout status from the document + if "status_code" in document: + row.rollout_status = Status( + code=Status.Code(document["status_code"]), + message=document.get("status_message", ""), + details=document.get("status_details", []), + ) logger.info("Stopping status polling for rollout %s", row.execution_metadata.rollout_id) break From 91f23782469000007d6feb8b013bdd09dd8b9e1f Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Mon, 6 Oct 2025 10:39:28 -0700 Subject: [PATCH 04/25] various changes --- eval_protocol/adapters/__init__.py | 7 +-- eval_protocol/adapters/fireworks_tracing.py | 4 +- .../pytest/remote_rollout_processor.py | 28 ++++++++-- tests/remote_server/quickstart.py | 53 +++++++++++++++++++ tests/remote_server/test_remote_fireworks.py | 8 +-- 5 files changed, 86 insertions(+), 14 deletions(-) create mode 100644 tests/remote_server/quickstart.py diff --git a/eval_protocol/adapters/__init__.py b/eval_protocol/adapters/__init__.py index 39b312d8..d664b425 100644 --- a/eval_protocol/adapters/__init__.py +++ b/eval_protocol/adapters/__init__.py @@ -25,12 +25,9 @@ except ImportError: pass -try: - from .fireworks_tracing import FireworksTracingAdapter, create_fireworks_tracing_adapter +from .fireworks_tracing import FireworksTracingAdapter, create_fireworks_tracing_adapter - __all__.extend(["FireworksTracingAdapter", "create_fireworks_tracing_adapter"]) -except ImportError: - pass +__all__.extend(["FireworksTracingAdapter", "create_fireworks_tracing_adapter"]) try: from .huggingface import ( diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index f3155023..108e85a5 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -347,9 +347,9 @@ def get_evaluation_rows( # Make request to proxy if self.project_id: - url = f"{self.base_url}/v1/project_id/{self.project_id}/langfuse/traces" + url = f"{self.base_url}/v1/project_id/{self.project_id}/traces" else: - url = f"{self.base_url}/v1/langfuse/traces" + url = f"{self.base_url}/v1/traces" try: response = requests.post(url, json=payload, timeout=self.timeout) diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index c88ccc56..d879cd17 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -8,6 +8,8 @@ from eval_protocol.models import EvaluationRow, Status from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig, InitRequest, RolloutMetadata +from eval_protocol.adapters.fireworks_tracing import create_fireworks_tracing_adapter +from eval_protocol.quickstart.utils import filter_longest_conversation from .rollout_processor import RolloutProcessor from .types import RolloutProcessorConfig from .elasticsearch_setup import ElasticsearchSetup @@ -18,10 +20,30 @@ logger = logging.getLogger(__name__) +def _default_output_data_loader(rollout_id: str, base_url: str) -> DynamicDataLoader: + """Default output data loader that fetches traces from Fireworks tracing proxy. + + Args: + rollout_id: The rollout ID to filter traces by + + Returns: + DynamicDataLoader configured to fetch and process traces + """ + + def fetch_traces() -> List[EvaluationRow]: + adapter = create_fireworks_tracing_adapter(base_url=base_url) + return adapter.get_evaluation_rows(tags=[f"rollout_id:{rollout_id}"], max_retries=5) + + return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=filter_longest_conversation) + + class RemoteRolloutProcessor(RolloutProcessor): """ Rollout processor that triggers a remote HTTP server to perform the rollout. + By default, fetches traces from the Fireworks tracing proxy using rollout_id tags. + You can provide a custom output_data_loader for different tracing backends. + See https://evalprotocol.io/tutorial/remote-rollout-processor for documentation. """ @@ -32,7 +54,7 @@ def __init__( model_base_url: str = "https://tracing.fireworks.ai", poll_interval: float = 1.0, timeout_seconds: float = 120.0, - output_data_loader: Callable[[str], DynamicDataLoader], + output_data_loader: Optional[Callable[[str, str], DynamicDataLoader]] = None, disable_elastic_search: bool = False, elastic_search_config: Optional[ElasticsearchConfig] = None, ): @@ -44,7 +66,7 @@ def __init__( self._remote_base_url = os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL") self._poll_interval = poll_interval self._timeout_seconds = timeout_seconds - self._output_data_loader = output_data_loader + self._output_data_loader = output_data_loader or _default_output_data_loader self._disable_elastic_search = disable_elastic_search self._elastic_search_config = elastic_search_config @@ -242,7 +264,7 @@ def _get_status() -> Dict[str, Any]: if row.execution_metadata.rollout_id is None: raise ValueError("Rollout ID is required in RemoteRolloutProcessor") - data_loader = self._output_data_loader(row.execution_metadata.rollout_id) + data_loader = self._output_data_loader(row.execution_metadata.rollout_id, model_base_url) def _load_data(): return data_loader.load() diff --git a/tests/remote_server/quickstart.py b/tests/remote_server/quickstart.py new file mode 100644 index 00000000..1d204dd2 --- /dev/null +++ b/tests/remote_server/quickstart.py @@ -0,0 +1,53 @@ +# MANUAL SERVER STARTUP REQUIRED: +# +# For Python server testing, start: +# python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000) +# +# For TypeScript server testing, start: +# cd tests/remote_server/typescript-server +# npm install +# npm start +# +# The TypeScript server should be running on http://127.0.0.1:3000 +# You only need to start one of the servers! + +import os +from typing import List + +import pytest + +from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest import evaluation_test +from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor + + +def rows() -> List[EvaluationRow]: + row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")]) + return [row, row, row] + + +@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)") +@pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}]) +@evaluation_test( + data_loaders=DynamicDataLoader( + generators=[rows], + ), + rollout_processor=RemoteRolloutProcessor( + remote_base_url="http://127.0.0.1:3000", + timeout_seconds=30, + ), +) +async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> EvaluationRow: + """ + End-to-end test: + - REQUIRES MANUAL SERVER STARTUP: python -m tests.remote_server.remote_server + - trigger remote rollout via RemoteRolloutProcessor (calls init/status) + - fetch traces from Langfuse via Fireworks tracing proxy (uses default FireworksTracingAdapter) + - FAIL if no traces found or rollout_id missing + """ + assert row.messages[0].content == "What is the capital of France?", "Row should have correct message content" + assert len(row.messages) > 1, "Row should have a response. If this fails, we fellback to the original row." + assert row.execution_metadata.rollout_id, "Row should have a rollout_id from the remote rollout" + + return row diff --git a/tests/remote_server/test_remote_fireworks.py b/tests/remote_server/test_remote_fireworks.py index 21fef2e0..a9505481 100644 --- a/tests/remote_server/test_remote_fireworks.py +++ b/tests/remote_server/test_remote_fireworks.py @@ -36,17 +36,17 @@ def check_rollout_coverage(): assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}" -def fetch_fireworks_traces(rollout_id: str) -> List[EvaluationRow]: +def fetch_fireworks_traces(rollout_id: str, base_url: str) -> List[EvaluationRow]: global ROLLOUT_IDS # Track all rollout_ids we've seen ROLLOUT_IDS.add(rollout_id) - adapter = create_fireworks_tracing_adapter() + adapter = create_fireworks_tracing_adapter(base_url=base_url) return adapter.get_evaluation_rows(tags=[f"rollout_id:{rollout_id}"], max_retries=5) -def fireworks_output_data_loader(rollout_id: str) -> DynamicDataLoader: +def fireworks_output_data_loader(rollout_id: str, base_url: str) -> DynamicDataLoader: return DynamicDataLoader( - generators=[lambda: fetch_fireworks_traces(rollout_id)], preprocess_fn=filter_longest_conversation + generators=[lambda: fetch_fireworks_traces(rollout_id, base_url)], preprocess_fn=filter_longest_conversation ) From d0b35ed31b6b392e9774a0ee9411ff7caf021060 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Mon, 6 Oct 2025 10:40:04 -0700 Subject: [PATCH 05/25] Refactor remote server startup to use argparse for host and port configuration; add tests for fireworks status propagation. --- tests/remote_server/remote_server.py | 21 +++- .../test_remote_fireworks_propagate_status.py | 97 +++++++++++++++++++ 2 files changed, 115 insertions(+), 3 deletions(-) create mode 100644 tests/remote_server/test_remote_fireworks_propagate_status.py diff --git a/tests/remote_server/remote_server.py b/tests/remote_server/remote_server.py index 8f07a474..32546a7d 100644 --- a/tests/remote_server/remote_server.py +++ b/tests/remote_server/remote_server.py @@ -1,6 +1,7 @@ import os import random import threading +import argparse import uvicorn from fastapi import FastAPI @@ -61,9 +62,23 @@ def _worker(): def main(): - host = os.getenv("REMOTE_SERVER_HOST", "127.0.0.1") - port = int(os.getenv("REMOTE_SERVER_PORT", "3000")) - uvicorn.run(app, host=host, port=port) + parser = argparse.ArgumentParser(description="Run the remote server for evaluation protocol") + parser.add_argument( + "--host", + type=str, + default=os.getenv("REMOTE_SERVER_HOST", "127.0.0.1"), + help="Host to bind the server to (default: 127.0.0.1 or REMOTE_SERVER_HOST env var)", + ) + parser.add_argument( + "--port", + type=int, + default=int(os.getenv("REMOTE_SERVER_PORT", "3000")), + help="Port to bind the server to (default: 3000 or REMOTE_SERVER_PORT env var)", + ) + + args = parser.parse_args() + + uvicorn.run(app, host=args.host, port=args.port) if __name__ == "__main__": diff --git a/tests/remote_server/test_remote_fireworks_propagate_status.py b/tests/remote_server/test_remote_fireworks_propagate_status.py new file mode 100644 index 00000000..842ed0a7 --- /dev/null +++ b/tests/remote_server/test_remote_fireworks_propagate_status.py @@ -0,0 +1,97 @@ +# MANUAL SERVER STARTUP REQUIRED: +# +# For Python server testing, start: +# python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000) +# +# For TypeScript server testing, start: +# cd tests/remote_server/typescript-server +# npm install +# npm start +# +# The TypeScript server should be running on http://127.0.0.1:3000 +# You only need to start one of the servers! + +import os +import random +import subprocess +import socket +import time +from typing import List + +import pytest +import requests + +from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest import evaluation_test +from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor +from eval_protocol.adapters.fireworks_tracing import create_fireworks_tracing_adapter +from eval_protocol.quickstart.utils import filter_longest_conversation + + +def find_available_port() -> int: + """Find an available port on localhost""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + port = s.getsockname()[1] + return port + + +SERVER_PORT = find_available_port() + + +def wait_for_server_to_startup(timeout: int = 10): + start_time = time.time() + while True: + try: + requests.get(f"http://127.0.0.1:{SERVER_PORT}") + break + except requests.exceptions.RequestException: + time.sleep(1) + if time.time() - start_time > timeout: + raise TimeoutError(f"Server did not start within {timeout} seconds") + + +@pytest.fixture(autouse=True) +def setup_remote_server(): + """Start the remote server""" + host = "127.0.0.1" + process = subprocess.Popen( + ["python", "-m", "tests.remote_server.remote_server", "--host", host, "--port", str(SERVER_PORT)] + ) + # wait for the server to startup by pollingK + wait_for_server_to_startup() + yield + process.terminate() + process.wait() + + +def fetch_fireworks_traces(rollout_id: str) -> List[EvaluationRow]: + adapter = create_fireworks_tracing_adapter() + return adapter.get_evaluation_rows(tags=[f"rollout_id:{rollout_id}"], max_retries=5) + + +def fireworks_output_data_loader(rollout_id: str) -> DynamicDataLoader: + return DynamicDataLoader( + generators=[lambda: fetch_fireworks_traces(rollout_id)], preprocess_fn=filter_longest_conversation + ) + + +def rows() -> List[EvaluationRow]: + row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")]) + return [row] + + +@pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}]) +@evaluation_test( + data_loaders=DynamicDataLoader( + generators=[rows], + ), + rollout_processor=RemoteRolloutProcessor( + remote_base_url=f"http://127.0.0.1:{SERVER_PORT}", + timeout_seconds=30, + output_data_loader=fireworks_output_data_loader, + ), +) +async def test_remote_rollout_and_fetch_fireworks_propagate_status(row: EvaluationRow) -> EvaluationRow: + return row From 0400e215976017712f1ce64383739a23d5ad90be Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Mon, 6 Oct 2025 12:04:34 -0700 Subject: [PATCH 06/25] fix test --- .../test_remote_fireworks_propagate_status.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/remote_server/test_remote_fireworks_propagate_status.py b/tests/remote_server/test_remote_fireworks_propagate_status.py index 842ed0a7..15dc7665 100644 --- a/tests/remote_server/test_remote_fireworks_propagate_status.py +++ b/tests/remote_server/test_remote_fireworks_propagate_status.py @@ -66,14 +66,14 @@ def setup_remote_server(): process.wait() -def fetch_fireworks_traces(rollout_id: str) -> List[EvaluationRow]: - adapter = create_fireworks_tracing_adapter() +def fetch_fireworks_traces(rollout_id: str, base_url: str) -> List[EvaluationRow]: + adapter = create_fireworks_tracing_adapter(base_url=base_url) return adapter.get_evaluation_rows(tags=[f"rollout_id:{rollout_id}"], max_retries=5) -def fireworks_output_data_loader(rollout_id: str) -> DynamicDataLoader: +def fireworks_output_data_loader(rollout_id: str, base_url: str) -> DynamicDataLoader: return DynamicDataLoader( - generators=[lambda: fetch_fireworks_traces(rollout_id)], preprocess_fn=filter_longest_conversation + generators=[lambda: fetch_fireworks_traces(rollout_id, base_url)], preprocess_fn=filter_longest_conversation ) From 260f7216234d2f9c51c087d80385e1616871dc75 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Mon, 6 Oct 2025 13:06:32 -0700 Subject: [PATCH 07/25] add dataloaderconfig --- .../pytest/remote_rollout_processor.py | 22 +++++++++++++------ .../types/remote_rollout_processor.py | 7 ++++++ tests/remote_server/test_remote_fireworks.py | 12 +++++----- tests/remote_server/test_remote_langfuse.py | 11 +++++----- 4 files changed, 35 insertions(+), 17 deletions(-) diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index 7a116143..ceb9e7d4 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -7,7 +7,12 @@ from eval_protocol.log_utils.elasticsearch_client import ElasticsearchClient from eval_protocol.models import EvaluationRow, Status from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader -from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig, InitRequest, RolloutMetadata +from eval_protocol.types.remote_rollout_processor import ( + DataLoaderConfig, + ElasticsearchConfig, + InitRequest, + RolloutMetadata, +) from eval_protocol.adapters.fireworks_tracing import create_fireworks_tracing_adapter from eval_protocol.quickstart.utils import filter_longest_conversation from .rollout_processor import RolloutProcessor @@ -20,19 +25,20 @@ logger = logging.getLogger(__name__) -def _default_output_data_loader(rollout_id: str, base_url: str) -> DynamicDataLoader: +def _default_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: """Default output data loader that fetches traces from Fireworks tracing proxy. Args: - rollout_id: The rollout ID to filter traces by + config: Configuration containing rollout_id and optional model_base_url Returns: DynamicDataLoader configured to fetch and process traces """ def fetch_traces() -> List[EvaluationRow]: + base_url = config.model_base_url or "https://tracing.fireworks.ai" adapter = create_fireworks_tracing_adapter(base_url=base_url) - return adapter.get_evaluation_rows(tags=[f"rollout_id:{rollout_id}"], max_retries=5) + return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5) return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=filter_longest_conversation) @@ -54,7 +60,7 @@ def __init__( model_base_url: str = "https://tracing.fireworks.ai", poll_interval: float = 1.0, timeout_seconds: float = 120.0, - output_data_loader: Optional[Callable[[str, str], DynamicDataLoader]] = None, + output_data_loader: Optional[Callable[[DataLoaderConfig], DynamicDataLoader]] = None, disable_elastic_search: bool = False, elastic_search_config: Optional[ElasticsearchConfig] = None, ): @@ -64,7 +70,6 @@ def __init__( self._model_base_url = model_base_url if os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL"): self._remote_base_url = os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL") - self._model_base_url = model_base_url _ep_model_base_url = os.getenv("EP_MODEL_BASE_URL") if _ep_model_base_url: self._model_base_url = _ep_model_base_url @@ -268,7 +273,10 @@ def _get_status() -> Dict[str, Any]: if row.execution_metadata.rollout_id is None: raise ValueError("Rollout ID is required in RemoteRolloutProcessor") - data_loader = self._output_data_loader(row.execution_metadata.rollout_id, model_base_url) + loader_config = DataLoaderConfig( + rollout_id=row.execution_metadata.rollout_id, model_base_url=model_base_url + ) + data_loader = self._output_data_loader(loader_config) def _load_data(): return data_loader.load() diff --git a/eval_protocol/types/remote_rollout_processor.py b/eval_protocol/types/remote_rollout_processor.py index 67c3158a..a972d2b5 100644 --- a/eval_protocol/types/remote_rollout_processor.py +++ b/eval_protocol/types/remote_rollout_processor.py @@ -34,6 +34,13 @@ class RolloutMetadata(BaseModel): row_id: str +class DataLoaderConfig(BaseModel): + """Configuration passed to output_data_loader functions.""" + + rollout_id: str + model_base_url: Optional[str] = None + + class InitRequest(BaseModel): """Request model for POST /init endpoint.""" diff --git a/tests/remote_server/test_remote_fireworks.py b/tests/remote_server/test_remote_fireworks.py index a9505481..7dc65cf7 100644 --- a/tests/remote_server/test_remote_fireworks.py +++ b/tests/remote_server/test_remote_fireworks.py @@ -22,6 +22,7 @@ from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor from eval_protocol.adapters.fireworks_tracing import create_fireworks_tracing_adapter from eval_protocol.quickstart.utils import filter_longest_conversation +from eval_protocol.types.remote_rollout_processor import DataLoaderConfig ROLLOUT_IDS = set() @@ -36,17 +37,18 @@ def check_rollout_coverage(): assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}" -def fetch_fireworks_traces(rollout_id: str, base_url: str) -> List[EvaluationRow]: +def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]: global ROLLOUT_IDS # Track all rollout_ids we've seen - ROLLOUT_IDS.add(rollout_id) + ROLLOUT_IDS.add(config.rollout_id) + base_url = config.model_base_url or "https://tracing.fireworks.ai" adapter = create_fireworks_tracing_adapter(base_url=base_url) - return adapter.get_evaluation_rows(tags=[f"rollout_id:{rollout_id}"], max_retries=5) + return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5) -def fireworks_output_data_loader(rollout_id: str, base_url: str) -> DynamicDataLoader: +def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: return DynamicDataLoader( - generators=[lambda: fetch_fireworks_traces(rollout_id, base_url)], preprocess_fn=filter_longest_conversation + generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation ) diff --git a/tests/remote_server/test_remote_langfuse.py b/tests/remote_server/test_remote_langfuse.py index 753518f0..35828570 100644 --- a/tests/remote_server/test_remote_langfuse.py +++ b/tests/remote_server/test_remote_langfuse.py @@ -22,6 +22,7 @@ from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor from eval_protocol.adapters.langfuse import create_langfuse_adapter from eval_protocol.quickstart.utils import filter_longest_conversation +from eval_protocol.types.remote_rollout_processor import DataLoaderConfig ROLLOUT_IDS = set() @@ -36,17 +37,17 @@ def check_rollout_coverage(): assert len(ROLLOUT_IDS) == 3, f"Expected to see {ROLLOUT_IDS} rollout_ids, but only saw {ROLLOUT_IDS}" -def fetch_langfuse_traces(rollout_id: str) -> List[EvaluationRow]: +def fetch_langfuse_traces(config: DataLoaderConfig) -> List[EvaluationRow]: global ROLLOUT_IDS # Track all rollout_ids we've seen - ROLLOUT_IDS.add(rollout_id) + ROLLOUT_IDS.add(config.rollout_id) adapter = create_langfuse_adapter() - return adapter.get_evaluation_rows(tags=[f"rollout_id:{rollout_id}"], max_retries=5) + return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5) -def langfuse_output_data_loader(rollout_id: str) -> DynamicDataLoader: +def langfuse_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: return DynamicDataLoader( - generators=[lambda: fetch_langfuse_traces(rollout_id)], preprocess_fn=filter_longest_conversation + generators=[lambda: fetch_langfuse_traces(config)], preprocess_fn=filter_longest_conversation ) From aff3200ea09e3811ea5010796f30aff11e55d701 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Mon, 6 Oct 2025 13:26:20 -0700 Subject: [PATCH 08/25] test_remote_rollout_and_fetch_fireworks_propagate_status --- tests/remote_server/remote_server.py | 29 ++++++++++++++++--- .../test_remote_fireworks_propagate_status.py | 21 +++++++++++--- 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/tests/remote_server/remote_server.py b/tests/remote_server/remote_server.py index 32546a7d..f13bc754 100644 --- a/tests/remote_server/remote_server.py +++ b/tests/remote_server/remote_server.py @@ -18,6 +18,9 @@ logging.getLogger().addHandler(handler) +force_early_error_message = None + + @app.post("/init") def init(req: InitRequest): if req.elastic_search_config: @@ -47,21 +50,32 @@ def _worker(): completion = client.chat.completions.create(**completion_kwargs) logger.info(f"Completed response: {completion}") + # If force_early_error is set via command-line arg, log the error and return early + if force_early_error_message: + logger.error( + force_early_error_message, + extra={"status": Status.rollout_error(force_early_error_message)}, + ) + raise RuntimeError(force_early_error_message) + except Exception as e: # Best-effort; mark as done even on error to unblock polling print(f"❌ Error in rollout {req.metadata.rollout_id}: {e}") pass finally: - logger.info( - f"Rollout {req.metadata.rollout_id} completed", - extra={"status": Status.rollout_finished()}, - ) + if not force_early_error_message: + logger.info( + f"Rollout {req.metadata.rollout_id} completed", + extra={"status": Status.rollout_finished()}, + ) t = threading.Thread(target=_worker, daemon=True) t.start() def main(): + global force_early_error_message + parser = argparse.ArgumentParser(description="Run the remote server for evaluation protocol") parser.add_argument( "--host", @@ -75,8 +89,15 @@ def main(): default=int(os.getenv("REMOTE_SERVER_PORT", "3000")), help="Port to bind the server to (default: 3000 or REMOTE_SERVER_PORT env var)", ) + parser.add_argument( + "--force-early-error", + type=str, + default=None, + help="If set, /init will immediately return after logging a rollout_error with this message", + ) args = parser.parse_args() + force_early_error_message = args.force_early_error uvicorn.run(app, host=args.host, port=args.port) diff --git a/tests/remote_server/test_remote_fireworks_propagate_status.py b/tests/remote_server/test_remote_fireworks_propagate_status.py index 15dc7665..478dcd1c 100644 --- a/tests/remote_server/test_remote_fireworks_propagate_status.py +++ b/tests/remote_server/test_remote_fireworks_propagate_status.py @@ -11,8 +11,6 @@ # The TypeScript server should be running on http://127.0.0.1:3000 # You only need to start one of the servers! -import os -import random import subprocess import socket import time @@ -22,7 +20,7 @@ import requests from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader -from eval_protocol.models import EvaluationRow, Message +from eval_protocol.models import EvaluationRow, Message, Status from eval_protocol.pytest import evaluation_test from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor from eval_protocol.adapters.fireworks_tracing import create_fireworks_tracing_adapter @@ -55,9 +53,22 @@ def wait_for_server_to_startup(timeout: int = 10): @pytest.fixture(autouse=True) def setup_remote_server(): """Start the remote server""" + # kill all Python processes matching "python -m tests.remote_server.remote_server" + subprocess.run(["pkill", "-f", "python -m tests.remote_server.remote_server"]) + host = "127.0.0.1" process = subprocess.Popen( - ["python", "-m", "tests.remote_server.remote_server", "--host", host, "--port", str(SERVER_PORT)] + [ + "python", + "-m", + "tests.remote_server.remote_server", + "--host", + host, + "--port", + str(SERVER_PORT), + "--force-early-error", + "test error", + ] ) # wait for the server to startup by pollingK wait_for_server_to_startup() @@ -94,4 +105,6 @@ def rows() -> List[EvaluationRow]: ), ) async def test_remote_rollout_and_fetch_fireworks_propagate_status(row: EvaluationRow) -> EvaluationRow: + assert row.rollout_status.code == Status.Code.INTERNAL + assert row.rollout_status.message == "test error" return row From db28b968b9305771883da7db655e565a553551ac Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Mon, 6 Oct 2025 13:32:24 -0700 Subject: [PATCH 09/25] sync on latest --- eval_protocol/__init__.py | 2 ++ .../test_remote_fireworks_propagate_status.py | 9 ++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/eval_protocol/__init__.py b/eval_protocol/__init__.py index 33c48e95..9f17f8ac 100644 --- a/eval_protocol/__init__.py +++ b/eval_protocol/__init__.py @@ -40,6 +40,7 @@ RolloutMetadata, StatusResponse, create_langfuse_config_tags, + DataLoaderConfig, ) try: @@ -67,6 +68,7 @@ __all__ = [ "ElasticsearchDirectHttpHandler", "RolloutIdFilter", + "DataLoaderConfig", "Status", "RemoteRolloutProcessor", "InputMetadata", diff --git a/tests/remote_server/test_remote_fireworks_propagate_status.py b/tests/remote_server/test_remote_fireworks_propagate_status.py index 478dcd1c..4ebc50f7 100644 --- a/tests/remote_server/test_remote_fireworks_propagate_status.py +++ b/tests/remote_server/test_remote_fireworks_propagate_status.py @@ -25,6 +25,7 @@ from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor from eval_protocol.adapters.fireworks_tracing import create_fireworks_tracing_adapter from eval_protocol.quickstart.utils import filter_longest_conversation +from eval_protocol.types.remote_rollout_processor import DataLoaderConfig def find_available_port() -> int: @@ -77,14 +78,16 @@ def setup_remote_server(): process.wait() -def fetch_fireworks_traces(rollout_id: str, base_url: str) -> List[EvaluationRow]: +def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]: + rollout_id = config.rollout_id + base_url = config.model_base_url or "https://tracing.fireworks.ai" adapter = create_fireworks_tracing_adapter(base_url=base_url) return adapter.get_evaluation_rows(tags=[f"rollout_id:{rollout_id}"], max_retries=5) -def fireworks_output_data_loader(rollout_id: str, base_url: str) -> DynamicDataLoader: +def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: return DynamicDataLoader( - generators=[lambda: fetch_fireworks_traces(rollout_id, base_url)], preprocess_fn=filter_longest_conversation + generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation ) From 03d3b0cb6642ece42b9c4d833779255badf32b06 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Mon, 6 Oct 2025 14:48:46 -0700 Subject: [PATCH 10/25] use get --- eval_protocol/adapters/fireworks_tracing.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index 108e85a5..ba06c532 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -321,8 +321,8 @@ def get_evaluation_rows( eval_rows = [] - # Build request payload - payload = { + # Build query parameters for GET request + params = { "limit": limit, "sample_size": sample_size, "tags": tags, @@ -336,14 +336,12 @@ def get_evaluation_rows( "hours_back": hours_back, "from_timestamp": from_timestamp.isoformat() if from_timestamp else None, "to_timestamp": to_timestamp.isoformat() if to_timestamp else None, - "include_tool_calls": include_tool_calls, "sleep_between_gets": sleep_between_gets, "max_retries": max_retries, - "span_name": span_name, } # Remove None values - payload = {k: v for k, v in payload.items() if v is not None} + params = {k: v for k, v in params.items() if v is not None} # Make request to proxy if self.project_id: @@ -352,7 +350,7 @@ def get_evaluation_rows( url = f"{self.base_url}/v1/traces" try: - response = requests.post(url, json=payload, timeout=self.timeout) + response = requests.get(url, params=params, timeout=self.timeout) response.raise_for_status() result = response.json() except requests.exceptions.RequestException as e: From 26c45e0b7f034ed772a6de5d7112ad46a0ffd394 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Mon, 6 Oct 2025 15:48:39 -0700 Subject: [PATCH 11/25] run CI when parent is another PR --- .github/workflows/ci.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 03e836f9..0865ff42 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,7 +11,6 @@ on: - "docs/**" - "*.md" pull_request: - branches: [main] paths-ignore: - "docs/**" - "*.md" From 3e4d8b712f0035789f9d68efaf32a36594778f84 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Mon, 6 Oct 2025 15:48:48 -0700 Subject: [PATCH 12/25] Implement rollout status handling in rollout processor; add helper function to preserve error status during updates. --- eval_protocol/pytest/utils.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/eval_protocol/pytest/utils.py b/eval_protocol/pytest/utils.py index 706a75f9..c582d4be 100644 --- a/eval_protocol/pytest/utils.py +++ b/eval_protocol/pytest/utils.py @@ -312,6 +312,16 @@ def deep_update_dict(base: dict[str, Any], override: dict[str, Any]) -> dict[str return base +def _set_rollout_status_to_finished(result: EvaluationRow) -> None: + # Only set to finished if execution finished while not + # updating status itself. In the case that the rollout + # processor set the status to an error, we want to + # preserve the error so we do nothing in this case. + # test_remote_fireworks_propagate_status.py verifies this. + if result.rollout_status.is_running(): + result.rollout_status = Status.rollout_finished() + + async def rollout_processor_with_retry( rollout_processor: RolloutProcessor, fresh_dataset: list[EvaluationRow], @@ -359,7 +369,9 @@ async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: Evalu try: # Try original task first result = await task # pyright: ignore[reportUnknownVariableType] - result.rollout_status = Status.rollout_finished() + + _set_rollout_status_to_finished(result) + return result # pyright: ignore[reportUnknownVariableType] except Exception as e: # NOTE: we perform these checks because we don't put the backoff decorator on initial batch call. we don't want to retry whole batch if anything fails. @@ -372,7 +384,9 @@ async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: Evalu # Use shared backoff function for retryable exceptions try: result = await execute_row_with_backoff_retry(row) - result.rollout_status = Status.rollout_finished() + + _set_rollout_status_to_finished(result) + return result except Exception as retry_error: # Backoff gave up From d014324c571435030666c3de2b3d95e5501e1f38 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Tue, 7 Oct 2025 13:06:05 -0700 Subject: [PATCH 13/25] make work for GH action (test) --- eval_protocol/pytest/elasticsearch_setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eval_protocol/pytest/elasticsearch_setup.py b/eval_protocol/pytest/elasticsearch_setup.py index 18574473..3e593cb8 100644 --- a/eval_protocol/pytest/elasticsearch_setup.py +++ b/eval_protocol/pytest/elasticsearch_setup.py @@ -76,7 +76,7 @@ def _setup_initialized_docker_elasticsearch(self, env_file_path: str) -> Elastic # Use set -o pipefail to ensure we get the return code of the first failing command process = subprocess.Popen( [ - "sh", + "bash", "-c", f"set -o pipefail; curl -fsSL https://elastic.co/start-local | sh -s -- --esonly | tee {temp_file_path}", ], From 1e5137d109515e3506a29a4e3d874399c3f8873a Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Tue, 7 Oct 2025 13:33:17 -0700 Subject: [PATCH 14/25] disable test in regulaR CI / increase setup timeout --- tests/remote_server/test_remote_fireworks_propagate_status.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/remote_server/test_remote_fireworks_propagate_status.py b/tests/remote_server/test_remote_fireworks_propagate_status.py index 4ebc50f7..9dccba0c 100644 --- a/tests/remote_server/test_remote_fireworks_propagate_status.py +++ b/tests/remote_server/test_remote_fireworks_propagate_status.py @@ -39,7 +39,7 @@ def find_available_port() -> int: SERVER_PORT = find_available_port() -def wait_for_server_to_startup(timeout: int = 10): +def wait_for_server_to_startup(timeout: int = 120): start_time = time.time() while True: try: @@ -96,6 +96,7 @@ def rows() -> List[EvaluationRow]: return [row] +@pytest.mark.skip(reason="Smoke test - only runs in scheduled smoke test workflow") @pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}]) @evaluation_test( data_loaders=DynamicDataLoader( From 5fed8d043ccc52aa17d26dff34a7d88b5abb60a5 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Tue, 7 Oct 2025 13:33:24 -0700 Subject: [PATCH 15/25] smoke test --- .../fireworks-propagate-status-smoke-test.yml | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 .github/workflows/fireworks-propagate-status-smoke-test.yml diff --git a/.github/workflows/fireworks-propagate-status-smoke-test.yml b/.github/workflows/fireworks-propagate-status-smoke-test.yml new file mode 100644 index 00000000..47cdd67b --- /dev/null +++ b/.github/workflows/fireworks-propagate-status-smoke-test.yml @@ -0,0 +1,58 @@ +name: Fireworks Propagate Status Smoke Test + +# Run every 6 hours: at 00:00, 06:00, 12:00, and 18:00 UTC +on: + schedule: + - cron: '0 */6 * * *' + workflow_dispatch: # Allow manual triggering + inputs: + debug_mode: + description: 'Enable debug output' + required: false + default: false + type: boolean + +jobs: + fireworks-propagate-status-smoke-test: + name: Fireworks Propagate Status Smoke Test + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Install uv + uses: astral-sh/setup-uv@v6 + with: + enable-cache: true + + - name: Install the project + run: uv sync --locked --all-extras --dev + + - name: Run Fireworks Propagate Status Smoke Test + env: + FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }} + FIREWORKS_ACCOUNT_ID: ${{ secrets.FIREWORKS_ACCOUNT_ID }} + PYTHONWARNINGS: "ignore::DeprecationWarning,ignore::RuntimeWarning" + run: | + uv run pytest tests/remote_server/test_remote_fireworks_propagate_status.py::test_remote_rollout_and_fetch_fireworks_propagate_status \ + -v --tb=short \ + --ignore-markers=skip + + - name: Send failure notification to Slack + uses: act10ns/slack@v1 + if: failure() + with: + status: failure + message: | + Fireworks Propagate Status Smoke Test failed + Job: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} + env: + SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} From bbdd4246c4b1ba68b05bd5c3ea7f4a5e401aea8a Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Tue, 7 Oct 2025 13:36:49 -0700 Subject: [PATCH 16/25] for testing --- .github/workflows/fireworks-propagate-status-smoke-test.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/fireworks-propagate-status-smoke-test.yml b/.github/workflows/fireworks-propagate-status-smoke-test.yml index 47cdd67b..af9e1975 100644 --- a/.github/workflows/fireworks-propagate-status-smoke-test.yml +++ b/.github/workflows/fireworks-propagate-status-smoke-test.yml @@ -4,6 +4,8 @@ name: Fireworks Propagate Status Smoke Test on: schedule: - cron: '0 */6 * * *' + pull_request: # Temporarily enable for PR testing + branches: [propagate-error-status] workflow_dispatch: # Allow manual triggering inputs: debug_mode: From 3bb7367edfb30bbe593a8f4108f56912e7c85f02 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Tue, 7 Oct 2025 13:39:11 -0700 Subject: [PATCH 17/25] test correctly --- .github/workflows/fireworks-propagate-status-smoke-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/fireworks-propagate-status-smoke-test.yml b/.github/workflows/fireworks-propagate-status-smoke-test.yml index af9e1975..07e23b03 100644 --- a/.github/workflows/fireworks-propagate-status-smoke-test.yml +++ b/.github/workflows/fireworks-propagate-status-smoke-test.yml @@ -5,7 +5,7 @@ on: schedule: - cron: '0 */6 * * *' pull_request: # Temporarily enable for PR testing - branches: [propagate-error-status] + branches: [main] workflow_dispatch: # Allow manual triggering inputs: debug_mode: From 5578700390d9b5cfd6b94a4070fe5f8fb2d16ff1 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Tue, 7 Oct 2025 13:41:19 -0700 Subject: [PATCH 18/25] udpate --- .github/workflows/fireworks-propagate-status-smoke-test.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/fireworks-propagate-status-smoke-test.yml b/.github/workflows/fireworks-propagate-status-smoke-test.yml index 07e23b03..bd716a21 100644 --- a/.github/workflows/fireworks-propagate-status-smoke-test.yml +++ b/.github/workflows/fireworks-propagate-status-smoke-test.yml @@ -45,8 +45,7 @@ jobs: PYTHONWARNINGS: "ignore::DeprecationWarning,ignore::RuntimeWarning" run: | uv run pytest tests/remote_server/test_remote_fireworks_propagate_status.py::test_remote_rollout_and_fetch_fireworks_propagate_status \ - -v --tb=short \ - --ignore-markers=skip + -v --tb=short -m "skip" - name: Send failure notification to Slack uses: act10ns/slack@v1 From 37caf087729c7bc103b116b9c5682a7f13f7ca47 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Tue, 7 Oct 2025 13:44:19 -0700 Subject: [PATCH 19/25] fix test --- .../test_remote_fireworks_propagate_status.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/tests/remote_server/test_remote_fireworks_propagate_status.py b/tests/remote_server/test_remote_fireworks_propagate_status.py index 9dccba0c..8351bcf7 100644 --- a/tests/remote_server/test_remote_fireworks_propagate_status.py +++ b/tests/remote_server/test_remote_fireworks_propagate_status.py @@ -23,9 +23,6 @@ from eval_protocol.models import EvaluationRow, Message, Status from eval_protocol.pytest import evaluation_test from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor -from eval_protocol.adapters.fireworks_tracing import create_fireworks_tracing_adapter -from eval_protocol.quickstart.utils import filter_longest_conversation -from eval_protocol.types.remote_rollout_processor import DataLoaderConfig def find_available_port() -> int: @@ -78,19 +75,6 @@ def setup_remote_server(): process.wait() -def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]: - rollout_id = config.rollout_id - base_url = config.model_base_url or "https://tracing.fireworks.ai" - adapter = create_fireworks_tracing_adapter(base_url=base_url) - return adapter.get_evaluation_rows(tags=[f"rollout_id:{rollout_id}"], max_retries=5) - - -def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader: - return DynamicDataLoader( - generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation - ) - - def rows() -> List[EvaluationRow]: row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")]) return [row] @@ -105,7 +89,6 @@ def rows() -> List[EvaluationRow]: rollout_processor=RemoteRolloutProcessor( remote_base_url=f"http://127.0.0.1:{SERVER_PORT}", timeout_seconds=30, - output_data_loader=fireworks_output_data_loader, ), ) async def test_remote_rollout_and_fetch_fireworks_propagate_status(row: EvaluationRow) -> EvaluationRow: From c43ecd063f247c84676531752db25250f98709da Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Tue, 7 Oct 2025 13:44:46 -0700 Subject: [PATCH 20/25] update test name --- .github/workflows/fireworks-propagate-status-smoke-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/fireworks-propagate-status-smoke-test.yml b/.github/workflows/fireworks-propagate-status-smoke-test.yml index bd716a21..d97f28bd 100644 --- a/.github/workflows/fireworks-propagate-status-smoke-test.yml +++ b/.github/workflows/fireworks-propagate-status-smoke-test.yml @@ -1,4 +1,4 @@ -name: Fireworks Propagate Status Smoke Test +name: RemoteRolloutProcessor Propagate Status Smoke Test # Run every 6 hours: at 00:00, 06:00, 12:00, and 18:00 UTC on: From a3970f5999091a6531206115e5eabe606e1c77dd Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Tue, 7 Oct 2025 13:45:19 -0700 Subject: [PATCH 21/25] remove unnecessary secret --- .github/workflows/fireworks-propagate-status-smoke-test.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/fireworks-propagate-status-smoke-test.yml b/.github/workflows/fireworks-propagate-status-smoke-test.yml index d97f28bd..0080eb75 100644 --- a/.github/workflows/fireworks-propagate-status-smoke-test.yml +++ b/.github/workflows/fireworks-propagate-status-smoke-test.yml @@ -41,7 +41,6 @@ jobs: - name: Run Fireworks Propagate Status Smoke Test env: FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }} - FIREWORKS_ACCOUNT_ID: ${{ secrets.FIREWORKS_ACCOUNT_ID }} PYTHONWARNINGS: "ignore::DeprecationWarning,ignore::RuntimeWarning" run: | uv run pytest tests/remote_server/test_remote_fireworks_propagate_status.py::test_remote_rollout_and_fetch_fireworks_propagate_status \ From c75a5ca232cdf27f6b46d746f3bce51b831392dd Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Tue, 7 Oct 2025 13:49:55 -0700 Subject: [PATCH 22/25] ensure it runs --- .github/workflows/ci.yml | 1 + ...=> remote-rollout-processor-propagate-status-smoke-test.yml} | 2 +- tests/remote_server/test_remote_fireworks_propagate_status.py | 1 - 3 files changed, 2 insertions(+), 2 deletions(-) rename .github/workflows/{fireworks-propagate-status-smoke-test.yml => remote-rollout-processor-propagate-status-smoke-test.yml} (97%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0865ff42..903c7734 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -109,6 +109,7 @@ jobs: --ignore=tests/test_tau_bench_airline_smoke.py \ --ignore=tests/pytest/test_svgbench.py \ --ignore=tests/pytest/test_livesvgbench.py \ + --ignore=tests/remote_server/test_remote_fireworks_propagate_status.py \ --ignore=eval_protocol/benchmarks/ \ --cov=eval_protocol --cov-append --cov-report=xml --cov-report=term-missing -v --durations=10 diff --git a/.github/workflows/fireworks-propagate-status-smoke-test.yml b/.github/workflows/remote-rollout-processor-propagate-status-smoke-test.yml similarity index 97% rename from .github/workflows/fireworks-propagate-status-smoke-test.yml rename to .github/workflows/remote-rollout-processor-propagate-status-smoke-test.yml index 0080eb75..9e9a195a 100644 --- a/.github/workflows/fireworks-propagate-status-smoke-test.yml +++ b/.github/workflows/remote-rollout-processor-propagate-status-smoke-test.yml @@ -44,7 +44,7 @@ jobs: PYTHONWARNINGS: "ignore::DeprecationWarning,ignore::RuntimeWarning" run: | uv run pytest tests/remote_server/test_remote_fireworks_propagate_status.py::test_remote_rollout_and_fetch_fireworks_propagate_status \ - -v --tb=short -m "skip" + -v --tb=short - name: Send failure notification to Slack uses: act10ns/slack@v1 diff --git a/tests/remote_server/test_remote_fireworks_propagate_status.py b/tests/remote_server/test_remote_fireworks_propagate_status.py index 8351bcf7..27ac977b 100644 --- a/tests/remote_server/test_remote_fireworks_propagate_status.py +++ b/tests/remote_server/test_remote_fireworks_propagate_status.py @@ -80,7 +80,6 @@ def rows() -> List[EvaluationRow]: return [row] -@pytest.mark.skip(reason="Smoke test - only runs in scheduled smoke test workflow") @pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}]) @evaluation_test( data_loaders=DynamicDataLoader( From b2c3e5e531d9aaac26ab3aa1c95b26ebb4fce04b Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Tue, 7 Oct 2025 13:59:37 -0700 Subject: [PATCH 23/25] remove from PRs --- .../remote-rollout-processor-propagate-status-smoke-test.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/remote-rollout-processor-propagate-status-smoke-test.yml b/.github/workflows/remote-rollout-processor-propagate-status-smoke-test.yml index 9e9a195a..93482478 100644 --- a/.github/workflows/remote-rollout-processor-propagate-status-smoke-test.yml +++ b/.github/workflows/remote-rollout-processor-propagate-status-smoke-test.yml @@ -4,8 +4,6 @@ name: RemoteRolloutProcessor Propagate Status Smoke Test on: schedule: - cron: '0 */6 * * *' - pull_request: # Temporarily enable for PR testing - branches: [main] workflow_dispatch: # Allow manual triggering inputs: debug_mode: From d8f02d1893e9fbbf1d24f4eadfee6d30ddf831dd Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Tue, 7 Oct 2025 14:36:05 -0700 Subject: [PATCH 24/25] run on all pull requests --- ...-processor-propagate-status-smoke-test.yml | 33 +++++++------------ 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/.github/workflows/remote-rollout-processor-propagate-status-smoke-test.yml b/.github/workflows/remote-rollout-processor-propagate-status-smoke-test.yml index 93482478..8cd17d4b 100644 --- a/.github/workflows/remote-rollout-processor-propagate-status-smoke-test.yml +++ b/.github/workflows/remote-rollout-processor-propagate-status-smoke-test.yml @@ -1,19 +1,19 @@ name: RemoteRolloutProcessor Propagate Status Smoke Test -# Run every 6 hours: at 00:00, 06:00, 12:00, and 18:00 UTC on: - schedule: - - cron: '0 */6 * * *' + push: + branches: [main] + paths-ignore: + - "docs/**" + - "*.md" + pull_request: # Run on all pull requests + paths-ignore: + - "docs/**" + - "*.md" workflow_dispatch: # Allow manual triggering - inputs: - debug_mode: - description: 'Enable debug output' - required: false - default: false - type: boolean jobs: - fireworks-propagate-status-smoke-test: + remote-rollout-processor-propagate-status-smoke-test: name: Fireworks Propagate Status Smoke Test runs-on: ubuntu-latest @@ -36,21 +36,10 @@ jobs: - name: Install the project run: uv sync --locked --all-extras --dev - - name: Run Fireworks Propagate Status Smoke Test + - name: Run RemoteRolloutProcessor Propagate Status Smoke Test env: FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }} PYTHONWARNINGS: "ignore::DeprecationWarning,ignore::RuntimeWarning" run: | uv run pytest tests/remote_server/test_remote_fireworks_propagate_status.py::test_remote_rollout_and_fetch_fireworks_propagate_status \ -v --tb=short - - - name: Send failure notification to Slack - uses: act10ns/slack@v1 - if: failure() - with: - status: failure - message: | - Fireworks Propagate Status Smoke Test failed - Job: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} - env: - SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} From 7f93159d2e8b996dc853470a55cdbe09d7f19b77 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Tue, 7 Oct 2025 14:37:14 -0700 Subject: [PATCH 25/25] update name --- ...t.yml => remote-rollout-processor-propagate-status-test.yml} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename .github/workflows/{remote-rollout-processor-propagate-status-smoke-test.yml => remote-rollout-processor-propagate-status-test.yml} (95%) diff --git a/.github/workflows/remote-rollout-processor-propagate-status-smoke-test.yml b/.github/workflows/remote-rollout-processor-propagate-status-test.yml similarity index 95% rename from .github/workflows/remote-rollout-processor-propagate-status-smoke-test.yml rename to .github/workflows/remote-rollout-processor-propagate-status-test.yml index 8cd17d4b..d8080777 100644 --- a/.github/workflows/remote-rollout-processor-propagate-status-smoke-test.yml +++ b/.github/workflows/remote-rollout-processor-propagate-status-test.yml @@ -1,4 +1,4 @@ -name: RemoteRolloutProcessor Propagate Status Smoke Test +name: RemoteRolloutProcessor Propagate Status Test on: push: