Skip to content

Commit d162501

Browse files
initial commit
1 parent f77b26f commit d162501

2 files changed

Lines changed: 58 additions & 33 deletions

File tree

eval_protocol/adapters/fireworks_tracing.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,38 @@ def search_logs(self, tags: List[str], limit: int = 100, hours_back: int = 24) -
327327
)
328328
return results
329329

330+
def get_status(self, rollout_id: str) -> Optional[Dict[str, Any]]:
331+
"""Fetch rollout status from the lightweight /status endpoint.
332+
333+
Returns the parsed JSON response or None if the status is not yet available.
334+
Response shape: {"rollout_id": "...", "status": {"code": ...} | null}
335+
"""
336+
from ..common_utils import get_user_agent
337+
338+
headers = {
339+
"Authorization": f"Bearer {self._get_api_key()}",
340+
"User-Agent": get_user_agent(),
341+
}
342+
params: Dict[str, Any] = {"rollout_id": rollout_id}
343+
344+
urls_to_try = [f"{self.base_url}/status", f"{self.base_url}/v1/status"]
345+
last_error: Optional[str] = None
346+
for url in urls_to_try:
347+
try:
348+
response = requests.get(url, params=params, timeout=self.timeout, headers=headers)
349+
if response.status_code == 404:
350+
last_error = f"404 for {url}"
351+
continue
352+
response.raise_for_status()
353+
return response.json()
354+
except requests.exceptions.RequestException as e:
355+
last_error = str(e)
356+
continue
357+
358+
if last_error:
359+
logger.error("Failed to fetch status from Fireworks (tried %s): %s", urls_to_try, last_error)
360+
return None
361+
330362
def get_evaluation_rows(
331363
self,
332364
tags: List[str],

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -120,46 +120,39 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
120120
deadline = time.time() + timeout_seconds
121121

122122
while time.time() < deadline:
123-
# Search Fireworks tracing logs for completion (run in thread to avoid blocking event loop)
124-
completed_logs = await asyncio.to_thread(
125-
self._tracing_adapter.search_logs, tags=[f"rollout_id:{row.execution_metadata.rollout_id}"]
123+
# Poll status (run in thread to avoid blocking event loop)
124+
status_result = await asyncio.to_thread(
125+
self._tracing_adapter.get_status, rollout_id=row.execution_metadata.rollout_id
126126
)
127-
# Filter for logs that actually have status information
128-
status_logs = []
129-
for log in completed_logs:
130-
status_dict = log.get("status")
131-
if status_dict and isinstance(status_dict, dict) and "code" in status_dict:
132-
status_logs.append(log)
133-
134-
if status_logs:
135-
if len(status_logs) > 1:
136-
logger.warning(
137-
"Found %s status logs for rollout %s; expected at most 1. Using the first one: %s",
138-
len(status_logs),
139-
row.execution_metadata.rollout_id,
140-
status_logs[0],
141-
)
142-
# Use the first log with status information
143-
status_log = status_logs[0]
144-
status_dict = status_log.get("status")
145-
raw_extras = status_log.get("extras") or {}
146-
status_extras = {
147-
k: v for k, v in raw_extras.items() if k not in ("logger_name", "level", "timestamp")
148-
}
127+
if status_result and status_result.get("status"):
128+
status_code = status_result["status"]["code"]
149129

150130
logger.info(
151-
f"Found status log for rollout {row.execution_metadata.rollout_id}: {status_log.get('message', '')}"
131+
"Found status for rollout %s with code %s",
132+
row.execution_metadata.rollout_id,
133+
status_code,
152134
)
153135

154-
status_code = status_dict.get("code")
155-
status_message = status_dict.get("message", "")
156-
status_details = status_dict.get("details", [])
157-
158-
logger.info(
159-
f"Found Fireworks log for rollout {row.execution_metadata.rollout_id} with status code {status_code}"
136+
# Backfill message/details/extras from the full Logs table (one-shot)
137+
completed_logs = await asyncio.to_thread(
138+
self._tracing_adapter.search_logs,
139+
tags=[f"rollout_id:{row.execution_metadata.rollout_id}"],
160140
)
141+
status_message = ""
142+
status_details: list = []
143+
status_extras: dict = {}
144+
for log in completed_logs:
145+
sd = log.get("status")
146+
if sd and isinstance(sd, dict) and "code" in sd:
147+
status_message = sd.get("message", "")
148+
status_details = sd.get("details", [])
149+
raw_extras = log.get("extras") or {}
150+
status_extras = {
151+
k: v for k, v in raw_extras.items()
152+
if k not in ("logger_name", "level", "timestamp")
153+
}
154+
break
161155

162-
# Create and raise exception if appropriate, preserving original message
163156
exception = exception_for_status_code(status_code, status_message)
164157
if exception is not None:
165158
raise exception

0 commit comments

Comments
 (0)