diff --git a/faststream_concurrent_aiokafka/processing.py b/faststream_concurrent_aiokafka/processing.py index 571f0a5..3bd6f52 100644 --- a/faststream_concurrent_aiokafka/processing.py +++ b/faststream_concurrent_aiokafka/processing.py @@ -2,7 +2,6 @@ import functools import logging import signal -import time import typing from faststream.kafka import ConsumerRecord, TopicPartition @@ -33,7 +32,11 @@ def __init__( raise ValueError(msg) self._limiter = asyncio.Semaphore(concurrency_limit) - self._current_tasks: set[asyncio.Task[typing.Any]] = set() + # Counter + Event replace the old _current_tasks set: shutdown waits on the event, + # which is set once every tracked task has fired its done-callback. + self._tracked_count: int = 0 + self._all_done_event: asyncio.Event = asyncio.Event() + self._all_done_event.set() # 0 tasks ⇒ "all done" is True self._is_running: bool = False self._committer: KafkaBatchCommitter = committer self._stop_task: asyncio.Task[typing.Any] | None = None @@ -41,16 +44,11 @@ def __init__( async def wait_for_subtasks(self) -> None: logger.info("Kafka middleware. Gracefully waiting for tasks to end...") - deadline = time.monotonic() + self._shutdown_timeout_sec try: - pending = [t for t in self._current_tasks if not t.done()] - while pending: - remaining = max(deadline - time.monotonic(), 0) - await asyncio.wait_for( - asyncio.gather(*pending, return_exceptions=True), - timeout=remaining, - ) - pending = [t for t in self._current_tasks if not t.done()] + await asyncio.wait_for( + self._all_done_event.wait(), + timeout=self._shutdown_timeout_sec, + ) except TimeoutError: logger.exception("Kafka middleware. Whoops, some tasks haven't finished in graceful time, sorry") @@ -60,7 +58,9 @@ def _finish_task(self, task: asyncio.Task[typing.Any]) -> None: exc: typing.Final[BaseException | None] = task.exception() if exc: logger.error("Kafka middleware. Task has failed with the exception", exc_info=exc) - self._current_tasks.discard(task) + self._tracked_count -= 1 + if self._tracked_count == 0: + self._all_done_event.set() async def handle_task( self, @@ -70,7 +70,12 @@ async def handle_task( ) -> None: await self._limiter.acquire() task: typing.Final = asyncio.ensure_future(coroutine) - self._current_tasks.add(task) + # Increment + clear before add_done_callback. add_done_callback fires synchronously + # if the task is already done; that path then immediately decrements back to a + # consistent state. Reverse order would skew the count for a synchronously-finished + # task. + self._tracked_count += 1 + self._all_done_event.clear() task.add_done_callback(self._finish_task) try: await self._committer.send_task( diff --git a/tests/test_concurrent_processing.py b/tests/test_concurrent_processing.py index a3ef101..016585b 100644 --- a/tests/test_concurrent_processing.py +++ b/tests/test_concurrent_processing.py @@ -47,6 +47,16 @@ def sample_record() -> MockConsumerRecord: return MockConsumerRecord() +def _track_external(handler: KafkaConcurrentHandler, task: asyncio.Task[typing.Any]) -> None: + """Register an externally-created task with the handler's count+event tracking. + + Mirrors the bookkeeping that handle_task does so wait_for_subtasks waits for it. + """ + handler._tracked_count += 1 + handler._all_done_event.clear() + task.add_done_callback(handler._finish_task) + + def test_concurrent_init_zero_concurrency_limit_raises() -> None: with pytest.raises(ValueError, match="concurrency_limit must be >= 1"): KafkaConcurrentHandler(committer=MockKafkaBatchCommitter(), concurrency_limit=0) # ty: ignore[invalid-argument-type] @@ -82,13 +92,17 @@ async def test_concurrent_failed_task_exception( assert "Task has failed with the exception" in caplog.text -async def test_concurrent_removes_task_from_set(handler: KafkaConcurrentHandler) -> None: +async def test_concurrent_finish_task_decrements_and_sets_done_event(handler: KafkaConcurrentHandler) -> None: mock_task: typing.Final = MagicMock() mock_task.cancelled.return_value = False mock_task.exception.return_value = None - handler._current_tasks.add(mock_task) + handler._tracked_count = 1 + handler._all_done_event.clear() + handler._finish_task(mock_task) - assert mock_task not in handler._current_tasks + + assert handler._tracked_count == 0 + assert handler._all_done_event.is_set() async def test_concurrent_creates_task( @@ -98,10 +112,11 @@ async def coro() -> str: return "result" await handler.handle_task(coro(), sample_record, sample_message) # ty: ignore[invalid-argument-type] - assert len(handler._current_tasks) == 1 + assert handler._tracked_count == 1 + assert not handler._all_done_event.is_set() -async def test_concurrent_task_added_to_set( +async def test_concurrent_task_passed_to_committer( handler: KafkaConcurrentHandler, sample_message: MockKafkaMessage, sample_record: MockConsumerRecord ) -> None: async def coro() -> str: @@ -109,8 +124,8 @@ async def coro() -> str: await handler.handle_task(coro(), sample_record, sample_message) # ty: ignore[invalid-argument-type] - task: typing.Final = next(iter(handler._current_tasks)) - assert isinstance(task, asyncio.Task) + sent_commit_task: typing.Final = handler._committer.send_task.call_args[0][0] # ty: ignore[unresolved-attribute] + assert isinstance(sent_commit_task.asyncio_task, asyncio.Task) async def test_concurrent_done_callback_added( @@ -121,8 +136,8 @@ async def coro() -> str: await handler.handle_task(coro(), sample_record, sample_message) # ty: ignore[invalid-argument-type] - task: typing.Final = next(iter(handler._current_tasks)) - assert len(task._callbacks) > 0 + sent_commit_task: typing.Final = handler._committer.send_task.call_args[0][0] # ty: ignore[unresolved-attribute] + assert len(sent_commit_task.asyncio_task._callbacks) > 0 async def test_concurrent_acquires_limiter_when_limited( @@ -263,10 +278,11 @@ async def task2() -> str: results.append(2) return "task2" - handler._current_tasks.add(asyncio.create_task(task1())) - handler._current_tasks.add(asyncio.create_task(task2())) + _track_external(handler, asyncio.create_task(task1())) + _track_external(handler, asyncio.create_task(task2())) await handler.wait_for_subtasks() assert len(results) == expected_tasks_len + assert handler._tracked_count == 0 async def test_concurrent_handles_task_exceptions( @@ -278,9 +294,10 @@ async def failing_task() -> typing.Never: msg: typing.Final = "Task failed" raise ValueError(msg) - handler._current_tasks.add(asyncio.create_task(failing_task())) + failing: typing.Final = asyncio.create_task(failing_task()) + _track_external(handler, failing) await handler.wait_for_subtasks() - assert handler._current_tasks.pop().done() + assert failing.done() async def test_concurrent_wait_for_subtasks_drains_tasks_added_during_wait( @@ -300,13 +317,9 @@ async def late() -> None: async def inject_during_wait() -> None: await asyncio.sleep(0.01) - late_task: typing.Final = asyncio.create_task(late()) - handler._current_tasks.add(late_task) - late_task.add_done_callback(handler._finish_task) + _track_external(handler, asyncio.create_task(late())) - initial_task: typing.Final = asyncio.create_task(initial()) - handler._current_tasks.add(initial_task) - initial_task.add_done_callback(handler._finish_task) + _track_external(handler, asyncio.create_task(initial())) injector: typing.Final = asyncio.create_task(inject_during_wait()) await handler.wait_for_subtasks() @@ -314,7 +327,7 @@ async def inject_during_wait() -> None: assert initial_done.is_set() assert late_done.is_set() - assert len(handler._current_tasks) == 0 + assert handler._tracked_count == 0 async def test_concurrent_logs_timeout(caplog: pytest.LogCaptureFixture) -> None: @@ -327,9 +340,13 @@ async def test_concurrent_logs_timeout(caplog: pytest.LogCaptureFixture) -> None async def slow_task() -> None: await asyncio.sleep(100) - handler._current_tasks.add(asyncio.create_task(slow_task())) + slow: typing.Final = asyncio.create_task(slow_task()) + _track_external(handler, slow) await handler.wait_for_subtasks() assert "haven't finished in graceful time" in caplog.text + slow.cancel() + with contextlib.suppress(asyncio.CancelledError): + await slow async def test_handler_uses_shutdown_timeout_kwarg() -> None: @@ -344,12 +361,12 @@ async def test_concurrent_finish_task_does_not_crash_on_cancelled_task( handler_with_limit: KafkaConcurrentHandler, ) -> None: task: typing.Final = asyncio.create_task(asyncio.sleep(10)) - handler_with_limit._current_tasks.add(task) - task.add_done_callback(handler_with_limit._finish_task) + _track_external(handler_with_limit, task) task.cancel() with contextlib.suppress(asyncio.CancelledError): await task - assert task not in handler_with_limit._current_tasks + assert handler_with_limit._tracked_count == 0 + assert handler_with_limit._all_done_event.is_set() async def test_concurrent_full_lifecycle() -> None: