Skip to content

Commit 3c8d8f2

Browse files
authored
retry on connection errors (#436)
* retry on connection errors * try * fix * update * attempted fix * update
1 parent 69cb5dc commit 3c8d8f2

File tree

5 files changed

+28
-28
lines changed

5 files changed

+28
-28
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,8 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
449449
finally:
450450
if output_buffer:
451451
await output_buffer.close()
452+
await rollout_processor.acleanup()
453+
rollout_processor.cleanup()
452454

453455
for res in priority_results:
454456
run_idx = (res.execution_metadata.extra or {}).get("run_index", 0)
@@ -697,15 +699,19 @@ async def _collect_result(config, lst):
697699
# Lazy import (cached after first import above)
698700
from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
699701

700-
if isinstance(rollout_processor, MCPGymRolloutProcessor):
701-
# For MCPGymRolloutProcessor, create and execute tasks one at a time to avoid port conflicts
702-
for run_idx in range(num_runs):
703-
task = asyncio.create_task(execute_run(run_idx, config))
704-
await task
705-
else:
706-
# For other processors, create all tasks at once and run in parallel
707-
# Concurrency is now controlled by the shared semaphore in each rollout processor
708-
await run_tasks_with_run_progress(execute_run, num_runs, config)
702+
try:
703+
if isinstance(rollout_processor, MCPGymRolloutProcessor):
704+
# For MCPGymRolloutProcessor, create and execute tasks one at a time to avoid port conflicts
705+
for run_idx in range(num_runs):
706+
task = asyncio.create_task(execute_run(run_idx, config))
707+
await task
708+
else:
709+
# For other processors, create all tasks at once and run in parallel
710+
# Concurrency is now controlled by the shared semaphore in each rollout processor
711+
await run_tasks_with_run_progress(execute_run, num_runs, config)
712+
finally:
713+
await rollout_processor.acleanup()
714+
rollout_processor.cleanup()
709715

710716
experiment_duration_seconds = time.perf_counter() - experiment_start_time
711717

eval_protocol/pytest/evaluation_test_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,8 +476,12 @@ async def execute_row_with_backoff_and_log(
476476
yield result
477477

478478
finally:
479-
await rollout_processor.acleanup()
480-
rollout_processor.cleanup()
479+
# Cleanup is intentionally NOT called here. rollout_processor_with_retry
480+
# is invoked per-run, but the processor (and its session) is shared
481+
# across parallel runs. Closing per-run would kill in-flight requests
482+
# in other runs. Cleanup is called once after all runs complete in
483+
# evaluation_test.py.
484+
pass
481485

482486

483487
def sanitize_filename(text: str) -> str:

eval_protocol/pytest/exception_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def get_default_retryable_exceptions() -> Set[Type[Exception]]:
2323
return _default_retryable_exceptions
2424

2525
# Lazy imports (these are expensive)
26+
import aiohttp
2627
import httpx
2728
import litellm
2829
import requests
@@ -32,6 +33,9 @@ def get_default_retryable_exceptions() -> Set[Type[Exception]]:
3233
ConnectionError, # type: ignore[assignment]
3334
TimeoutError, # type: ignore[assignment]
3435
OSError, # type: ignore[assignment] # Covers network-related OS errors
36+
# aiohttp library exceptions
37+
aiohttp.ClientConnectionError,
38+
aiohttp.ServerDisconnectedError,
3539
# Requests library exceptions
3640
requests.exceptions.ConnectionError,
3741
requests.exceptions.Timeout,

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
104104
try:
105105
session = self._get_or_create_session()
106106
async with session.post(init_url, json=init_payload.model_dump(), timeout=timeout_init) as resp:
107+
if resp.status >= 500:
108+
body = await resp.text()
109+
raise ConnectionError(f"Remote /init returned server error (HTTP {resp.status}): {body}")
107110
if resp.status >= 400:
108111
body = await resp.text()
109112
raise RuntimeError(f"Remote /init failed (HTTP {resp.status}): {body}")
@@ -215,8 +218,6 @@ def cleanup(self) -> None:
215218
loop = asyncio.get_running_loop()
216219
loop.create_task(self._session.close())
217220
except RuntimeError:
218-
# No running event loop - can't safely close the session.
219-
# The session will be garbage collected eventually, but warn about it.
220221
logger.warning(
221222
"RemoteRolloutProcessor.cleanup() called outside of async context. "
222223
"Session may not be properly closed. Use `await processor.acleanup()` when possible."

tests/pytest/test_utils.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,6 @@ async def mock_task():
7272
assert mock_config.logger.log.call_count == 1
7373
mock_config.logger.log.assert_called_once_with(results[0])
7474

75-
# Verify async cleanup was called (aclose is preferred over cleanup)
76-
mock_rollout_processor.acleanup.assert_awaited_once()
77-
7875
@pytest.mark.asyncio
7976
async def test_logger_called_on_failed_execution(self, mock_rollout_processor, mock_config, sample_dataset):
8077
"""Test that the logger is called when execution fails."""
@@ -98,9 +95,6 @@ async def failing_task():
9895
assert results[0].rollout_status.code == 13 # INTERNAL error code
9996
assert "Test error" in results[0].rollout_status.message
10097

101-
# Verify async cleanup was called (aclose is preferred over cleanup)
102-
mock_rollout_processor.acleanup.assert_awaited_once()
103-
10498
@pytest.mark.asyncio
10599
async def test_logger_called_on_retry_execution(self, mock_rollout_processor, mock_config, sample_dataset):
106100
"""Test that the logger is called when execution succeeds after retry."""
@@ -135,9 +129,6 @@ async def flaky_task():
135129
assert mock_config.logger.log.call_count == 1
136130
mock_config.logger.log.assert_called_once_with(results[0])
137131

138-
# Verify async cleanup was called (aclose is preferred over cleanup)
139-
mock_rollout_processor.acleanup.assert_awaited_once()
140-
141132
@pytest.mark.asyncio
142133
async def test_logger_called_for_multiple_rows(self, mock_rollout_processor, mock_config):
143134
"""Test that the logger is called for each row in a multi-row dataset."""
@@ -183,9 +174,6 @@ async def mock_task():
183174
assert mock_config.logger.log.call_count == 2
184175
assert len(results) == 2
185176

186-
# Verify async cleanup was called (aclose is preferred over cleanup)
187-
mock_rollout_processor.acleanup.assert_awaited_once()
188-
189177
@pytest.mark.asyncio
190178
async def test_logger_called_even_when_processor_fails_to_initialize(
191179
self, mock_rollout_processor, mock_config, sample_dataset
@@ -198,6 +186,3 @@ async def test_logger_called_even_when_processor_fails_to_initialize(
198186
with pytest.raises(RuntimeError, match="Processor failed to initialize"):
199187
async for result in rollout_processor_with_retry(mock_rollout_processor, sample_dataset, mock_config):
200188
pass
201-
202-
# Verify async cleanup was called even though the function failed
203-
mock_rollout_processor.acleanup.assert_awaited_once()

0 commit comments

Comments
 (0)