44from typing import Any , Callable , Dict , List , Optional
55
66import requests
7-
7+ from datetime import datetime , timezone , timedelta
88from eval_protocol .models import EvaluationRow , Status
99from eval_protocol .data_loader .dynamic_data_loader import DynamicDataLoader
1010from eval_protocol .types .remote_rollout_processor import DataLoaderConfig
@@ -22,6 +22,8 @@ class GithubActionRolloutProcessor(RolloutProcessor):
2222 - Workflow dispatch with inputs: model, messages_b64, tools_b64, rollout_id, etc.
2323 - Workflow uploads artifact named "rollout-trace-{rollout_id}" containing trace JSON
2424 - Trace JSON format: {"status": "success"|"error", "messages": [...], "tools": [...], "error": str?}
25+
26+ NOTE: GHA has a rate limit of 5000 requests per hour.
2527 """
2628
2729 def __init__ (
@@ -34,6 +36,7 @@ def __init__(
3436 model_base_url : str = "https://tracing.fireworks.ai" ,
3537 poll_interval : float = 3.0 ,
3638 timeout_seconds : float = 1800.0 ,
39+ max_retry_attempts : int = 5 ,
3740 output_data_loader : Optional [Callable [[DataLoaderConfig ], DynamicDataLoader ]] = None ,
3841 ):
3942 self .owner = owner
@@ -46,6 +49,7 @@ def __init__(
4649 self .model_base_url = _ep_model_base_url
4750 self .poll_interval = poll_interval
4851 self .timeout_seconds = timeout_seconds
52+ self .max_retry_attempts = max_retry_attempts
4953 self ._output_data_loader = output_data_loader or default_fireworks_output_data_loader
5054
5155 def _headers (self ) -> Dict [str , str ]:
@@ -57,6 +61,10 @@ def _headers(self) -> Dict[str, str]:
5761 return headers
5862
5963 def __call__ (self , rows : List [EvaluationRow ], config : RolloutProcessorConfig ) -> List [asyncio .Task [EvaluationRow ]]:
64+ # Calculate max_pages based on number of rows we're processing
65+ num_rows = len (rows )
66+ max_pages = (num_rows + 99 ) // 100 # Round up pages
67+
6068 async def _process_row (row : EvaluationRow ) -> EvaluationRow :
6169 start_time = time .perf_counter ()
6270
@@ -88,29 +96,70 @@ def _dispatch_workflow():
8896
8997 await asyncio .to_thread (_dispatch_workflow )
9098
91- # Need to wait a bit for GitHub to create the run. Is this problematic when we have a lot of workflows to start?
92- await asyncio . sleep ( 5 )
99+ run = None
100+ target_name = f"rollout: { row . execution_metadata . rollout_id } "
93101
94- def _get_workflow_runs () -> Dict [str , Any ]:
95- """Get recent workflow runs for this workflow."""
96- url = (
97- f"https://api.github.com/repos/{ self .owner } /{ self .repo } /actions/workflows/{ self .workflow_id } /runs"
102+ # Look for runs created in the last 15 minutes (we just dispatched it)
103+ cutoff_time = datetime .now (timezone .utc ) - timedelta (minutes = 15 )
104+ cutoff_iso = cutoff_time .isoformat ()
105+
106+ for attempt in range (self .max_retry_attempts ):
107+ try :
108+ page = 1
109+ while page <= max_pages :
110+
111+ def _list_runs ():
112+ url = f"https://api.github.com/repos/{ self .owner } /{ self .repo } /actions/workflows/{ self .workflow_id } /runs"
113+ params = {
114+ "event" : "workflow_dispatch" ,
115+ "branch" : self .ref ,
116+ "per_page" : 100 ,
117+ "page" : page ,
118+ "created" : f">={ cutoff_iso } " , # Only look at recent runs
119+ }
120+
121+ r = requests .get (url , params = params , headers = self ._headers (), timeout = 30 )
122+ r .raise_for_status ()
123+ return r .json ()
124+
125+ runs_data = await asyncio .to_thread (_list_runs )
126+
127+ # Search for our target run in this page
128+ for candidate_run in runs_data .get ("workflow_runs" , []):
129+ if candidate_run .get ("name" ) == target_name :
130+ run = candidate_run
131+
132+ # If we got fewer results than per_page, we've reached the end
133+ if len (runs_data .get ("workflow_runs" , [])) < 100 :
134+ break
135+
136+ page += 1
137+
138+ # If no run found, GHA might still be populating it, retry
139+ if attempt < self .max_retry_attempts - 1 :
140+ delay = 2 ** attempt # Exponential backoff
141+ await asyncio .sleep (delay )
142+
143+ except requests .exceptions .HTTPError as e :
144+ # Retry on rate limits (HTTP 429)
145+ if e .response and e .response .status_code == 429 :
146+ if attempt < self .max_retry_attempts - 1 :
147+ delay = 2 ** attempt # Exponential backoff
148+ await asyncio .sleep (delay )
149+ else :
150+ # Give up after max attempts
151+ raise e
152+ else :
153+ raise e
154+
155+ if not run :
156+ row .rollout_status = Status .rollout_error (
157+ f"Failed to find workflow run in GHA with rollout_id { row .execution_metadata .rollout_id } "
98158 )
99- params = {"event" : "workflow_dispatch" , "branch" : self .ref , "per_page" : 20 }
100- r = requests .get (url , params = params , headers = self ._headers (), timeout = 30 )
101- r .raise_for_status ()
102- return r .json ()
103-
104- runs_data = await asyncio .to_thread (_get_workflow_runs )
105-
106- # Find our specific run by name
107- target_name = f"rollout:{ row .execution_metadata .rollout_id } "
108- run_id = None
109- for run in runs_data .get ("workflow_runs" , []):
110- if run .get ("name" ) == target_name :
111- run_id = run .get ("id" )
112- break
159+ row .execution_metadata .duration_seconds = time .perf_counter () - start_time
160+ return row
113161
162+ run_id = run .get ("id" )
114163 if not run_id :
115164 row .rollout_status = Status .rollout_error (
116165 f"Failed to find workflow run in GHA with rollout_id { row .execution_metadata .rollout_id } "
0 commit comments