Skip to content

Commit 82dcc0e

Browse files
committed
addressed comments
1 parent 003b9f1 commit 82dcc0e

File tree

4 files changed

+19
-15
lines changed

4 files changed

+19
-15
lines changed

eval_protocol/pytest/github_action_rollout_processor.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ class GithubActionRolloutProcessor(RolloutProcessor):
1919
Rollout processor that dispatches and monitors a GitHub Actions workflow per evaluation row.
2020
2121
Expected GitHub Actions workflow:
22-
- Workflow dispatch with inputs: model, messages_b64, tools_b64, rollout_id, etc.
23-
- Workflow uploads artifact named "rollout-trace-{rollout_id}" containing trace JSON
24-
- Trace JSON format: {"status": "success"|"error", "messages": [...], "tools": [...], "error": str?}
22+
- Workflow dispatch with inputs: model, metadata (JSON), model_base_url
23+
- Workflow makes API calls that get traced (e.g., via Fireworks tracing proxy)
24+
- Traces are fetched later via output_data_loader using rollout_id tags
2525
2626
NOTE: GHA has a rate limit of 5000 requests per hour.
2727
"""
@@ -34,9 +34,10 @@ def __init__(
3434
workflow_id: str,
3535
ref: str = "main",
3636
model_base_url: str = "https://tracing.fireworks.ai",
37-
poll_interval: float = 3.0,
37+
poll_interval: float = 10.0,
3838
timeout_seconds: float = 1800.0,
39-
max_retry_attempts: int = 5,
39+
max_find_workflow_retries: int = 5,
40+
github_token: Optional[str] = None,
4041
output_data_loader: Optional[Callable[[DataLoaderConfig], DynamicDataLoader]] = None,
4142
):
4243
self.owner = owner
@@ -49,14 +50,17 @@ def __init__(
4950
self.model_base_url = _ep_model_base_url
5051
self.poll_interval = poll_interval
5152
self.timeout_seconds = timeout_seconds
52-
self.max_retry_attempts = max_retry_attempts
53+
self.max_find_workflow_retries = max_find_workflow_retries
54+
self.github_token = github_token
5355
self._output_data_loader = output_data_loader or default_fireworks_output_data_loader
5456

5557
def _headers(self) -> Dict[str, str]:
5658
headers = {"Accept": "application/vnd.github+json"}
57-
token = os.getenv("GITHUB_TOKEN")
59+
token = self.github_token or os.getenv("GITHUB_TOKEN")
5860
if not token:
59-
raise ValueError("GITHUB_TOKEN environment variable is required")
61+
raise ValueError(
62+
"GitHub token is required. Provide it via github_token parameter or GITHUB_TOKEN environment variable"
63+
)
6064
headers["Authorization"] = f"Bearer {token}"
6165
return headers
6266

@@ -103,7 +107,7 @@ def _dispatch_workflow():
103107
cutoff_time = datetime.now(timezone.utc) - timedelta(minutes=15)
104108
cutoff_iso = cutoff_time.isoformat()
105109

106-
for attempt in range(self.max_retry_attempts):
110+
for attempt in range(self.max_find_workflow_retries):
107111
try:
108112
page = 1
109113
while page <= max_pages:
@@ -113,7 +117,7 @@ def _list_runs():
113117
params = {
114118
"event": "workflow_dispatch",
115119
"branch": self.ref,
116-
"per_page": 100,
120+
"per_page": 100, # Max per_page is 100, minimize total number of pages
117121
"page": page,
118122
"created": f">={cutoff_iso}", # Only look at recent runs
119123
}
@@ -129,21 +133,21 @@ def _list_runs():
129133
if candidate_run.get("name") == target_name:
130134
run = candidate_run
131135

132-
# If we got fewer results than per_page, we've reached the end
136+
# If we got fewer results than 100, we've reached the end, since we paginate in chunks of 100
133137
if len(runs_data.get("workflow_runs", [])) < 100:
134138
break
135139

136140
page += 1
137141

138142
# If no run found, GHA might still be populating it, retry
139-
if attempt < self.max_retry_attempts - 1:
143+
if attempt < self.max_find_workflow_retries - 1:
140144
delay = 2**attempt # Exponential backoff
141145
await asyncio.sleep(delay)
142146

143147
except requests.exceptions.HTTPError as e:
144148
# Retry on rate limits (HTTP 429)
145149
if e.response and e.response.status_code == 429:
146-
if attempt < self.max_retry_attempts - 1:
150+
if attempt < self.max_find_workflow_retries - 1:
147151
delay = 2**attempt # Exponential backoff
148152
await asyncio.sleep(delay)
149153
else:

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
import base64
32
import time
43
from typing import Any, Dict, List, Optional, Callable
54

tests/github_actions/rollout_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
GitHub Actions rollout worker script.
44
55
This script is called by the GitHub Actions workflow to perform the actual rollout.
6-
It makes an OpenAI completion call and saves the full conversation trace as JSON.
6+
It makes an OpenAI completion call that gets automatically traced via the tracing proxy.
77
"""
88

99
import argparse

tests/github_actions/test_github_actions_rollout.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def rows() -> List[EvaluationRow]:
6464
repo="python-sdk",
6565
workflow_id="rollout.yml", # or you can use numeric ID like "12345678"
6666
ref=os.getenv("GITHUB_REF", "main"),
67+
poll_interval=3.0, # For multi-turn, you'll likely want higher poll interval
6768
timeout_seconds=300,
6869
output_data_loader=fireworks_output_data_loader,
6970
),

0 commit comments

Comments
 (0)