Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions faststream_concurrent_aiokafka/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import functools
import logging
import signal
import time
import typing

from faststream.kafka import ConsumerRecord, TopicPartition
Expand Down Expand Up @@ -33,24 +32,23 @@ def __init__(
raise ValueError(msg)

self._limiter = asyncio.Semaphore(concurrency_limit)
self._current_tasks: set[asyncio.Task[typing.Any]] = set()
# Counter + Event replace the old _current_tasks set: shutdown waits on the event,
# which is set once every tracked task has fired its done-callback.
self._tracked_count: int = 0
self._all_done_event: asyncio.Event = asyncio.Event()
self._all_done_event.set() # 0 tasks ⇒ "all done" is True
self._is_running: bool = False
self._committer: KafkaBatchCommitter = committer
self._stop_task: asyncio.Task[typing.Any] | None = None
self._shutdown_timeout_sec: float = shutdown_timeout_sec

async def wait_for_subtasks(self) -> None:
logger.info("Kafka middleware. Gracefully waiting for tasks to end...")
deadline = time.monotonic() + self._shutdown_timeout_sec
try:
pending = [t for t in self._current_tasks if not t.done()]
while pending:
remaining = max(deadline - time.monotonic(), 0)
await asyncio.wait_for(
asyncio.gather(*pending, return_exceptions=True),
timeout=remaining,
)
pending = [t for t in self._current_tasks if not t.done()]
await asyncio.wait_for(
self._all_done_event.wait(),
timeout=self._shutdown_timeout_sec,
)
except TimeoutError:
logger.exception("Kafka middleware. Whoops, some tasks haven't finished in graceful time, sorry")

Expand All @@ -60,7 +58,9 @@ def _finish_task(self, task: asyncio.Task[typing.Any]) -> None:
exc: typing.Final[BaseException | None] = task.exception()
if exc:
logger.error("Kafka middleware. Task has failed with the exception", exc_info=exc)
self._current_tasks.discard(task)
self._tracked_count -= 1
if self._tracked_count == 0:
self._all_done_event.set()

async def handle_task(
self,
Expand All @@ -70,7 +70,12 @@ async def handle_task(
) -> None:
await self._limiter.acquire()
task: typing.Final = asyncio.ensure_future(coroutine)
self._current_tasks.add(task)
# Increment + clear before add_done_callback. add_done_callback fires synchronously
# if the task is already done; that path then immediately decrements back to a
# consistent state. Reverse order would skew the count for a synchronously-finished
# task.
self._tracked_count += 1
self._all_done_event.clear()
task.add_done_callback(self._finish_task)
try:
await self._committer.send_task(
Expand Down
65 changes: 41 additions & 24 deletions tests/test_concurrent_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ def sample_record() -> MockConsumerRecord:
return MockConsumerRecord()


def _track_external(handler: KafkaConcurrentHandler, task: asyncio.Task[typing.Any]) -> None:
"""Register an externally-created task with the handler's count+event tracking.

Mirrors the bookkeeping that handle_task does so wait_for_subtasks waits for it.
"""
handler._tracked_count += 1
handler._all_done_event.clear()
task.add_done_callback(handler._finish_task)


def test_concurrent_init_zero_concurrency_limit_raises() -> None:
with pytest.raises(ValueError, match="concurrency_limit must be >= 1"):
KafkaConcurrentHandler(committer=MockKafkaBatchCommitter(), concurrency_limit=0) # ty: ignore[invalid-argument-type]
Expand Down Expand Up @@ -82,13 +92,17 @@ async def test_concurrent_failed_task_exception(
assert "Task has failed with the exception" in caplog.text


async def test_concurrent_removes_task_from_set(handler: KafkaConcurrentHandler) -> None:
async def test_concurrent_finish_task_decrements_and_sets_done_event(handler: KafkaConcurrentHandler) -> None:
mock_task: typing.Final = MagicMock()
mock_task.cancelled.return_value = False
mock_task.exception.return_value = None
handler._current_tasks.add(mock_task)
handler._tracked_count = 1
handler._all_done_event.clear()

handler._finish_task(mock_task)
assert mock_task not in handler._current_tasks

assert handler._tracked_count == 0
assert handler._all_done_event.is_set()


async def test_concurrent_creates_task(
Expand All @@ -98,19 +112,20 @@ async def coro() -> str:
return "result"

await handler.handle_task(coro(), sample_record, sample_message) # ty: ignore[invalid-argument-type]
assert len(handler._current_tasks) == 1
assert handler._tracked_count == 1
assert not handler._all_done_event.is_set()


async def test_concurrent_task_added_to_set(
async def test_concurrent_task_passed_to_committer(
handler: KafkaConcurrentHandler, sample_message: MockKafkaMessage, sample_record: MockConsumerRecord
) -> None:
async def coro() -> str:
return "result"

await handler.handle_task(coro(), sample_record, sample_message) # ty: ignore[invalid-argument-type]

task: typing.Final = next(iter(handler._current_tasks))
assert isinstance(task, asyncio.Task)
sent_commit_task: typing.Final = handler._committer.send_task.call_args[0][0] # ty: ignore[unresolved-attribute]
assert isinstance(sent_commit_task.asyncio_task, asyncio.Task)


async def test_concurrent_done_callback_added(
Expand All @@ -121,8 +136,8 @@ async def coro() -> str:

await handler.handle_task(coro(), sample_record, sample_message) # ty: ignore[invalid-argument-type]

task: typing.Final = next(iter(handler._current_tasks))
assert len(task._callbacks) > 0
sent_commit_task: typing.Final = handler._committer.send_task.call_args[0][0] # ty: ignore[unresolved-attribute]
assert len(sent_commit_task.asyncio_task._callbacks) > 0


async def test_concurrent_acquires_limiter_when_limited(
Expand Down Expand Up @@ -263,10 +278,11 @@ async def task2() -> str:
results.append(2)
return "task2"

handler._current_tasks.add(asyncio.create_task(task1()))
handler._current_tasks.add(asyncio.create_task(task2()))
_track_external(handler, asyncio.create_task(task1()))
_track_external(handler, asyncio.create_task(task2()))
await handler.wait_for_subtasks()
assert len(results) == expected_tasks_len
assert handler._tracked_count == 0


async def test_concurrent_handles_task_exceptions(
Expand All @@ -278,9 +294,10 @@ async def failing_task() -> typing.Never:
msg: typing.Final = "Task failed"
raise ValueError(msg)

handler._current_tasks.add(asyncio.create_task(failing_task()))
failing: typing.Final = asyncio.create_task(failing_task())
_track_external(handler, failing)
await handler.wait_for_subtasks()
assert handler._current_tasks.pop().done()
assert failing.done()


async def test_concurrent_wait_for_subtasks_drains_tasks_added_during_wait(
Expand All @@ -300,21 +317,17 @@ async def late() -> None:

async def inject_during_wait() -> None:
await asyncio.sleep(0.01)
late_task: typing.Final = asyncio.create_task(late())
handler._current_tasks.add(late_task)
late_task.add_done_callback(handler._finish_task)
_track_external(handler, asyncio.create_task(late()))

initial_task: typing.Final = asyncio.create_task(initial())
handler._current_tasks.add(initial_task)
initial_task.add_done_callback(handler._finish_task)
_track_external(handler, asyncio.create_task(initial()))

injector: typing.Final = asyncio.create_task(inject_during_wait())
await handler.wait_for_subtasks()
await injector

assert initial_done.is_set()
assert late_done.is_set()
assert len(handler._current_tasks) == 0
assert handler._tracked_count == 0


async def test_concurrent_logs_timeout(caplog: pytest.LogCaptureFixture) -> None:
Expand All @@ -327,9 +340,13 @@ async def test_concurrent_logs_timeout(caplog: pytest.LogCaptureFixture) -> None
async def slow_task() -> None:
await asyncio.sleep(100)

handler._current_tasks.add(asyncio.create_task(slow_task()))
slow: typing.Final = asyncio.create_task(slow_task())
_track_external(handler, slow)
await handler.wait_for_subtasks()
assert "haven't finished in graceful time" in caplog.text
slow.cancel()
with contextlib.suppress(asyncio.CancelledError):
await slow


async def test_handler_uses_shutdown_timeout_kwarg() -> None:
Expand All @@ -344,12 +361,12 @@ async def test_concurrent_finish_task_does_not_crash_on_cancelled_task(
handler_with_limit: KafkaConcurrentHandler,
) -> None:
task: typing.Final = asyncio.create_task(asyncio.sleep(10))
handler_with_limit._current_tasks.add(task)
task.add_done_callback(handler_with_limit._finish_task)
_track_external(handler_with_limit, task)
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
assert task not in handler_with_limit._current_tasks
assert handler_with_limit._tracked_count == 0
assert handler_with_limit._all_done_event.is_set()


async def test_concurrent_full_lifecycle() -> None:
Expand Down
Loading