Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Key design: `handle_task()` fires-and-forgets the user coroutine as an asyncio t
A single function that accepts a `ContextRepo` and returns `True` if the handler is present and `is_healthy` (i.e. `_is_running` AND committer task alive). Intended for readiness/liveness probes.

**`batch_committer.py` — `KafkaBatchCommitter`**
Runs as a background asyncio task (`spawn()`). Pulls `KafkaCommitTask`s off a queue, batches by `(timeout OR batch_size)`, awaits each task's asyncio future, groups by `(consumer_id, partition)`, takes the max offset per partition (stopping at the first cancelled task), and commits via `consumer.commit({TopicPartition: offset+1})`. Transient `KafkaError` re-queues the batch; `CommitFailedError`/`IllegalStateError` (rebalance/revocation) discards it. `CommitterIsDeadError` is raised to callers when the committer's main task has died, which triggers `handler.stop()`.
Runs as a background asyncio task (`spawn()`). Streaming loop: continuously absorbs `KafkaCommitTask`s from the queue into per-partition pending state, watches the head not-done task per partition, and commits each partition's contiguous-done prefix when total pending ≥ `commit_batch_size`, when `commit_batch_timeout_sec` fires, or when `commit_all`/`close` sets the flush event. Per partition, `_extract_ready_prefixes` sorts by offset (tolerates re-queued tasks landing out of order) and stops at the first not-done task; a cancelled task is a hard boundary — cancelled + everything after is dropped from pending while `_map_offsets_per_partition` stops the offset advance at the cancelled task (so uncommitted offsets get redelivered on restart, at-least-once). Per consumer-id group, commits via `consumer.commit({TopicPartition: max_offset+1})`. Transient `KafkaError` re-queues the batch; `CommitFailedError`/`IllegalStateError` (rebalance/revocation) discards it. `CommitterIsDeadError` is raised to callers when the committer's main task has died, which triggers `handler.stop()`.

**`rebalance.py` — `ConsumerRebalanceListener`**
Returned by `handler.create_rebalance_listener()`. On `on_partitions_revoked`, calls `committer.commit_all()` so offsets are flushed before the partition is reassigned, preventing duplicate processing after rebalance.
Expand Down
273 changes: 166 additions & 107 deletions faststream_concurrent_aiokafka/batch_committer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,25 @@ class KafkaCommitTask:
consumer: typing.Any


@dataclasses.dataclass(kw_only=True, slots=True)
class _StreamingState:
queue_get_task: asyncio.Task[KafkaCommitTask]
flush_wait_task: asyncio.Task[bool]
timeout_task: asyncio.Task[None] | None = None
pending: dict[TopicPartition, list[KafkaCommitTask]] = dataclasses.field(default_factory=dict)
should_shutdown: bool = False
# Active commit_all (flush event seen, _stop_requested is False): keep committing every
# iteration until pending drains, so messages_queue.join() can return.
flush_in_progress: bool = False

def cancel_outstanding(self) -> None:
for task in (self.queue_get_task, self.flush_wait_task):
if not task.done():
task.cancel()
if self.timeout_task is not None and not self.timeout_task.done():
self.timeout_task.cancel()


class KafkaBatchCommitter:
def __init__(
self,
Expand All @@ -50,57 +69,6 @@ def _check_is_commit_task_running(self) -> None:
msg: typing.Final = "Committer main task is not running"
raise CommitterIsDeadError(msg)

def _flush_tasks_queue(self) -> list[KafkaCommitTask]:
tasks_to_return: typing.Final[list[KafkaCommitTask]] = []
while not self._messages_queue.empty():
tasks_to_return.append(self._messages_queue.get_nowait())
return tasks_to_return

async def _populate_commit_batch(self) -> tuple[list[KafkaCommitTask], bool]:
uncommited_tasks: typing.Final[list[KafkaCommitTask]] = []
should_shutdown = False
queue_get_task: asyncio.Task[typing.Any] | None = None
# Create timeout and flush-wait tasks once; reused across queue-get iterations.
timeout_task: asyncio.Task[None] = asyncio.create_task(asyncio.sleep(self._commit_batch_timeout_sec))
flush_wait_task: asyncio.Task[bool] = asyncio.create_task(self._flush_batch_event.wait())
try:
while len(uncommited_tasks) < self._commit_batch_size:
queue_get_task = asyncio.create_task(self._messages_queue.get())
done, _ = await asyncio.wait(
[queue_get_task, flush_wait_task, timeout_task],
return_when=asyncio.FIRST_COMPLETED,
)

if queue_get_task in done:
uncommited_tasks.append(queue_get_task.result())
else:
queue_get_task.cancel()

# flush event — drain remaining queue items; stop only if close() was called
if flush_wait_task in done:
uncommited_tasks.extend(self._flush_tasks_queue())
should_shutdown = self._stop_requested
break

if timeout_task in done:
logger.debug("Timeout exceeded, batch contains %s elements", len(uncommited_tasks))
break

logger.debug("Batch condition reached with %s elements", len(uncommited_tasks))
except asyncio.CancelledError:
should_shutdown = True
uncommited_tasks.extend(self._flush_tasks_queue())

for task in (queue_get_task, flush_wait_task, timeout_task):
if task:
task.cancel()
# Reset on every exit (size, timeout, flush, cancelled). If commit_all() set the
# event but the loop exited via size or timeout first, leaving it set would cost
# one wasted populate cycle on the next iteration.
self._flush_batch_event.clear()

return uncommited_tasks, should_shutdown

async def _call_committer(
self, tasks_batch: list[KafkaCommitTask], partitions_to_offsets: dict[TopicPartition, int]
) -> bool:
Expand Down Expand Up @@ -141,43 +109,45 @@ def _map_offsets_per_partition(consumer_tasks: list[KafkaCommitTask]) -> dict[To
return partitions_to_offsets

@staticmethod
def _partition_ready(
pending: list[KafkaCommitTask],
) -> tuple[list[KafkaCommitTask], list[KafkaCommitTask]]:
# Per partition (sorted by offset), find the first task that is either cancelled or
# not-done. Tasks before that boundary are ready. A cancelled boundary means
# graceful-shutdown is in progress: the cancelled task and all later same-partition
# tasks are added to ready too — _map_offsets_per_partition stops at the cancelled
# offset (so nothing past it commits) and task_done() is called on all of them.
# A not-done boundary keeps that task and everything after it on its partition blocked.
by_partition: dict[TopicPartition, list[KafkaCommitTask]] = {}
for task in pending:
by_partition.setdefault(task.topic_partition, []).append(task)

ready: list[KafkaCommitTask] = []
still_blocked: list[KafkaCommitTask] = []
for tasks in by_partition.values():
tasks.sort(key=lambda t: t.offset)
cancelled_at: int | None = None
blocked_at: int | None = None
for index, task in enumerate(tasks):
def _extract_ready_prefixes(
pending: dict[TopicPartition, list[KafkaCommitTask]],
) -> dict[TopicPartition, list[KafkaCommitTask]]:
# Per partition (sorted by offset), find the first task that is not-done. Tasks before
# it form the contiguous-done prefix and become "ready". A cancelled task is treated
# as a hard boundary: cancelled + everything after is dropped from pending and added
# to ready (so task_done() balances messages_queue.join), while
# _map_offsets_per_partition stops the offset advance at the cancelled task so the
# uncommitted offsets get redelivered on restart (at-least-once).
ready: dict[TopicPartition, list[KafkaCommitTask]] = {}
empty_partitions: list[TopicPartition] = []
for partition, partition_pending in pending.items():
# Re-queued-on-transient-error tasks land at the queue tail and may arrive
# out of offset order with respect to newer same-partition tasks. Sort here.
partition_pending.sort(key=lambda t: t.offset)

prefix_end = 0
for index, task in enumerate(partition_pending):
if task.asyncio_task.cancelled():
cancelled_at = index
prefix_end = len(partition_pending)
break
if not task.asyncio_task.done():
blocked_at = index
prefix_end = index
break
if cancelled_at is not None:
ready.extend(tasks)
elif blocked_at is not None:
ready.extend(tasks[:blocked_at])
still_blocked.extend(tasks[blocked_at:])
else:
ready.extend(tasks)
return ready, still_blocked

async def _commit_ready_slice(self, ready: list[KafkaCommitTask]) -> bool:
for task in ready:
prefix_end = index + 1

if prefix_end > 0:
ready[partition] = partition_pending[:prefix_end]
del partition_pending[:prefix_end]
if not partition_pending:
empty_partitions.append(partition)

for k in empty_partitions:
del pending[k]
return ready

async def _commit_partitions(self, ready: dict[TopicPartition, list[KafkaCommitTask]]) -> bool:
flat: typing.Final[list[KafkaCommitTask]] = [t for tasks in ready.values() for t in tasks]
for task in flat:
if task.asyncio_task.cancelled():
continue
exc = task.asyncio_task.exception()
Expand All @@ -186,7 +156,7 @@ async def _commit_ready_slice(self, ready: list[KafkaCommitTask]) -> bool:

# Group by consumer instance — each AIOKafkaConsumer can only commit its own partitions
consumers_tasks: dict[int, list[KafkaCommitTask]] = {}
for task in ready:
for task in flat:
consumers_tasks.setdefault(id(task.consumer), []).append(task)

all_succeeded = True
Expand All @@ -195,34 +165,123 @@ async def _commit_ready_slice(self, ready: list[KafkaCommitTask]) -> bool:
if not await self._call_committer(consumer_tasks, partitions_to_offsets):
all_succeeded = False

for _ in ready:
for _ in flat:
self._messages_queue.task_done()
return all_succeeded

async def _commit_tasks_batch(self, tasks_batch: list[KafkaCommitTask]) -> bool:
pending: list[KafkaCommitTask] = list(tasks_batch)
all_succeeded = True

while pending:
ready, still_blocked = self._partition_ready(pending)
if ready:
if not await self._commit_ready_slice(ready):
all_succeeded = False
pending = still_blocked
continue

# _partition_ready places every done/cancelled task in ready, so an empty
# ready implies every pending task is still in-flight.
await asyncio.wait([t.asyncio_task for t in pending], return_when=asyncio.FIRST_COMPLETED)
@staticmethod
def _pending_head_tasks(
pending: dict[TopicPartition, list[KafkaCommitTask]],
) -> list[asyncio.Task[typing.Any]]:
# Watch only the first not-done task per partition. A cancelled head is treated as
# done by _extract_ready_prefixes, so it is intentionally not watched (would busy-loop).
heads: list[asyncio.Task[typing.Any]] = []
for partition_pending in pending.values():
for ct in partition_pending:
if ct.asyncio_task.cancelled():
break
if not ct.asyncio_task.done():
heads.append(ct.asyncio_task)
break
return heads

return all_succeeded
def _reset_timeout(
self,
timeout_task: asyncio.Task[None] | None,
pending_non_empty: bool,
) -> asyncio.Task[None] | None:
if timeout_task is not None and not timeout_task.done():
timeout_task.cancel()
if pending_non_empty:
return asyncio.create_task(asyncio.sleep(self._commit_batch_timeout_sec))
return None

async def _run_commit_process(self) -> None:
should_shutdown = False
while not should_shutdown:
commit_batch, should_shutdown = await self._populate_commit_batch()
if commit_batch:
await self._commit_tasks_batch(commit_batch)
# Streaming committer: one loop continuously absorbs queue items into per-partition
# pending state and commits each partition's contiguous-done prefix when total pending
# crosses commit_batch_size, when the timeout fires, or when commit_all/close sets the
# flush event. Queue depth no longer correlates with stuck-batch wait time.
state: typing.Final = _StreamingState(
queue_get_task=asyncio.create_task(self._messages_queue.get()),
flush_wait_task=asyncio.create_task(self._flush_batch_event.wait()),
)

try:
while not (state.should_shutdown and not state.pending):
await self._streaming_iteration(state)
finally:
state.cancel_outstanding()

async def _streaming_iteration(self, state: "_StreamingState") -> None:
wait_targets: list[asyncio.Future[typing.Any]] = [state.flush_wait_task]
if not state.should_shutdown:
wait_targets.append(state.queue_get_task)
if state.timeout_task is not None:
wait_targets.append(state.timeout_task)
wait_targets.extend(self._pending_head_tasks(state.pending))

await asyncio.wait(wait_targets, return_when=asyncio.FIRST_COMPLETED)

if not state.should_shutdown and state.queue_get_task.done():
new_ct = state.queue_get_task.result()
state.pending.setdefault(new_ct.topic_partition, []).append(new_ct)
state.queue_get_task = asyncio.create_task(self._messages_queue.get())
if state.timeout_task is None:
state.timeout_task = asyncio.create_task(asyncio.sleep(self._commit_batch_timeout_sec))

timeout_fired: typing.Final = state.timeout_task is not None and state.timeout_task.done()
flush_fired: typing.Final = state.flush_wait_task.done()

if flush_fired:
self._handle_flush_fired(state)

ready: typing.Final = await self._maybe_commit(state, timeout_fired)
if state.flush_in_progress and not state.pending:
state.flush_in_progress = False

# Reset the timer after any commit OR on timeout firing. Let it tick otherwise.
# Invariant: pending empty ⇒ timeout_task is None (guaranteed by _reset_timeout
# always being called when pending is mutated to empty), so no separate cleanup is needed.
if ready or timeout_fired:
state.timeout_task = self._reset_timeout(state.timeout_task, bool(state.pending))

def _handle_flush_fired(self, state: "_StreamingState") -> None:
if self._stop_requested:
state.should_shutdown = True
# Drain anything still buffered in messages_queue into pending so close()
# can commit it. Without this, items put before close() but not yet absorbed
# by queue_get would be silently dropped (offsets stay uncommitted; redelivered
# on restart, but commit_all/close() callers expect everything enqueued to be
# processed).
while True:
try:
ct = self._messages_queue.get_nowait()
except asyncio.QueueEmpty:
break
state.pending.setdefault(ct.topic_partition, []).append(ct)
if not state.queue_get_task.done():
state.queue_get_task.cancel()
else:
state.flush_in_progress = True
self._flush_batch_event.clear()
state.flush_wait_task = asyncio.create_task(self._flush_batch_event.wait())

async def _maybe_commit(
self, state: "_StreamingState", timeout_fired: bool
) -> dict[TopicPartition, list[KafkaCommitTask]]:
total_pending: typing.Final = sum(len(p) for p in state.pending.values())
commit_triggered: typing.Final = (
total_pending >= self._commit_batch_size
or timeout_fired
or state.flush_in_progress
or state.should_shutdown
)
if not commit_triggered:
return {}
ready: typing.Final = self._extract_ready_prefixes(state.pending)
if ready:
await self._commit_partitions(ready)
return ready

async def commit_all(self) -> None:
"""Flush and commit all pending tasks without stopping the committer loop.
Expand Down
Loading
Loading