Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 33 additions & 33 deletions faststream_concurrent_aiokafka/batch_committer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,13 @@ class _StreamingState:
queue_get_task: asyncio.Task[KafkaCommitTask]
flush_wait_task: asyncio.Task[bool]
task_completed_wait_task: asyncio.Task[bool]
timeout_task: asyncio.Task[None] | None = None
# Absolute loop-time deadline for the next commit_batch_timeout firing. None when pending
# is empty (no timer needed). Passed as `timeout=` to asyncio.wait — no Task allocation.
timeout_deadline: float | None = None
pending: dict[TopicPartition, list[KafkaCommitTask]] = dataclasses.field(default_factory=dict)
# Cached count of all tasks in `pending` across partitions; kept in sync with
# _insert_sorted callers and post-extract. Lets _maybe_commit avoid an O(P) sum every loop.
pending_count: int = 0
should_shutdown: bool = False
# Active commit_all (flush event seen, _stop_requested is False): keep committing every
# iteration until pending drains, so messages_queue.join() can return.
Expand All @@ -48,8 +53,6 @@ def cancel_outstanding(self) -> None:
for task in (self.queue_get_task, self.flush_wait_task, self.task_completed_wait_task):
if not task.done():
task.cancel()
if self.timeout_task is not None and not self.timeout_task.done():
self.timeout_task.cancel()


def _insert_sorted(partition_pending: list[KafkaCommitTask], new_ct: KafkaCommitTask) -> None:
Expand Down Expand Up @@ -175,31 +178,20 @@ async def _commit_partitions(self, ready: dict[TopicPartition, list[KafkaCommitT
# transient KafkaError re-queues a task, and a per-commit log would emit duplicates.
flat: typing.Final[list[KafkaCommitTask]] = [t for tasks in ready.values() for t in tasks]

# Group by consumer instance — each AIOKafkaConsumer can only commit its own partitions
# Group by consumer instance — each AIOKafkaConsumer can only commit its own partitions.
# With more than one consumer (router with multiple subscribers sharing the handler),
# each commit is an independent network round-trip and can run concurrently.
consumers_tasks: dict[int, list[KafkaCommitTask]] = {}
for task in flat:
consumers_tasks.setdefault(id(task.consumer), []).append(task)

all_succeeded = True
for consumer_tasks in consumers_tasks.values():
partitions_to_offsets = self._map_offsets_per_partition(consumer_tasks)
if not await self._call_committer(consumer_tasks, partitions_to_offsets):
all_succeeded = False
results: typing.Final = await asyncio.gather(
*(self._call_committer(ct, self._map_offsets_per_partition(ct)) for ct in consumers_tasks.values())
)

for _ in flat:
self._messages_queue.task_done()
return all_succeeded

def _reset_timeout(
self,
timeout_task: asyncio.Task[None] | None,
pending_non_empty: bool,
) -> asyncio.Task[None] | None:
if timeout_task is not None and not timeout_task.done():
timeout_task.cancel()
if pending_non_empty:
return asyncio.create_task(asyncio.sleep(self._commit_batch_timeout_sec))
return None
return all(results)

async def _run_commit_process(self) -> None:
# Streaming committer: one loop continuously absorbs queue items into per-partition
Expand All @@ -225,26 +217,34 @@ async def _streaming_iteration(self, state: "_StreamingState") -> None:
]
if not state.should_shutdown:
wait_targets.append(state.queue_get_task)
if state.timeout_task is not None:
wait_targets.append(state.timeout_task)

await asyncio.wait(wait_targets, return_when=asyncio.FIRST_COMPLETED)
loop: typing.Final = asyncio.get_running_loop()
remaining: float | None = None
if state.timeout_deadline is not None:
remaining = max(state.timeout_deadline - loop.time(), 0.0)

await asyncio.wait(wait_targets, return_when=asyncio.FIRST_COMPLETED, timeout=remaining)

# Capture once after the wait — clock may have advanced past the deadline even if no
# future fired (the asyncio.wait timeout is what made us return).
now: typing.Final = loop.time()

if not state.should_shutdown and state.queue_get_task.done():
new_ct = state.queue_get_task.result()
self._track_user_task(new_ct)
_insert_sorted(state.pending.setdefault(new_ct.topic_partition, []), new_ct)
state.pending_count += 1
state.queue_get_task = asyncio.create_task(self._messages_queue.get())
if state.timeout_task is None:
state.timeout_task = asyncio.create_task(asyncio.sleep(self._commit_batch_timeout_sec))
if state.timeout_deadline is None:
state.timeout_deadline = now + self._commit_batch_timeout_sec

# Re-arm completion event before extract, so any task finishing during extract is
# captured by the next iteration instead of being lost between clear and re-wait.
if state.task_completed_wait_task.done():
self._task_completed_event.clear()
state.task_completed_wait_task = asyncio.create_task(self._task_completed_event.wait())

timeout_fired: typing.Final = state.timeout_task is not None and state.timeout_task.done()
timeout_fired: typing.Final = state.timeout_deadline is not None and now >= state.timeout_deadline
flush_fired: typing.Final = state.flush_wait_task.done()

if flush_fired:
Expand All @@ -254,11 +254,10 @@ async def _streaming_iteration(self, state: "_StreamingState") -> None:
if state.flush_in_progress and not state.pending:
state.flush_in_progress = False

# Reset the timer after any commit OR on timeout firing. Let it tick otherwise.
# Invariant: pending empty ⇒ timeout_task is None (guaranteed by _reset_timeout
# always being called when pending is mutated to empty), so no separate cleanup is needed.
# Reset the deadline after any commit OR on timeout firing. Let it tick otherwise.
# Invariant: pending empty ⇒ timeout_deadline is None.
if ready or timeout_fired:
state.timeout_task = self._reset_timeout(state.timeout_task, bool(state.pending))
state.timeout_deadline = (loop.time() + self._commit_batch_timeout_sec) if state.pending else None

def _handle_flush_fired(self, state: "_StreamingState") -> None:
if self._stop_requested:
Expand All @@ -275,6 +274,7 @@ def _handle_flush_fired(self, state: "_StreamingState") -> None:
break
self._track_user_task(ct)
_insert_sorted(state.pending.setdefault(ct.topic_partition, []), ct)
state.pending_count += 1
if not state.queue_get_task.done():
state.queue_get_task.cancel()
else:
Expand All @@ -285,9 +285,8 @@ def _handle_flush_fired(self, state: "_StreamingState") -> None:
async def _maybe_commit(
self, state: "_StreamingState", timeout_fired: bool
) -> dict[TopicPartition, list[KafkaCommitTask]]:
total_pending: typing.Final = sum(len(p) for p in state.pending.values())
commit_triggered: typing.Final = (
total_pending >= self._commit_batch_size
state.pending_count >= self._commit_batch_size
or timeout_fired
or state.flush_in_progress
or state.should_shutdown
Expand All @@ -296,6 +295,7 @@ async def _maybe_commit(
return {}
ready: typing.Final = self._extract_ready_prefixes(state.pending)
if ready:
state.pending_count -= sum(len(v) for v in ready.values())
await self._commit_partitions(ready)
return ready

Expand Down
3 changes: 1 addition & 2 deletions faststream_concurrent_aiokafka/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ async def handle_task(
record: ConsumerRecord,
kafka_message: KafkaAckableMessage,
) -> None:
if self._limiter:
await self._limiter.acquire()
await self._limiter.acquire()
task: typing.Final = asyncio.ensure_future(coroutine)
self._current_tasks.add(task)
task.add_done_callback(self._finish_task)
Expand Down
Loading