Skip to content

Commit 003b9f1

Browse files
committed
add rate limit retry logic
1 parent aee8fdc commit 003b9f1

File tree

2 files changed

+70
-23
lines changed

2 files changed

+70
-23
lines changed

eval_protocol/pytest/github_action_rollout_processor.py

Lines changed: 70 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Any, Callable, Dict, List, Optional
55

66
import requests
7-
7+
from datetime import datetime, timezone, timedelta
88
from eval_protocol.models import EvaluationRow, Status
99
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
1010
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig
@@ -22,6 +22,8 @@ class GithubActionRolloutProcessor(RolloutProcessor):
2222
- Workflow dispatch with inputs: model, messages_b64, tools_b64, rollout_id, etc.
2323
- Workflow uploads artifact named "rollout-trace-{rollout_id}" containing trace JSON
2424
- Trace JSON format: {"status": "success"|"error", "messages": [...], "tools": [...], "error": str?}
25+
26+
NOTE: GHA has a rate limit of 5000 requests per hour.
2527
"""
2628

2729
def __init__(
@@ -34,6 +36,7 @@ def __init__(
3436
model_base_url: str = "https://tracing.fireworks.ai",
3537
poll_interval: float = 3.0,
3638
timeout_seconds: float = 1800.0,
39+
max_retry_attempts: int = 5,
3740
output_data_loader: Optional[Callable[[DataLoaderConfig], DynamicDataLoader]] = None,
3841
):
3942
self.owner = owner
@@ -46,6 +49,7 @@ def __init__(
4649
self.model_base_url = _ep_model_base_url
4750
self.poll_interval = poll_interval
4851
self.timeout_seconds = timeout_seconds
52+
self.max_retry_attempts = max_retry_attempts
4953
self._output_data_loader = output_data_loader or default_fireworks_output_data_loader
5054

5155
def _headers(self) -> Dict[str, str]:
@@ -57,6 +61,10 @@ def _headers(self) -> Dict[str, str]:
5761
return headers
5862

5963
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
64+
# Calculate max_pages based on number of rows we're processing
65+
num_rows = len(rows)
66+
max_pages = (num_rows + 99) // 100 # Round up pages
67+
6068
async def _process_row(row: EvaluationRow) -> EvaluationRow:
6169
start_time = time.perf_counter()
6270

@@ -88,29 +96,70 @@ def _dispatch_workflow():
8896

8997
await asyncio.to_thread(_dispatch_workflow)
9098

91-
# Need to wait a bit for GitHub to create the run. Is this problematic when we have a lot of workflows to start?
92-
await asyncio.sleep(5)
99+
run = None
100+
target_name = f"rollout:{row.execution_metadata.rollout_id}"
93101

94-
def _get_workflow_runs() -> Dict[str, Any]:
95-
"""Get recent workflow runs for this workflow."""
96-
url = (
97-
f"https://api.github.com/repos/{self.owner}/{self.repo}/actions/workflows/{self.workflow_id}/runs"
102+
# Look for runs created in the last 15 minutes (we just dispatched it)
103+
cutoff_time = datetime.now(timezone.utc) - timedelta(minutes=15)
104+
cutoff_iso = cutoff_time.isoformat()
105+
106+
for attempt in range(self.max_retry_attempts):
107+
try:
108+
page = 1
109+
while page <= max_pages:
110+
111+
def _list_runs():
112+
url = f"https://api.github.com/repos/{self.owner}/{self.repo}/actions/workflows/{self.workflow_id}/runs"
113+
params = {
114+
"event": "workflow_dispatch",
115+
"branch": self.ref,
116+
"per_page": 100,
117+
"page": page,
118+
"created": f">={cutoff_iso}", # Only look at recent runs
119+
}
120+
121+
r = requests.get(url, params=params, headers=self._headers(), timeout=30)
122+
r.raise_for_status()
123+
return r.json()
124+
125+
runs_data = await asyncio.to_thread(_list_runs)
126+
127+
# Search for our target run in this page
128+
for candidate_run in runs_data.get("workflow_runs", []):
129+
if candidate_run.get("name") == target_name:
130+
run = candidate_run
131+
132+
# If we got fewer results than per_page, we've reached the end
133+
if len(runs_data.get("workflow_runs", [])) < 100:
134+
break
135+
136+
page += 1
137+
138+
# If no run found, GHA might still be populating it, retry
139+
if attempt < self.max_retry_attempts - 1:
140+
delay = 2**attempt # Exponential backoff
141+
await asyncio.sleep(delay)
142+
143+
except requests.exceptions.HTTPError as e:
144+
# Retry on rate limits (HTTP 429)
145+
if e.response and e.response.status_code == 429:
146+
if attempt < self.max_retry_attempts - 1:
147+
delay = 2**attempt # Exponential backoff
148+
await asyncio.sleep(delay)
149+
else:
150+
# Give up after max attempts
151+
raise e
152+
else:
153+
raise e
154+
155+
if not run:
156+
row.rollout_status = Status.rollout_error(
157+
f"Failed to find workflow run in GHA with rollout_id {row.execution_metadata.rollout_id}"
98158
)
99-
params = {"event": "workflow_dispatch", "branch": self.ref, "per_page": 20}
100-
r = requests.get(url, params=params, headers=self._headers(), timeout=30)
101-
r.raise_for_status()
102-
return r.json()
103-
104-
runs_data = await asyncio.to_thread(_get_workflow_runs)
105-
106-
# Find our specific run by name
107-
target_name = f"rollout:{row.execution_metadata.rollout_id}"
108-
run_id = None
109-
for run in runs_data.get("workflow_runs", []):
110-
if run.get("name") == target_name:
111-
run_id = run.get("id")
112-
break
159+
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
160+
return row
113161

162+
run_id = run.get("id")
114163
if not run_id:
115164
row.rollout_status = Status.rollout_error(
116165
f"Failed to find workflow run in GHA with rollout_id {row.execution_metadata.rollout_id}"

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
from eval_protocol.types.remote_rollout_processor import (
1212
DataLoaderConfig,
1313
ElasticsearchConfig,
14-
InitRequest,
15-
RolloutMetadata,
1614
)
1715
from .rollout_processor import RolloutProcessor
1816
from .types import RolloutProcessorConfig

0 commit comments

Comments
 (0)