diff --git a/CHANGES.rst b/CHANGES.rst index 6e2a7cc..e03cf1a 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,10 @@ CHANGES .. towncrier release notes start +0.8.3 (2026-03-10) +================== +- Fix exception handling in acquire context manager to properly handle asyncio.CancelledError. + 0.8.2 (2024-05-07) ================== - Fix a static typing error with ``Client.get()``. diff --git a/aiomcache/__init__.py b/aiomcache/__init__.py index aee4b68..3bba337 100644 --- a/aiomcache/__init__.py +++ b/aiomcache/__init__.py @@ -14,4 +14,4 @@ __all__ = ("Client", "ClientException", "FlagClient", "ValidationException") -__version__ = "0.8.2" +__version__ = "0.8.3" diff --git a/aiomcache/client.py b/aiomcache/client.py index 78896f4..c47b401 100644 --- a/aiomcache/client.py +++ b/aiomcache/client.py @@ -1,3 +1,4 @@ +import asyncio import functools import re import sys @@ -35,7 +36,7 @@ async def wrapper(self: _Client, *args: _P.args, # type: ignore[misc] conn = await self._pool.acquire() try: return await func(self, conn, *args, **kwargs) - except Exception as exc: + except (Exception, asyncio.CancelledError) as exc: conn[0].set_exception(exc) raise finally: diff --git a/tests/pool_test.py b/tests/pool_test.py index bc57106..6677f8e 100644 --- a/tests/pool_test.py +++ b/tests/pool_test.py @@ -1,6 +1,8 @@ import asyncio import random import socket +from typing import NoReturn +from unittest.mock import create_autospec, call import pytest @@ -149,3 +151,43 @@ async def test_bad_connection(mcache_params: McacheParams) -> None: assert isinstance(conn.writer, asyncio.StreamWriter) pool.release(conn) assert pool.size() == 0 + + +@pytest.mark.parametrize( + "exc_type,should_catch", + ( + (Exception, True), + (asyncio.CancelledError, True), + (BaseException, False), + (KeyboardInterrupt, False), + ), +) +async def test_acquire_catch_exc_from_task( + mcache_params: McacheParams, exc_type: type[BaseException], should_catch: bool +) -> None: + mock_conn = create_autospec(Connection, spec_set=True, instance=True) + + mock_pool = create_autospec(MemcachePool, spec_set=True, instance=True) + mock_pool.acquire.return_value = mock_conn + + exception_message = f"{exc_type.__name__} from acquire" + exception_instance = exc_type(exception_message) + + class TestClient(Client): + def __init__(self, pool_size: int = 4): + self._pool = mock_pool + + @acquire + async def acquire_wait_release(self, conn: Connection) -> NoReturn: + raise exception_instance + + pool_size = 4 + client = TestClient(pool_size=pool_size) + with pytest.raises(exc_type) as exc_info: + await client.acquire_wait_release() + + assert str(exc_info.value) == exception_message + + expected = [call(exception_instance)] if should_catch else [] + assert mock_conn[0].set_exception.call_args_list == expected + assert mock_pool.release.call_args_list == [call(mock_conn)]