Skip to content

Commit 20b0f23

Browse files
committed
fix: poll rollout status asynchronously
Use the lightweight status endpoint from RemoteRolloutProcessor via the shared aiohttp session and avoid the logs backfill after terminal status is observed. Made-with: Cursor
1 parent a17c9b8 commit 20b0f23

2 files changed

Lines changed: 23 additions & 42 deletions

File tree

eval_protocol/adapters/fireworks_tracing.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -375,31 +375,30 @@ async def async_search_logs(
375375
)
376376
return results
377377

378-
def get_status(self, rollout_id: str) -> Optional[Dict[str, Any]]:
378+
async def async_get_status(self, session: aiohttp.ClientSession, rollout_id: str) -> Optional[Dict[str, Any]]:
379379
"""Fetch rollout status from the lightweight /status endpoint.
380380
381381
Returns the parsed JSON response or None if the status is not yet available.
382382
Response shape: {"rollout_id": "...", "status": {"code": ...} | null}
383383
"""
384-
from ..common_utils import get_user_agent
385-
386384
headers = {
387385
"Authorization": f"Bearer {self._get_api_key()}",
388386
"User-Agent": get_user_agent(),
389387
}
390388
params: Dict[str, Any] = {"rollout_id": rollout_id}
389+
timeout = aiohttp.ClientTimeout(total=self.timeout)
391390

392-
urls_to_try = [f"{self.base_url}/status", f"{self.base_url}/v1/status"]
391+
urls_to_try = [f"{self.base_url}/v1/status", f"{self.base_url}/status"]
393392
last_error: Optional[str] = None
394393
for url in urls_to_try:
395394
try:
396-
response = requests.get(url, params=params, timeout=self.timeout, headers=headers)
397-
if response.status_code == 404:
398-
last_error = f"404 for {url}"
399-
continue
400-
response.raise_for_status()
401-
return response.json()
402-
except requests.exceptions.RequestException as e:
395+
async with session.get(url, params=params, headers=headers, timeout=timeout) as resp:
396+
if resp.status == 404:
397+
last_error = f"404 for {url}"
398+
continue
399+
resp.raise_for_status()
400+
return (await resp.json(content_type=None)) or {}
401+
except (aiohttp.ClientError, asyncio.TimeoutError, json.JSONDecodeError) as e:
403402
last_error = str(e)
404403
continue
405404

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -121,43 +121,23 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
121121
deadline = time.time() + timeout_seconds
122122

123123
while time.time() < deadline:
124-
# Poll status (run in thread to avoid blocking event loop)
125-
status_result = await asyncio.to_thread(
126-
self._tracing_adapter.get_status, rollout_id=row.execution_metadata.rollout_id
124+
session = self._get_or_create_session()
125+
status_result = await self._tracing_adapter.async_get_status(
126+
session,
127+
rollout_id=row.execution_metadata.rollout_id,
127128
)
128129
status = (status_result or {}).get("status")
129-
if status and "code" in status:
130+
if isinstance(status, dict) and "code" in status:
130131
status_code = status["code"]
131132

132-
if status_code == Status.Code.RUNNING:
133-
await asyncio.sleep(poll_interval)
134-
continue
135-
136133
logger.info(
137134
"Found status for rollout %s with code %s",
138135
row.execution_metadata.rollout_id,
139136
status_code,
140137
)
141138

142-
# Backfill message/details/extras from the full Logs table (one-shot)
143-
session = self._get_or_create_session()
144-
completed_logs = await self._tracing_adapter.async_search_logs(
145-
session, tags=[f"rollout_id:{row.execution_metadata.rollout_id}"],
146-
)
147-
status_message = ""
148-
status_details: list = []
149-
status_extras: dict = {}
150-
for log in completed_logs:
151-
sd = log.get("status")
152-
if sd and isinstance(sd, dict) and "code" in sd:
153-
status_message = sd.get("message", "")
154-
status_details = sd.get("details", [])
155-
raw_extras = log.get("extras") or {}
156-
status_extras = {
157-
k: v for k, v in raw_extras.items()
158-
if k not in ("logger_name", "level", "timestamp")
159-
}
160-
break
139+
status_message = status.get("message", "") or ""
140+
status_details = status.get("details", []) or []
161141

162142
exception = exception_for_status_code(status_code, status_message)
163143
if exception is not None:
@@ -169,10 +149,12 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
169149
details=status_details,
170150
)
171151

172-
if row.execution_metadata.extra:
173-
row.execution_metadata.extra.update(status_extras)
174-
else:
175-
row.execution_metadata.extra = status_extras
152+
status_extras = status.get("extras")
153+
if isinstance(status_extras, dict):
154+
if row.execution_metadata.extra:
155+
row.execution_metadata.extra.update(status_extras)
156+
else:
157+
row.execution_metadata.extra = status_extras
176158

177159
logger.info("Stopping polling for rollout %s", row.execution_metadata.rollout_id)
178160
break

0 commit comments

Comments
 (0)