diff --git a/faststream_concurrent_aiokafka/batch_committer.py b/faststream_concurrent_aiokafka/batch_committer.py index bff93e0..538ce24 100644 --- a/faststream_concurrent_aiokafka/batch_committer.py +++ b/faststream_concurrent_aiokafka/batch_committer.py @@ -37,8 +37,13 @@ class _StreamingState: queue_get_task: asyncio.Task[KafkaCommitTask] flush_wait_task: asyncio.Task[bool] task_completed_wait_task: asyncio.Task[bool] - timeout_task: asyncio.Task[None] | None = None + # Absolute loop-time deadline for the next commit_batch_timeout firing. None when pending + # is empty (no timer needed). Passed as `timeout=` to asyncio.wait — no Task allocation. + timeout_deadline: float | None = None pending: dict[TopicPartition, list[KafkaCommitTask]] = dataclasses.field(default_factory=dict) + # Cached count of all tasks in `pending` across partitions; kept in sync with + # _insert_sorted callers and post-extract. Lets _maybe_commit avoid an O(P) sum every loop. + pending_count: int = 0 should_shutdown: bool = False # Active commit_all (flush event seen, _stop_requested is False): keep committing every # iteration until pending drains, so messages_queue.join() can return. @@ -48,8 +53,6 @@ def cancel_outstanding(self) -> None: for task in (self.queue_get_task, self.flush_wait_task, self.task_completed_wait_task): if not task.done(): task.cancel() - if self.timeout_task is not None and not self.timeout_task.done(): - self.timeout_task.cancel() def _insert_sorted(partition_pending: list[KafkaCommitTask], new_ct: KafkaCommitTask) -> None: @@ -175,31 +178,20 @@ async def _commit_partitions(self, ready: dict[TopicPartition, list[KafkaCommitT # transient KafkaError re-queues a task, and a per-commit log would emit duplicates. flat: typing.Final[list[KafkaCommitTask]] = [t for tasks in ready.values() for t in tasks] - # Group by consumer instance — each AIOKafkaConsumer can only commit its own partitions + # Group by consumer instance — each AIOKafkaConsumer can only commit its own partitions. + # With more than one consumer (router with multiple subscribers sharing the handler), + # each commit is an independent network round-trip and can run concurrently. consumers_tasks: dict[int, list[KafkaCommitTask]] = {} for task in flat: consumers_tasks.setdefault(id(task.consumer), []).append(task) - all_succeeded = True - for consumer_tasks in consumers_tasks.values(): - partitions_to_offsets = self._map_offsets_per_partition(consumer_tasks) - if not await self._call_committer(consumer_tasks, partitions_to_offsets): - all_succeeded = False + results: typing.Final = await asyncio.gather( + *(self._call_committer(ct, self._map_offsets_per_partition(ct)) for ct in consumers_tasks.values()) + ) for _ in flat: self._messages_queue.task_done() - return all_succeeded - - def _reset_timeout( - self, - timeout_task: asyncio.Task[None] | None, - pending_non_empty: bool, - ) -> asyncio.Task[None] | None: - if timeout_task is not None and not timeout_task.done(): - timeout_task.cancel() - if pending_non_empty: - return asyncio.create_task(asyncio.sleep(self._commit_batch_timeout_sec)) - return None + return all(results) async def _run_commit_process(self) -> None: # Streaming committer: one loop continuously absorbs queue items into per-partition @@ -225,18 +217,26 @@ async def _streaming_iteration(self, state: "_StreamingState") -> None: ] if not state.should_shutdown: wait_targets.append(state.queue_get_task) - if state.timeout_task is not None: - wait_targets.append(state.timeout_task) - await asyncio.wait(wait_targets, return_when=asyncio.FIRST_COMPLETED) + loop: typing.Final = asyncio.get_running_loop() + remaining: float | None = None + if state.timeout_deadline is not None: + remaining = max(state.timeout_deadline - loop.time(), 0.0) + + await asyncio.wait(wait_targets, return_when=asyncio.FIRST_COMPLETED, timeout=remaining) + + # Capture once after the wait — clock may have advanced past the deadline even if no + # future fired (the asyncio.wait timeout is what made us return). + now: typing.Final = loop.time() if not state.should_shutdown and state.queue_get_task.done(): new_ct = state.queue_get_task.result() self._track_user_task(new_ct) _insert_sorted(state.pending.setdefault(new_ct.topic_partition, []), new_ct) + state.pending_count += 1 state.queue_get_task = asyncio.create_task(self._messages_queue.get()) - if state.timeout_task is None: - state.timeout_task = asyncio.create_task(asyncio.sleep(self._commit_batch_timeout_sec)) + if state.timeout_deadline is None: + state.timeout_deadline = now + self._commit_batch_timeout_sec # Re-arm completion event before extract, so any task finishing during extract is # captured by the next iteration instead of being lost between clear and re-wait. @@ -244,7 +244,7 @@ async def _streaming_iteration(self, state: "_StreamingState") -> None: self._task_completed_event.clear() state.task_completed_wait_task = asyncio.create_task(self._task_completed_event.wait()) - timeout_fired: typing.Final = state.timeout_task is not None and state.timeout_task.done() + timeout_fired: typing.Final = state.timeout_deadline is not None and now >= state.timeout_deadline flush_fired: typing.Final = state.flush_wait_task.done() if flush_fired: @@ -254,11 +254,10 @@ async def _streaming_iteration(self, state: "_StreamingState") -> None: if state.flush_in_progress and not state.pending: state.flush_in_progress = False - # Reset the timer after any commit OR on timeout firing. Let it tick otherwise. - # Invariant: pending empty ⇒ timeout_task is None (guaranteed by _reset_timeout - # always being called when pending is mutated to empty), so no separate cleanup is needed. + # Reset the deadline after any commit OR on timeout firing. Let it tick otherwise. + # Invariant: pending empty ⇒ timeout_deadline is None. if ready or timeout_fired: - state.timeout_task = self._reset_timeout(state.timeout_task, bool(state.pending)) + state.timeout_deadline = (loop.time() + self._commit_batch_timeout_sec) if state.pending else None def _handle_flush_fired(self, state: "_StreamingState") -> None: if self._stop_requested: @@ -275,6 +274,7 @@ def _handle_flush_fired(self, state: "_StreamingState") -> None: break self._track_user_task(ct) _insert_sorted(state.pending.setdefault(ct.topic_partition, []), ct) + state.pending_count += 1 if not state.queue_get_task.done(): state.queue_get_task.cancel() else: @@ -285,9 +285,8 @@ def _handle_flush_fired(self, state: "_StreamingState") -> None: async def _maybe_commit( self, state: "_StreamingState", timeout_fired: bool ) -> dict[TopicPartition, list[KafkaCommitTask]]: - total_pending: typing.Final = sum(len(p) for p in state.pending.values()) commit_triggered: typing.Final = ( - total_pending >= self._commit_batch_size + state.pending_count >= self._commit_batch_size or timeout_fired or state.flush_in_progress or state.should_shutdown @@ -296,6 +295,7 @@ async def _maybe_commit( return {} ready: typing.Final = self._extract_ready_prefixes(state.pending) if ready: + state.pending_count -= sum(len(v) for v in ready.values()) await self._commit_partitions(ready) return ready diff --git a/faststream_concurrent_aiokafka/processing.py b/faststream_concurrent_aiokafka/processing.py index d270a21..571f0a5 100644 --- a/faststream_concurrent_aiokafka/processing.py +++ b/faststream_concurrent_aiokafka/processing.py @@ -68,8 +68,7 @@ async def handle_task( record: ConsumerRecord, kafka_message: KafkaAckableMessage, ) -> None: - if self._limiter: - await self._limiter.acquire() + await self._limiter.acquire() task: typing.Final = asyncio.ensure_future(coroutine) self._current_tasks.add(task) task.add_done_callback(self._finish_task)