|
| 1 | +import asyncio |
| 2 | +import base64 |
| 3 | +import json |
| 4 | +import os |
| 5 | +import tempfile |
| 6 | +import time |
| 7 | +import zipfile |
| 8 | +from typing import Any, Dict, List, Optional |
| 9 | + |
| 10 | +import requests |
| 11 | + |
| 12 | +from eval_protocol.models import EvaluationRow, Message, Status |
| 13 | + |
| 14 | +from .rollout_processor import RolloutProcessor |
| 15 | +from .types import RolloutProcessorConfig |
| 16 | + |
| 17 | + |
| 18 | +class GithubActionRolloutProcessor(RolloutProcessor): |
| 19 | + """ |
| 20 | + Rollout processor that dispatches and monitors a GitHub Actions workflow per evaluation row. |
| 21 | +
|
| 22 | + Expected GitHub Actions workflow: |
| 23 | + - Workflow dispatch with inputs: model, messages_b64, tools_b64, rollout_id, etc. |
| 24 | + - Workflow uploads artifact named "rollout-trace-{rollout_id}" containing trace JSON |
| 25 | + - Trace JSON format: {"status": "success"|"error", "messages": [...], "tools": [...], "error": str?} |
| 26 | + """ |
| 27 | + |
| 28 | + def __init__( |
| 29 | + self, |
| 30 | + *, |
| 31 | + owner: str, |
| 32 | + repo: str, |
| 33 | + workflow_id: str, |
| 34 | + ref: str = "main", |
| 35 | + github_token: Optional[str] = None, |
| 36 | + poll_interval: float = 3.0, |
| 37 | + timeout_seconds: float = 1800.0, |
| 38 | + ): |
| 39 | + self._owner = owner |
| 40 | + self._repo = repo |
| 41 | + self._workflow_id = workflow_id |
| 42 | + self._ref = ref |
| 43 | + self._poll_interval = poll_interval |
| 44 | + self._timeout_seconds = timeout_seconds |
| 45 | + self._token = github_token or os.getenv("GITHUB_TOKEN") or os.getenv("GH_TOKEN") |
| 46 | + |
| 47 | + def _headers(self) -> Dict[str, str]: |
| 48 | + headers = {"Accept": "application/vnd.github+json"} |
| 49 | + if self._token: |
| 50 | + headers["Authorization"] = f"Bearer {self._token}" |
| 51 | + return headers |
| 52 | + |
| 53 | + def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: |
| 54 | + async def _process_row(row: EvaluationRow) -> EvaluationRow: |
| 55 | + start_time = time.perf_counter() |
| 56 | + |
| 57 | + # Extract model |
| 58 | + model: Optional[str] = None |
| 59 | + if row.input_metadata and row.input_metadata.completion_params: |
| 60 | + model = row.input_metadata.completion_params.get("model") |
| 61 | + if model is None and config.completion_params: |
| 62 | + model = config.completion_params.get("model") |
| 63 | + if model is None: |
| 64 | + raise ValueError("Model must be provided") |
| 65 | + |
| 66 | + # Clean and encode messages |
| 67 | + allowed_fields = {"role", "content", "tool_calls", "tool_call_id", "name"} |
| 68 | + clean_messages = [] |
| 69 | + for m in row.messages: |
| 70 | + if hasattr(m, "model_dump"): |
| 71 | + md = m.model_dump() |
| 72 | + elif isinstance(m, dict): |
| 73 | + md = m |
| 74 | + else: |
| 75 | + md = { |
| 76 | + "role": getattr(m, "role", None), |
| 77 | + "content": getattr(m, "content", None), |
| 78 | + "tool_calls": getattr(m, "tool_calls", None), |
| 79 | + "tool_call_id": getattr(m, "tool_call_id", None), |
| 80 | + "name": getattr(m, "name", None), |
| 81 | + } |
| 82 | + clean_messages.append({k: v for k, v in md.items() if k in allowed_fields and v is not None}) |
| 83 | + |
| 84 | + # Prepare workflow inputs |
| 85 | + inputs = { |
| 86 | + "model": model, |
| 87 | + "rollout_id": row.execution_metadata.rollout_id, |
| 88 | + "messages_b64": base64.b64encode(json.dumps(clean_messages).encode()).decode(), |
| 89 | + } |
| 90 | + if row.tools: |
| 91 | + inputs["tools_b64"] = base64.b64encode(json.dumps(row.tools).encode()).decode() |
| 92 | + |
| 93 | + # Dispatch workflow |
| 94 | + def _dispatch(): |
| 95 | + url = f"https://api.github.com/repos/{self._owner}/{self._repo}/actions/workflows/{self._workflow_id}/dispatches" |
| 96 | + payload = {"ref": self._ref, "inputs": inputs} |
| 97 | + r = requests.post(url, json=payload, headers=self._headers(), timeout=30) |
| 98 | + r.raise_for_status() |
| 99 | + |
| 100 | + await asyncio.to_thread(_dispatch) |
| 101 | + |
| 102 | + # Poll for completion |
| 103 | + deadline = time.time() + self._timeout_seconds |
| 104 | + run_id = None |
| 105 | + |
| 106 | + while time.time() < deadline: |
| 107 | + |
| 108 | + def _list_runs(): |
| 109 | + url = f"https://api.github.com/repos/{self._owner}/{self._repo}/actions/workflows/{self._workflow_id}/runs" |
| 110 | + params = {"event": "workflow_dispatch", "branch": self._ref, "per_page": 10} |
| 111 | + r = requests.get(url, params=params, headers=self._headers(), timeout=30) |
| 112 | + r.raise_for_status() |
| 113 | + return r.json() |
| 114 | + |
| 115 | + runs_data = await asyncio.to_thread(_list_runs) |
| 116 | + runs = runs_data.get("workflow_runs", []) |
| 117 | + |
| 118 | + # Find our run (prefer by name, fallback to newest) |
| 119 | + preferred_name = f"rollout-{row.execution_metadata.rollout_id}" |
| 120 | + candidate_run = None |
| 121 | + for r in runs: |
| 122 | + if r.get("name") == preferred_name: |
| 123 | + candidate_run = r |
| 124 | + break |
| 125 | + if not candidate_run and runs: |
| 126 | + candidate_run = sorted(runs, key=lambda r: r.get("id", 0), reverse=True)[0] |
| 127 | + |
| 128 | + if candidate_run and candidate_run.get("status") == "completed": |
| 129 | + run_id = candidate_run.get("id") |
| 130 | + row.rollout_status = self._map_conclusion_to_status(candidate_run.get("conclusion")) |
| 131 | + break |
| 132 | + |
| 133 | + await asyncio.sleep(self._poll_interval) |
| 134 | + else: |
| 135 | + row.rollout_status = Status.rollout_error( |
| 136 | + f"GitHub Actions run timed out after {self._timeout_seconds} seconds" |
| 137 | + ) |
| 138 | + row.execution_metadata.duration_seconds = time.perf_counter() - start_time |
| 139 | + return row |
| 140 | + |
| 141 | + # Fetch trace from artifacts |
| 142 | + if run_id: |
| 143 | + |
| 144 | + def _get_artifacts(): |
| 145 | + url = f"https://api.github.com/repos/{self._owner}/{self._repo}/actions/runs/{run_id}/artifacts" |
| 146 | + r = requests.get(url, headers=self._headers(), timeout=30) |
| 147 | + r.raise_for_status() |
| 148 | + return r.json() |
| 149 | + |
| 150 | + artifacts_data = await asyncio.to_thread(_get_artifacts) |
| 151 | + artifacts = artifacts_data.get("artifacts", []) |
| 152 | + |
| 153 | + # Find trace artifact |
| 154 | + trace_artifact = None |
| 155 | + for artifact in artifacts: |
| 156 | + if artifact.get("name") == f"rollout-trace-{row.execution_metadata.rollout_id}": |
| 157 | + trace_artifact = artifact |
| 158 | + break |
| 159 | + |
| 160 | + if trace_artifact: |
| 161 | + |
| 162 | + def _download_and_extract(): |
| 163 | + # Download artifact |
| 164 | + r = requests.get(trace_artifact["archive_download_url"], headers=self._headers(), timeout=60) |
| 165 | + r.raise_for_status() |
| 166 | + |
| 167 | + # Extract trace JSON |
| 168 | + with tempfile.NamedTemporaryFile() as tmp_file: |
| 169 | + tmp_file.write(r.content) |
| 170 | + tmp_file.flush() |
| 171 | + |
| 172 | + with zipfile.ZipFile(tmp_file.name, "r") as zip_file: |
| 173 | + trace_filename = f"rollout_trace_{row.execution_metadata.rollout_id}.json" |
| 174 | + if trace_filename in zip_file.namelist(): |
| 175 | + with zip_file.open(trace_filename) as trace_file: |
| 176 | + return json.loads(trace_file.read().decode("utf-8")) |
| 177 | + return None |
| 178 | + |
| 179 | + trace_data = await asyncio.to_thread(_download_and_extract) |
| 180 | + |
| 181 | + if trace_data and trace_data.get("status") == "success": |
| 182 | + trace_messages = trace_data.get("messages", []) |
| 183 | + if len(trace_messages) > len(row.messages): |
| 184 | + row.messages = [Message(**msg) if isinstance(msg, dict) else msg for msg in trace_messages] |
| 185 | + if trace_data.get("tools"): |
| 186 | + row.tools = trace_data["tools"] |
| 187 | + else: |
| 188 | + row.rollout_status = Status.rollout_error("Rollout finished with same number of messages") |
| 189 | + else: |
| 190 | + error_msg = trace_data.get("error", "Unknown error") if trace_data else "No trace data found" |
| 191 | + row.rollout_status = Status.rollout_error(f"Rollout failed: {error_msg}") |
| 192 | + |
| 193 | + row.execution_metadata.duration_seconds = time.perf_counter() - start_time |
| 194 | + return row |
| 195 | + |
| 196 | + semaphore = config.semaphore |
| 197 | + |
| 198 | + async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: |
| 199 | + async with semaphore: |
| 200 | + return await _process_row(r) |
| 201 | + |
| 202 | + return [asyncio.create_task(_sem_wrapper(row)) for row in rows] |
| 203 | + |
| 204 | + @staticmethod |
| 205 | + def _map_conclusion_to_status(conclusion: Optional[str]) -> Status: |
| 206 | + if conclusion == "success": |
| 207 | + return Status.finished("GitHub Actions workflow succeeded") |
| 208 | + if conclusion in {"failure", "timed_out", "cancelled", "stale"}: |
| 209 | + return Status.rollout_error(f"GitHub Actions workflow concluded with '{conclusion}'") |
| 210 | + return Status(code=Status.Code.UNKNOWN, message=f"GitHub Actions workflow concluded with '{conclusion}'") |
| 211 | + |
| 212 | + def cleanup(self) -> None: |
| 213 | + return None |
0 commit comments