@@ -19,9 +19,9 @@ class GithubActionRolloutProcessor(RolloutProcessor):
1919 Rollout processor that dispatches and monitors a GitHub Actions workflow per evaluation row.
2020
2121 Expected GitHub Actions workflow:
22- - Workflow dispatch with inputs: model, messages_b64, tools_b64, rollout_id, etc.
23- - Workflow uploads artifact named "rollout-trace-{rollout_id}" containing trace JSON
24- - Trace JSON format: {"status": "success"|"error", "messages": [...], "tools": [...], "error": str?}
22+ - Workflow dispatch with inputs: model, metadata (JSON), model_base_url
23+ - Workflow makes API calls that get traced (e.g., via Fireworks tracing proxy)
24+ - Traces are fetched later via output_data_loader using rollout_id tags
2525
2626 NOTE: GHA has a rate limit of 5000 requests per hour.
2727 """
@@ -34,9 +34,10 @@ def __init__(
3434 workflow_id : str ,
3535 ref : str = "main" ,
3636 model_base_url : str = "https://tracing.fireworks.ai" ,
37- poll_interval : float = 3 .0 ,
37+ poll_interval : float = 10 .0 ,
3838 timeout_seconds : float = 1800.0 ,
39- max_retry_attempts : int = 5 ,
39+ max_find_workflow_retries : int = 5 ,
40+ github_token : Optional [str ] = None ,
4041 output_data_loader : Optional [Callable [[DataLoaderConfig ], DynamicDataLoader ]] = None ,
4142 ):
4243 self .owner = owner
@@ -49,14 +50,17 @@ def __init__(
4950 self .model_base_url = _ep_model_base_url
5051 self .poll_interval = poll_interval
5152 self .timeout_seconds = timeout_seconds
52- self .max_retry_attempts = max_retry_attempts
53+ self .max_find_workflow_retries = max_find_workflow_retries
54+ self .github_token = github_token
5355 self ._output_data_loader = output_data_loader or default_fireworks_output_data_loader
5456
5557 def _headers (self ) -> Dict [str , str ]:
5658 headers = {"Accept" : "application/vnd.github+json" }
57- token = os .getenv ("GITHUB_TOKEN" )
59+ token = self . github_token or os .getenv ("GITHUB_TOKEN" )
5860 if not token :
59- raise ValueError ("GITHUB_TOKEN environment variable is required" )
61+ raise ValueError (
62+ "GitHub token is required. Provide it via github_token parameter or GITHUB_TOKEN environment variable"
63+ )
6064 headers ["Authorization" ] = f"Bearer { token } "
6165 return headers
6266
@@ -103,7 +107,7 @@ def _dispatch_workflow():
103107 cutoff_time = datetime .now (timezone .utc ) - timedelta (minutes = 15 )
104108 cutoff_iso = cutoff_time .isoformat ()
105109
106- for attempt in range (self .max_retry_attempts ):
110+ for attempt in range (self .max_find_workflow_retries ):
107111 try :
108112 page = 1
109113 while page <= max_pages :
@@ -113,7 +117,7 @@ def _list_runs():
113117 params = {
114118 "event" : "workflow_dispatch" ,
115119 "branch" : self .ref ,
116- "per_page" : 100 ,
120+ "per_page" : 100 , # Max per_page is 100, minimize total number of pages
117121 "page" : page ,
118122 "created" : f">={ cutoff_iso } " , # Only look at recent runs
119123 }
@@ -129,21 +133,21 @@ def _list_runs():
129133 if candidate_run .get ("name" ) == target_name :
130134 run = candidate_run
131135
132- # If we got fewer results than per_page , we've reached the end
136+ # If we got fewer results than 100 , we've reached the end, since we paginate in chunks of 100
133137 if len (runs_data .get ("workflow_runs" , [])) < 100 :
134138 break
135139
136140 page += 1
137141
138142 # If no run found, GHA might still be populating it, retry
139- if attempt < self .max_retry_attempts - 1 :
143+ if attempt < self .max_find_workflow_retries - 1 :
140144 delay = 2 ** attempt # Exponential backoff
141145 await asyncio .sleep (delay )
142146
143147 except requests .exceptions .HTTPError as e :
144148 # Retry on rate limits (HTTP 429)
145149 if e .response and e .response .status_code == 429 :
146- if attempt < self .max_retry_attempts - 1 :
150+ if attempt < self .max_find_workflow_retries - 1 :
147151 delay = 2 ** attempt # Exponential backoff
148152 await asyncio .sleep (delay )
149153 else :
0 commit comments