Skip to content

Commit 0ca752b

Browse files
committed
gha test
1 parent a3baa0a commit 0ca752b

File tree

7 files changed

+497
-1
lines changed

7 files changed

+497
-1
lines changed

.github/workflows/rollout.yml

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
name: Eval Protocol Rollout
2+
3+
on:
4+
workflow_dispatch:
5+
inputs:
6+
model:
7+
description: 'Model to use for the rollout'
8+
required: true
9+
type: string
10+
rollout_id:
11+
description: 'Rollout ID for tracking'
12+
required: true
13+
type: string
14+
messages_b64:
15+
description: 'Base64 encoded JSON messages array'
16+
required: true
17+
type: string
18+
tools_b64:
19+
description: 'Base64 encoded JSON tools array (optional)'
20+
required: false
21+
type: string
22+
23+
jobs:
24+
rollout:
25+
runs-on: ubuntu-latest
26+
name: rollout-${{ inputs.rollout_id }}
27+
28+
steps:
29+
- name: Checkout code
30+
uses: actions/checkout@v4
31+
32+
- name: Set up Python
33+
uses: actions/setup-python@v5
34+
with:
35+
python-version: '3.11'
36+
37+
- name: Install dependencies
38+
run: |
39+
python -m pip install --upgrade pip
40+
pip install -e .
41+
42+
- name: Run rollout script
43+
env:
44+
FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }}
45+
run: |
46+
python tests/github_actions/rollout_worker.py \
47+
--model "${{ inputs.model }}" \
48+
--rollout-id "${{ inputs.rollout_id }}" \
49+
--messages-b64 "${{ inputs.messages_b64 }}" \
50+
${{ inputs.tools_b64 && format('--tools-b64 "{0}"', inputs.tools_b64) || '' }}
51+
52+
- name: Upload rollout trace
53+
uses: actions/upload-artifact@v4
54+
if: always() # Upload even if the rollout failed
55+
with:
56+
name: rollout-trace-${{ inputs.rollout_id }}
57+
path: rollout_trace_${{ inputs.rollout_id }}.json
58+
retention-days: 7

eval_protocol/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from .reward_function import RewardFunction
3131
from .typed_interface import reward_function
3232
from .quickstart import aha_judge, multi_turn_assistant_to_ground_truth, assistant_to_ground_truth
33-
from .pytest import evaluation_test, SingleTurnRolloutProcessor, RemoteRolloutProcessor
33+
from .pytest import evaluation_test, SingleTurnRolloutProcessor, RemoteRolloutProcessor, GithubActionRolloutProcessor
3434
from .pytest.remote_rollout_processor import create_elasticsearch_config_from_env
3535
from .pytest.parameterize import DefaultParameterIdGenerator
3636
from .log_utils.elasticsearch_direct_http_handler import ElasticsearchDirectHttpHandler
@@ -85,6 +85,7 @@
8585
"DataLoaderConfig",
8686
"Status",
8787
"RemoteRolloutProcessor",
88+
"GithubActionRolloutProcessor",
8889
"InputMetadata",
8990
"EvaluationRow",
9091
"DefaultParameterIdGenerator",

eval_protocol/pytest/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .default_no_op_rollout_processor import NoOpRolloutProcessor
55
from .default_single_turn_rollout_process import SingleTurnRolloutProcessor
66
from .remote_rollout_processor import RemoteRolloutProcessor
7+
from .github_action_rollout_processor import GithubActionRolloutProcessor
78
from .evaluation_test import evaluation_test
89
from .exception_config import ExceptionHandlerConfig, BackoffConfig, get_default_exception_handler_config
910
from .rollout_processor import RolloutProcessor
@@ -33,6 +34,7 @@
3334
"RolloutProcessor",
3435
"SingleTurnRolloutProcessor",
3536
"RemoteRolloutProcessor",
37+
"GithubActionRolloutProcessor",
3638
"NoOpRolloutProcessor",
3739
"default_dataset_adapter",
3840
"RolloutProcessorConfig",
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
import asyncio
2+
import base64
3+
import json
4+
import os
5+
import tempfile
6+
import time
7+
import zipfile
8+
from typing import Any, Dict, List, Optional
9+
10+
import requests
11+
12+
from eval_protocol.models import EvaluationRow, Message, Status
13+
14+
from .rollout_processor import RolloutProcessor
15+
from .types import RolloutProcessorConfig
16+
17+
18+
class GithubActionRolloutProcessor(RolloutProcessor):
19+
"""
20+
Rollout processor that dispatches and monitors a GitHub Actions workflow per evaluation row.
21+
22+
Expected GitHub Actions workflow:
23+
- Workflow dispatch with inputs: model, messages_b64, tools_b64, rollout_id, etc.
24+
- Workflow uploads artifact named "rollout-trace-{rollout_id}" containing trace JSON
25+
- Trace JSON format: {"status": "success"|"error", "messages": [...], "tools": [...], "error": str?}
26+
"""
27+
28+
def __init__(
29+
self,
30+
*,
31+
owner: str,
32+
repo: str,
33+
workflow_id: str,
34+
ref: str = "main",
35+
github_token: Optional[str] = None,
36+
poll_interval: float = 3.0,
37+
timeout_seconds: float = 1800.0,
38+
):
39+
self._owner = owner
40+
self._repo = repo
41+
self._workflow_id = workflow_id
42+
self._ref = ref
43+
self._poll_interval = poll_interval
44+
self._timeout_seconds = timeout_seconds
45+
self._token = github_token or os.getenv("GITHUB_TOKEN") or os.getenv("GH_TOKEN")
46+
47+
def _headers(self) -> Dict[str, str]:
48+
headers = {"Accept": "application/vnd.github+json"}
49+
if self._token:
50+
headers["Authorization"] = f"Bearer {self._token}"
51+
return headers
52+
53+
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
54+
async def _process_row(row: EvaluationRow) -> EvaluationRow:
55+
start_time = time.perf_counter()
56+
57+
# Extract model
58+
model: Optional[str] = None
59+
if row.input_metadata and row.input_metadata.completion_params:
60+
model = row.input_metadata.completion_params.get("model")
61+
if model is None and config.completion_params:
62+
model = config.completion_params.get("model")
63+
if model is None:
64+
raise ValueError("Model must be provided")
65+
66+
# Clean and encode messages
67+
allowed_fields = {"role", "content", "tool_calls", "tool_call_id", "name"}
68+
clean_messages = []
69+
for m in row.messages:
70+
if hasattr(m, "model_dump"):
71+
md = m.model_dump()
72+
elif isinstance(m, dict):
73+
md = m
74+
else:
75+
md = {
76+
"role": getattr(m, "role", None),
77+
"content": getattr(m, "content", None),
78+
"tool_calls": getattr(m, "tool_calls", None),
79+
"tool_call_id": getattr(m, "tool_call_id", None),
80+
"name": getattr(m, "name", None),
81+
}
82+
clean_messages.append({k: v for k, v in md.items() if k in allowed_fields and v is not None})
83+
84+
# Prepare workflow inputs
85+
inputs = {
86+
"model": model,
87+
"rollout_id": row.execution_metadata.rollout_id,
88+
"messages_b64": base64.b64encode(json.dumps(clean_messages).encode()).decode(),
89+
}
90+
if row.tools:
91+
inputs["tools_b64"] = base64.b64encode(json.dumps(row.tools).encode()).decode()
92+
93+
# Dispatch workflow
94+
def _dispatch():
95+
url = f"https://api.github.com/repos/{self._owner}/{self._repo}/actions/workflows/{self._workflow_id}/dispatches"
96+
payload = {"ref": self._ref, "inputs": inputs}
97+
r = requests.post(url, json=payload, headers=self._headers(), timeout=30)
98+
r.raise_for_status()
99+
100+
await asyncio.to_thread(_dispatch)
101+
102+
# Poll for completion
103+
deadline = time.time() + self._timeout_seconds
104+
run_id = None
105+
106+
while time.time() < deadline:
107+
108+
def _list_runs():
109+
url = f"https://api.github.com/repos/{self._owner}/{self._repo}/actions/workflows/{self._workflow_id}/runs"
110+
params = {"event": "workflow_dispatch", "branch": self._ref, "per_page": 10}
111+
r = requests.get(url, params=params, headers=self._headers(), timeout=30)
112+
r.raise_for_status()
113+
return r.json()
114+
115+
runs_data = await asyncio.to_thread(_list_runs)
116+
runs = runs_data.get("workflow_runs", [])
117+
118+
# Find our run (prefer by name, fallback to newest)
119+
preferred_name = f"rollout-{row.execution_metadata.rollout_id}"
120+
candidate_run = None
121+
for r in runs:
122+
if r.get("name") == preferred_name:
123+
candidate_run = r
124+
break
125+
if not candidate_run and runs:
126+
candidate_run = sorted(runs, key=lambda r: r.get("id", 0), reverse=True)[0]
127+
128+
if candidate_run and candidate_run.get("status") == "completed":
129+
run_id = candidate_run.get("id")
130+
row.rollout_status = self._map_conclusion_to_status(candidate_run.get("conclusion"))
131+
break
132+
133+
await asyncio.sleep(self._poll_interval)
134+
else:
135+
row.rollout_status = Status.rollout_error(
136+
f"GitHub Actions run timed out after {self._timeout_seconds} seconds"
137+
)
138+
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
139+
return row
140+
141+
# Fetch trace from artifacts
142+
if run_id:
143+
144+
def _get_artifacts():
145+
url = f"https://api.github.com/repos/{self._owner}/{self._repo}/actions/runs/{run_id}/artifacts"
146+
r = requests.get(url, headers=self._headers(), timeout=30)
147+
r.raise_for_status()
148+
return r.json()
149+
150+
artifacts_data = await asyncio.to_thread(_get_artifacts)
151+
artifacts = artifacts_data.get("artifacts", [])
152+
153+
# Find trace artifact
154+
trace_artifact = None
155+
for artifact in artifacts:
156+
if artifact.get("name") == f"rollout-trace-{row.execution_metadata.rollout_id}":
157+
trace_artifact = artifact
158+
break
159+
160+
if trace_artifact:
161+
162+
def _download_and_extract():
163+
# Download artifact
164+
r = requests.get(trace_artifact["archive_download_url"], headers=self._headers(), timeout=60)
165+
r.raise_for_status()
166+
167+
# Extract trace JSON
168+
with tempfile.NamedTemporaryFile() as tmp_file:
169+
tmp_file.write(r.content)
170+
tmp_file.flush()
171+
172+
with zipfile.ZipFile(tmp_file.name, "r") as zip_file:
173+
trace_filename = f"rollout_trace_{row.execution_metadata.rollout_id}.json"
174+
if trace_filename in zip_file.namelist():
175+
with zip_file.open(trace_filename) as trace_file:
176+
return json.loads(trace_file.read().decode("utf-8"))
177+
return None
178+
179+
trace_data = await asyncio.to_thread(_download_and_extract)
180+
181+
if trace_data and trace_data.get("status") == "success":
182+
trace_messages = trace_data.get("messages", [])
183+
if len(trace_messages) > len(row.messages):
184+
row.messages = [Message(**msg) if isinstance(msg, dict) else msg for msg in trace_messages]
185+
if trace_data.get("tools"):
186+
row.tools = trace_data["tools"]
187+
else:
188+
row.rollout_status = Status.rollout_error("Rollout finished with same number of messages")
189+
else:
190+
error_msg = trace_data.get("error", "Unknown error") if trace_data else "No trace data found"
191+
row.rollout_status = Status.rollout_error(f"Rollout failed: {error_msg}")
192+
193+
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
194+
return row
195+
196+
semaphore = config.semaphore
197+
198+
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
199+
async with semaphore:
200+
return await _process_row(r)
201+
202+
return [asyncio.create_task(_sem_wrapper(row)) for row in rows]
203+
204+
@staticmethod
205+
def _map_conclusion_to_status(conclusion: Optional[str]) -> Status:
206+
if conclusion == "success":
207+
return Status.finished("GitHub Actions workflow succeeded")
208+
if conclusion in {"failure", "timed_out", "cancelled", "stale"}:
209+
return Status.rollout_error(f"GitHub Actions workflow concluded with '{conclusion}'")
210+
return Status(code=Status.Code.UNKNOWN, message=f"GitHub Actions workflow concluded with '{conclusion}'")
211+
212+
def cleanup(self) -> None:
213+
return None

tests/github_actions/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# GitHub Actions rollout processor tests and scripts

0 commit comments

Comments
 (0)