Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 8 additions & 17 deletions .github/workflows/rollout.yml
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
name: Eval Protocol Rollout

run-name: rollout:${{ inputs.rollout_id }}
run-name: rollout:${{ fromJSON(inputs.metadata).rollout_id }}

on:
workflow_dispatch:
inputs:
model:
description: 'Model to use for the rollout'
description: 'Model to use'
required: true
type: string
rollout_id:
description: 'Rollout ID for tracking'
metadata:
description: 'JSON serialized metadata object'
required: true
type: string
prompt:
description: 'User prompt for the rollout'
model_base_url:
description: 'Base URL for the model API'
required: true
type: string

jobs:
rollout:
runs-on: ubuntu-latest
name: rollout-${{ inputs.rollout_id }}

steps:
- name: Checkout code
Expand All @@ -43,13 +42,5 @@ jobs:
run: |
python tests/github_actions/rollout_worker.py \
--model "${{ inputs.model }}" \
--rollout-id "${{ inputs.rollout_id }}" \
--prompt "${{ inputs.prompt }}"

- name: Upload rollout trace
uses: actions/upload-artifact@v4
if: always() # Upload even if the rollout failed
with:
name: rollout-trace-${{ inputs.rollout_id }}
path: rollout_trace_${{ inputs.rollout_id }}.json
retention-days: 7
--metadata '${{ inputs.metadata }}' \
--model-base-url "${{ inputs.model_base_url }}"
3 changes: 2 additions & 1 deletion eval_protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .reward_function import RewardFunction
from .typed_interface import reward_function
from .quickstart import aha_judge, multi_turn_assistant_to_ground_truth, assistant_to_ground_truth
from .pytest import evaluation_test, SingleTurnRolloutProcessor, RemoteRolloutProcessor
from .pytest import evaluation_test, SingleTurnRolloutProcessor, RemoteRolloutProcessor, GithubActionRolloutProcessor
from .pytest.remote_rollout_processor import create_elasticsearch_config_from_env
from .pytest.parameterize import DefaultParameterIdGenerator
from .log_utils.elasticsearch_direct_http_handler import ElasticsearchDirectHttpHandler
Expand Down Expand Up @@ -85,6 +85,7 @@
"DataLoaderConfig",
"Status",
"RemoteRolloutProcessor",
"GithubActionRolloutProcessor",
"InputMetadata",
"EvaluationRow",
"DefaultParameterIdGenerator",
Expand Down
4 changes: 3 additions & 1 deletion eval_protocol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,9 @@ class EvaluationRow(BaseModel):
model_config = ConfigDict(extra="allow")

# Core OpenAI ChatCompletion compatible conversation data
messages: List[Message] = Field(description="List of messages in the conversation. Also known as a trajectory.")
messages: List[Message] = Field(
default_factory=list, description="List of messages in the conversation. Also known as a trajectory."
)

# Tool and function call information
tools: Optional[List[Dict[str, Any]]] = Field(
Expand Down
2 changes: 2 additions & 0 deletions eval_protocol/pytest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .default_no_op_rollout_processor import NoOpRolloutProcessor
from .default_single_turn_rollout_process import SingleTurnRolloutProcessor
from .remote_rollout_processor import RemoteRolloutProcessor
from .github_action_rollout_processor import GithubActionRolloutProcessor
from .evaluation_test import evaluation_test
from .exception_config import ExceptionHandlerConfig, BackoffConfig, get_default_exception_handler_config
from .rollout_processor import RolloutProcessor
Expand Down Expand Up @@ -33,6 +34,7 @@
"RolloutProcessor",
"SingleTurnRolloutProcessor",
"RemoteRolloutProcessor",
"GithubActionRolloutProcessor",
"NoOpRolloutProcessor",
"default_dataset_adapter",
"RolloutProcessorConfig",
Expand Down
2 changes: 1 addition & 1 deletion eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
)


from eval_protocol.pytest.utils import (
from eval_protocol.pytest.evaluation_test_utils import (
AggregationMethod,
add_cost_metrics,
log_eval_status_and_rows,
Expand Down
7 changes: 6 additions & 1 deletion eval_protocol/pytest/evaluation_test_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from eval_protocol.models import CompletionParams, EvaluationRow, EvaluationThreshold, Status
from eval_protocol.pytest.handle_persist_flow import handle_persist_flow
from eval_protocol.pytest.types import EvaluationTestMode
from eval_protocol.pytest.utils import AggregationMethod, aggregate, extract_effort_tag, sanitize_filename
from eval_protocol.pytest.evaluation_test_utils import (
AggregationMethod,
aggregate,
extract_effort_tag,
sanitize_filename,
)
from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci


Expand Down
2 changes: 1 addition & 1 deletion eval_protocol/pytest/generate_parameter_combinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from eval_protocol.data_loader.models import EvaluationDataLoader
from eval_protocol.models import CompletionParams, EvaluationRow
from eval_protocol.pytest.types import Dataset, DatasetPathParam, EvaluationInputParam, InputMessagesParam
from eval_protocol.pytest.utils import parse_ep_max_rows
from eval_protocol.pytest.evaluation_test_utils import parse_ep_max_rows
from collections.abc import Sequence


Expand Down
223 changes: 223 additions & 0 deletions eval_protocol/pytest/github_action_rollout_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import asyncio
import os
import time
from typing import Any, Callable, Dict, List, Optional

import requests
from datetime import datetime, timezone, timedelta
from eval_protocol.models import EvaluationRow, Status
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig

from .rollout_processor import RolloutProcessor
from .types import RolloutProcessorConfig
from .tracing_utils import default_fireworks_output_data_loader, build_init_request, update_row_with_remote_trace


class GithubActionRolloutProcessor(RolloutProcessor):
"""
Rollout processor that dispatches and monitors a GitHub Actions workflow per evaluation row.

Expected GitHub Actions workflow:
- Workflow dispatch with inputs: model, metadata (JSON), model_base_url
- Workflow makes API calls that get traced (e.g., via Fireworks tracing proxy)
- Traces are fetched later via output_data_loader using rollout_id tags

NOTE: GHA has a rate limit of 5000 requests per hour.
"""

def __init__(
self,
*,
owner: str,
repo: str,
workflow_id: str,
ref: str = "main",
model_base_url: str = "https://tracing.fireworks.ai",
poll_interval: float = 10.0,
timeout_seconds: float = 1800.0,
max_find_workflow_retries: int = 5,
github_token: Optional[str] = None,
output_data_loader: Optional[Callable[[DataLoaderConfig], DynamicDataLoader]] = None,
):
self.owner = owner
self.repo = repo
self.workflow_id = workflow_id
self.ref = ref
self.model_base_url = model_base_url
_ep_model_base_url = os.getenv("EP_MODEL_BASE_URL")
if _ep_model_base_url:
self.model_base_url = _ep_model_base_url
self.poll_interval = poll_interval
self.timeout_seconds = timeout_seconds
self.max_find_workflow_retries = max_find_workflow_retries
self.github_token = github_token
self._output_data_loader = output_data_loader or default_fireworks_output_data_loader

def _headers(self) -> Dict[str, str]:
headers = {"Accept": "application/vnd.github+json"}
token = self.github_token or os.getenv("GITHUB_TOKEN")
if not token:
raise ValueError(
"GitHub token is required. Provide it via github_token parameter or GITHUB_TOKEN environment variable"
)
headers["Authorization"] = f"Bearer {token}"
return headers

def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
# Calculate max_pages based on number of rows we're processing
num_rows = len(rows)
max_pages = (num_rows + 99) // 100 # Round up pages

async def _process_row(row: EvaluationRow) -> EvaluationRow:
start_time = time.perf_counter()

if row.execution_metadata.invocation_id is None:
raise ValueError("Invocation ID is required in GithubActionRolloutProcessor")
if row.execution_metadata.experiment_id is None:
raise ValueError("Experiment ID is required in GithubActionRolloutProcessor")
if row.execution_metadata.rollout_id is None:
raise ValueError("Rollout ID is required in GithubActionRolloutProcessor")
if row.execution_metadata.run_id is None:
raise ValueError("Run ID is required in GithubActionRolloutProcessor")
if row.input_metadata.row_id is None:
raise ValueError("Row ID is required in GithubActionRolloutProcessor")

init_request = build_init_request(row, config, self.model_base_url)

def _dispatch_workflow():
url = f"https://api.github.com/repos/{self.owner}/{self.repo}/actions/workflows/{self.workflow_id}/dispatches"
payload = {
"ref": self.ref,
"inputs": {
"model": init_request.model,
"metadata": init_request.metadata.model_dump_json(),
"model_base_url": init_request.model_base_url,
},
}
r = requests.post(url, json=payload, headers=self._headers(), timeout=30)
r.raise_for_status()

await asyncio.to_thread(_dispatch_workflow)

run = None
target_name = f"rollout:{row.execution_metadata.rollout_id}"

# Look for runs created in the last 15 minutes (we just dispatched it)
cutoff_time = datetime.now(timezone.utc) - timedelta(minutes=15)
cutoff_iso = cutoff_time.isoformat()

for attempt in range(self.max_find_workflow_retries):
try:
page = 1
while page <= max_pages:

def _list_runs():
url = f"https://api.github.com/repos/{self.owner}/{self.repo}/actions/workflows/{self.workflow_id}/runs"
params = {
"event": "workflow_dispatch",
"branch": self.ref,
"per_page": 100, # Max per_page is 100, minimize total number of pages
"page": page,
"created": f">={cutoff_iso}", # Only look at recent runs
}

r = requests.get(url, params=params, headers=self._headers(), timeout=30)
r.raise_for_status()
return r.json()

runs_data = await asyncio.to_thread(_list_runs)

# Search for our target run in this page
for candidate_run in runs_data.get("workflow_runs", []):
if candidate_run.get("name") == target_name:
run = candidate_run

# If we got fewer results than 100, we've reached the end, since we paginate in chunks of 100
if len(runs_data.get("workflow_runs", [])) < 100:
break

page += 1

# If no run found, GHA might still be populating it, retry
if attempt < self.max_find_workflow_retries - 1:
delay = 2**attempt # Exponential backoff
await asyncio.sleep(delay)

except requests.exceptions.HTTPError as e:
# Retry on rate limits (HTTP 429)
if e.response and e.response.status_code == 429:
if attempt < self.max_find_workflow_retries - 1:
delay = 2**attempt # Exponential backoff
await asyncio.sleep(delay)
else:
# Give up after max attempts
raise e
else:
raise e

if not run:
row.rollout_status = Status.rollout_error(
f"Failed to find workflow run in GHA with rollout_id {row.execution_metadata.rollout_id}"
)
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
return row

run_id = run.get("id")
if not run_id:
row.rollout_status = Status.rollout_error(
f"Failed to find workflow run in GHA with rollout_id {row.execution_metadata.rollout_id}"
)
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
return row

# Poll the specific run until completion
deadline = time.time() + self.timeout_seconds

def _get_run() -> Dict[str, Any]:
"""Get status of a specific workflow run."""
url = f"https://api.github.com/repos/{self.owner}/{self.repo}/actions/runs/{run_id}"
r = requests.get(url, headers=self._headers(), timeout=30)
r.raise_for_status()
return r.json()

while time.time() < deadline:
run_data = await asyncio.to_thread(_get_run)

if run_data.get("status") == "completed":
break

await asyncio.sleep(self.poll_interval)
else:
row.rollout_status = Status.rollout_error(
f"GitHub Actions run timed out after {self.timeout_seconds} seconds"
)
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
return row

row.execution_metadata.duration_seconds = time.perf_counter() - start_time

def _update_with_trace() -> None:
return update_row_with_remote_trace(row, self._output_data_loader, self.model_base_url)

await asyncio.to_thread(_update_with_trace)

# Add GitHub Actions run URL to session data
if run_id:
github_run_url = f"https://github.com/{self.owner}/{self.repo}/actions/runs/{run_id}"
if not row.input_metadata.session_data:
row.input_metadata.session_data = {}
row.input_metadata.session_data["github_actions_run_url"] = github_run_url

return row

semaphore = config.semaphore

async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
async with semaphore:
return await _process_row(r)

return [asyncio.create_task(_sem_wrapper(row)) for row in rows]

def cleanup(self) -> None:
return None
Loading
Loading