Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
341a033
Fireworks Tracing
xzrderek Oct 5, 2025
fd204bb
update path
xzrderek Oct 5, 2025
77f9906
add status handling from ECS
Oct 6, 2025
f4527eb
Merge branch 'main' into derekx/fireworks-tracing
Oct 6, 2025
69cd1cf
Merge branch 'derekx/fireworks-tracing' into propagate-error-status
Oct 6, 2025
91f2378
various changes
xzrderek Oct 6, 2025
712a37d
Merge branch 'derekx/fireworks-tracing' of https://github.com/eval-pr…
xzrderek Oct 6, 2025
d0b35ed
Refactor remote server startup to use argparse for host and port conf…
Oct 6, 2025
ea97cd4
Merge branch 'derekx/fireworks-tracing' into propagate-error-status
Oct 6, 2025
0400e21
fix test
Oct 6, 2025
260f721
add dataloaderconfig
xzrderek Oct 6, 2025
aff3200
test_remote_rollout_and_fetch_fireworks_propagate_status
Oct 6, 2025
faf5f7c
Merge branch 'derekx/fireworks-tracing' into propagate-error-status
Oct 6, 2025
db28b96
sync on latest
Oct 6, 2025
03d3b0c
use get
xzrderek Oct 6, 2025
b36366e
Merge branch 'derekx/fireworks-tracing' into propagate-error-status
Oct 6, 2025
26c45e0
run CI when parent is another PR
Oct 6, 2025
3e4d8b7
Implement rollout status handling in rollout processor; add helper fu…
Oct 6, 2025
d014324
make work for GH action (test)
Oct 7, 2025
1e5137d
disable test in regulaR CI / increase setup timeout
Oct 7, 2025
5fed8d0
smoke test
Oct 7, 2025
215eb26
Merge branch 'main' into propagate-error-status
Oct 7, 2025
bbdd424
for testing
Oct 7, 2025
3bb7367
test correctly
Oct 7, 2025
5578700
udpate
Oct 7, 2025
37caf08
fix test
Oct 7, 2025
c43ecd0
update test name
Oct 7, 2025
a3970f5
remove unnecessary secret
Oct 7, 2025
c75a5ca
ensure it runs
Oct 7, 2025
b2c3e5e
remove from PRs
Oct 7, 2025
d8f02d1
run on all pull requests
Oct 7, 2025
7f93159
update name
Oct 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ on:
- "docs/**"
- "*.md"
pull_request:
branches: [main]
paths-ignore:
- "docs/**"
- "*.md"
Expand Down Expand Up @@ -110,6 +109,7 @@ jobs:
--ignore=tests/test_tau_bench_airline_smoke.py \
--ignore=tests/pytest/test_svgbench.py \
--ignore=tests/pytest/test_livesvgbench.py \
--ignore=tests/remote_server/test_remote_fireworks_propagate_status.py \
--ignore=eval_protocol/benchmarks/ \
--cov=eval_protocol --cov-append --cov-report=xml --cov-report=term-missing -v --durations=10

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
name: RemoteRolloutProcessor Propagate Status Test

on:
push:
branches: [main]
paths-ignore:
- "docs/**"
- "*.md"
pull_request: # Run on all pull requests
paths-ignore:
- "docs/**"
- "*.md"
workflow_dispatch: # Allow manual triggering

jobs:
remote-rollout-processor-propagate-status-smoke-test:
name: Fireworks Propagate Status Smoke Test
runs-on: ubuntu-latest

steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0

- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: "3.10"

- name: Install uv
uses: astral-sh/setup-uv@v6
with:
enable-cache: true

- name: Install the project
run: uv sync --locked --all-extras --dev

- name: Run RemoteRolloutProcessor Propagate Status Smoke Test
env:
FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }}
PYTHONWARNINGS: "ignore::DeprecationWarning,ignore::RuntimeWarning"
run: |
uv run pytest tests/remote_server/test_remote_fireworks_propagate_status.py::test_remote_rollout_and_fetch_fireworks_propagate_status \
-v --tb=short
2 changes: 2 additions & 0 deletions eval_protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
RolloutMetadata,
StatusResponse,
create_langfuse_config_tags,
DataLoaderConfig,
)

try:
Expand Down Expand Up @@ -67,6 +68,7 @@
__all__ = [
"ElasticsearchDirectHttpHandler",
"RolloutIdFilter",
"DataLoaderConfig",
"Status",
"RemoteRolloutProcessor",
"InputMetadata",
Expand Down
2 changes: 1 addition & 1 deletion eval_protocol/pytest/elasticsearch_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _setup_initialized_docker_elasticsearch(self, env_file_path: str) -> Elastic
# Use set -o pipefail to ensure we get the return code of the first failing command
process = subprocess.Popen(
[
"sh",
"bash",
"-c",
f"set -o pipefail; curl -fsSL https://elastic.co/start-local | sh -s -- --esonly | tee {temp_file_path}",
],
Expand Down
9 changes: 8 additions & 1 deletion eval_protocol/pytest/remote_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,12 +262,19 @@ def _get_status() -> Dict[str, Any]:
hits = search_results["hits"]["hits"] if search_results else []

if hits:
# log all statuses found
# log all statuses found and update rollout status from the last hit
for hit in hits:
document = hit["_source"]
logger.info(
f"Found log for rollout {row.execution_metadata.rollout_id} with status code {document['status_code']}"
)
# Update rollout status from the document
if "status_code" in document:
row.rollout_status = Status(
code=Status.Code(document["status_code"]),
message=document.get("status_message", ""),
details=document.get("status_details", []),
)
logger.info("Stopping status polling for rollout %s", row.execution_metadata.rollout_id)
break

Expand Down
18 changes: 16 additions & 2 deletions eval_protocol/pytest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,16 @@ def deep_update_dict(base: dict[str, Any], override: dict[str, Any]) -> dict[str
return base


def _set_rollout_status_to_finished(result: EvaluationRow) -> None:
# Only set to finished if execution finished while not
# updating status itself. In the case that the rollout
# processor set the status to an error, we want to
# preserve the error so we do nothing in this case.
# test_remote_fireworks_propagate_status.py verifies this.
if result.rollout_status.is_running():
result.rollout_status = Status.rollout_finished()


async def rollout_processor_with_retry(
rollout_processor: RolloutProcessor,
fresh_dataset: list[EvaluationRow],
Expand Down Expand Up @@ -359,7 +369,9 @@ async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: Evalu
try:
# Try original task first
result = await task # pyright: ignore[reportUnknownVariableType]
result.rollout_status = Status.rollout_finished()

_set_rollout_status_to_finished(result)

return result # pyright: ignore[reportUnknownVariableType]
except Exception as e:
# NOTE: we perform these checks because we don't put the backoff decorator on initial batch call. we don't want to retry whole batch if anything fails.
Expand All @@ -372,7 +384,9 @@ async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: Evalu
# Use shared backoff function for retryable exceptions
try:
result = await execute_row_with_backoff_retry(row)
result.rollout_status = Status.rollout_finished()

_set_rollout_status_to_finished(result)

return result
except Exception as retry_error:
# Backoff gave up
Expand Down
50 changes: 43 additions & 7 deletions tests/remote_server/remote_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import random
import threading
import argparse

import uvicorn
from fastapi import FastAPI
Expand All @@ -17,6 +18,9 @@
logging.getLogger().addHandler(handler)


force_early_error_message = None


@app.post("/init")
def init(req: InitRequest):
if req.elastic_search_config:
Expand Down Expand Up @@ -46,24 +50,56 @@ def _worker():
completion = client.chat.completions.create(**completion_kwargs)
logger.info(f"Completed response: {completion}")

# If force_early_error is set via command-line arg, log the error and return early
if force_early_error_message:
logger.error(
force_early_error_message,
extra={"status": Status.rollout_error(force_early_error_message)},
)
raise RuntimeError(force_early_error_message)

except Exception as e:
# Best-effort; mark as done even on error to unblock polling
print(f"❌ Error in rollout {req.metadata.rollout_id}: {e}")
pass
finally:
logger.info(
f"Rollout {req.metadata.rollout_id} completed",
extra={"status": Status.rollout_finished()},
)
if not force_early_error_message:
logger.info(
f"Rollout {req.metadata.rollout_id} completed",
extra={"status": Status.rollout_finished()},
)

t = threading.Thread(target=_worker, daemon=True)
t.start()


def main():
host = os.getenv("REMOTE_SERVER_HOST", "127.0.0.1")
port = int(os.getenv("REMOTE_SERVER_PORT", "3000"))
uvicorn.run(app, host=host, port=port)
global force_early_error_message

parser = argparse.ArgumentParser(description="Run the remote server for evaluation protocol")
parser.add_argument(
"--host",
type=str,
default=os.getenv("REMOTE_SERVER_HOST", "127.0.0.1"),
help="Host to bind the server to (default: 127.0.0.1 or REMOTE_SERVER_HOST env var)",
)
parser.add_argument(
"--port",
type=int,
default=int(os.getenv("REMOTE_SERVER_PORT", "3000")),
help="Port to bind the server to (default: 3000 or REMOTE_SERVER_PORT env var)",
)
parser.add_argument(
"--force-early-error",
type=str,
default=None,
help="If set, /init will immediately return after logging a rollout_error with this message",
)

args = parser.parse_args()
force_early_error_message = args.force_early_error

uvicorn.run(app, host=args.host, port=args.port)


if __name__ == "__main__":
Expand Down
96 changes: 96 additions & 0 deletions tests/remote_server/test_remote_fireworks_propagate_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# MANUAL SERVER STARTUP REQUIRED:
#
# For Python server testing, start:
# python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000)
#
# For TypeScript server testing, start:
# cd tests/remote_server/typescript-server
# npm install
# npm start
#
# The TypeScript server should be running on http://127.0.0.1:3000
# You only need to start one of the servers!

import subprocess
import socket
import time
from typing import List

import pytest
import requests

from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
from eval_protocol.models import EvaluationRow, Message, Status
from eval_protocol.pytest import evaluation_test
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor


def find_available_port() -> int:
"""Find an available port on localhost"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
port = s.getsockname()[1]
return port


SERVER_PORT = find_available_port()


def wait_for_server_to_startup(timeout: int = 120):
start_time = time.time()
while True:
try:
requests.get(f"http://127.0.0.1:{SERVER_PORT}")
break
except requests.exceptions.RequestException:
time.sleep(1)
if time.time() - start_time > timeout:
raise TimeoutError(f"Server did not start within {timeout} seconds")


@pytest.fixture(autouse=True)
def setup_remote_server():
"""Start the remote server"""
# kill all Python processes matching "python -m tests.remote_server.remote_server"
subprocess.run(["pkill", "-f", "python -m tests.remote_server.remote_server"])

host = "127.0.0.1"
process = subprocess.Popen(
[
"python",
"-m",
"tests.remote_server.remote_server",
"--host",
host,
"--port",
str(SERVER_PORT),
"--force-early-error",
"test error",
]
)
# wait for the server to startup by pollingK
wait_for_server_to_startup()
yield
process.terminate()
process.wait()


def rows() -> List[EvaluationRow]:
row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")])
return [row]


@pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}])
@evaluation_test(
data_loaders=DynamicDataLoader(
generators=[rows],
),
rollout_processor=RemoteRolloutProcessor(
remote_base_url=f"http://127.0.0.1:{SERVER_PORT}",
timeout_seconds=30,
),
)
async def test_remote_rollout_and_fetch_fireworks_propagate_status(row: EvaluationRow) -> EvaluationRow:
assert row.rollout_status.code == Status.Code.INTERNAL
assert row.rollout_status.message == "test error"
return row
Loading