Skip to content

Commit 86a52a4

Browse files
feat: two-phase status polling via lightweight /status endpoint (#446)
* initial commit * quick fix * Rebase * fix: poll rollout status asynchronously Use the lightweight status endpoint from RemoteRolloutProcessor via the shared aiohttp session and avoid the logs backfill after terminal status is observed. Made-with: Cursor * fix: keep polling running rollout statuses Continue polling when the lightweight status endpoint returns RUNNING so the remote rollout processor only exits the poll loop for terminal statuses. Made-with: Cursor * fix: preserve rollout status extras Read rollout status extras from the top-level status response so RemoteRolloutProcessor preserves metadata that previously came from log entries. Made-with: Cursor --------- Co-authored-by: Derek Xu <xzrderek@gmail.com>
1 parent 6b9bea9 commit 86a52a4

2 files changed

Lines changed: 51 additions & 37 deletions

File tree

eval_protocol/adapters/fireworks_tracing.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,37 @@ async def async_search_logs(
375375
)
376376
return results
377377

378+
async def async_get_status(self, session: aiohttp.ClientSession, rollout_id: str) -> Optional[Dict[str, Any]]:
379+
"""Fetch rollout status from the lightweight /status endpoint.
380+
381+
Returns the parsed JSON response or None if the status is not yet available.
382+
Response shape: {"rollout_id": "...", "status": {"code": ...} | null, "extras": {...} | null}
383+
"""
384+
headers = {
385+
"Authorization": f"Bearer {self._get_api_key()}",
386+
"User-Agent": get_user_agent(),
387+
}
388+
params: Dict[str, Any] = {"rollout_id": rollout_id}
389+
timeout = aiohttp.ClientTimeout(total=self.timeout)
390+
391+
urls_to_try = [f"{self.base_url}/v1/status", f"{self.base_url}/status"]
392+
last_error: Optional[str] = None
393+
for url in urls_to_try:
394+
try:
395+
async with session.get(url, params=params, headers=headers, timeout=timeout) as resp:
396+
if resp.status == 404:
397+
last_error = f"404 for {url}"
398+
continue
399+
resp.raise_for_status()
400+
return (await resp.json(content_type=None)) or {}
401+
except (aiohttp.ClientError, asyncio.TimeoutError, json.JSONDecodeError) as e:
402+
last_error = str(e)
403+
continue
404+
405+
if last_error:
406+
logger.error("Failed to fetch status from Fireworks (tried %s): %s", urls_to_try, last_error)
407+
return None
408+
378409
def get_evaluation_rows(
379410
self,
380411
tags: List[str],

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 20 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -122,45 +122,26 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
122122

123123
while time.time() < deadline:
124124
session = self._get_or_create_session()
125-
completed_logs = await self._tracing_adapter.async_search_logs(
126-
session, tags=[f"rollout_id:{row.execution_metadata.rollout_id}"]
125+
status_result = await self._tracing_adapter.async_get_status(
126+
session,
127+
rollout_id=row.execution_metadata.rollout_id,
127128
)
128-
# Filter for logs that actually have status information
129-
status_logs = []
130-
for log in completed_logs:
131-
status_dict = log.get("status")
132-
if status_dict and isinstance(status_dict, dict) and "code" in status_dict:
133-
status_logs.append(log)
134-
135-
if status_logs:
136-
if len(status_logs) > 1:
137-
logger.warning(
138-
"Found %s status logs for rollout %s; expected at most 1. Using the first one: %s",
139-
len(status_logs),
140-
row.execution_metadata.rollout_id,
141-
status_logs[0],
142-
)
143-
# Use the first log with status information
144-
status_log = status_logs[0]
145-
status_dict = status_log.get("status")
146-
raw_extras = status_log.get("extras") or {}
147-
status_extras = {
148-
k: v for k, v in raw_extras.items() if k not in ("logger_name", "level", "timestamp")
149-
}
129+
status = (status_result or {}).get("status")
130+
if isinstance(status, dict) and "code" in status:
131+
status_code = status["code"]
132+
if status_code == Status.Code.RUNNING:
133+
await asyncio.sleep(poll_interval)
134+
continue
150135

151136
logger.info(
152-
f"Found status log for rollout {row.execution_metadata.rollout_id}: {status_log.get('message', '')}"
137+
"Found status for rollout %s with code %s",
138+
row.execution_metadata.rollout_id,
139+
status_code,
153140
)
154141

155-
status_code = status_dict.get("code")
156-
status_message = status_dict.get("message", "")
157-
status_details = status_dict.get("details", [])
142+
status_message = status.get("message", "") or ""
143+
status_details = status.get("details", []) or []
158144

159-
logger.info(
160-
f"Found Fireworks log for rollout {row.execution_metadata.rollout_id} with status code {status_code}"
161-
)
162-
163-
# Create and raise exception if appropriate, preserving original message
164145
exception = exception_for_status_code(status_code, status_message)
165146
if exception is not None:
166147
raise exception
@@ -171,10 +152,12 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
171152
details=status_details,
172153
)
173154

174-
if row.execution_metadata.extra:
175-
row.execution_metadata.extra.update(status_extras)
176-
else:
177-
row.execution_metadata.extra = status_extras
155+
status_extras = (status_result or {}).get("extras")
156+
if isinstance(status_extras, dict):
157+
if row.execution_metadata.extra:
158+
row.execution_metadata.extra.update(status_extras)
159+
else:
160+
row.execution_metadata.extra = status_extras
178161

179162
logger.info("Stopping polling for rollout %s", row.execution_metadata.rollout_id)
180163
break

0 commit comments

Comments
 (0)