diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 03e836f9..903c7734 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,7 +11,6 @@ on: - "docs/**" - "*.md" pull_request: - branches: [main] paths-ignore: - "docs/**" - "*.md" @@ -110,6 +109,7 @@ jobs: --ignore=tests/test_tau_bench_airline_smoke.py \ --ignore=tests/pytest/test_svgbench.py \ --ignore=tests/pytest/test_livesvgbench.py \ + --ignore=tests/remote_server/test_remote_fireworks_propagate_status.py \ --ignore=eval_protocol/benchmarks/ \ --cov=eval_protocol --cov-append --cov-report=xml --cov-report=term-missing -v --durations=10 diff --git a/.github/workflows/remote-rollout-processor-propagate-status-test.yml b/.github/workflows/remote-rollout-processor-propagate-status-test.yml new file mode 100644 index 00000000..d8080777 --- /dev/null +++ b/.github/workflows/remote-rollout-processor-propagate-status-test.yml @@ -0,0 +1,45 @@ +name: RemoteRolloutProcessor Propagate Status Test + +on: + push: + branches: [main] + paths-ignore: + - "docs/**" + - "*.md" + pull_request: # Run on all pull requests + paths-ignore: + - "docs/**" + - "*.md" + workflow_dispatch: # Allow manual triggering + +jobs: + remote-rollout-processor-propagate-status-smoke-test: + name: Fireworks Propagate Status Smoke Test + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Install uv + uses: astral-sh/setup-uv@v6 + with: + enable-cache: true + + - name: Install the project + run: uv sync --locked --all-extras --dev + + - name: Run RemoteRolloutProcessor Propagate Status Smoke Test + env: + FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }} + PYTHONWARNINGS: "ignore::DeprecationWarning,ignore::RuntimeWarning" + run: | + uv run pytest tests/remote_server/test_remote_fireworks_propagate_status.py::test_remote_rollout_and_fetch_fireworks_propagate_status \ + -v --tb=short diff --git a/eval_protocol/__init__.py b/eval_protocol/__init__.py index 33c48e95..9f17f8ac 100644 --- a/eval_protocol/__init__.py +++ b/eval_protocol/__init__.py @@ -40,6 +40,7 @@ RolloutMetadata, StatusResponse, create_langfuse_config_tags, + DataLoaderConfig, ) try: @@ -67,6 +68,7 @@ __all__ = [ "ElasticsearchDirectHttpHandler", "RolloutIdFilter", + "DataLoaderConfig", "Status", "RemoteRolloutProcessor", "InputMetadata", diff --git a/eval_protocol/pytest/elasticsearch_setup.py b/eval_protocol/pytest/elasticsearch_setup.py index 18574473..3e593cb8 100644 --- a/eval_protocol/pytest/elasticsearch_setup.py +++ b/eval_protocol/pytest/elasticsearch_setup.py @@ -76,7 +76,7 @@ def _setup_initialized_docker_elasticsearch(self, env_file_path: str) -> Elastic # Use set -o pipefail to ensure we get the return code of the first failing command process = subprocess.Popen( [ - "sh", + "bash", "-c", f"set -o pipefail; curl -fsSL https://elastic.co/start-local | sh -s -- --esonly | tee {temp_file_path}", ], diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index 1d4b6553..cca6884f 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -262,12 +262,19 @@ def _get_status() -> Dict[str, Any]: hits = search_results["hits"]["hits"] if search_results else [] if hits: - # log all statuses found + # log all statuses found and update rollout status from the last hit for hit in hits: document = hit["_source"] logger.info( f"Found log for rollout {row.execution_metadata.rollout_id} with status code {document['status_code']}" ) + # Update rollout status from the document + if "status_code" in document: + row.rollout_status = Status( + code=Status.Code(document["status_code"]), + message=document.get("status_message", ""), + details=document.get("status_details", []), + ) logger.info("Stopping status polling for rollout %s", row.execution_metadata.rollout_id) break diff --git a/eval_protocol/pytest/utils.py b/eval_protocol/pytest/utils.py index 706a75f9..c582d4be 100644 --- a/eval_protocol/pytest/utils.py +++ b/eval_protocol/pytest/utils.py @@ -312,6 +312,16 @@ def deep_update_dict(base: dict[str, Any], override: dict[str, Any]) -> dict[str return base +def _set_rollout_status_to_finished(result: EvaluationRow) -> None: + # Only set to finished if execution finished while not + # updating status itself. In the case that the rollout + # processor set the status to an error, we want to + # preserve the error so we do nothing in this case. + # test_remote_fireworks_propagate_status.py verifies this. + if result.rollout_status.is_running(): + result.rollout_status = Status.rollout_finished() + + async def rollout_processor_with_retry( rollout_processor: RolloutProcessor, fresh_dataset: list[EvaluationRow], @@ -359,7 +369,9 @@ async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: Evalu try: # Try original task first result = await task # pyright: ignore[reportUnknownVariableType] - result.rollout_status = Status.rollout_finished() + + _set_rollout_status_to_finished(result) + return result # pyright: ignore[reportUnknownVariableType] except Exception as e: # NOTE: we perform these checks because we don't put the backoff decorator on initial batch call. we don't want to retry whole batch if anything fails. @@ -372,7 +384,9 @@ async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: Evalu # Use shared backoff function for retryable exceptions try: result = await execute_row_with_backoff_retry(row) - result.rollout_status = Status.rollout_finished() + + _set_rollout_status_to_finished(result) + return result except Exception as retry_error: # Backoff gave up diff --git a/tests/remote_server/remote_server.py b/tests/remote_server/remote_server.py index 8f07a474..f13bc754 100644 --- a/tests/remote_server/remote_server.py +++ b/tests/remote_server/remote_server.py @@ -1,6 +1,7 @@ import os import random import threading +import argparse import uvicorn from fastapi import FastAPI @@ -17,6 +18,9 @@ logging.getLogger().addHandler(handler) +force_early_error_message = None + + @app.post("/init") def init(req: InitRequest): if req.elastic_search_config: @@ -46,24 +50,56 @@ def _worker(): completion = client.chat.completions.create(**completion_kwargs) logger.info(f"Completed response: {completion}") + # If force_early_error is set via command-line arg, log the error and return early + if force_early_error_message: + logger.error( + force_early_error_message, + extra={"status": Status.rollout_error(force_early_error_message)}, + ) + raise RuntimeError(force_early_error_message) + except Exception as e: # Best-effort; mark as done even on error to unblock polling print(f"❌ Error in rollout {req.metadata.rollout_id}: {e}") pass finally: - logger.info( - f"Rollout {req.metadata.rollout_id} completed", - extra={"status": Status.rollout_finished()}, - ) + if not force_early_error_message: + logger.info( + f"Rollout {req.metadata.rollout_id} completed", + extra={"status": Status.rollout_finished()}, + ) t = threading.Thread(target=_worker, daemon=True) t.start() def main(): - host = os.getenv("REMOTE_SERVER_HOST", "127.0.0.1") - port = int(os.getenv("REMOTE_SERVER_PORT", "3000")) - uvicorn.run(app, host=host, port=port) + global force_early_error_message + + parser = argparse.ArgumentParser(description="Run the remote server for evaluation protocol") + parser.add_argument( + "--host", + type=str, + default=os.getenv("REMOTE_SERVER_HOST", "127.0.0.1"), + help="Host to bind the server to (default: 127.0.0.1 or REMOTE_SERVER_HOST env var)", + ) + parser.add_argument( + "--port", + type=int, + default=int(os.getenv("REMOTE_SERVER_PORT", "3000")), + help="Port to bind the server to (default: 3000 or REMOTE_SERVER_PORT env var)", + ) + parser.add_argument( + "--force-early-error", + type=str, + default=None, + help="If set, /init will immediately return after logging a rollout_error with this message", + ) + + args = parser.parse_args() + force_early_error_message = args.force_early_error + + uvicorn.run(app, host=args.host, port=args.port) if __name__ == "__main__": diff --git a/tests/remote_server/test_remote_fireworks_propagate_status.py b/tests/remote_server/test_remote_fireworks_propagate_status.py new file mode 100644 index 00000000..27ac977b --- /dev/null +++ b/tests/remote_server/test_remote_fireworks_propagate_status.py @@ -0,0 +1,96 @@ +# MANUAL SERVER STARTUP REQUIRED: +# +# For Python server testing, start: +# python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000) +# +# For TypeScript server testing, start: +# cd tests/remote_server/typescript-server +# npm install +# npm start +# +# The TypeScript server should be running on http://127.0.0.1:3000 +# You only need to start one of the servers! + +import subprocess +import socket +import time +from typing import List + +import pytest +import requests + +from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader +from eval_protocol.models import EvaluationRow, Message, Status +from eval_protocol.pytest import evaluation_test +from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor + + +def find_available_port() -> int: + """Find an available port on localhost""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + port = s.getsockname()[1] + return port + + +SERVER_PORT = find_available_port() + + +def wait_for_server_to_startup(timeout: int = 120): + start_time = time.time() + while True: + try: + requests.get(f"http://127.0.0.1:{SERVER_PORT}") + break + except requests.exceptions.RequestException: + time.sleep(1) + if time.time() - start_time > timeout: + raise TimeoutError(f"Server did not start within {timeout} seconds") + + +@pytest.fixture(autouse=True) +def setup_remote_server(): + """Start the remote server""" + # kill all Python processes matching "python -m tests.remote_server.remote_server" + subprocess.run(["pkill", "-f", "python -m tests.remote_server.remote_server"]) + + host = "127.0.0.1" + process = subprocess.Popen( + [ + "python", + "-m", + "tests.remote_server.remote_server", + "--host", + host, + "--port", + str(SERVER_PORT), + "--force-early-error", + "test error", + ] + ) + # wait for the server to startup by pollingK + wait_for_server_to_startup() + yield + process.terminate() + process.wait() + + +def rows() -> List[EvaluationRow]: + row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")]) + return [row] + + +@pytest.mark.parametrize("completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}]) +@evaluation_test( + data_loaders=DynamicDataLoader( + generators=[rows], + ), + rollout_processor=RemoteRolloutProcessor( + remote_base_url=f"http://127.0.0.1:{SERVER_PORT}", + timeout_seconds=30, + ), +) +async def test_remote_rollout_and_fetch_fireworks_propagate_status(row: EvaluationRow) -> EvaluationRow: + assert row.rollout_status.code == Status.Code.INTERNAL + assert row.rollout_status.message == "test error" + return row