Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,8 @@ def _log_eval_error(status: Status, rows: list[EvaluationRow] | None, passed: bo
finally:
if output_buffer:
await output_buffer.close()
await rollout_processor.acleanup()
rollout_processor.cleanup()

for res in priority_results:
run_idx = (res.execution_metadata.extra or {}).get("run_index", 0)
Expand Down Expand Up @@ -697,15 +699,19 @@ async def _collect_result(config, lst):
# Lazy import (cached after first import above)
from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor

if isinstance(rollout_processor, MCPGymRolloutProcessor):
# For MCPGymRolloutProcessor, create and execute tasks one at a time to avoid port conflicts
for run_idx in range(num_runs):
task = asyncio.create_task(execute_run(run_idx, config))
await task
else:
# For other processors, create all tasks at once and run in parallel
# Concurrency is now controlled by the shared semaphore in each rollout processor
await run_tasks_with_run_progress(execute_run, num_runs, config)
try:
if isinstance(rollout_processor, MCPGymRolloutProcessor):
# For MCPGymRolloutProcessor, create and execute tasks one at a time to avoid port conflicts
for run_idx in range(num_runs):
task = asyncio.create_task(execute_run(run_idx, config))
await task
else:
# For other processors, create all tasks at once and run in parallel
# Concurrency is now controlled by the shared semaphore in each rollout processor
await run_tasks_with_run_progress(execute_run, num_runs, config)
finally:
await rollout_processor.acleanup()
rollout_processor.cleanup()

experiment_duration_seconds = time.perf_counter() - experiment_start_time

Expand Down
8 changes: 6 additions & 2 deletions eval_protocol/pytest/evaluation_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,12 @@ async def execute_row_with_backoff_and_log(
yield result

finally:
await rollout_processor.acleanup()
rollout_processor.cleanup()
# Cleanup is intentionally NOT called here. rollout_processor_with_retry
# is invoked per-run, but the processor (and its session) is shared
# across parallel runs. Closing per-run would kill in-flight requests
# in other runs. Cleanup is called once after all runs complete in
# evaluation_test.py.
pass


def sanitize_filename(text: str) -> str:
Expand Down
4 changes: 4 additions & 0 deletions eval_protocol/pytest/exception_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def get_default_retryable_exceptions() -> Set[Type[Exception]]:
return _default_retryable_exceptions

# Lazy imports (these are expensive)
import aiohttp
import httpx
import litellm
import requests
Expand All @@ -32,6 +33,9 @@ def get_default_retryable_exceptions() -> Set[Type[Exception]]:
ConnectionError, # type: ignore[assignment]
TimeoutError, # type: ignore[assignment]
OSError, # type: ignore[assignment] # Covers network-related OS errors
# aiohttp library exceptions
aiohttp.ClientConnectionError,
aiohttp.ServerDisconnectedError,
# Requests library exceptions
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
Expand Down
5 changes: 3 additions & 2 deletions eval_protocol/pytest/remote_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
try:
session = self._get_or_create_session()
async with session.post(init_url, json=init_payload.model_dump(), timeout=timeout_init) as resp:
if resp.status >= 500:
body = await resp.text()
raise ConnectionError(f"Remote /init returned server error (HTTP {resp.status}): {body}")
if resp.status >= 400:
body = await resp.text()
raise RuntimeError(f"Remote /init failed (HTTP {resp.status}): {body}")
Expand Down Expand Up @@ -215,8 +218,6 @@ def cleanup(self) -> None:
loop = asyncio.get_running_loop()
loop.create_task(self._session.close())
except RuntimeError:
# No running event loop - can't safely close the session.
# The session will be garbage collected eventually, but warn about it.
logger.warning(
"RemoteRolloutProcessor.cleanup() called outside of async context. "
"Session may not be properly closed. Use `await processor.acleanup()` when possible."
Expand Down
15 changes: 0 additions & 15 deletions tests/pytest/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,6 @@ async def mock_task():
assert mock_config.logger.log.call_count == 1
mock_config.logger.log.assert_called_once_with(results[0])

# Verify async cleanup was called (aclose is preferred over cleanup)
mock_rollout_processor.acleanup.assert_awaited_once()

@pytest.mark.asyncio
async def test_logger_called_on_failed_execution(self, mock_rollout_processor, mock_config, sample_dataset):
"""Test that the logger is called when execution fails."""
Expand All @@ -98,9 +95,6 @@ async def failing_task():
assert results[0].rollout_status.code == 13 # INTERNAL error code
assert "Test error" in results[0].rollout_status.message

# Verify async cleanup was called (aclose is preferred over cleanup)
mock_rollout_processor.acleanup.assert_awaited_once()

@pytest.mark.asyncio
async def test_logger_called_on_retry_execution(self, mock_rollout_processor, mock_config, sample_dataset):
"""Test that the logger is called when execution succeeds after retry."""
Expand Down Expand Up @@ -135,9 +129,6 @@ async def flaky_task():
assert mock_config.logger.log.call_count == 1
mock_config.logger.log.assert_called_once_with(results[0])

# Verify async cleanup was called (aclose is preferred over cleanup)
mock_rollout_processor.acleanup.assert_awaited_once()

@pytest.mark.asyncio
async def test_logger_called_for_multiple_rows(self, mock_rollout_processor, mock_config):
"""Test that the logger is called for each row in a multi-row dataset."""
Expand Down Expand Up @@ -183,9 +174,6 @@ async def mock_task():
assert mock_config.logger.log.call_count == 2
assert len(results) == 2

# Verify async cleanup was called (aclose is preferred over cleanup)
mock_rollout_processor.acleanup.assert_awaited_once()

@pytest.mark.asyncio
async def test_logger_called_even_when_processor_fails_to_initialize(
self, mock_rollout_processor, mock_config, sample_dataset
Expand All @@ -198,6 +186,3 @@ async def test_logger_called_even_when_processor_fails_to_initialize(
with pytest.raises(RuntimeError, match="Processor failed to initialize"):
async for result in rollout_processor_with_retry(mock_rollout_processor, sample_dataset, mock_config):
pass

# Verify async cleanup was called even though the function failed
mock_rollout_processor.acleanup.assert_awaited_once()
Loading