Skip to content

Commit 2061e71

Browse files
initial commit
1 parent b3b02c8 commit 2061e71

2 files changed

Lines changed: 58 additions & 33 deletions

File tree

eval_protocol/adapters/fireworks_tracing.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,38 @@ async def async_search_logs(
375375
)
376376
return results
377377

378+
def get_status(self, rollout_id: str) -> Optional[Dict[str, Any]]:
379+
"""Fetch rollout status from the lightweight /status endpoint.
380+
381+
Returns the parsed JSON response or None if the status is not yet available.
382+
Response shape: {"rollout_id": "...", "status": {"code": ...} | null}
383+
"""
384+
from ..common_utils import get_user_agent
385+
386+
headers = {
387+
"Authorization": f"Bearer {self._get_api_key()}",
388+
"User-Agent": get_user_agent(),
389+
}
390+
params: Dict[str, Any] = {"rollout_id": rollout_id}
391+
392+
urls_to_try = [f"{self.base_url}/status", f"{self.base_url}/v1/status"]
393+
last_error: Optional[str] = None
394+
for url in urls_to_try:
395+
try:
396+
response = requests.get(url, params=params, timeout=self.timeout, headers=headers)
397+
if response.status_code == 404:
398+
last_error = f"404 for {url}"
399+
continue
400+
response.raise_for_status()
401+
return response.json()
402+
except requests.exceptions.RequestException as e:
403+
last_error = str(e)
404+
continue
405+
406+
if last_error:
407+
logger.error("Failed to fetch status from Fireworks (tried %s): %s", urls_to_try, last_error)
408+
return None
409+
378410
def get_evaluation_rows(
379411
self,
380412
tags: List[str],

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -121,46 +121,39 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
121121
deadline = time.time() + timeout_seconds
122122

123123
while time.time() < deadline:
124-
session = self._get_or_create_session()
125-
completed_logs = await self._tracing_adapter.async_search_logs(
126-
session, tags=[f"rollout_id:{row.execution_metadata.rollout_id}"]
124+
# Poll status (run in thread to avoid blocking event loop)
125+
status_result = await asyncio.to_thread(
126+
self._tracing_adapter.get_status, rollout_id=row.execution_metadata.rollout_id
127127
)
128-
# Filter for logs that actually have status information
129-
status_logs = []
130-
for log in completed_logs:
131-
status_dict = log.get("status")
132-
if status_dict and isinstance(status_dict, dict) and "code" in status_dict:
133-
status_logs.append(log)
134-
135-
if status_logs:
136-
if len(status_logs) > 1:
137-
logger.warning(
138-
"Found %s status logs for rollout %s; expected at most 1. Using the first one: %s",
139-
len(status_logs),
140-
row.execution_metadata.rollout_id,
141-
status_logs[0],
142-
)
143-
# Use the first log with status information
144-
status_log = status_logs[0]
145-
status_dict = status_log.get("status")
146-
raw_extras = status_log.get("extras") or {}
147-
status_extras = {
148-
k: v for k, v in raw_extras.items() if k not in ("logger_name", "level", "timestamp")
149-
}
128+
if status_result and status_result.get("status"):
129+
status_code = status_result["status"]["code"]
150130

151131
logger.info(
152-
f"Found status log for rollout {row.execution_metadata.rollout_id}: {status_log.get('message', '')}"
132+
"Found status for rollout %s with code %s",
133+
row.execution_metadata.rollout_id,
134+
status_code,
153135
)
154136

155-
status_code = status_dict.get("code")
156-
status_message = status_dict.get("message", "")
157-
status_details = status_dict.get("details", [])
158-
159-
logger.info(
160-
f"Found Fireworks log for rollout {row.execution_metadata.rollout_id} with status code {status_code}"
137+
# Backfill message/details/extras from the full Logs table (one-shot)
138+
completed_logs = await asyncio.to_thread(
139+
self._tracing_adapter.search_logs,
140+
tags=[f"rollout_id:{row.execution_metadata.rollout_id}"],
161141
)
142+
status_message = ""
143+
status_details: list = []
144+
status_extras: dict = {}
145+
for log in completed_logs:
146+
sd = log.get("status")
147+
if sd and isinstance(sd, dict) and "code" in sd:
148+
status_message = sd.get("message", "")
149+
status_details = sd.get("details", [])
150+
raw_extras = log.get("extras") or {}
151+
status_extras = {
152+
k: v for k, v in raw_extras.items()
153+
if k not in ("logger_name", "level", "timestamp")
154+
}
155+
break
162156

163-
# Create and raise exception if appropriate, preserving original message
164157
exception = exception_for_status_code(status_code, status_message)
165158
if exception is not None:
166159
raise exception

0 commit comments

Comments
 (0)