Skip to content

Commit d6644e3

Browse files
committed
use aiohttp in remote rollout processor to fix the open files issue
1 parent 77842b5 commit d6644e3

File tree

1 file changed

+26
-41
lines changed

1 file changed

+26
-41
lines changed

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 26 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
import asyncio
22
import time
3-
from typing import Any, Dict, List, Optional
3+
from typing import List, Optional
44

5-
import requests
5+
import aiohttp
66

77
from eval_protocol.models import EvaluationRow, Status
8-
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
9-
from eval_protocol.types.remote_rollout_processor import (
10-
DataLoaderConfig,
11-
)
128
from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter
139
from eval_protocol.exceptions import exception_for_status_code
1410

@@ -88,48 +84,24 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
8884
init_payload = build_init_request(row, config, model_base_url)
8985

9086
# Fire-and-poll
91-
def _post_init() -> None:
92-
url = f"{remote_base_url}/init"
87+
init_url = f"{remote_base_url}/init"
88+
89+
timeout_init = aiohttp.ClientTimeout(total=300)
90+
91+
async with aiohttp.ClientSession() as session:
9392
try:
94-
r = requests.post(url, json=init_payload.model_dump(), timeout=300)
95-
r.raise_for_status()
96-
except requests.exceptions.Timeout:
93+
async with session.post(init_url, json=init_payload.model_dump(), timeout=timeout_init) as resp:
94+
if resp.status >= 400:
95+
body = await resp.text()
96+
raise RuntimeError(f"Remote /init failed (HTTP {resp.status}): {body}")
97+
except asyncio.TimeoutError:
9798
raise TimeoutError(
98-
f"The /init endpoint tried {url} with {init_payload.model_dump()} but timed out after 300 seconds."
99+
f"The /init endpoint tried {init_url} with {init_payload.model_dump()} but timed out after 300 seconds."
99100
)
100101

101-
await asyncio.to_thread(_post_init)
102-
103-
terminated = False
104102
deadline = time.time() + timeout_seconds
105103

106-
def _get_status() -> Dict[str, Any]:
107-
url = f"{remote_base_url}/status"
108-
r = requests.get(url, params={"rollout_id": row.execution_metadata.rollout_id}, timeout=15)
109-
r.raise_for_status()
110-
return r.json()
111-
112-
continue_polling_status = True
113104
while time.time() < deadline:
114-
try:
115-
if continue_polling_status:
116-
status = await asyncio.to_thread(_get_status)
117-
terminated = bool(status.get("terminated", False))
118-
if terminated:
119-
break
120-
except requests.exceptions.HTTPError as e:
121-
if e.response is not None and e.response.status_code == 404:
122-
# 404 means server doesn't implement /status endpoint, stop polling
123-
logger.debug(
124-
f"Server doesn't implement /status endpoint (404), stopping status polling for rollout {row.execution_metadata.rollout_id}"
125-
)
126-
continue_polling_status = False
127-
else:
128-
raise
129-
except Exception:
130-
# For all other exceptions, raise them
131-
raise
132-
133105
# Search Fireworks tracing logs for completion (run in thread to avoid blocking event loop)
134106
completed_logs = await asyncio.to_thread(
135107
self._tracing_adapter.search_logs, tags=[f"rollout_id:{row.execution_metadata.rollout_id}"]
@@ -142,6 +114,17 @@ def _get_status() -> Dict[str, Any]:
142114
status_logs.append(log)
143115

144116
if status_logs:
117+
# finished_logs = []
118+
# for log in status_logs:
119+
# sd = log.get("status") or {}
120+
# if isinstance(sd, dict) and sd.get("code") == Status.Code.FINISHED:
121+
# finished_logs.append(log)
122+
# if len(finished_logs) > 1:
123+
# logger.warning(
124+
# "Found %s FINISHED status logs for rollout %s; expected at most 1. Using the first one.",
125+
# len(finished_logs),
126+
# row.execution_metadata.rollout_id,
127+
# )
145128
# Use the first log with status information
146129
status_log = status_logs[0]
147130
status_dict = status_log.get("status")
@@ -169,6 +152,8 @@ def _get_status() -> Dict[str, Any]:
169152
details=status_details,
170153
)
171154

155+
# then add the log extras to be stuffed into row.artifacts or something
156+
172157
logger.info("Stopping polling for rollout %s", row.execution_metadata.rollout_id)
173158
break
174159

0 commit comments

Comments
 (0)