Skip to content

Commit cf4fc4e

Browse files
authored
use aiohttp in remote rollout processor to fix the open files issue (#401)
* use aiohttp in remote rollout processor to fix the open files issue * update * remove comments * fix test * fixed * update * rename * update
1 parent 5e31311 commit cf4fc4e

File tree

5 files changed

+61
-58
lines changed

5 files changed

+61
-58
lines changed

eval_protocol/pytest/evaluation_test_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ async def execute_row_with_backoff_retry(row: EvaluationRow) -> EvaluationRow:
371371
retry_config = replace(config, kwargs={**(config.kwargs or {}), "start_server": False})
372372
retry_tasks = rollout_processor([row], retry_config)
373373
result = await retry_tasks[0]
374-
374+
375375
# Apply post-processing quality checks if configured
376376
# This must be inside the retry function so ResponseQualityError can trigger retries
377377
if config.post_processor is not None:
@@ -380,7 +380,7 @@ async def execute_row_with_backoff_retry(row: EvaluationRow) -> EvaluationRow:
380380
except ResponseQualityError as quality_error:
381381
# Re-raise ResponseQualityError to trigger retry logic
382382
raise quality_error
383-
383+
384384
return result
385385

386386
async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: EvaluationRow) -> EvaluationRow:
@@ -464,6 +464,7 @@ async def execute_row_with_backoff_and_log(
464464
yield result
465465

466466
finally:
467+
await rollout_processor.acleanup()
467468
rollout_processor.cleanup()
468469

469470

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 41 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
import asyncio
22
import time
3-
from typing import Any, Dict, List, Optional
3+
from typing import List, Optional
44

5-
import requests
5+
import aiohttp
66

77
from eval_protocol.models import EvaluationRow, Status
8-
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
9-
from eval_protocol.types.remote_rollout_processor import (
10-
DataLoaderConfig,
11-
)
128
from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter
139
from eval_protocol.exceptions import exception_for_status_code
1410

@@ -51,6 +47,12 @@ def __init__(
5147
self._poll_interval = poll_interval
5248
self._timeout_seconds = timeout_seconds
5349
self._tracing_adapter = FireworksTracingAdapter(base_url=self._model_base_url)
50+
self._session: Optional[aiohttp.ClientSession] = None
51+
52+
def _get_or_create_session(self) -> aiohttp.ClientSession:
53+
if self._session is None or self._session.closed:
54+
self._session = aiohttp.ClientSession()
55+
return self._session
5456

5557
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
5658
tasks: List[asyncio.Task[EvaluationRow]] = []
@@ -88,48 +90,26 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
8890
init_payload = build_init_request(row, config, model_base_url)
8991

9092
# Fire-and-poll
91-
def _post_init() -> None:
92-
url = f"{remote_base_url}/init"
93-
try:
94-
r = requests.post(url, json=init_payload.model_dump(), timeout=300)
95-
r.raise_for_status()
96-
except requests.exceptions.Timeout:
97-
raise TimeoutError(
98-
f"The /init endpoint tried {url} with {init_payload.model_dump()} but timed out after 300 seconds."
99-
)
100-
101-
await asyncio.to_thread(_post_init)
93+
init_url = f"{remote_base_url}/init"
94+
95+
timeout_init = aiohttp.ClientTimeout(total=300)
96+
97+
try:
98+
session = self._get_or_create_session()
99+
async with session.post(init_url, json=init_payload.model_dump(), timeout=timeout_init) as resp:
100+
if resp.status >= 400:
101+
body = await resp.text()
102+
raise RuntimeError(f"Remote /init failed (HTTP {resp.status}): {body}")
103+
resp.raise_for_status()
104+
await resp.read() # Drain the response body and release the connection back to the pool
105+
except asyncio.TimeoutError:
106+
raise TimeoutError(
107+
f"The /init endpoint tried {init_url} with {init_payload.model_dump()} but timed out after 300 seconds."
108+
)
102109

103-
terminated = False
104110
deadline = time.time() + timeout_seconds
105111

106-
def _get_status() -> Dict[str, Any]:
107-
url = f"{remote_base_url}/status"
108-
r = requests.get(url, params={"rollout_id": row.execution_metadata.rollout_id}, timeout=15)
109-
r.raise_for_status()
110-
return r.json()
111-
112-
continue_polling_status = True
113112
while time.time() < deadline:
114-
try:
115-
if continue_polling_status:
116-
status = await asyncio.to_thread(_get_status)
117-
terminated = bool(status.get("terminated", False))
118-
if terminated:
119-
break
120-
except requests.exceptions.HTTPError as e:
121-
if e.response is not None and e.response.status_code == 404:
122-
# 404 means server doesn't implement /status endpoint, stop polling
123-
logger.debug(
124-
f"Server doesn't implement /status endpoint (404), stopping status polling for rollout {row.execution_metadata.rollout_id}"
125-
)
126-
continue_polling_status = False
127-
else:
128-
raise
129-
except Exception:
130-
# For all other exceptions, raise them
131-
raise
132-
133113
# Search Fireworks tracing logs for completion (run in thread to avoid blocking event loop)
134114
completed_logs = await asyncio.to_thread(
135115
self._tracing_adapter.search_logs, tags=[f"rollout_id:{row.execution_metadata.rollout_id}"]
@@ -200,5 +180,21 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
200180
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
201181
return tasks
202182

183+
async def acleanup(self) -> None:
184+
"""Async cleanup - preferred when you can await."""
185+
if self._session and not self._session.closed:
186+
await self._session.close()
187+
203188
def cleanup(self) -> None:
204-
return None
189+
"""Sync cleanup - best-effort, schedules close if event loop is running."""
190+
if self._session and not self._session.closed:
191+
try:
192+
loop = asyncio.get_running_loop()
193+
loop.create_task(self._session.close())
194+
except RuntimeError:
195+
# No running event loop - can't safely close the session.
196+
# The session will be garbage collected eventually, but warn about it.
197+
logger.warning(
198+
"RemoteRolloutProcessor.cleanup() called outside of async context. "
199+
"Session may not be properly closed. Use `await processor.acleanup()` when possible."
200+
)

eval_protocol/pytest/rollout_processor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ def __call__(self, rows: list[EvaluationRow], config: RolloutProcessorConfig) ->
1919
"""Process evaluation rows and return async tasks. Must be implemented by subclasses."""
2020
pass
2121

22+
async def acleanup(self) -> None:
23+
"""Async cleanup - preferred when you can await."""
24+
pass
25+
2226
def cleanup(self) -> None:
2327
"""Cleanup resources. Override in subclasses if cleanup is needed."""
2428
pass

eval_protocol/training/gepa_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,7 @@ async def evaluate_with_ep(
503503
}
504504

505505
finally:
506+
await rollout_processor.acleanup()
506507
rollout_processor.cleanup()
507508

508509
def run_ep_evaluation(

tests/pytest/test_utils.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import asyncio
2-
from unittest.mock import MagicMock
2+
from unittest.mock import AsyncMock, MagicMock
33
import pytest
44

55
from eval_protocol.pytest.evaluation_test_utils import rollout_processor_with_retry
@@ -16,6 +16,7 @@ def mock_rollout_processor(self):
1616
"""Create a mock rollout processor that returns async tasks."""
1717
processor = MagicMock()
1818
processor.cleanup = MagicMock()
19+
processor.acleanup = AsyncMock() # async cleanup method
1920
return processor
2021

2122
@pytest.fixture
@@ -71,8 +72,8 @@ async def mock_task():
7172
assert mock_config.logger.log.call_count == 1
7273
mock_config.logger.log.assert_called_once_with(results[0])
7374

74-
# Verify cleanup was called
75-
mock_rollout_processor.cleanup.assert_called_once()
75+
# Verify async cleanup was called (aclose is preferred over cleanup)
76+
mock_rollout_processor.acleanup.assert_awaited_once()
7677

7778
@pytest.mark.asyncio
7879
async def test_logger_called_on_failed_execution(self, mock_rollout_processor, mock_config, sample_dataset):
@@ -97,8 +98,8 @@ async def failing_task():
9798
assert results[0].rollout_status.code == 13 # INTERNAL error code
9899
assert "Test error" in results[0].rollout_status.message
99100

100-
# Verify cleanup was called
101-
mock_rollout_processor.cleanup.assert_called_once()
101+
# Verify async cleanup was called (aclose is preferred over cleanup)
102+
mock_rollout_processor.acleanup.assert_awaited_once()
102103

103104
@pytest.mark.asyncio
104105
async def test_logger_called_on_retry_execution(self, mock_rollout_processor, mock_config, sample_dataset):
@@ -134,8 +135,8 @@ async def flaky_task():
134135
assert mock_config.logger.log.call_count == 1
135136
mock_config.logger.log.assert_called_once_with(results[0])
136137

137-
# Verify cleanup was called
138-
mock_rollout_processor.cleanup.assert_called_once()
138+
# Verify async cleanup was called (aclose is preferred over cleanup)
139+
mock_rollout_processor.acleanup.assert_awaited_once()
139140

140141
@pytest.mark.asyncio
141142
async def test_logger_called_for_multiple_rows(self, mock_rollout_processor, mock_config):
@@ -182,8 +183,8 @@ async def mock_task():
182183
assert mock_config.logger.log.call_count == 2
183184
assert len(results) == 2
184185

185-
# Verify cleanup was called
186-
mock_rollout_processor.cleanup.assert_called_once()
186+
# Verify async cleanup was called (aclose is preferred over cleanup)
187+
mock_rollout_processor.acleanup.assert_awaited_once()
187188

188189
@pytest.mark.asyncio
189190
async def test_logger_called_even_when_processor_fails_to_initialize(
@@ -198,5 +199,5 @@ async def test_logger_called_even_when_processor_fails_to_initialize(
198199
async for result in rollout_processor_with_retry(mock_rollout_processor, sample_dataset, mock_config):
199200
pass
200201

201-
# Verify cleanup was called even though the function failed
202-
mock_rollout_processor.cleanup.assert_called_once()
202+
# Verify async cleanup was called even though the function failed
203+
mock_rollout_processor.acleanup.assert_awaited_once()

0 commit comments

Comments
 (0)