From f07a9f51cfd486fea8122f173a7f2a414a02e825 Mon Sep 17 00:00:00 2001 From: Artur Shiriev Date: Sun, 3 May 2026 17:13:37 +0300 Subject: [PATCH 1/3] hot-path optimizations: event-based wakeups, sort-on-insert, cached middleware reflection Replace per-iteration asyncio.wait fan-out across O(P) head tasks with a single asyncio.Event set from each user task's done-callback (registered when the committer absorbs the task). Pending lists are now maintained in offset order via bisect.insort on the rare out-of-order arrival path, so _extract_ready_prefixes no longer sorts on every commit. Drop redundant per-task exception logging in _commit_partitions (handler's _finish_task already logs once with the right context). Cache type(consumer).__name__ and _enable_auto_commit per-consumer in a WeakKeyDictionary so the middleware doesn't re-evaluate on every message. Co-Authored-By: Claude Opus 4.7 --- .../batch_committer.py | 92 ++++++++------ faststream_concurrent_aiokafka/middleware.py | 35 +++++- tests/mocks.py | 1 + tests/test_kafka_committer.py | 114 ++++++------------ 4 files changed, 126 insertions(+), 116 deletions(-) diff --git a/faststream_concurrent_aiokafka/batch_committer.py b/faststream_concurrent_aiokafka/batch_committer.py index 004ca2e..bff93e0 100644 --- a/faststream_concurrent_aiokafka/batch_committer.py +++ b/faststream_concurrent_aiokafka/batch_committer.py @@ -1,7 +1,9 @@ import asyncio +import bisect import contextlib import dataclasses import logging +import operator import typing from aiokafka.errors import CommitFailedError, IllegalStateError, KafkaError @@ -16,6 +18,7 @@ DEFAULT_SHUTDOWN_TIMEOUT_SEC: typing.Final = 20.0 +_OFFSET_KEY: typing.Final = operator.attrgetter("offset") class CommitterIsDeadError(Exception): ... @@ -33,6 +36,7 @@ class KafkaCommitTask: class _StreamingState: queue_get_task: asyncio.Task[KafkaCommitTask] flush_wait_task: asyncio.Task[bool] + task_completed_wait_task: asyncio.Task[bool] timeout_task: asyncio.Task[None] | None = None pending: dict[TopicPartition, list[KafkaCommitTask]] = dataclasses.field(default_factory=dict) should_shutdown: bool = False @@ -41,13 +45,23 @@ class _StreamingState: flush_in_progress: bool = False def cancel_outstanding(self) -> None: - for task in (self.queue_get_task, self.flush_wait_task): + for task in (self.queue_get_task, self.flush_wait_task, self.task_completed_wait_task): if not task.done(): task.cancel() if self.timeout_task is not None and not self.timeout_task.done(): self.timeout_task.cancel() +def _insert_sorted(partition_pending: list[KafkaCommitTask], new_ct: KafkaCommitTask) -> None: + # Common case: tasks arrive from the broker in offset order, so append is correct and + # the list stays sorted. Out-of-order arrivals only happen when _call_committer + # re-queues a batch on transient KafkaError; bisect handles the rare case in O(log N). + if not partition_pending or partition_pending[-1].offset <= new_ct.offset: + partition_pending.append(new_ct) + else: + bisect.insort(partition_pending, new_ct, key=_OFFSET_KEY) + + class KafkaBatchCommitter: def __init__( self, @@ -58,12 +72,25 @@ def __init__( self._messages_queue: asyncio.Queue[KafkaCommitTask] = asyncio.Queue() self._commit_task: asyncio.Task[typing.Any] | None = None self._flush_batch_event = asyncio.Event() + # Set from each user task's done-callback (registered in handle_task). Wakes the + # streaming loop without us having to add per-task callbacks via asyncio.wait every + # iteration. Fan-in cost is O(1) regardless of partition count or pending depth. + self._task_completed_event = asyncio.Event() self._stop_requested: bool = False self._commit_batch_timeout_sec = commit_batch_timeout_sec self._commit_batch_size = commit_batch_size self._shutdown_timeout = shutdown_timeout_sec + def _on_user_task_done(self, _task: asyncio.Future[typing.Any]) -> None: + """Done-callback target for user tasks; wakes the streaming loop.""" + self._task_completed_event.set() + + def _track_user_task(self, ct: KafkaCommitTask) -> None: + # add_done_callback fires the callback synchronously if the future is already done, + # so a task that completed between create_task and absorb still triggers the wakeup. + ct.asyncio_task.add_done_callback(self._on_user_task_done) + def _check_is_commit_task_running(self) -> None: if not self._commit_task or self._commit_task.done(): msg: typing.Final = "Committer main task is not running" @@ -99,7 +126,7 @@ def _map_offsets_per_partition(consumer_tasks: list[KafkaCommitTask]) -> dict[To partitions_to_offsets: dict[TopicPartition, int] = {} for partition, tasks in by_partition.items(): max_offset: int | None = None - for task in sorted(tasks, key=lambda x: x.offset): + for task in sorted(tasks, key=_OFFSET_KEY): if task.asyncio_task.cancelled(): break # stop committing at first cancelled task — message was not processed max_offset = task.offset @@ -112,19 +139,16 @@ def _map_offsets_per_partition(consumer_tasks: list[KafkaCommitTask]) -> dict[To def _extract_ready_prefixes( pending: dict[TopicPartition, list[KafkaCommitTask]], ) -> dict[TopicPartition, list[KafkaCommitTask]]: - # Per partition (sorted by offset), find the first task that is not-done. Tasks before - # it form the contiguous-done prefix and become "ready". A cancelled task is treated - # as a hard boundary: cancelled + everything after is dropped from pending and added - # to ready (so task_done() balances messages_queue.join), while - # _map_offsets_per_partition stops the offset advance at the cancelled task so the - # uncommitted offsets get redelivered on restart (at-least-once). + # Pending lists are maintained in offset order by _insert_sorted. Per partition, find + # the first not-done task; tasks before it form the contiguous-done prefix and become + # "ready". A cancelled task is treated as a hard boundary: cancelled + everything after + # is dropped from pending and added to ready (so task_done() balances + # messages_queue.join), while _map_offsets_per_partition stops the offset advance at + # the cancelled task so the uncommitted offsets get redelivered on restart + # (at-least-once). ready: dict[TopicPartition, list[KafkaCommitTask]] = {} empty_partitions: list[TopicPartition] = [] for partition, partition_pending in pending.items(): - # Re-queued-on-transient-error tasks land at the queue tail and may arrive - # out of offset order with respect to newer same-partition tasks. Sort here. - partition_pending.sort(key=lambda t: t.offset) - prefix_end = 0 for index, task in enumerate(partition_pending): if task.asyncio_task.cancelled(): @@ -146,13 +170,10 @@ def _extract_ready_prefixes( return ready async def _commit_partitions(self, ready: dict[TopicPartition, list[KafkaCommitTask]]) -> bool: + # Task exception logging is handled by the handler's _finish_task done-callback so + # it fires once per task at completion time. We intentionally do NOT log here: + # transient KafkaError re-queues a task, and a per-commit log would emit duplicates. flat: typing.Final[list[KafkaCommitTask]] = [t for tasks in ready.values() for t in tasks] - for task in flat: - if task.asyncio_task.cancelled(): - continue - exc = task.asyncio_task.exception() - if exc is not None: - logger.error("Task has finished with an exception", exc_info=exc) # Group by consumer instance — each AIOKafkaConsumer can only commit its own partitions consumers_tasks: dict[int, list[KafkaCommitTask]] = {} @@ -169,22 +190,6 @@ async def _commit_partitions(self, ready: dict[TopicPartition, list[KafkaCommitT self._messages_queue.task_done() return all_succeeded - @staticmethod - def _pending_head_tasks( - pending: dict[TopicPartition, list[KafkaCommitTask]], - ) -> list[asyncio.Task[typing.Any]]: - # Watch only the first not-done task per partition. A cancelled head is treated as - # done by _extract_ready_prefixes, so it is intentionally not watched (would busy-loop). - heads: list[asyncio.Task[typing.Any]] = [] - for partition_pending in pending.values(): - for ct in partition_pending: - if ct.asyncio_task.cancelled(): - break - if not ct.asyncio_task.done(): - heads.append(ct.asyncio_task) - break - return heads - def _reset_timeout( self, timeout_task: asyncio.Task[None] | None, @@ -204,6 +209,7 @@ async def _run_commit_process(self) -> None: state: typing.Final = _StreamingState( queue_get_task=asyncio.create_task(self._messages_queue.get()), flush_wait_task=asyncio.create_task(self._flush_batch_event.wait()), + task_completed_wait_task=asyncio.create_task(self._task_completed_event.wait()), ) try: @@ -213,22 +219,31 @@ async def _run_commit_process(self) -> None: state.cancel_outstanding() async def _streaming_iteration(self, state: "_StreamingState") -> None: - wait_targets: list[asyncio.Future[typing.Any]] = [state.flush_wait_task] + wait_targets: list[asyncio.Future[typing.Any]] = [ + state.flush_wait_task, + state.task_completed_wait_task, + ] if not state.should_shutdown: wait_targets.append(state.queue_get_task) if state.timeout_task is not None: wait_targets.append(state.timeout_task) - wait_targets.extend(self._pending_head_tasks(state.pending)) await asyncio.wait(wait_targets, return_when=asyncio.FIRST_COMPLETED) if not state.should_shutdown and state.queue_get_task.done(): new_ct = state.queue_get_task.result() - state.pending.setdefault(new_ct.topic_partition, []).append(new_ct) + self._track_user_task(new_ct) + _insert_sorted(state.pending.setdefault(new_ct.topic_partition, []), new_ct) state.queue_get_task = asyncio.create_task(self._messages_queue.get()) if state.timeout_task is None: state.timeout_task = asyncio.create_task(asyncio.sleep(self._commit_batch_timeout_sec)) + # Re-arm completion event before extract, so any task finishing during extract is + # captured by the next iteration instead of being lost between clear and re-wait. + if state.task_completed_wait_task.done(): + self._task_completed_event.clear() + state.task_completed_wait_task = asyncio.create_task(self._task_completed_event.wait()) + timeout_fired: typing.Final = state.timeout_task is not None and state.timeout_task.done() flush_fired: typing.Final = state.flush_wait_task.done() @@ -258,7 +273,8 @@ def _handle_flush_fired(self, state: "_StreamingState") -> None: ct = self._messages_queue.get_nowait() except asyncio.QueueEmpty: break - state.pending.setdefault(ct.topic_partition, []).append(ct) + self._track_user_task(ct) + _insert_sorted(state.pending.setdefault(ct.topic_partition, []), ct) if not state.queue_get_task.done(): state.queue_get_task.cancel() else: diff --git a/faststream_concurrent_aiokafka/middleware.py b/faststream_concurrent_aiokafka/middleware.py index 0e317ad..ddcefa0 100644 --- a/faststream_concurrent_aiokafka/middleware.py +++ b/faststream_concurrent_aiokafka/middleware.py @@ -1,5 +1,8 @@ +import contextlib +import dataclasses import logging import typing +import weakref from faststream import BaseMiddleware, ContextRepo from faststream.kafka.message import KafkaAckableMessage @@ -20,6 +23,33 @@ logger = logging.getLogger(__name__) +@dataclasses.dataclass(frozen=True, slots=True) +class _ConsumerAttrs: + is_fake: bool + auto_commit: bool + + +# Static, per-consumer flags that drive the per-message branch in consume_scope. Reading +# them on every message via type().__name__ and getattr was visible in profiles. WeakKey +# keeps the cache empty when consumers are GC'd; tests that build many MagicMock consumers +# don't leak. +_consumer_attrs_cache: typing.Final[weakref.WeakKeyDictionary[typing.Any, _ConsumerAttrs]] = weakref.WeakKeyDictionary() + + +def _consumer_attrs(consumer: typing.Any) -> _ConsumerAttrs: # noqa: ANN401 + cached: typing.Final = _consumer_attrs_cache.get(consumer) + if cached is not None: + return cached + attrs: typing.Final = _ConsumerAttrs( + is_fake=type(consumer).__name__ == "FakeConsumer", + auto_commit=bool(getattr(consumer, "_enable_auto_commit", False)), + ) + # Consumer may not be weakreferable (rare, e.g. exotic mock subclasses); fall through. + with contextlib.suppress(TypeError): + _consumer_attrs_cache[consumer] = attrs + return attrs + + class KafkaConcurrentProcessingMiddleware(BaseMiddleware): async def consume_scope( # ty: ignore[invalid-method-override] self, @@ -32,7 +62,8 @@ async def consume_scope( # ty: ignore[invalid-method-override] err = "No Kafka message found in context. Ensure the middleware is used with a Kafka subscriber." raise RuntimeError(err) - if type(kafka_message.consumer).__name__ == "FakeConsumer": + attrs: typing.Final = _consumer_attrs(kafka_message.consumer) + if attrs.is_fake: return await call_next(msg) # KafkaAckableMessage (AckPolicy.MANUAL) starts with committed=None. @@ -54,7 +85,7 @@ async def consume_scope( # ty: ignore[invalid-method-override] logger.warning("Kafka middleware. Handler is shutting down, skipping message") return None - if getattr(kafka_message.consumer, "_enable_auto_commit", False): + if attrs.auto_commit: err = ( "KafkaConcurrentProcessingMiddleware requires ack_policy=AckPolicy.MANUAL on all subscribers. " "Auto-commit is enabled on this consumer, which commits offsets before processing tasks " diff --git a/tests/mocks.py b/tests/mocks.py index daffde9..2c17e9e 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -62,6 +62,7 @@ def __init__(self, *_args: object, **_kwargs: object) -> None: self.close = AsyncMock() self.spawn = Mock() self.commit_all = AsyncMock() + self.notify_task_completed = Mock() self._healthy = True @property diff --git a/tests/test_kafka_committer.py b/tests/test_kafka_committer.py index 2801950..a9291ba 100644 --- a/tests/test_kafka_committer.py +++ b/tests/test_kafka_committer.py @@ -10,7 +10,12 @@ from aiokafka.errors import CommitFailedError, KafkaError from faststream.kafka import TopicPartition -from faststream_concurrent_aiokafka.batch_committer import CommitterIsDeadError, KafkaBatchCommitter, KafkaCommitTask +from faststream_concurrent_aiokafka.batch_committer import ( + CommitterIsDeadError, + KafkaBatchCommitter, + KafkaCommitTask, + _insert_sorted, +) from tests.mocks import MockAIOKafkaConsumer, MockAsyncioTask @@ -443,71 +448,47 @@ def test_extract_ready_prefixes_cancelled_drops_partition(mock_consumer: MockAIO assert pending == {} # partition emptied -def test_pending_head_tasks_skips_cancelled_head(mock_consumer: MockAIOKafkaConsumer) -> None: - """A cancelled head must not be added to the wait set (would busy-loop).""" - cancelled: typing.Final = MockAsyncioTask(cancelled=True) +def test_insert_sorted_appends_in_order(mock_consumer: MockAIOKafkaConsumer) -> None: + """In-order arrivals (the common case) just append — no bisect cost.""" tp: typing.Final = TopicPartition(topic="t", partition=0) - ct: typing.Final = KafkaCommitTask( - asyncio_task=cancelled, # ty: ignore[invalid-argument-type] - offset=10, - consumer=mock_consumer, - topic_partition=tp, - ) - heads: typing.Final = KafkaBatchCommitter._pending_head_tasks({tp: [ct]}) - assert heads == [] - - -def test_pending_head_tasks_returns_first_not_done(mock_consumer: MockAIOKafkaConsumer) -> None: - not_done_head: typing.Final = MockAsyncioTask(done=False) - tp: typing.Final = TopicPartition(topic="t", partition=0) - pending: typing.Final = { - tp: [ + pending: list[KafkaCommitTask] = [] + for offset in (10, 11, 12): + _insert_sorted( + pending, KafkaCommitTask( asyncio_task=MockAsyncioTask(done=True), # ty: ignore[invalid-argument-type] - offset=10, + offset=offset, consumer=mock_consumer, topic_partition=tp, ), + ) + assert [t.offset for t in pending] == [10, 11, 12] + + +def test_insert_sorted_bisects_out_of_order(mock_consumer: MockAIOKafkaConsumer) -> None: + """A re-queued task with a lower offset slides into the right position.""" + tp: typing.Final = TopicPartition(topic="t", partition=0) + pending: list[KafkaCommitTask] = [] + for offset in (10, 11): + _insert_sorted( + pending, KafkaCommitTask( - asyncio_task=not_done_head, # ty: ignore[invalid-argument-type] - offset=11, + asyncio_task=MockAsyncioTask(done=True), # ty: ignore[invalid-argument-type] + offset=offset, consumer=mock_consumer, topic_partition=tp, ), - ], - } - heads: typing.Final = KafkaBatchCommitter._pending_head_tasks(pending) - assert heads == [not_done_head] - - -def test_extract_ready_prefixes_sorts_by_offset(mock_consumer: MockAIOKafkaConsumer) -> None: - """Tasks appended out of offset order are sorted before extraction. - - Re-queued-after-transient-KafkaError tasks land at the queue tail and may arrive - after newer same-partition tasks; the lazy sort tolerates that. - """ - tp: typing.Final = TopicPartition(topic="t", partition=0) - out_of_order_task: typing.Final = KafkaCommitTask( - asyncio_task=MockAsyncioTask(done=True), # ty: ignore[invalid-argument-type] - offset=5, # earlier offset, but appended after later ones - consumer=mock_consumer, - topic_partition=tp, - ) - later_tasks: typing.Final = [ + ) + _insert_sorted( + pending, KafkaCommitTask( asyncio_task=MockAsyncioTask(done=True), # ty: ignore[invalid-argument-type] - offset=offset, + offset=5, consumer=mock_consumer, topic_partition=tp, - ) - for offset in (10, 11) - ] - pending: dict[TopicPartition, list[KafkaCommitTask]] = {tp: [*later_tasks, out_of_order_task]} - - ready: typing.Final = KafkaBatchCommitter._extract_ready_prefixes(pending) - - assert ready[tp][0] == out_of_order_task # sorted prefix starts at offset 5 - assert [t.offset for t in ready[tp]] == [5, 10, 11] + ), + ) + assert [t.offset for t in pending] == [5, 10, 11] # ---------- _commit_partitions ---------- @@ -539,32 +520,14 @@ async def test_commit_partitions_calls_commit_per_partition_max( mock_consumer.commit.assert_called_once_with({tp: expected_offset + 2}) -async def test_commit_partitions_logs_task_exceptions( - committer: KafkaBatchCommitter, mock_consumer: MockAIOKafkaConsumer, caplog: pytest.LogCaptureFixture -) -> None: - caplog.set_level(logging.ERROR) - failing: typing.Final = MockAsyncioTask(exception=ValueError("handler failed"), done=True) - - commit_task: typing.Final = KafkaCommitTask( - asyncio_task=failing, # ty: ignore[invalid-argument-type] - offset=100, - consumer=mock_consumer, - topic_partition=TopicPartition(topic="t", partition=0), - ) - await committer._messages_queue.put(commit_task) - await committer._messages_queue.get() - - await committer._commit_partitions({commit_task.topic_partition: [commit_task]}) - - assert "Task has finished with an exception" in caplog.text - - async def test_commit_partitions_partial_failure_still_commits_offset( - committer: KafkaBatchCommitter, mock_consumer: MockAIOKafkaConsumer, caplog: pytest.LogCaptureFixture + committer: KafkaBatchCommitter, mock_consumer: MockAIOKafkaConsumer ) -> None: - """One task raising still commits the partition's max offset, with the error logged.""" - caplog.set_level(logging.ERROR) + """One task raising still commits the partition's max offset. + Per-task exception logging is owned by the handler's _finish_task callback, + not the committer. + """ failing: typing.Final = MockAsyncioTask(exception=ValueError("handler failed"), done=True) succeeding: typing.Final = MockAsyncioTask(result="ok", done=True) tp: typing.Final = TopicPartition(topic="t1", partition=0) @@ -590,7 +553,6 @@ async def test_commit_partitions_partial_failure_still_commits_offset( await committer._commit_partitions({tp: tasks}) mock_consumer.commit.assert_called_once_with({tp: 102}) - assert "Task has finished with an exception" in caplog.text async def test_commit_partitions_handles_multiple_consumers(committer: KafkaBatchCommitter) -> None: From 5b41f5daf3eaaa13a746911f7de8a01deed11674 Mon Sep 17 00:00:00 2001 From: Artur Shiriev Date: Sun, 3 May 2026 17:16:27 +0300 Subject: [PATCH 2/3] drop dead exception field on MockAsyncioTask MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Removed in the same change that dropped per-task exception logging in _commit_partitions — nothing reads MockAsyncioTask.exception() anymore. Co-Authored-By: Claude Opus 4.7 --- tests/mocks.py | 5 ----- tests/test_kafka_committer.py | 2 +- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/mocks.py b/tests/mocks.py index 2c17e9e..0251d93 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -13,12 +13,10 @@ class MockAsyncioTask: def __init__( self, result: str | None = None, - exception: Exception | None = None, done: bool = True, cancelled: bool = False, ) -> None: self._result: str | None = result - self._exception: Exception | None = exception self._done: bool = done self._cancelled: bool = cancelled @@ -28,9 +26,6 @@ def cancelled(self) -> bool: def done(self) -> bool: return self._done or self._cancelled - def exception(self) -> Exception | None: - return self._exception - class MockKafkaMessage: def __init__( diff --git a/tests/test_kafka_committer.py b/tests/test_kafka_committer.py index a9291ba..ffab7bc 100644 --- a/tests/test_kafka_committer.py +++ b/tests/test_kafka_committer.py @@ -528,7 +528,7 @@ async def test_commit_partitions_partial_failure_still_commits_offset( Per-task exception logging is owned by the handler's _finish_task callback, not the committer. """ - failing: typing.Final = MockAsyncioTask(exception=ValueError("handler failed"), done=True) + failing: typing.Final = MockAsyncioTask(done=True) succeeding: typing.Final = MockAsyncioTask(result="ok", done=True) tp: typing.Final = TopicPartition(topic="t1", partition=0) From 73978b5cd255965aa7157f3aa6bfe4f33410e380 Mon Sep 17 00:00:00 2001 From: Artur Shiriev Date: Sun, 3 May 2026 17:33:18 +0300 Subject: [PATCH 3/3] fix flaky test_committer_absorbs_queue_during_slow_handler MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The test exited the drive loop on the first commit and asserted on the last call, but depending on scheduling the streaming committer can issue a partial commit at offset 11 (only fast0 done) before fast1 lands in pending. That partial commit is correct behavior — across-batch pipelining streams whatever's done — but the test was racing it against the assertion. Wait for the final outcome (offset 12) to actually appear in the call history and use assert_any_call. Surfaced as a Python 3.14 CI failure; the same race exists in 3.13 but lost the coin flip less often. Co-Authored-By: Claude Opus 4.7 --- tests/test_kafka_committer.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/test_kafka_committer.py b/tests/test_kafka_committer.py index ffab7bc..b0ebccf 100644 --- a/tests/test_kafka_committer.py +++ b/tests/test_kafka_committer.py @@ -930,9 +930,15 @@ async def fast_handler() -> None: ) ) - await _drive_until(lambda: consumer.commit.called, deadline_sec=1.0) - # The fast partition committed without waiting on the slow one. - consumer.commit.assert_called_with({fast_partition: 12}) + # The fast partition committed without waiting on the slow one. Wait for the fast + # partition to actually reach offset 12 — depending on scheduling, the loop may + # issue a partial commit at 11 first, which is still correct (across-batch pipelining + # streams whatever's done) but the test asserts the final outcome. + await _drive_until( + lambda: any(c.args[0] == {fast_partition: 12} for c in consumer.commit.call_args_list), + deadline_sec=1.0, + ) + consumer.commit.assert_any_call({fast_partition: 12}) slow_event.set() await committer.close()