diff --git a/src/qasync/__init__.py b/src/qasync/__init__.py index d35a5ee..21f5a30 100644 --- a/src/qasync/__init__.py +++ b/src/qasync/__init__.py @@ -22,6 +22,7 @@ import time from concurrent.futures import Future from queue import Queue +from threading import Lock from typing import TYPE_CHECKING, Literal, Tuple, cast, get_args logger = logging.getLogger(__name__) @@ -172,51 +173,83 @@ def __init__(self, max_workers=10, stack_size=None): self.__workers = [ _QThreadWorker(self.__queue, i + 1, stack_size) for i in range(max_workers) ] + self.__shutdown_lock = Lock() self.__been_shutdown = False for w in self.__workers: w.start() def submit(self, callback, *args, **kwargs): - if self.__been_shutdown: - raise RuntimeError("QThreadExecutor has been shutdown") + with self.__shutdown_lock: + if self.__been_shutdown: + raise RuntimeError("QThreadExecutor has been shutdown") - future = Future() - self._logger.debug( - "Submitting callback %s with args %s and kwargs %s to thread worker queue", - callback, - args, - kwargs, - ) - self.__queue.put((future, callback, args, kwargs)) - return future + future = Future() + self._logger.debug( + "Submitting callback %s with args %s and kwargs %s to thread worker queue", + callback, + args, + kwargs, + ) + self.__queue.put((future, callback, args, kwargs)) + return future def map(self, func, *iterables, timeout=None): - raise NotImplementedError("use as_completed on the event loop") - - def shutdown(self, wait=True): - if self.__been_shutdown: - raise RuntimeError("QThreadExecutor has been shutdown") - - self.__been_shutdown = True + deadline = time.monotonic() + timeout if timeout is not None else None + futures = [self.submit(func, *args) for args in zip(*iterables)] - self._logger.debug("Shutting down") - for i in range(len(self.__workers)): - # Signal workers to stop - self.__queue.put(None) - if wait: - for w in self.__workers: - w.wait() + # must have generator as a closure so that the submit occurs before first iteration + def generator(): + try: + futures.reverse() + while futures: + if deadline is not None: + yield _result_or_cancel( + futures.pop(), timeout=deadline - time.monotonic() + ) + else: + yield _result_or_cancel(futures.pop()) + finally: + for future in futures: + future.cancel() + + return generator() + + def shutdown(self, wait=True, *, cancel_futures=False): + with self.__shutdown_lock: + self.__been_shutdown = True + self._logger.debug("Shutting down") + if cancel_futures: + # pop all the futures and cancel them + while not self.__queue.empty(): + item = self.__queue.get_nowait() + if item is not None: + future, _, _, _ = item + future.cancel() + for i in range(len(self.__workers)): + # Signal workers to stop + self.__queue.put(None) + if wait: + for w in self.__workers: + w.wait() def __enter__(self, *args): - if self.__been_shutdown: - raise RuntimeError("QThreadExecutor has been shutdown") return self def __exit__(self, *args): self.shutdown() +def _result_or_cancel(fut, timeout=None): + try: + try: + return fut.result(timeout) + finally: + fut.cancel() + finally: + del fut # break reference cycle in exceptions + + def _format_handle(handle: asyncio.Handle): cb = getattr(handle, "_callback", None) if isinstance(getattr(cb, "__self__", None), asyncio.tasks.Task): diff --git a/tests/test_qthreadexec.py b/tests/test_qthreadexec.py index 67c1833..973bddb 100644 --- a/tests/test_qthreadexec.py +++ b/tests/test_qthreadexec.py @@ -5,6 +5,8 @@ import logging import threading import weakref +from concurrent.futures import Future, TimeoutError +from unittest.mock import Mock, patch import pytest @@ -44,15 +46,28 @@ def shutdown_executor(): return exe -def test_shutdown_after_shutdown(shutdown_executor): - with pytest.raises(RuntimeError): - shutdown_executor.shutdown() +@pytest.fixture +def executor0(): + """ + Provides a QThreadExecutor with max_workers=0 for deterministic testing. + """ + executor = qasync.QThreadExecutor(max_workers=0) + try: + yield executor + finally: + executor.shutdown(wait=True, cancel_futures=False) + + +@pytest.mark.parametrize("wait", [True, False]) +def test_shutdown_after_shutdown(shutdown_executor, wait): + # it is safe to shutdown twice + shutdown_executor.shutdown(wait=wait) def test_ctx_after_shutdown(shutdown_executor): - with pytest.raises(RuntimeError): - with shutdown_executor: - pass + # it is safe to enter and exit the context after shutdown + with shutdown_executor: + pass def test_submit_after_shutdown(shutdown_executor): @@ -104,3 +119,143 @@ def test_no_stale_reference_as_result(executor, disable_executor_logging): assert collected is True, ( "Stale reference to executor result not collected within timeout." ) + + +def test_context(executor): + """Test that the context manager will shutdown executor""" + with executor: + f = executor.submit(lambda: 42) + assert f.result() == 42 + + # it can be entered again + with executor: + # but will fail when we submit + with pytest.raises(RuntimeError): + executor.submit(lambda: 42) + + +@pytest.mark.parametrize("cancel", [True, False]) +def test_shutdown_cancel_futures(executor0, cancel): + """Test that shutdown with cancel_futures=True cancels all remaining futures in the queue.""" + + futures = [executor0.submit(lambda: None) for _ in range(10)] + + # Shutdown with cancel_futures parameter + executor0.shutdown(wait=False, cancel_futures=cancel) + + if cancel: + # All futures should be cancelled since no workers consumed them + cancelled_count = sum(1 for f in futures if f.cancelled()) + assert cancelled_count == 10, ( + f"Expected all 10 futures to be cancelled, got {cancelled_count}" + ) + else: + # No futures should be cancelled, they should still be pending + cancelled_count = sum(1 for f in futures if f.cancelled()) + assert cancelled_count == 0, ( + f"Expected no futures to be cancelled, got {cancelled_count}" + ) + + +def test_map(executor): + """Basic test of executor map functionality""" + results = list(executor.map(lambda x: x + 1, range(10))) + assert results == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + results = list(executor.map(lambda x, y: x + y, range(10), range(9))) + assert results == [0, 2, 4, 6, 8, 10, 12, 14, 16] + + +def test_map_timeout(executor0): + """Test that map with timeout propagates the timeout parameter to future.result()""" + + f = Mock(spec=Future) + f.result = Mock(side_effect=TimeoutError("Timeout")) + f.cancel = Mock(return_value=True) + + with patch.object(executor0, "submit", return_value=f): + with pytest.raises(TimeoutError, match="Timeout"): + list(executor0.map(lambda x: x, [1], timeout=0.5)) + + # Verify the timeout parameter was passed to result() (not None) + # Note: The timeout is calculated as (deadline - time.monotonic()), so it will be + # slightly less than 0.5 due to the time taken to submit futures and start iteration + assert f.result.called + f_timeout = f.result.call_args[0][0] if f.result.call_args[0] else None + assert f_timeout is not None + assert f_timeout <= 0.5 + + +def test_map_error(executor0): + """Test that map with an exception will raise, and remaining tasks are cancelled""" + + # Create 3 futures: one success, one exception, one to be cancelled + mock_futures = [] + + # First future succeeds + f0 = Mock(spec=Future) + f0.result = Mock(return_value=0) + f0.cancel = Mock(return_value=True) + mock_futures.append(f0) + + # Second future raises an exception + f1 = Future() + f1.set_exception(ValueError("Test error")) + mock_futures.append(f1) + + # Third future should be cancelled + f2 = Mock(spec=Future) + f2.result = Mock(return_value=2) + f2.cancel = Mock(return_value=True) + mock_futures.append(f2) + + with patch.object(executor0, "submit", side_effect=mock_futures): + with pytest.raises(ValueError, match="Test error"): + list(executor0.map(lambda x: x, range(3))) + + # Verify the third future was cancelled when the exception occurred + assert f2.cancel.called, "Future after exception should have been cancelled" + + +def test_map_start(executor0): + """Test that map starts tasks immediately, before iterating""" + + # Mock future that returns immediately + mock_future = Mock(spec=Future) + mock_future.result = Mock(return_value=0) + mock_future.cancel = Mock(return_value=True) + + with patch.object(executor0, "submit", return_value=mock_future) as mock_submit: + # Create the map - submit should be called immediately + m = executor0.map(lambda x: x, range(1)) + + # Verify submit was called before we start iterating + mock_submit.assert_called_once() + + # Now iterate to verify the result + assert list(m) == [0] + + +def test_map_close(executor0): + """Test that closing a running map cancels all remaining tasks.""" + + # Create mock futures with proper result() method + mock_futures = [] + for i in range(10): + mock_future = Mock(spec=Future) + mock_future.cancel = Mock(return_value=True) + mock_future.result = Mock(return_value=i) + mock_futures.append(mock_future) + + # Mock submit to return our pre-created futures + with patch.object(executor0, "submit", side_effect=mock_futures): + m = executor0.map(lambda x: x, range(10)) + # must start the generator so that close() has any effect + assert next(m) == 0 + m.close() + + # All futures should have cancel() called: + # - The first one via _result_or_cancel after next() consumed it + # - The rest via the finally block when the generator is closed + for i, f in enumerate(mock_futures): + assert f.cancel.called, f"Future {i} should have been cancelled"