Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 54 additions & 38 deletions faststream_concurrent_aiokafka/batch_committer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import bisect
import contextlib
import dataclasses
import logging
import operator
import typing

from aiokafka.errors import CommitFailedError, IllegalStateError, KafkaError
Expand All @@ -16,6 +18,7 @@


DEFAULT_SHUTDOWN_TIMEOUT_SEC: typing.Final = 20.0
_OFFSET_KEY: typing.Final = operator.attrgetter("offset")


class CommitterIsDeadError(Exception): ...
Expand All @@ -33,6 +36,7 @@ class KafkaCommitTask:
class _StreamingState:
queue_get_task: asyncio.Task[KafkaCommitTask]
flush_wait_task: asyncio.Task[bool]
task_completed_wait_task: asyncio.Task[bool]
timeout_task: asyncio.Task[None] | None = None
pending: dict[TopicPartition, list[KafkaCommitTask]] = dataclasses.field(default_factory=dict)
should_shutdown: bool = False
Expand All @@ -41,13 +45,23 @@ class _StreamingState:
flush_in_progress: bool = False

def cancel_outstanding(self) -> None:
for task in (self.queue_get_task, self.flush_wait_task):
for task in (self.queue_get_task, self.flush_wait_task, self.task_completed_wait_task):
if not task.done():
task.cancel()
if self.timeout_task is not None and not self.timeout_task.done():
self.timeout_task.cancel()


def _insert_sorted(partition_pending: list[KafkaCommitTask], new_ct: KafkaCommitTask) -> None:
# Common case: tasks arrive from the broker in offset order, so append is correct and
# the list stays sorted. Out-of-order arrivals only happen when _call_committer
# re-queues a batch on transient KafkaError; bisect handles the rare case in O(log N).
if not partition_pending or partition_pending[-1].offset <= new_ct.offset:
partition_pending.append(new_ct)
else:
bisect.insort(partition_pending, new_ct, key=_OFFSET_KEY)


class KafkaBatchCommitter:
def __init__(
self,
Expand All @@ -58,12 +72,25 @@ def __init__(
self._messages_queue: asyncio.Queue[KafkaCommitTask] = asyncio.Queue()
self._commit_task: asyncio.Task[typing.Any] | None = None
self._flush_batch_event = asyncio.Event()
# Set from each user task's done-callback (registered in handle_task). Wakes the
# streaming loop without us having to add per-task callbacks via asyncio.wait every
# iteration. Fan-in cost is O(1) regardless of partition count or pending depth.
self._task_completed_event = asyncio.Event()
self._stop_requested: bool = False

self._commit_batch_timeout_sec = commit_batch_timeout_sec
self._commit_batch_size = commit_batch_size
self._shutdown_timeout = shutdown_timeout_sec

def _on_user_task_done(self, _task: asyncio.Future[typing.Any]) -> None:
"""Done-callback target for user tasks; wakes the streaming loop."""
self._task_completed_event.set()

def _track_user_task(self, ct: KafkaCommitTask) -> None:
# add_done_callback fires the callback synchronously if the future is already done,
# so a task that completed between create_task and absorb still triggers the wakeup.
ct.asyncio_task.add_done_callback(self._on_user_task_done)

def _check_is_commit_task_running(self) -> None:
if not self._commit_task or self._commit_task.done():
msg: typing.Final = "Committer main task is not running"
Expand Down Expand Up @@ -99,7 +126,7 @@ def _map_offsets_per_partition(consumer_tasks: list[KafkaCommitTask]) -> dict[To
partitions_to_offsets: dict[TopicPartition, int] = {}
for partition, tasks in by_partition.items():
max_offset: int | None = None
for task in sorted(tasks, key=lambda x: x.offset):
for task in sorted(tasks, key=_OFFSET_KEY):
if task.asyncio_task.cancelled():
break # stop committing at first cancelled task — message was not processed
max_offset = task.offset
Expand All @@ -112,19 +139,16 @@ def _map_offsets_per_partition(consumer_tasks: list[KafkaCommitTask]) -> dict[To
def _extract_ready_prefixes(
pending: dict[TopicPartition, list[KafkaCommitTask]],
) -> dict[TopicPartition, list[KafkaCommitTask]]:
# Per partition (sorted by offset), find the first task that is not-done. Tasks before
# it form the contiguous-done prefix and become "ready". A cancelled task is treated
# as a hard boundary: cancelled + everything after is dropped from pending and added
# to ready (so task_done() balances messages_queue.join), while
# _map_offsets_per_partition stops the offset advance at the cancelled task so the
# uncommitted offsets get redelivered on restart (at-least-once).
# Pending lists are maintained in offset order by _insert_sorted. Per partition, find
# the first not-done task; tasks before it form the contiguous-done prefix and become
# "ready". A cancelled task is treated as a hard boundary: cancelled + everything after
# is dropped from pending and added to ready (so task_done() balances
# messages_queue.join), while _map_offsets_per_partition stops the offset advance at
# the cancelled task so the uncommitted offsets get redelivered on restart
# (at-least-once).
ready: dict[TopicPartition, list[KafkaCommitTask]] = {}
empty_partitions: list[TopicPartition] = []
for partition, partition_pending in pending.items():
# Re-queued-on-transient-error tasks land at the queue tail and may arrive
# out of offset order with respect to newer same-partition tasks. Sort here.
partition_pending.sort(key=lambda t: t.offset)

prefix_end = 0
for index, task in enumerate(partition_pending):
if task.asyncio_task.cancelled():
Expand All @@ -146,13 +170,10 @@ def _extract_ready_prefixes(
return ready

async def _commit_partitions(self, ready: dict[TopicPartition, list[KafkaCommitTask]]) -> bool:
# Task exception logging is handled by the handler's _finish_task done-callback so
# it fires once per task at completion time. We intentionally do NOT log here:
# transient KafkaError re-queues a task, and a per-commit log would emit duplicates.
flat: typing.Final[list[KafkaCommitTask]] = [t for tasks in ready.values() for t in tasks]
for task in flat:
if task.asyncio_task.cancelled():
continue
exc = task.asyncio_task.exception()
if exc is not None:
logger.error("Task has finished with an exception", exc_info=exc)

# Group by consumer instance — each AIOKafkaConsumer can only commit its own partitions
consumers_tasks: dict[int, list[KafkaCommitTask]] = {}
Expand All @@ -169,22 +190,6 @@ async def _commit_partitions(self, ready: dict[TopicPartition, list[KafkaCommitT
self._messages_queue.task_done()
return all_succeeded

@staticmethod
def _pending_head_tasks(
pending: dict[TopicPartition, list[KafkaCommitTask]],
) -> list[asyncio.Task[typing.Any]]:
# Watch only the first not-done task per partition. A cancelled head is treated as
# done by _extract_ready_prefixes, so it is intentionally not watched (would busy-loop).
heads: list[asyncio.Task[typing.Any]] = []
for partition_pending in pending.values():
for ct in partition_pending:
if ct.asyncio_task.cancelled():
break
if not ct.asyncio_task.done():
heads.append(ct.asyncio_task)
break
return heads

def _reset_timeout(
self,
timeout_task: asyncio.Task[None] | None,
Expand All @@ -204,6 +209,7 @@ async def _run_commit_process(self) -> None:
state: typing.Final = _StreamingState(
queue_get_task=asyncio.create_task(self._messages_queue.get()),
flush_wait_task=asyncio.create_task(self._flush_batch_event.wait()),
task_completed_wait_task=asyncio.create_task(self._task_completed_event.wait()),
)

try:
Expand All @@ -213,22 +219,31 @@ async def _run_commit_process(self) -> None:
state.cancel_outstanding()

async def _streaming_iteration(self, state: "_StreamingState") -> None:
wait_targets: list[asyncio.Future[typing.Any]] = [state.flush_wait_task]
wait_targets: list[asyncio.Future[typing.Any]] = [
state.flush_wait_task,
state.task_completed_wait_task,
]
if not state.should_shutdown:
wait_targets.append(state.queue_get_task)
if state.timeout_task is not None:
wait_targets.append(state.timeout_task)
wait_targets.extend(self._pending_head_tasks(state.pending))

await asyncio.wait(wait_targets, return_when=asyncio.FIRST_COMPLETED)

if not state.should_shutdown and state.queue_get_task.done():
new_ct = state.queue_get_task.result()
state.pending.setdefault(new_ct.topic_partition, []).append(new_ct)
self._track_user_task(new_ct)
_insert_sorted(state.pending.setdefault(new_ct.topic_partition, []), new_ct)
state.queue_get_task = asyncio.create_task(self._messages_queue.get())
if state.timeout_task is None:
state.timeout_task = asyncio.create_task(asyncio.sleep(self._commit_batch_timeout_sec))

# Re-arm completion event before extract, so any task finishing during extract is
# captured by the next iteration instead of being lost between clear and re-wait.
if state.task_completed_wait_task.done():
self._task_completed_event.clear()
state.task_completed_wait_task = asyncio.create_task(self._task_completed_event.wait())

timeout_fired: typing.Final = state.timeout_task is not None and state.timeout_task.done()
flush_fired: typing.Final = state.flush_wait_task.done()

Expand Down Expand Up @@ -258,7 +273,8 @@ def _handle_flush_fired(self, state: "_StreamingState") -> None:
ct = self._messages_queue.get_nowait()
except asyncio.QueueEmpty:
break
state.pending.setdefault(ct.topic_partition, []).append(ct)
self._track_user_task(ct)
_insert_sorted(state.pending.setdefault(ct.topic_partition, []), ct)
if not state.queue_get_task.done():
state.queue_get_task.cancel()
else:
Expand Down
35 changes: 33 additions & 2 deletions faststream_concurrent_aiokafka/middleware.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import contextlib
import dataclasses
import logging
import typing
import weakref

from faststream import BaseMiddleware, ContextRepo
from faststream.kafka.message import KafkaAckableMessage
Expand All @@ -20,6 +23,33 @@
logger = logging.getLogger(__name__)


@dataclasses.dataclass(frozen=True, slots=True)
class _ConsumerAttrs:
is_fake: bool
auto_commit: bool


# Static, per-consumer flags that drive the per-message branch in consume_scope. Reading
# them on every message via type().__name__ and getattr was visible in profiles. WeakKey
# keeps the cache empty when consumers are GC'd; tests that build many MagicMock consumers
# don't leak.
_consumer_attrs_cache: typing.Final[weakref.WeakKeyDictionary[typing.Any, _ConsumerAttrs]] = weakref.WeakKeyDictionary()


def _consumer_attrs(consumer: typing.Any) -> _ConsumerAttrs: # noqa: ANN401
cached: typing.Final = _consumer_attrs_cache.get(consumer)
if cached is not None:
return cached
attrs: typing.Final = _ConsumerAttrs(
is_fake=type(consumer).__name__ == "FakeConsumer",
auto_commit=bool(getattr(consumer, "_enable_auto_commit", False)),
)
# Consumer may not be weakreferable (rare, e.g. exotic mock subclasses); fall through.
with contextlib.suppress(TypeError):
_consumer_attrs_cache[consumer] = attrs
return attrs


class KafkaConcurrentProcessingMiddleware(BaseMiddleware):
async def consume_scope( # ty: ignore[invalid-method-override]
self,
Expand All @@ -32,7 +62,8 @@ async def consume_scope( # ty: ignore[invalid-method-override]
err = "No Kafka message found in context. Ensure the middleware is used with a Kafka subscriber."
raise RuntimeError(err)

if type(kafka_message.consumer).__name__ == "FakeConsumer":
attrs: typing.Final = _consumer_attrs(kafka_message.consumer)
if attrs.is_fake:
return await call_next(msg)

# KafkaAckableMessage (AckPolicy.MANUAL) starts with committed=None.
Expand All @@ -54,7 +85,7 @@ async def consume_scope( # ty: ignore[invalid-method-override]
logger.warning("Kafka middleware. Handler is shutting down, skipping message")
return None

if getattr(kafka_message.consumer, "_enable_auto_commit", False):
if attrs.auto_commit:
err = (
"KafkaConcurrentProcessingMiddleware requires ack_policy=AckPolicy.MANUAL on all subscribers. "
"Auto-commit is enabled on this consumer, which commits offsets before processing tasks "
Expand Down
6 changes: 1 addition & 5 deletions tests/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@ class MockAsyncioTask:
def __init__(
self,
result: str | None = None,
exception: Exception | None = None,
done: bool = True,
cancelled: bool = False,
) -> None:
self._result: str | None = result
self._exception: Exception | None = exception
self._done: bool = done
self._cancelled: bool = cancelled

Expand All @@ -28,9 +26,6 @@ def cancelled(self) -> bool:
def done(self) -> bool:
return self._done or self._cancelled

def exception(self) -> Exception | None:
return self._exception


class MockKafkaMessage:
def __init__(
Expand Down Expand Up @@ -62,6 +57,7 @@ def __init__(self, *_args: object, **_kwargs: object) -> None:
self.close = AsyncMock()
self.spawn = Mock()
self.commit_all = AsyncMock()
self.notify_task_completed = Mock()
self._healthy = True

@property
Expand Down
Loading
Loading