|
1 | 1 | import asyncio |
2 | 2 | import time |
3 | | -from typing import Any, Dict, List, Optional |
| 3 | +from typing import List, Optional |
4 | 4 |
|
5 | | -import requests |
| 5 | +import aiohttp |
6 | 6 |
|
7 | 7 | from eval_protocol.models import EvaluationRow, Status |
8 | | -from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader |
9 | | -from eval_protocol.types.remote_rollout_processor import ( |
10 | | - DataLoaderConfig, |
11 | | -) |
12 | 8 | from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter |
13 | 9 | from eval_protocol.exceptions import exception_for_status_code |
14 | 10 |
|
@@ -51,6 +47,12 @@ def __init__( |
51 | 47 | self._poll_interval = poll_interval |
52 | 48 | self._timeout_seconds = timeout_seconds |
53 | 49 | self._tracing_adapter = FireworksTracingAdapter(base_url=self._model_base_url) |
| 50 | + self._session: Optional[aiohttp.ClientSession] = None |
| 51 | + |
| 52 | + def _get_or_create_session(self) -> aiohttp.ClientSession: |
| 53 | + if self._session is None or self._session.closed: |
| 54 | + self._session = aiohttp.ClientSession() |
| 55 | + return self._session |
54 | 56 |
|
55 | 57 | def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: |
56 | 58 | tasks: List[asyncio.Task[EvaluationRow]] = [] |
@@ -88,48 +90,26 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow: |
88 | 90 | init_payload = build_init_request(row, config, model_base_url) |
89 | 91 |
|
90 | 92 | # Fire-and-poll |
91 | | - def _post_init() -> None: |
92 | | - url = f"{remote_base_url}/init" |
93 | | - try: |
94 | | - r = requests.post(url, json=init_payload.model_dump(), timeout=300) |
95 | | - r.raise_for_status() |
96 | | - except requests.exceptions.Timeout: |
97 | | - raise TimeoutError( |
98 | | - f"The /init endpoint tried {url} with {init_payload.model_dump()} but timed out after 300 seconds." |
99 | | - ) |
100 | | - |
101 | | - await asyncio.to_thread(_post_init) |
| 93 | + init_url = f"{remote_base_url}/init" |
| 94 | + |
| 95 | + timeout_init = aiohttp.ClientTimeout(total=300) |
| 96 | + |
| 97 | + try: |
| 98 | + session = self._get_or_create_session() |
| 99 | + async with session.post(init_url, json=init_payload.model_dump(), timeout=timeout_init) as resp: |
| 100 | + if resp.status >= 400: |
| 101 | + body = await resp.text() |
| 102 | + raise RuntimeError(f"Remote /init failed (HTTP {resp.status}): {body}") |
| 103 | + resp.raise_for_status() |
| 104 | + await resp.read() # Drain the response body and release the connection back to the pool |
| 105 | + except asyncio.TimeoutError: |
| 106 | + raise TimeoutError( |
| 107 | + f"The /init endpoint tried {init_url} with {init_payload.model_dump()} but timed out after 300 seconds." |
| 108 | + ) |
102 | 109 |
|
103 | | - terminated = False |
104 | 110 | deadline = time.time() + timeout_seconds |
105 | 111 |
|
106 | | - def _get_status() -> Dict[str, Any]: |
107 | | - url = f"{remote_base_url}/status" |
108 | | - r = requests.get(url, params={"rollout_id": row.execution_metadata.rollout_id}, timeout=15) |
109 | | - r.raise_for_status() |
110 | | - return r.json() |
111 | | - |
112 | | - continue_polling_status = True |
113 | 112 | while time.time() < deadline: |
114 | | - try: |
115 | | - if continue_polling_status: |
116 | | - status = await asyncio.to_thread(_get_status) |
117 | | - terminated = bool(status.get("terminated", False)) |
118 | | - if terminated: |
119 | | - break |
120 | | - except requests.exceptions.HTTPError as e: |
121 | | - if e.response is not None and e.response.status_code == 404: |
122 | | - # 404 means server doesn't implement /status endpoint, stop polling |
123 | | - logger.debug( |
124 | | - f"Server doesn't implement /status endpoint (404), stopping status polling for rollout {row.execution_metadata.rollout_id}" |
125 | | - ) |
126 | | - continue_polling_status = False |
127 | | - else: |
128 | | - raise |
129 | | - except Exception: |
130 | | - # For all other exceptions, raise them |
131 | | - raise |
132 | | - |
133 | 113 | # Search Fireworks tracing logs for completion (run in thread to avoid blocking event loop) |
134 | 114 | completed_logs = await asyncio.to_thread( |
135 | 115 | self._tracing_adapter.search_logs, tags=[f"rollout_id:{row.execution_metadata.rollout_id}"] |
@@ -200,5 +180,21 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: |
200 | 180 | tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows] |
201 | 181 | return tasks |
202 | 182 |
|
| 183 | + async def acleanup(self) -> None: |
| 184 | + """Async cleanup - preferred when you can await.""" |
| 185 | + if self._session and not self._session.closed: |
| 186 | + await self._session.close() |
| 187 | + |
203 | 188 | def cleanup(self) -> None: |
204 | | - return None |
| 189 | + """Sync cleanup - best-effort, schedules close if event loop is running.""" |
| 190 | + if self._session and not self._session.closed: |
| 191 | + try: |
| 192 | + loop = asyncio.get_running_loop() |
| 193 | + loop.create_task(self._session.close()) |
| 194 | + except RuntimeError: |
| 195 | + # No running event loop - can't safely close the session. |
| 196 | + # The session will be garbage collected eventually, but warn about it. |
| 197 | + logger.warning( |
| 198 | + "RemoteRolloutProcessor.cleanup() called outside of async context. " |
| 199 | + "Session may not be properly closed. Use `await processor.acleanup()` when possible." |
| 200 | + ) |
0 commit comments