Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 8 additions & 17 deletions .github/workflows/rollout.yml
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
name: Eval Protocol Rollout

run-name: rollout:${{ inputs.rollout_id }}
run-name: rollout:${{ fromJSON(inputs.metadata).rollout_id }}

on:
workflow_dispatch:
inputs:
model:
description: 'Model to use for the rollout'
description: 'Model to use'
required: true
type: string
rollout_id:
description: 'Rollout ID for tracking'
metadata:
description: 'JSON serialized metadata object'
required: true
type: string
prompt:
description: 'User prompt for the rollout'
model_base_url:
description: 'Base URL for the model API'
required: true
type: string

jobs:
rollout:
runs-on: ubuntu-latest
name: rollout-${{ inputs.rollout_id }}

steps:
- name: Checkout code
Expand All @@ -43,13 +42,5 @@ jobs:
run: |
python tests/github_actions/rollout_worker.py \
--model "${{ inputs.model }}" \
--rollout-id "${{ inputs.rollout_id }}" \
--prompt "${{ inputs.prompt }}"

- name: Upload rollout trace
uses: actions/upload-artifact@v4
if: always() # Upload even if the rollout failed
with:
name: rollout-trace-${{ inputs.rollout_id }}
path: rollout_trace_${{ inputs.rollout_id }}.json
retention-days: 7
--metadata '${{ inputs.metadata }}' \
--model-base-url "${{ inputs.model_base_url }}"
3 changes: 2 additions & 1 deletion eval_protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .reward_function import RewardFunction
from .typed_interface import reward_function
from .quickstart import aha_judge, multi_turn_assistant_to_ground_truth, assistant_to_ground_truth
from .pytest import evaluation_test, SingleTurnRolloutProcessor, RemoteRolloutProcessor
from .pytest import evaluation_test, SingleTurnRolloutProcessor, RemoteRolloutProcessor, GithubActionRolloutProcessor
from .pytest.remote_rollout_processor import create_elasticsearch_config_from_env
from .pytest.parameterize import DefaultParameterIdGenerator
from .log_utils.elasticsearch_direct_http_handler import ElasticsearchDirectHttpHandler
Expand Down Expand Up @@ -85,6 +85,7 @@
"DataLoaderConfig",
"Status",
"RemoteRolloutProcessor",
"GithubActionRolloutProcessor",
"InputMetadata",
"EvaluationRow",
"DefaultParameterIdGenerator",
Expand Down
4 changes: 3 additions & 1 deletion eval_protocol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,9 @@ class EvaluationRow(BaseModel):
model_config = ConfigDict(extra="allow")

# Core OpenAI ChatCompletion compatible conversation data
messages: List[Message] = Field(description="List of messages in the conversation. Also known as a trajectory.")
messages: List[Message] = Field(
Comment thread
xzrderek marked this conversation as resolved.
default_factory=list, description="List of messages in the conversation. Also known as a trajectory."
)

# Tool and function call information
tools: Optional[List[Dict[str, Any]]] = Field(
Expand Down
2 changes: 2 additions & 0 deletions eval_protocol/pytest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .default_no_op_rollout_processor import NoOpRolloutProcessor
from .default_single_turn_rollout_process import SingleTurnRolloutProcessor
from .remote_rollout_processor import RemoteRolloutProcessor
from .github_action_rollout_processor import GithubActionRolloutProcessor
from .evaluation_test import evaluation_test
from .exception_config import ExceptionHandlerConfig, BackoffConfig, get_default_exception_handler_config
from .rollout_processor import RolloutProcessor
Expand Down Expand Up @@ -33,6 +34,7 @@
"RolloutProcessor",
"SingleTurnRolloutProcessor",
"RemoteRolloutProcessor",
"GithubActionRolloutProcessor",
"NoOpRolloutProcessor",
"default_dataset_adapter",
"RolloutProcessorConfig",
Expand Down
2 changes: 1 addition & 1 deletion eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
)


from eval_protocol.pytest.utils import (
from eval_protocol.pytest.evaluation_test_utils import (
AggregationMethod,
add_cost_metrics,
log_eval_status_and_rows,
Expand Down
7 changes: 6 additions & 1 deletion eval_protocol/pytest/evaluation_test_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from eval_protocol.models import CompletionParams, EvaluationRow, EvaluationThreshold, Status
from eval_protocol.pytest.handle_persist_flow import handle_persist_flow
from eval_protocol.pytest.types import EvaluationTestMode
from eval_protocol.pytest.utils import AggregationMethod, aggregate, extract_effort_tag, sanitize_filename
from eval_protocol.pytest.evaluation_test_utils import (
AggregationMethod,
aggregate,
extract_effort_tag,
sanitize_filename,
)
from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci


Expand Down
2 changes: 1 addition & 1 deletion eval_protocol/pytest/generate_parameter_combinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from eval_protocol.data_loader.models import EvaluationDataLoader
from eval_protocol.models import CompletionParams, EvaluationRow
from eval_protocol.pytest.types import Dataset, DatasetPathParam, EvaluationInputParam, InputMessagesParam
from eval_protocol.pytest.utils import parse_ep_max_rows
from eval_protocol.pytest.evaluation_test_utils import parse_ep_max_rows
from collections.abc import Sequence


Expand Down
219 changes: 219 additions & 0 deletions eval_protocol/pytest/github_action_rollout_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import asyncio
import os
import time
from typing import Any, Callable, Dict, List, Optional

import requests
from datetime import datetime, timezone, timedelta
from eval_protocol.models import EvaluationRow, Status
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig

from .rollout_processor import RolloutProcessor
from .types import RolloutProcessorConfig
from .tracing_utils import default_fireworks_output_data_loader, build_init_request, update_row_with_remote_trace


class GithubActionRolloutProcessor(RolloutProcessor):
"""
Rollout processor that dispatches and monitors a GitHub Actions workflow per evaluation row.

Expected GitHub Actions workflow:
- Workflow dispatch with inputs: model, messages_b64, tools_b64, rollout_id, etc.
Comment thread
xzrderek marked this conversation as resolved.
Outdated
- Workflow uploads artifact named "rollout-trace-{rollout_id}" containing trace JSON
Comment thread
xzrderek marked this conversation as resolved.
Outdated
- Trace JSON format: {"status": "success"|"error", "messages": [...], "tools": [...], "error": str?}

NOTE: GHA has a rate limit of 5000 requests per hour.
"""

def __init__(
self,
*,
owner: str,
repo: str,
workflow_id: str,
ref: str = "main",
model_base_url: str = "https://tracing.fireworks.ai",
poll_interval: float = 3.0,
Comment thread
xzrderek marked this conversation as resolved.
Outdated
timeout_seconds: float = 1800.0,
max_retry_attempts: int = 5,
Comment thread
xzrderek marked this conversation as resolved.
Outdated
output_data_loader: Optional[Callable[[DataLoaderConfig], DynamicDataLoader]] = None,
Comment thread
xzrderek marked this conversation as resolved.
):
self.owner = owner
self.repo = repo
self.workflow_id = workflow_id
self.ref = ref
self.model_base_url = model_base_url
_ep_model_base_url = os.getenv("EP_MODEL_BASE_URL")
if _ep_model_base_url:
self.model_base_url = _ep_model_base_url
self.poll_interval = poll_interval
self.timeout_seconds = timeout_seconds
self.max_retry_attempts = max_retry_attempts
self._output_data_loader = output_data_loader or default_fireworks_output_data_loader

def _headers(self) -> Dict[str, str]:
headers = {"Accept": "application/vnd.github+json"}
token = os.getenv("GITHUB_TOKEN")
if not token:
raise ValueError("GITHUB_TOKEN environment variable is required")
headers["Authorization"] = f"Bearer {token}"
return headers

def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
# Calculate max_pages based on number of rows we're processing
num_rows = len(rows)
max_pages = (num_rows + 99) // 100 # Round up pages

async def _process_row(row: EvaluationRow) -> EvaluationRow:
start_time = time.perf_counter()

if row.execution_metadata.invocation_id is None:
raise ValueError("Invocation ID is required in GithubActionRolloutProcessor")
if row.execution_metadata.experiment_id is None:
raise ValueError("Experiment ID is required in GithubActionRolloutProcessor")
if row.execution_metadata.rollout_id is None:
raise ValueError("Rollout ID is required in GithubActionRolloutProcessor")
if row.execution_metadata.run_id is None:
raise ValueError("Run ID is required in GithubActionRolloutProcessor")
if row.input_metadata.row_id is None:
raise ValueError("Row ID is required in GithubActionRolloutProcessor")

init_request = build_init_request(row, config, self.model_base_url)

def _dispatch_workflow():
url = f"https://api.github.com/repos/{self.owner}/{self.repo}/actions/workflows/{self.workflow_id}/dispatches"
payload = {
"ref": self.ref,
"inputs": {
"model": init_request.model,
"metadata": init_request.metadata.model_dump_json(),
"model_base_url": init_request.model_base_url,
},
}
r = requests.post(url, json=payload, headers=self._headers(), timeout=30)
r.raise_for_status()

await asyncio.to_thread(_dispatch_workflow)

run = None
target_name = f"rollout:{row.execution_metadata.rollout_id}"

# Look for runs created in the last 15 minutes (we just dispatched it)
cutoff_time = datetime.now(timezone.utc) - timedelta(minutes=15)
cutoff_iso = cutoff_time.isoformat()

for attempt in range(self.max_retry_attempts):
Comment thread
xzrderek marked this conversation as resolved.
Outdated
try:
page = 1
while page <= max_pages:

def _list_runs():
url = f"https://api.github.com/repos/{self.owner}/{self.repo}/actions/workflows/{self.workflow_id}/runs"
params = {
"event": "workflow_dispatch",
"branch": self.ref,
"per_page": 100,
"page": page,
"created": f">={cutoff_iso}", # Only look at recent runs
}

r = requests.get(url, params=params, headers=self._headers(), timeout=30)
r.raise_for_status()
return r.json()

runs_data = await asyncio.to_thread(_list_runs)

# Search for our target run in this page
for candidate_run in runs_data.get("workflow_runs", []):
if candidate_run.get("name") == target_name:
run = candidate_run

# If we got fewer results than per_page, we've reached the end
Comment thread
xzrderek marked this conversation as resolved.
Outdated
if len(runs_data.get("workflow_runs", [])) < 100:
break

page += 1

# If no run found, GHA might still be populating it, retry
if attempt < self.max_retry_attempts - 1:
delay = 2**attempt # Exponential backoff
await asyncio.sleep(delay)

except requests.exceptions.HTTPError as e:
# Retry on rate limits (HTTP 429)
if e.response and e.response.status_code == 429:
if attempt < self.max_retry_attempts - 1:
delay = 2**attempt # Exponential backoff
await asyncio.sleep(delay)
else:
# Give up after max attempts
raise e
else:
raise e

if not run:
row.rollout_status = Status.rollout_error(
f"Failed to find workflow run in GHA with rollout_id {row.execution_metadata.rollout_id}"
)
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
return row

run_id = run.get("id")
if not run_id:
row.rollout_status = Status.rollout_error(
f"Failed to find workflow run in GHA with rollout_id {row.execution_metadata.rollout_id}"
)
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
return row

# Poll the specific run until completion
deadline = time.time() + self.timeout_seconds

def _get_run() -> Dict[str, Any]:
"""Get status of a specific workflow run."""
url = f"https://api.github.com/repos/{self.owner}/{self.repo}/actions/runs/{run_id}"
r = requests.get(url, headers=self._headers(), timeout=30)
r.raise_for_status()
return r.json()

while time.time() < deadline:
run_data = await asyncio.to_thread(_get_run)

if run_data.get("status") == "completed":
break

await asyncio.sleep(self.poll_interval)
else:
row.rollout_status = Status.rollout_error(
f"GitHub Actions run timed out after {self.timeout_seconds} seconds"
)
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
return row

row.execution_metadata.duration_seconds = time.perf_counter() - start_time

def _update_with_trace() -> None:
return update_row_with_remote_trace(row, self._output_data_loader, self.model_base_url)

await asyncio.to_thread(_update_with_trace)

# Add GitHub Actions run URL to session data
if run_id:
github_run_url = f"https://github.com/{self.owner}/{self.repo}/actions/runs/{run_id}"
if not row.input_metadata.session_data:
row.input_metadata.session_data = {}
row.input_metadata.session_data["github_actions_run_url"] = github_run_url

return row

semaphore = config.semaphore

async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
async with semaphore:
return await _process_row(r)

return [asyncio.create_task(_sem_wrapper(row)) for row in rows]

def cleanup(self) -> None:
return None
Loading
Loading