-
Notifications
You must be signed in to change notification settings - Fork 16
GithubActionRolloutProcessor #273
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
0ca752b
gha test
xzrderek 0cbb564
switch to prompt
xzrderek 650a349
fix rollout worker
xzrderek 269f0ae
test naming
xzrderek f7755d5
test
xzrderek c31771c
update test
xzrderek 932a88e
working example now
xzrderek 6482c64
remove unused code
xzrderek 5690f98
Merge branch 'main' into derekx/gha-rollout-processor
xzrderek 2b5887f
fix test
xzrderek f72295c
merged wrong rollout yml
xzrderek 1f9435b
remove some uneeded imports
xzrderek aee8fdc
make better comment
xzrderek 003b9f1
add rate limit retry logic
xzrderek 82dcc0e
addressed comments
xzrderek File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
223 changes: 223 additions & 0 deletions
223
eval_protocol/pytest/github_action_rollout_processor.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,223 @@ | ||
| import asyncio | ||
| import os | ||
| import time | ||
| from typing import Any, Callable, Dict, List, Optional | ||
|
|
||
| import requests | ||
| from datetime import datetime, timezone, timedelta | ||
| from eval_protocol.models import EvaluationRow, Status | ||
| from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader | ||
| from eval_protocol.types.remote_rollout_processor import DataLoaderConfig | ||
|
|
||
| from .rollout_processor import RolloutProcessor | ||
| from .types import RolloutProcessorConfig | ||
| from .tracing_utils import default_fireworks_output_data_loader, build_init_request, update_row_with_remote_trace | ||
|
|
||
|
|
||
| class GithubActionRolloutProcessor(RolloutProcessor): | ||
| """ | ||
| Rollout processor that dispatches and monitors a GitHub Actions workflow per evaluation row. | ||
|
|
||
| Expected GitHub Actions workflow: | ||
| - Workflow dispatch with inputs: model, metadata (JSON), model_base_url | ||
| - Workflow makes API calls that get traced (e.g., via Fireworks tracing proxy) | ||
| - Traces are fetched later via output_data_loader using rollout_id tags | ||
|
|
||
| NOTE: GHA has a rate limit of 5000 requests per hour. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| owner: str, | ||
| repo: str, | ||
| workflow_id: str, | ||
| ref: str = "main", | ||
| model_base_url: str = "https://tracing.fireworks.ai", | ||
| poll_interval: float = 10.0, | ||
| timeout_seconds: float = 1800.0, | ||
| max_find_workflow_retries: int = 5, | ||
| github_token: Optional[str] = None, | ||
| output_data_loader: Optional[Callable[[DataLoaderConfig], DynamicDataLoader]] = None, | ||
xzrderek marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ): | ||
| self.owner = owner | ||
| self.repo = repo | ||
| self.workflow_id = workflow_id | ||
| self.ref = ref | ||
| self.model_base_url = model_base_url | ||
| _ep_model_base_url = os.getenv("EP_MODEL_BASE_URL") | ||
| if _ep_model_base_url: | ||
| self.model_base_url = _ep_model_base_url | ||
| self.poll_interval = poll_interval | ||
| self.timeout_seconds = timeout_seconds | ||
| self.max_find_workflow_retries = max_find_workflow_retries | ||
| self.github_token = github_token | ||
| self._output_data_loader = output_data_loader or default_fireworks_output_data_loader | ||
|
|
||
| def _headers(self) -> Dict[str, str]: | ||
| headers = {"Accept": "application/vnd.github+json"} | ||
| token = self.github_token or os.getenv("GITHUB_TOKEN") | ||
| if not token: | ||
| raise ValueError( | ||
| "GitHub token is required. Provide it via github_token parameter or GITHUB_TOKEN environment variable" | ||
| ) | ||
| headers["Authorization"] = f"Bearer {token}" | ||
| return headers | ||
|
|
||
| def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: | ||
| # Calculate max_pages based on number of rows we're processing | ||
| num_rows = len(rows) | ||
| max_pages = (num_rows + 99) // 100 # Round up pages | ||
|
|
||
| async def _process_row(row: EvaluationRow) -> EvaluationRow: | ||
| start_time = time.perf_counter() | ||
|
|
||
| if row.execution_metadata.invocation_id is None: | ||
| raise ValueError("Invocation ID is required in GithubActionRolloutProcessor") | ||
| if row.execution_metadata.experiment_id is None: | ||
| raise ValueError("Experiment ID is required in GithubActionRolloutProcessor") | ||
| if row.execution_metadata.rollout_id is None: | ||
| raise ValueError("Rollout ID is required in GithubActionRolloutProcessor") | ||
| if row.execution_metadata.run_id is None: | ||
| raise ValueError("Run ID is required in GithubActionRolloutProcessor") | ||
| if row.input_metadata.row_id is None: | ||
| raise ValueError("Row ID is required in GithubActionRolloutProcessor") | ||
|
|
||
| init_request = build_init_request(row, config, self.model_base_url) | ||
|
|
||
| def _dispatch_workflow(): | ||
| url = f"https://api.github.com/repos/{self.owner}/{self.repo}/actions/workflows/{self.workflow_id}/dispatches" | ||
| payload = { | ||
| "ref": self.ref, | ||
| "inputs": { | ||
| "model": init_request.model, | ||
| "metadata": init_request.metadata.model_dump_json(), | ||
| "model_base_url": init_request.model_base_url, | ||
| }, | ||
| } | ||
| r = requests.post(url, json=payload, headers=self._headers(), timeout=30) | ||
| r.raise_for_status() | ||
|
|
||
| await asyncio.to_thread(_dispatch_workflow) | ||
|
|
||
| run = None | ||
| target_name = f"rollout:{row.execution_metadata.rollout_id}" | ||
|
|
||
| # Look for runs created in the last 15 minutes (we just dispatched it) | ||
| cutoff_time = datetime.now(timezone.utc) - timedelta(minutes=15) | ||
| cutoff_iso = cutoff_time.isoformat() | ||
|
|
||
| for attempt in range(self.max_find_workflow_retries): | ||
| try: | ||
| page = 1 | ||
| while page <= max_pages: | ||
|
|
||
| def _list_runs(): | ||
| url = f"https://api.github.com/repos/{self.owner}/{self.repo}/actions/workflows/{self.workflow_id}/runs" | ||
| params = { | ||
| "event": "workflow_dispatch", | ||
| "branch": self.ref, | ||
| "per_page": 100, # Max per_page is 100, minimize total number of pages | ||
| "page": page, | ||
| "created": f">={cutoff_iso}", # Only look at recent runs | ||
| } | ||
|
|
||
| r = requests.get(url, params=params, headers=self._headers(), timeout=30) | ||
| r.raise_for_status() | ||
| return r.json() | ||
|
|
||
| runs_data = await asyncio.to_thread(_list_runs) | ||
|
|
||
| # Search for our target run in this page | ||
| for candidate_run in runs_data.get("workflow_runs", []): | ||
| if candidate_run.get("name") == target_name: | ||
| run = candidate_run | ||
|
|
||
| # If we got fewer results than 100, we've reached the end, since we paginate in chunks of 100 | ||
| if len(runs_data.get("workflow_runs", [])) < 100: | ||
| break | ||
|
|
||
| page += 1 | ||
|
|
||
| # If no run found, GHA might still be populating it, retry | ||
| if attempt < self.max_find_workflow_retries - 1: | ||
| delay = 2**attempt # Exponential backoff | ||
| await asyncio.sleep(delay) | ||
|
|
||
| except requests.exceptions.HTTPError as e: | ||
| # Retry on rate limits (HTTP 429) | ||
| if e.response and e.response.status_code == 429: | ||
| if attempt < self.max_find_workflow_retries - 1: | ||
| delay = 2**attempt # Exponential backoff | ||
| await asyncio.sleep(delay) | ||
| else: | ||
| # Give up after max attempts | ||
| raise e | ||
| else: | ||
| raise e | ||
|
|
||
| if not run: | ||
| row.rollout_status = Status.rollout_error( | ||
| f"Failed to find workflow run in GHA with rollout_id {row.execution_metadata.rollout_id}" | ||
| ) | ||
| row.execution_metadata.duration_seconds = time.perf_counter() - start_time | ||
| return row | ||
|
|
||
| run_id = run.get("id") | ||
| if not run_id: | ||
| row.rollout_status = Status.rollout_error( | ||
| f"Failed to find workflow run in GHA with rollout_id {row.execution_metadata.rollout_id}" | ||
| ) | ||
| row.execution_metadata.duration_seconds = time.perf_counter() - start_time | ||
| return row | ||
|
|
||
| # Poll the specific run until completion | ||
| deadline = time.time() + self.timeout_seconds | ||
|
|
||
| def _get_run() -> Dict[str, Any]: | ||
| """Get status of a specific workflow run.""" | ||
| url = f"https://api.github.com/repos/{self.owner}/{self.repo}/actions/runs/{run_id}" | ||
| r = requests.get(url, headers=self._headers(), timeout=30) | ||
| r.raise_for_status() | ||
| return r.json() | ||
|
|
||
| while time.time() < deadline: | ||
| run_data = await asyncio.to_thread(_get_run) | ||
|
|
||
| if run_data.get("status") == "completed": | ||
| break | ||
|
|
||
| await asyncio.sleep(self.poll_interval) | ||
| else: | ||
| row.rollout_status = Status.rollout_error( | ||
| f"GitHub Actions run timed out after {self.timeout_seconds} seconds" | ||
| ) | ||
| row.execution_metadata.duration_seconds = time.perf_counter() - start_time | ||
| return row | ||
|
|
||
| row.execution_metadata.duration_seconds = time.perf_counter() - start_time | ||
|
|
||
| def _update_with_trace() -> None: | ||
| return update_row_with_remote_trace(row, self._output_data_loader, self.model_base_url) | ||
|
|
||
| await asyncio.to_thread(_update_with_trace) | ||
|
|
||
| # Add GitHub Actions run URL to session data | ||
| if run_id: | ||
| github_run_url = f"https://github.com/{self.owner}/{self.repo}/actions/runs/{run_id}" | ||
| if not row.input_metadata.session_data: | ||
| row.input_metadata.session_data = {} | ||
| row.input_metadata.session_data["github_actions_run_url"] = github_run_url | ||
|
|
||
| return row | ||
|
|
||
| semaphore = config.semaphore | ||
|
|
||
| async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: | ||
| async with semaphore: | ||
| return await _process_row(r) | ||
|
|
||
| return [asyncio.create_task(_sem_wrapper(row)) for row in rows] | ||
|
|
||
| def cleanup(self) -> None: | ||
| return None | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.