From 88858d98cf8ac7cbcbdd783d87b33af6fc431461 Mon Sep 17 00:00:00 2001 From: Artur Shiriev Date: Sun, 3 May 2026 16:15:32 +0300 Subject: [PATCH] streaming committer for across-batch pipelining MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the populate-then-commit two-phase loop in `KafkaBatchCommitter` with a single streaming loop that continuously absorbs queue items into per-partition pending state and commits each partition's contiguous-done prefix. Eliminates the queue-grows-unbounded pathology when one batch stalls on a slow handler. Fast partitions now commit independently of slow ones across batch boundaries, not just within a single batch. `commit_batch_size` is still global; `commit_batch_timeout_sec` now anchors on first-task arrival rather than ticking through populate blocks. `commit_all`, `close`, and rebalance flushing are unchanged externally — flush events drive the loop the same way they always did. Co-Authored-By: Claude Opus 4.7 --- CLAUDE.md | 2 +- .../batch_committer.py | 273 ++-- tests/test_kafka_committer.py | 1259 +++++++++-------- 3 files changed, 856 insertions(+), 678 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 5a34b84..408b3de 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -43,7 +43,7 @@ Key design: `handle_task()` fires-and-forgets the user coroutine as an asyncio t A single function that accepts a `ContextRepo` and returns `True` if the handler is present and `is_healthy` (i.e. `_is_running` AND committer task alive). Intended for readiness/liveness probes. **`batch_committer.py` — `KafkaBatchCommitter`** -Runs as a background asyncio task (`spawn()`). Pulls `KafkaCommitTask`s off a queue, batches by `(timeout OR batch_size)`, awaits each task's asyncio future, groups by `(consumer_id, partition)`, takes the max offset per partition (stopping at the first cancelled task), and commits via `consumer.commit({TopicPartition: offset+1})`. Transient `KafkaError` re-queues the batch; `CommitFailedError`/`IllegalStateError` (rebalance/revocation) discards it. `CommitterIsDeadError` is raised to callers when the committer's main task has died, which triggers `handler.stop()`. +Runs as a background asyncio task (`spawn()`). Streaming loop: continuously absorbs `KafkaCommitTask`s from the queue into per-partition pending state, watches the head not-done task per partition, and commits each partition's contiguous-done prefix when total pending ≥ `commit_batch_size`, when `commit_batch_timeout_sec` fires, or when `commit_all`/`close` sets the flush event. Per partition, `_extract_ready_prefixes` sorts by offset (tolerates re-queued tasks landing out of order) and stops at the first not-done task; a cancelled task is a hard boundary — cancelled + everything after is dropped from pending while `_map_offsets_per_partition` stops the offset advance at the cancelled task (so uncommitted offsets get redelivered on restart, at-least-once). Per consumer-id group, commits via `consumer.commit({TopicPartition: max_offset+1})`. Transient `KafkaError` re-queues the batch; `CommitFailedError`/`IllegalStateError` (rebalance/revocation) discards it. `CommitterIsDeadError` is raised to callers when the committer's main task has died, which triggers `handler.stop()`. **`rebalance.py` — `ConsumerRebalanceListener`** Returned by `handler.create_rebalance_listener()`. On `on_partitions_revoked`, calls `committer.commit_all()` so offsets are flushed before the partition is reassigned, preventing duplicate processing after rebalance. diff --git a/faststream_concurrent_aiokafka/batch_committer.py b/faststream_concurrent_aiokafka/batch_committer.py index 2a1d57f..004ca2e 100644 --- a/faststream_concurrent_aiokafka/batch_committer.py +++ b/faststream_concurrent_aiokafka/batch_committer.py @@ -29,6 +29,25 @@ class KafkaCommitTask: consumer: typing.Any +@dataclasses.dataclass(kw_only=True, slots=True) +class _StreamingState: + queue_get_task: asyncio.Task[KafkaCommitTask] + flush_wait_task: asyncio.Task[bool] + timeout_task: asyncio.Task[None] | None = None + pending: dict[TopicPartition, list[KafkaCommitTask]] = dataclasses.field(default_factory=dict) + should_shutdown: bool = False + # Active commit_all (flush event seen, _stop_requested is False): keep committing every + # iteration until pending drains, so messages_queue.join() can return. + flush_in_progress: bool = False + + def cancel_outstanding(self) -> None: + for task in (self.queue_get_task, self.flush_wait_task): + if not task.done(): + task.cancel() + if self.timeout_task is not None and not self.timeout_task.done(): + self.timeout_task.cancel() + + class KafkaBatchCommitter: def __init__( self, @@ -50,57 +69,6 @@ def _check_is_commit_task_running(self) -> None: msg: typing.Final = "Committer main task is not running" raise CommitterIsDeadError(msg) - def _flush_tasks_queue(self) -> list[KafkaCommitTask]: - tasks_to_return: typing.Final[list[KafkaCommitTask]] = [] - while not self._messages_queue.empty(): - tasks_to_return.append(self._messages_queue.get_nowait()) - return tasks_to_return - - async def _populate_commit_batch(self) -> tuple[list[KafkaCommitTask], bool]: - uncommited_tasks: typing.Final[list[KafkaCommitTask]] = [] - should_shutdown = False - queue_get_task: asyncio.Task[typing.Any] | None = None - # Create timeout and flush-wait tasks once; reused across queue-get iterations. - timeout_task: asyncio.Task[None] = asyncio.create_task(asyncio.sleep(self._commit_batch_timeout_sec)) - flush_wait_task: asyncio.Task[bool] = asyncio.create_task(self._flush_batch_event.wait()) - try: - while len(uncommited_tasks) < self._commit_batch_size: - queue_get_task = asyncio.create_task(self._messages_queue.get()) - done, _ = await asyncio.wait( - [queue_get_task, flush_wait_task, timeout_task], - return_when=asyncio.FIRST_COMPLETED, - ) - - if queue_get_task in done: - uncommited_tasks.append(queue_get_task.result()) - else: - queue_get_task.cancel() - - # flush event — drain remaining queue items; stop only if close() was called - if flush_wait_task in done: - uncommited_tasks.extend(self._flush_tasks_queue()) - should_shutdown = self._stop_requested - break - - if timeout_task in done: - logger.debug("Timeout exceeded, batch contains %s elements", len(uncommited_tasks)) - break - - logger.debug("Batch condition reached with %s elements", len(uncommited_tasks)) - except asyncio.CancelledError: - should_shutdown = True - uncommited_tasks.extend(self._flush_tasks_queue()) - - for task in (queue_get_task, flush_wait_task, timeout_task): - if task: - task.cancel() - # Reset on every exit (size, timeout, flush, cancelled). If commit_all() set the - # event but the loop exited via size or timeout first, leaving it set would cost - # one wasted populate cycle on the next iteration. - self._flush_batch_event.clear() - - return uncommited_tasks, should_shutdown - async def _call_committer( self, tasks_batch: list[KafkaCommitTask], partitions_to_offsets: dict[TopicPartition, int] ) -> bool: @@ -141,43 +109,45 @@ def _map_offsets_per_partition(consumer_tasks: list[KafkaCommitTask]) -> dict[To return partitions_to_offsets @staticmethod - def _partition_ready( - pending: list[KafkaCommitTask], - ) -> tuple[list[KafkaCommitTask], list[KafkaCommitTask]]: - # Per partition (sorted by offset), find the first task that is either cancelled or - # not-done. Tasks before that boundary are ready. A cancelled boundary means - # graceful-shutdown is in progress: the cancelled task and all later same-partition - # tasks are added to ready too — _map_offsets_per_partition stops at the cancelled - # offset (so nothing past it commits) and task_done() is called on all of them. - # A not-done boundary keeps that task and everything after it on its partition blocked. - by_partition: dict[TopicPartition, list[KafkaCommitTask]] = {} - for task in pending: - by_partition.setdefault(task.topic_partition, []).append(task) - - ready: list[KafkaCommitTask] = [] - still_blocked: list[KafkaCommitTask] = [] - for tasks in by_partition.values(): - tasks.sort(key=lambda t: t.offset) - cancelled_at: int | None = None - blocked_at: int | None = None - for index, task in enumerate(tasks): + def _extract_ready_prefixes( + pending: dict[TopicPartition, list[KafkaCommitTask]], + ) -> dict[TopicPartition, list[KafkaCommitTask]]: + # Per partition (sorted by offset), find the first task that is not-done. Tasks before + # it form the contiguous-done prefix and become "ready". A cancelled task is treated + # as a hard boundary: cancelled + everything after is dropped from pending and added + # to ready (so task_done() balances messages_queue.join), while + # _map_offsets_per_partition stops the offset advance at the cancelled task so the + # uncommitted offsets get redelivered on restart (at-least-once). + ready: dict[TopicPartition, list[KafkaCommitTask]] = {} + empty_partitions: list[TopicPartition] = [] + for partition, partition_pending in pending.items(): + # Re-queued-on-transient-error tasks land at the queue tail and may arrive + # out of offset order with respect to newer same-partition tasks. Sort here. + partition_pending.sort(key=lambda t: t.offset) + + prefix_end = 0 + for index, task in enumerate(partition_pending): if task.asyncio_task.cancelled(): - cancelled_at = index + prefix_end = len(partition_pending) break if not task.asyncio_task.done(): - blocked_at = index + prefix_end = index break - if cancelled_at is not None: - ready.extend(tasks) - elif blocked_at is not None: - ready.extend(tasks[:blocked_at]) - still_blocked.extend(tasks[blocked_at:]) - else: - ready.extend(tasks) - return ready, still_blocked - - async def _commit_ready_slice(self, ready: list[KafkaCommitTask]) -> bool: - for task in ready: + prefix_end = index + 1 + + if prefix_end > 0: + ready[partition] = partition_pending[:prefix_end] + del partition_pending[:prefix_end] + if not partition_pending: + empty_partitions.append(partition) + + for k in empty_partitions: + del pending[k] + return ready + + async def _commit_partitions(self, ready: dict[TopicPartition, list[KafkaCommitTask]]) -> bool: + flat: typing.Final[list[KafkaCommitTask]] = [t for tasks in ready.values() for t in tasks] + for task in flat: if task.asyncio_task.cancelled(): continue exc = task.asyncio_task.exception() @@ -186,7 +156,7 @@ async def _commit_ready_slice(self, ready: list[KafkaCommitTask]) -> bool: # Group by consumer instance — each AIOKafkaConsumer can only commit its own partitions consumers_tasks: dict[int, list[KafkaCommitTask]] = {} - for task in ready: + for task in flat: consumers_tasks.setdefault(id(task.consumer), []).append(task) all_succeeded = True @@ -195,34 +165,123 @@ async def _commit_ready_slice(self, ready: list[KafkaCommitTask]) -> bool: if not await self._call_committer(consumer_tasks, partitions_to_offsets): all_succeeded = False - for _ in ready: + for _ in flat: self._messages_queue.task_done() return all_succeeded - async def _commit_tasks_batch(self, tasks_batch: list[KafkaCommitTask]) -> bool: - pending: list[KafkaCommitTask] = list(tasks_batch) - all_succeeded = True - - while pending: - ready, still_blocked = self._partition_ready(pending) - if ready: - if not await self._commit_ready_slice(ready): - all_succeeded = False - pending = still_blocked - continue - - # _partition_ready places every done/cancelled task in ready, so an empty - # ready implies every pending task is still in-flight. - await asyncio.wait([t.asyncio_task for t in pending], return_when=asyncio.FIRST_COMPLETED) + @staticmethod + def _pending_head_tasks( + pending: dict[TopicPartition, list[KafkaCommitTask]], + ) -> list[asyncio.Task[typing.Any]]: + # Watch only the first not-done task per partition. A cancelled head is treated as + # done by _extract_ready_prefixes, so it is intentionally not watched (would busy-loop). + heads: list[asyncio.Task[typing.Any]] = [] + for partition_pending in pending.values(): + for ct in partition_pending: + if ct.asyncio_task.cancelled(): + break + if not ct.asyncio_task.done(): + heads.append(ct.asyncio_task) + break + return heads - return all_succeeded + def _reset_timeout( + self, + timeout_task: asyncio.Task[None] | None, + pending_non_empty: bool, + ) -> asyncio.Task[None] | None: + if timeout_task is not None and not timeout_task.done(): + timeout_task.cancel() + if pending_non_empty: + return asyncio.create_task(asyncio.sleep(self._commit_batch_timeout_sec)) + return None async def _run_commit_process(self) -> None: - should_shutdown = False - while not should_shutdown: - commit_batch, should_shutdown = await self._populate_commit_batch() - if commit_batch: - await self._commit_tasks_batch(commit_batch) + # Streaming committer: one loop continuously absorbs queue items into per-partition + # pending state and commits each partition's contiguous-done prefix when total pending + # crosses commit_batch_size, when the timeout fires, or when commit_all/close sets the + # flush event. Queue depth no longer correlates with stuck-batch wait time. + state: typing.Final = _StreamingState( + queue_get_task=asyncio.create_task(self._messages_queue.get()), + flush_wait_task=asyncio.create_task(self._flush_batch_event.wait()), + ) + + try: + while not (state.should_shutdown and not state.pending): + await self._streaming_iteration(state) + finally: + state.cancel_outstanding() + + async def _streaming_iteration(self, state: "_StreamingState") -> None: + wait_targets: list[asyncio.Future[typing.Any]] = [state.flush_wait_task] + if not state.should_shutdown: + wait_targets.append(state.queue_get_task) + if state.timeout_task is not None: + wait_targets.append(state.timeout_task) + wait_targets.extend(self._pending_head_tasks(state.pending)) + + await asyncio.wait(wait_targets, return_when=asyncio.FIRST_COMPLETED) + + if not state.should_shutdown and state.queue_get_task.done(): + new_ct = state.queue_get_task.result() + state.pending.setdefault(new_ct.topic_partition, []).append(new_ct) + state.queue_get_task = asyncio.create_task(self._messages_queue.get()) + if state.timeout_task is None: + state.timeout_task = asyncio.create_task(asyncio.sleep(self._commit_batch_timeout_sec)) + + timeout_fired: typing.Final = state.timeout_task is not None and state.timeout_task.done() + flush_fired: typing.Final = state.flush_wait_task.done() + + if flush_fired: + self._handle_flush_fired(state) + + ready: typing.Final = await self._maybe_commit(state, timeout_fired) + if state.flush_in_progress and not state.pending: + state.flush_in_progress = False + + # Reset the timer after any commit OR on timeout firing. Let it tick otherwise. + # Invariant: pending empty ⇒ timeout_task is None (guaranteed by _reset_timeout + # always being called when pending is mutated to empty), so no separate cleanup is needed. + if ready or timeout_fired: + state.timeout_task = self._reset_timeout(state.timeout_task, bool(state.pending)) + + def _handle_flush_fired(self, state: "_StreamingState") -> None: + if self._stop_requested: + state.should_shutdown = True + # Drain anything still buffered in messages_queue into pending so close() + # can commit it. Without this, items put before close() but not yet absorbed + # by queue_get would be silently dropped (offsets stay uncommitted; redelivered + # on restart, but commit_all/close() callers expect everything enqueued to be + # processed). + while True: + try: + ct = self._messages_queue.get_nowait() + except asyncio.QueueEmpty: + break + state.pending.setdefault(ct.topic_partition, []).append(ct) + if not state.queue_get_task.done(): + state.queue_get_task.cancel() + else: + state.flush_in_progress = True + self._flush_batch_event.clear() + state.flush_wait_task = asyncio.create_task(self._flush_batch_event.wait()) + + async def _maybe_commit( + self, state: "_StreamingState", timeout_fired: bool + ) -> dict[TopicPartition, list[KafkaCommitTask]]: + total_pending: typing.Final = sum(len(p) for p in state.pending.values()) + commit_triggered: typing.Final = ( + total_pending >= self._commit_batch_size + or timeout_fired + or state.flush_in_progress + or state.should_shutdown + ) + if not commit_triggered: + return {} + ready: typing.Final = self._extract_ready_prefixes(state.pending) + if ready: + await self._commit_partitions(ready) + return ready async def commit_all(self) -> None: """Flush and commit all pending tasks without stopping the committer loop. diff --git a/tests/test_kafka_committer.py b/tests/test_kafka_committer.py index adc35ed..2801950 100644 --- a/tests/test_kafka_committer.py +++ b/tests/test_kafka_committer.py @@ -10,7 +10,6 @@ from aiokafka.errors import CommitFailedError, KafkaError from faststream.kafka import TopicPartition -from faststream_concurrent_aiokafka import batch_committer from faststream_concurrent_aiokafka.batch_committer import CommitterIsDeadError, KafkaBatchCommitter, KafkaCommitTask from tests.mocks import MockAIOKafkaConsumer, MockAsyncioTask @@ -41,6 +40,9 @@ async def committer() -> typing.AsyncIterator[KafkaBatchCommitter]: await committer._commit_task +# ---------- _check_is_commit_task_running ---------- + + def test_committer_raises_when_not_spawned(committer: KafkaBatchCommitter) -> None: with pytest.raises(CommitterIsDeadError, match="Committer main task is not running"): committer._check_is_commit_task_running() @@ -68,6 +70,9 @@ async def test_committer_passes_when_running(committer: KafkaBatchCommitter) -> committer._check_is_commit_task_running() +# ---------- spawn / send_task / is_healthy / close ---------- + + async def test_committer_spawn_creates_task(committer: KafkaBatchCommitter) -> None: committer.spawn() @@ -95,45 +100,29 @@ async def test_committer_send_task_raises_when_dead( await committer.send_task(sample_task) -async def test_committer_commit_all_flush(committer: KafkaBatchCommitter) -> None: - committer.spawn() +async def test_committer_close_graceful_shutdown() -> None: + """close() drives the streaming loop to drain pending and exit cleanly.""" + consumer: typing.Final = MockAIOKafkaConsumer() + committer: typing.Final = KafkaBatchCommitter(commit_batch_timeout_sec=10.0, commit_batch_size=10) - task: typing.Final = asyncio.create_task(asyncio.sleep(10)) - commit_task: typing.Final = batch_committer.KafkaCommitTask( - asyncio_task=task, + async def quick_handler() -> str: + return "ok" + + real_task: typing.Final = asyncio.create_task(quick_handler()) + commit_task: typing.Final = KafkaCommitTask( + asyncio_task=real_task, offset=100, - consumer=MockAIOKafkaConsumer(), - topic_partition=TopicPartition(topic="t1", partition=0), + consumer=consumer, + topic_partition=TopicPartition(topic="t", partition=0), ) - await committer.send_task(commit_task) - - with patch.object(committer, "_commit_tasks_batch", new_callable=AsyncMock) as mock_commit: - mock_commit.return_value = True - - async def mock_commit_side_effect(batch: list[str]) -> bool: - for _ in batch: - committer._messages_queue.task_done() - return True - - mock_commit.side_effect = mock_commit_side_effect - await committer.commit_all() - -async def test_committer_close_graceful_shutdown(committer: KafkaBatchCommitter, sample_task: KafkaCommitTask) -> None: committer.spawn() - await committer.send_task(sample_task) - - with patch.object(committer, "_commit_tasks_batch", new_callable=AsyncMock) as mock_commit: - - async def side_effect(batch: list[str]) -> bool: - for _ in batch: - committer._messages_queue.task_done() - return True + await committer.send_task(commit_task) + await committer.close() - mock_commit.side_effect = side_effect - await committer.close() - assert committer._commit_task - assert committer._commit_task.done() + assert committer._commit_task + assert committer._commit_task.done() + consumer.commit.assert_called_once() async def test_committer_close_timeout_cancels_task(committer: KafkaBatchCommitter) -> None: @@ -168,100 +157,53 @@ async def test_committer_is_healthy(committer: KafkaBatchCommitter) -> None: assert not committer.is_healthy -async def test_committer_returns_empty_on_empty_queue(committer: KafkaBatchCommitter) -> None: - committer._commit_batch_timeout_sec = 0.01 - tasks, should_shutdown = await committer._populate_commit_batch() - assert tasks == [] - assert should_shutdown is False - - -async def test_committer_collects_batch_size(committer: KafkaBatchCommitter, sample_task: KafkaCommitTask) -> None: - committer._commit_batch_size = 2 - committer._commit_batch_timeout_sec = 10.0 - - for _ in range(3): - await committer._messages_queue.put(sample_task) - - tasks, should_shutdown = await committer._populate_commit_batch() - assert len(tasks) == committer._commit_batch_size - assert should_shutdown is False - - -async def test_committer_returns_on_flush_event(committer: KafkaBatchCommitter, sample_task: KafkaCommitTask) -> None: - """commit_all() flushes without stopping the loop (should_shutdown stays False).""" - committer._commit_batch_timeout_sec = 10.0 - - for _ in range(6): - await committer._messages_queue.put(sample_task) - committer._flush_batch_event.set() - - tasks, should_shutdown = await committer._populate_commit_batch() - assert len(tasks) > committer._commit_batch_size - assert should_shutdown is False - - -async def test_committer_clears_flush_event_on_exit( - committer: KafkaBatchCommitter, sample_task: KafkaCommitTask -) -> None: - """The flush event must be cleared on every populate exit, not just the flush branch.""" - committer._commit_batch_size = 2 - committer._commit_batch_timeout_sec = 10.0 - - for _ in range(2): - await committer._messages_queue.put(sample_task) - committer._flush_batch_event.set() - - await committer._populate_commit_batch() - assert not committer._flush_batch_event.is_set() - +async def test_committer_uses_shutdown_timeout_kwarg() -> None: + committer: typing.Final = KafkaBatchCommitter(shutdown_timeout_sec=0.05) + assert committer._shutdown_timeout == 0.05 -async def test_committer_clears_flush_event_on_timeout_exit(committer: KafkaBatchCommitter) -> None: - """Cleanup-section clear runs even when the loop exits via timeout (no flush branch).""" - committer._commit_batch_timeout_sec = 0.01 - # No items in queue, flush not set — exits via timeout. Then we set the event manually - # to assert the cleanup clear ran (it would clear our just-set event too). - tasks, _ = await committer._populate_commit_batch() - assert tasks == [] - assert not committer._flush_batch_event.is_set() +async def test_committer_close_logs_when_task_already_died(caplog: pytest.LogCaptureFixture) -> None: + """If the committer task crashed before close() is called, the exception is logged.""" + committer: typing.Final = KafkaBatchCommitter(commit_batch_timeout_sec=0.1, commit_batch_size=10) -async def test_committer_returns_shutdown_on_close_flush( - committer: KafkaBatchCommitter, sample_task: KafkaCommitTask -) -> None: - """close() sets _stop_requested so flush triggers shutdown.""" - committer._commit_batch_timeout_sec = 10.0 - committer._stop_requested = True + async def crashing() -> typing.Never: + msg: typing.Final = "boom" + raise RuntimeError(msg) - for _ in range(2): - await committer._messages_queue.put(sample_task) - committer._flush_batch_event.set() + committer._commit_task = asyncio.create_task(crashing()) + with contextlib.suppress(RuntimeError): + await committer._commit_task - tasks, should_shutdown = await committer._populate_commit_batch() - assert len(tasks) == 2 - assert should_shutdown is True + await committer.close() + assert "Committer task had already died before close()" in caplog.text -async def test_committer_flush_empty_list_when_queue_empty(committer: KafkaBatchCommitter) -> None: - result: typing.Final = committer._flush_tasks_queue() - assert result == [] +async def test_committer_close_but_timeout_error(caplog: pytest.LogCaptureFixture) -> None: + committer: typing.Final = KafkaBatchCommitter(commit_batch_timeout_sec=0.1, commit_batch_size=10000) + committer._commit_task = asyncio.create_task(asyncio.sleep(30)) + committer._shutdown_timeout = 0.1 + await committer.close() + await asyncio.sleep(0.5) + assert "Committer main task shutdown timed out, forcing cancellation" in caplog.text + assert not committer.is_healthy -async def test_committer_handles_cancelled_error(committer: KafkaBatchCommitter, sample_task: KafkaCommitTask) -> None: - await committer._messages_queue.put(sample_task) +async def test_committer_close_but_unexpected_error() -> None: + committer: typing.Final = KafkaBatchCommitter(commit_batch_timeout_sec=0.1, commit_batch_size=10) + mock_task: typing.Final = MagicMock() + mock_task.done.return_value = False + mock_task.cancelled.return_value = False + committer._commit_task = mock_task - async def mock_wait(*_: list[str], **__: dict[str, str]) -> typing.Never: - raise asyncio.CancelledError + original_exception: typing.Final = RuntimeError("Original error") - with patch("asyncio.wait", side_effect=mock_wait): - tasks, should_shutdown = await committer._populate_commit_batch() + with patch("asyncio.wait_for", side_effect=original_exception), pytest.raises(RuntimeError) as exc_info: + await committer.close() - assert should_shutdown is True - assert len(tasks) == 1 + assert exc_info.value is original_exception -async def test_committer_check_on_timeout_working_correctly(committer: KafkaBatchCommitter) -> None: - committer._commit_batch_timeout_sec = 0.01 - _tasks, _ = await committer._populate_commit_batch() +# ---------- _call_committer ---------- async def test_committer_returns_true_on_empty_offsets(committer: KafkaBatchCommitter) -> None: @@ -287,11 +229,11 @@ async def test_committer_commits_to_kafka(committer: KafkaBatchCommitter, mock_c async def test_committer_retries_on_kafka_error( - committer: KafkaBatchCommitter, mock_consumer: MockAIOKafkaConsumer, sample_task: KafkaCommitTask + committer: KafkaBatchCommitter, mock_consumer: MockAIOKafkaConsumer ) -> None: """KafkaError re-queues the batch for retry on the next cycle.""" mock_task: typing.Final = MockAsyncioTask(result="success") - sample_task = KafkaCommitTask( + sample_task: typing.Final = KafkaCommitTask( asyncio_task=mock_task, # ty: ignore[invalid-argument-type] offset=100, consumer=mock_consumer, @@ -322,620 +264,797 @@ async def test_committer_ignores_commit_failed_error( assert committer._messages_queue.empty() -async def test_committer_waits_for_all_tasks( - committer: KafkaBatchCommitter, mock_consumer: MockAIOKafkaConsumer -) -> None: - task1: typing.Final = MockAsyncioTask(result="result1", done=False) - task2: typing.Final = MockAsyncioTask(result="result2", done=False) - expected_offset: typing.Final = 100 +# ---------- _map_offsets_per_partition ---------- - commit_task1: typing.Final = batch_committer.KafkaCommitTask( - asyncio_task=task1, # ty: ignore[invalid-argument-type] - offset=100, - consumer=mock_consumer, - topic_partition=TopicPartition(topic="t1", partition=0), - ) - commit_task2: typing.Final = batch_committer.KafkaCommitTask( - asyncio_task=task2, # ty: ignore[invalid-argument-type] - offset=101, - consumer=mock_consumer, - topic_partition=TopicPartition(topic="t1", partition=0), - ) - task1._done = True - task2._done = True - - await committer._messages_queue.put(commit_task1) - await committer._messages_queue.put(commit_task2) - - batch: typing.Final = [await committer._messages_queue.get(), await committer._messages_queue.get()] - - with patch.object(committer, "_call_committer", new_callable=AsyncMock) as mock_commit: - mock_commit.return_value = True - - result: typing.Final = await committer._commit_tasks_batch(batch) +def test_committer_map_offsets_skips_cancelled_tasks(mock_consumer: MockAIOKafkaConsumer) -> None: + """Offset must not advance past a cancelled task — that message was never processed.""" + tp: typing.Final = TopicPartition(topic="t1", partition=0) + ok_task: typing.Final = MockAsyncioTask(result="ok") + cancelled_task: typing.Final = MockAsyncioTask(cancelled=True) + later_ok_task: typing.Final = MockAsyncioTask(result="ok") - assert result is True - mock_commit.assert_called_once() - call_args: typing.Final = mock_commit.call_args[0][1] - assert call_args[TopicPartition(topic="t1", partition=0)] == expected_offset + 2 + tasks: typing.Final = [ + KafkaCommitTask( + asyncio_task=ok_task, # ty: ignore[invalid-argument-type] + offset=10, + consumer=mock_consumer, + topic_partition=tp, + ), + KafkaCommitTask( + asyncio_task=cancelled_task, # ty: ignore[invalid-argument-type] + offset=11, + consumer=mock_consumer, + topic_partition=tp, + ), + KafkaCommitTask( + asyncio_task=later_ok_task, # ty: ignore[invalid-argument-type] + offset=12, + consumer=mock_consumer, + topic_partition=tp, + ), + ] + offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(tasks) + # Only offset 10 is safe to commit; 11 was cancelled, 12 is beyond it. + assert offsets[tp] == 11 # max safe offset (10) + 1 -async def test_committer_logs_task_exceptions( - committer: KafkaBatchCommitter, mock_consumer: MockAIOKafkaConsumer, caplog: pytest.LogCaptureFixture -) -> None: - task: typing.Final = MockAsyncioTask(exception=ValueError("Task failed"), done=True) - commit_task: typing.Final = batch_committer.KafkaCommitTask( - asyncio_task=task, # ty: ignore[invalid-argument-type] - offset=100, +def test_committer_map_offsets_skips_partition_when_all_cancelled(mock_consumer: MockAIOKafkaConsumer) -> None: + """If all tasks on a partition are cancelled, that partition is omitted from the commit map.""" + tp: typing.Final = TopicPartition(topic="t1", partition=0) + cancelled_task: typing.Final = MockAsyncioTask(cancelled=True) + task: typing.Final = KafkaCommitTask( + asyncio_task=cancelled_task, # ty: ignore[invalid-argument-type] + offset=5, consumer=mock_consumer, - topic_partition=TopicPartition(topic="t1", partition=0), + topic_partition=tp, ) - await committer._messages_queue.put(commit_task) - with patch.object(committer, "_call_committer", new_callable=AsyncMock, return_value=True): - await committer._commit_tasks_batch([await committer._messages_queue.get()]) - - assert "Task has finished with an exception" in caplog.text + offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition([task]) + assert tp not in offsets -async def test_committer_groups_by_partition( - committer: KafkaBatchCommitter, mock_consumer: MockAIOKafkaConsumer -) -> None: - tasks: typing.Final = [] +def test_committer_map_offsets_advances_to_max_per_partition(mock_consumer: MockAIOKafkaConsumer) -> None: + """Offset advances to (max processed + 1) per partition.""" first_offset: typing.Final = 100 second_offset: typing.Final = 999 - for partition in [0, 0, 1, 1]: - task = MockAsyncioTask(done=True) - commit_task = batch_committer.KafkaCommitTask( - asyncio_task=task, # ty: ignore[invalid-argument-type] - offset=first_offset + partition * 10, - consumer=mock_consumer, - topic_partition=TopicPartition(topic="t1", partition=partition), + partition: typing.Final = 1 + + tasks: typing.Final[list[KafkaCommitTask]] = [] + # partition 0: offsets 100, 110 (two tasks) + for off in (first_offset, first_offset + 10): + tasks.append( + KafkaCommitTask( + asyncio_task=MockAsyncioTask(done=True), # ty: ignore[invalid-argument-type] + offset=off, + consumer=mock_consumer, + topic_partition=TopicPartition(topic="t1", partition=0), + ) + ) + # partition 1: offsets 100, 110, 999 (three tasks; max=999) + for off in (first_offset, first_offset + 10, second_offset): + tasks.append( + KafkaCommitTask( + asyncio_task=MockAsyncioTask(done=True), # ty: ignore[invalid-argument-type] + offset=off, + consumer=mock_consumer, + topic_partition=TopicPartition(topic="t1", partition=partition), + ) ) - await committer._messages_queue.put(commit_task) - tasks.append(commit_task) - task = MockAsyncioTask(done=True) - commit_task = batch_committer.KafkaCommitTask( - asyncio_task=task, # ty: ignore[invalid-argument-type] - offset=second_offset, - consumer=mock_consumer, - topic_partition=TopicPartition(topic="t1", partition=0), - ) - await committer._messages_queue.put(commit_task) - while not committer._messages_queue.empty(): - await committer._messages_queue.get() - tasks.append(commit_task) + offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(tasks) + assert offsets[TopicPartition(topic="t1", partition=0)] == first_offset + 10 + 1 + assert offsets[TopicPartition(topic="t1", partition=partition)] == second_offset + 1 - with patch.object(committer, "_call_committer", new_callable=AsyncMock) as mock_commit: - mock_commit.return_value = True - await committer._commit_tasks_batch(tasks) - partitions: typing.Final = mock_commit.call_args[0][1] - assert partitions[TopicPartition(topic="t1", partition=0)] == second_offset + 1 - assert partitions[TopicPartition(topic="t1", partition=1)] == first_offset + partition * 10 + 1 +# ---------- _extract_ready_prefixes ---------- -async def test_committer_runs_until_shutdown(committer: KafkaBatchCommitter, sample_task: KafkaCommitTask) -> None: - await committer._messages_queue.put(sample_task) - committer._commit_batch_timeout_sec = 0.01 +def test_extract_ready_prefixes_empty_pending() -> None: + pending: dict[TopicPartition, list[KafkaCommitTask]] = {} + ready: typing.Final = KafkaBatchCommitter._extract_ready_prefixes(pending) + assert ready == {} + assert pending == {} - with patch.object(committer, "_populate_commit_batch") as mock_populate: - mock_populate.side_effect = [ - ([sample_task], True), - ] - with patch.object(committer, "_commit_tasks_batch", new_callable=AsyncMock) as mock_commit: - mock_commit.return_value = True - await committer._run_commit_process() +def test_extract_ready_prefixes_all_done(mock_consumer: MockAIOKafkaConsumer) -> None: + tp: typing.Final = TopicPartition(topic="t", partition=0) + tasks: typing.Final = [ + KafkaCommitTask( + asyncio_task=MockAsyncioTask(done=True), # ty: ignore[invalid-argument-type] + offset=offset, + consumer=mock_consumer, + topic_partition=tp, + ) + for offset in (10, 11, 12) + ] + pending: dict[TopicPartition, list[KafkaCommitTask]] = {tp: list(tasks)} - mock_populate.assert_called_once() - mock_commit.assert_called_once_with([sample_task]) + ready: typing.Final = KafkaBatchCommitter._extract_ready_prefixes(pending) + assert ready == {tp: tasks} + assert pending == {} # partition emptied -async def test_committer_skips_empty_batches(committer: KafkaBatchCommitter) -> None: - committer._commit_batch_timeout_sec = 0.01 - with patch.object(committer, "_populate_commit_batch") as mock_populate: - mock_populate.side_effect = [ - ([], False), - ([], True), - ] - with patch.object(committer, "_commit_tasks_batch", new_callable=AsyncMock) as mock_commit: - await committer._run_commit_process() - mock_commit.assert_not_called() +def test_extract_ready_prefixes_blocks_on_first_pending(mock_consumer: MockAIOKafkaConsumer) -> None: + tp: typing.Final = TopicPartition(topic="t", partition=0) + pending_task: typing.Final = MockAsyncioTask(done=False) + tasks: typing.Final = [ + KafkaCommitTask( + asyncio_task=MockAsyncioTask(done=True), # ty: ignore[invalid-argument-type] + offset=10, + consumer=mock_consumer, + topic_partition=tp, + ), + KafkaCommitTask( + asyncio_task=pending_task, # ty: ignore[invalid-argument-type] + offset=11, + consumer=mock_consumer, + topic_partition=tp, + ), + KafkaCommitTask( + asyncio_task=MockAsyncioTask(done=True), # ty: ignore[invalid-argument-type] + offset=12, + consumer=mock_consumer, + topic_partition=tp, + ), + ] + pending: dict[TopicPartition, list[KafkaCommitTask]] = {tp: list(tasks)} + ready: typing.Final = KafkaBatchCommitter._extract_ready_prefixes(pending) -async def test_committer_full_flow_single_task() -> None: - committer: typing.Final = KafkaBatchCommitter(commit_batch_timeout_sec=0.1, commit_batch_size=10) - consumer: typing.Final = MockAIOKafkaConsumer() + assert ready == {tp: [tasks[0]]} # only the prefix before offset 11 + assert pending[tp] == [tasks[1], tasks[2]] - async def handler() -> str: - await asyncio.sleep(0.01) - return "processed" - real_task: typing.Final = asyncio.create_task(handler()) +def test_extract_ready_prefixes_cancelled_drops_partition(mock_consumer: MockAIOKafkaConsumer) -> None: + """Cancelled task drops cancelled + everything after from pending into ready. - commit_task: typing.Final = batch_committer.KafkaCommitTask( - asyncio_task=real_task, - offset=100, - consumer=consumer, - topic_partition=TopicPartition(topic="test", partition=0), - ) + task_done() balances messages_queue.join() that way. _map_offsets_per_partition + separately stops the offset advance at the cancelled task so it gets redelivered. + """ + tp: typing.Final = TopicPartition(topic="t", partition=0) + tasks: typing.Final = [ + KafkaCommitTask( + asyncio_task=MockAsyncioTask(done=True), # ty: ignore[invalid-argument-type] + offset=10, + consumer=mock_consumer, + topic_partition=tp, + ), + KafkaCommitTask( + asyncio_task=MockAsyncioTask(cancelled=True), # ty: ignore[invalid-argument-type] + offset=11, + consumer=mock_consumer, + topic_partition=tp, + ), + KafkaCommitTask( + asyncio_task=MockAsyncioTask(done=True), # ty: ignore[invalid-argument-type] + offset=12, + consumer=mock_consumer, + topic_partition=tp, + ), + ] + pending: dict[TopicPartition, list[KafkaCommitTask]] = {tp: list(tasks)} - committer.spawn() - await committer.send_task(commit_task) + ready: typing.Final = KafkaBatchCommitter._extract_ready_prefixes(pending) - await asyncio.sleep(0.2) - await committer.close() - assert consumer.commit.called - real_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await real_task + assert ready == {tp: tasks} # all three included in ready + assert pending == {} # partition emptied -async def test_committer_multiple_topics_and_partitions( - committer: KafkaBatchCommitter, mock_consumer: MockAIOKafkaConsumer -) -> None: - expected_amount_topic_partitions: typing.Final = 4 - tasks: typing.Final = [] - for topic in ["topic-a", "topic-b"]: - for partition in [0, 1]: - for offset in [100, 200]: - task = MockAsyncioTask(done=True) - commit_task = batch_committer.KafkaCommitTask( - asyncio_task=task, # ty: ignore[invalid-argument-type] - offset=offset, - consumer=mock_consumer, - topic_partition=TopicPartition(topic=topic, partition=partition), - ) - await committer._messages_queue.put(commit_task) - tasks.append(commit_task) - - with patch.object(committer, "_call_committer", new_callable=AsyncMock) as mock_commit: - mock_commit.return_value = True - await committer._commit_tasks_batch(tasks) - call_args: typing.Final = mock_commit.call_args[0][1] - assert len(call_args) == expected_amount_topic_partitions - - -async def test_committer_task_with_none_result( - committer: KafkaBatchCommitter, mock_consumer: MockAIOKafkaConsumer -) -> None: - task: typing.Final = MockAsyncioTask(result=None, done=True) - commit_task: typing.Final = batch_committer.KafkaCommitTask( - asyncio_task=task, # ty: ignore[invalid-argument-type] - offset=100, +def test_pending_head_tasks_skips_cancelled_head(mock_consumer: MockAIOKafkaConsumer) -> None: + """A cancelled head must not be added to the wait set (would busy-loop).""" + cancelled: typing.Final = MockAsyncioTask(cancelled=True) + tp: typing.Final = TopicPartition(topic="t", partition=0) + ct: typing.Final = KafkaCommitTask( + asyncio_task=cancelled, # ty: ignore[invalid-argument-type] + offset=10, consumer=mock_consumer, - topic_partition=TopicPartition(topic="t", partition=0), + topic_partition=tp, ) - await committer._messages_queue.put(commit_task) + heads: typing.Final = KafkaBatchCommitter._pending_head_tasks({tp: [ct]}) + assert heads == [] - with patch.object(committer, "_call_committer", new_callable=AsyncMock, return_value=True): - result: typing.Final = await committer._commit_tasks_batch([commit_task]) - assert result is True - -async def test_committer_very_large_batch_size(mock_consumer: MockAIOKafkaConsumer) -> None: - committer: typing.Final = KafkaBatchCommitter(commit_batch_timeout_sec=0.1, commit_batch_size=10000) - - task: typing.Final = MockAsyncioTask(done=True) - commit_task: typing.Final = batch_committer.KafkaCommitTask( - asyncio_task=task, # ty: ignore[invalid-argument-type] - offset=1, +def test_pending_head_tasks_returns_first_not_done(mock_consumer: MockAIOKafkaConsumer) -> None: + not_done_head: typing.Final = MockAsyncioTask(done=False) + tp: typing.Final = TopicPartition(topic="t", partition=0) + pending: typing.Final = { + tp: [ + KafkaCommitTask( + asyncio_task=MockAsyncioTask(done=True), # ty: ignore[invalid-argument-type] + offset=10, + consumer=mock_consumer, + topic_partition=tp, + ), + KafkaCommitTask( + asyncio_task=not_done_head, # ty: ignore[invalid-argument-type] + offset=11, + consumer=mock_consumer, + topic_partition=tp, + ), + ], + } + heads: typing.Final = KafkaBatchCommitter._pending_head_tasks(pending) + assert heads == [not_done_head] + + +def test_extract_ready_prefixes_sorts_by_offset(mock_consumer: MockAIOKafkaConsumer) -> None: + """Tasks appended out of offset order are sorted before extraction. + + Re-queued-after-transient-KafkaError tasks land at the queue tail and may arrive + after newer same-partition tasks; the lazy sort tolerates that. + """ + tp: typing.Final = TopicPartition(topic="t", partition=0) + out_of_order_task: typing.Final = KafkaCommitTask( + asyncio_task=MockAsyncioTask(done=True), # ty: ignore[invalid-argument-type] + offset=5, # earlier offset, but appended after later ones consumer=mock_consumer, - topic_partition=TopicPartition(topic="t", partition=0), + topic_partition=tp, ) + later_tasks: typing.Final = [ + KafkaCommitTask( + asyncio_task=MockAsyncioTask(done=True), # ty: ignore[invalid-argument-type] + offset=offset, + consumer=mock_consumer, + topic_partition=tp, + ) + for offset in (10, 11) + ] + pending: dict[TopicPartition, list[KafkaCommitTask]] = {tp: [*later_tasks, out_of_order_task]} - await committer._messages_queue.put(commit_task) - - tasks, _ = await committer._populate_commit_batch() - assert len(tasks) == 1 + ready: typing.Final = KafkaBatchCommitter._extract_ready_prefixes(pending) + assert ready[tp][0] == out_of_order_task # sorted prefix starts at offset 5 + assert [t.offset for t in ready[tp]] == [5, 10, 11] -async def test_committer_uses_shutdown_timeout_kwarg() -> None: - committer: typing.Final = KafkaBatchCommitter(shutdown_timeout_sec=0.05) - assert committer._shutdown_timeout == 0.05 +# ---------- _commit_partitions ---------- -async def test_committer_close_logs_when_task_already_died(caplog: pytest.LogCaptureFixture) -> None: - """If the committer task crashed before close() is called, the exception is logged.""" - committer: typing.Final = KafkaBatchCommitter(commit_batch_timeout_sec=0.1, commit_batch_size=10) - async def crashing() -> typing.Never: - msg: typing.Final = "boom" - raise RuntimeError(msg) - - committer._commit_task = asyncio.create_task(crashing()) - with contextlib.suppress(RuntimeError): - await committer._commit_task +async def test_commit_partitions_calls_commit_per_partition_max( + committer: KafkaBatchCommitter, mock_consumer: MockAIOKafkaConsumer +) -> None: + """All-done tasks → one commit call with max offset per partition (next-to-fetch = max + 1).""" + expected_offset: typing.Final = 100 - await committer.close() - assert "Committer task had already died before close()" in caplog.text + tp: typing.Final = TopicPartition(topic="t1", partition=0) + tasks: typing.Final = [ + KafkaCommitTask( + asyncio_task=MockAsyncioTask(result="ok", done=True), # ty: ignore[invalid-argument-type] + offset=offset, + consumer=mock_consumer, + topic_partition=tp, + ) + for offset in (expected_offset, expected_offset + 1) + ] + for t in tasks: + await committer._messages_queue.put(t) + [await committer._messages_queue.get() for _ in tasks] + result: typing.Final = await committer._commit_partitions({tp: tasks}) -async def test_committer_close_but_timeout_error(caplog: pytest.LogCaptureFixture) -> None: - committer: typing.Final = KafkaBatchCommitter(commit_batch_timeout_sec=0.1, commit_batch_size=10000) - committer._commit_task = asyncio.create_task(asyncio.sleep(30)) - committer._shutdown_timeout = 0.1 - await committer.close() - await asyncio.sleep(0.5) - assert "Committer main task shutdown timed out, forcing cancellation" in caplog.text - assert not committer.is_healthy + assert result is True + mock_consumer.commit.assert_called_once_with({tp: expected_offset + 2}) -async def test_committer_partial_batch_failure_still_commits( +async def test_commit_partitions_logs_task_exceptions( committer: KafkaBatchCommitter, mock_consumer: MockAIOKafkaConsumer, caplog: pytest.LogCaptureFixture ) -> None: - """When one task raises, the batch still commits offsets and the error is logged.""" caplog.set_level(logging.ERROR) + failing: typing.Final = MockAsyncioTask(exception=ValueError("handler failed"), done=True) - failing_task: typing.Final = MockAsyncioTask(exception=ValueError("handler failed"), done=True) - succeeding_task: typing.Final = MockAsyncioTask(result="ok", done=True) - - commit_task1: typing.Final = batch_committer.KafkaCommitTask( - asyncio_task=failing_task, # ty: ignore[invalid-argument-type] + commit_task: typing.Final = KafkaCommitTask( + asyncio_task=failing, # ty: ignore[invalid-argument-type] offset=100, consumer=mock_consumer, - topic_partition=TopicPartition(topic="t1", partition=0), - ) - commit_task2: typing.Final = batch_committer.KafkaCommitTask( - asyncio_task=succeeding_task, # ty: ignore[invalid-argument-type] - offset=101, - consumer=mock_consumer, - topic_partition=TopicPartition(topic="t1", partition=0), + topic_partition=TopicPartition(topic="t", partition=0), ) + await committer._messages_queue.put(commit_task) + await committer._messages_queue.get() + + await committer._commit_partitions({commit_task.topic_partition: [commit_task]}) + + assert "Task has finished with an exception" in caplog.text - # Must put items in queue before calling _commit_tasks_batch so task_done() is balanced - await committer._messages_queue.put(commit_task1) - await committer._messages_queue.put(commit_task2) - batch: typing.Final = [ - await committer._messages_queue.get(), - await committer._messages_queue.get(), + +async def test_commit_partitions_partial_failure_still_commits_offset( + committer: KafkaBatchCommitter, mock_consumer: MockAIOKafkaConsumer, caplog: pytest.LogCaptureFixture +) -> None: + """One task raising still commits the partition's max offset, with the error logged.""" + caplog.set_level(logging.ERROR) + + failing: typing.Final = MockAsyncioTask(exception=ValueError("handler failed"), done=True) + succeeding: typing.Final = MockAsyncioTask(result="ok", done=True) + tp: typing.Final = TopicPartition(topic="t1", partition=0) + + tasks: typing.Final = [ + KafkaCommitTask( + asyncio_task=failing, # ty: ignore[invalid-argument-type] + offset=100, + consumer=mock_consumer, + topic_partition=tp, + ), + KafkaCommitTask( + asyncio_task=succeeding, # ty: ignore[invalid-argument-type] + offset=101, + consumer=mock_consumer, + topic_partition=tp, + ), ] + for t in tasks: + await committer._messages_queue.put(t) + [await committer._messages_queue.get() for _ in tasks] - with patch.object(committer, "_call_committer", new_callable=AsyncMock, return_value=True) as mock_commit: - await committer._commit_tasks_batch(batch) + await committer._commit_partitions({tp: tasks}) - mock_commit.assert_called_once() - call_args: typing.Final = mock_commit.call_args[0][1] - # Max offset is 101; Kafka commits next-to-fetch offset = max + 1 - assert call_args[TopicPartition(topic="t1", partition=0)] == 102 + mock_consumer.commit.assert_called_once_with({tp: 102}) assert "Task has finished with an exception" in caplog.text -async def test_committer_handles_multiple_consumers(committer: KafkaBatchCommitter) -> None: +async def test_commit_partitions_handles_multiple_consumers(committer: KafkaBatchCommitter) -> None: """Each consumer's commit is called with only its own partitions — no cross-consumer commits.""" consumer_a: typing.Final = MockAIOKafkaConsumer() consumer_b: typing.Final = MockAIOKafkaConsumer() + tp_a: typing.Final = TopicPartition(topic="topic-a", partition=0) + tp_b: typing.Final = TopicPartition(topic="topic-b", partition=0) + task_a: typing.Final = KafkaCommitTask( asyncio_task=MockAsyncioTask(result="ok"), # ty: ignore[invalid-argument-type] offset=10, consumer=consumer_a, - topic_partition=TopicPartition(topic="topic-a", partition=0), + topic_partition=tp_a, ) task_b: typing.Final = KafkaCommitTask( asyncio_task=MockAsyncioTask(result="ok"), # ty: ignore[invalid-argument-type] offset=20, consumer=consumer_b, - topic_partition=TopicPartition(topic="topic-b", partition=0), + topic_partition=tp_b, ) - for t in (task_a, task_b): await committer._messages_queue.put(t) - batch: typing.Final = [await committer._messages_queue.get(), await committer._messages_queue.get()] - - await committer._commit_tasks_batch(batch) + [await committer._messages_queue.get() for _ in range(2)] - consumer_a.commit.assert_called_once_with({TopicPartition(topic="topic-a", partition=0): 11}) - consumer_b.commit.assert_called_once_with({TopicPartition(topic="topic-b", partition=0): 21}) + await committer._commit_partitions({tp_a: [task_a], tp_b: [task_b]}) + consumer_a.commit.assert_called_once_with({tp_a: 11}) + consumer_b.commit.assert_called_once_with({tp_b: 21}) -async def test_committer_close_but_unexpected_error() -> None: - committer: typing.Final = KafkaBatchCommitter(commit_batch_timeout_sec=0.1, commit_batch_size=10) - mock_task: typing.Final = MagicMock() - mock_task.done.return_value = False - mock_task.cancelled.return_value = False - committer._commit_task = mock_task - - original_exception: typing.Final = RuntimeError("Original error") - with patch("asyncio.wait_for", side_effect=original_exception), pytest.raises(RuntimeError) as exc_info: - await committer.close() - - assert exc_info.value is original_exception - - -async def test_committer_cancelled_task_stops_offset_advance( +async def test_commit_partitions_returns_false_on_commit_failure( committer: KafkaBatchCommitter, mock_consumer: MockAIOKafkaConsumer ) -> None: - """Offset must not advance past a cancelled task — that message was never processed.""" - ok_task: typing.Final = MockAsyncioTask(result="ok") - cancelled_task: typing.Final = MockAsyncioTask(cancelled=True) - later_ok_task: typing.Final = MockAsyncioTask(result="ok") - - tp: typing.Final = TopicPartition(topic="t1", partition=0) - commit_task_ok: typing.Final = KafkaCommitTask( - asyncio_task=ok_task, # ty: ignore[invalid-argument-type] - offset=10, - consumer=mock_consumer, - topic_partition=tp, - ) - commit_task_cancelled: typing.Final = KafkaCommitTask( - asyncio_task=cancelled_task, # ty: ignore[invalid-argument-type] - offset=11, - consumer=mock_consumer, - topic_partition=tp, - ) - commit_task_later: typing.Final = KafkaCommitTask( - asyncio_task=later_ok_task, # ty: ignore[invalid-argument-type] - offset=12, + """_commit_partitions returns False when _call_committer fails.""" + task: typing.Final = MockAsyncioTask(result="ok") + tp: typing.Final = TopicPartition(topic="t", partition=0) + commit_task: typing.Final = KafkaCommitTask( + asyncio_task=task, # ty: ignore[invalid-argument-type] + offset=1, consumer=mock_consumer, topic_partition=tp, ) + await committer._messages_queue.put(commit_task) + await committer._messages_queue.get() - for t in (commit_task_ok, commit_task_cancelled, commit_task_later): - await committer._messages_queue.put(t) - batch: typing.Final = [await committer._messages_queue.get() for _ in range(3)] - - with patch.object(committer, "_call_committer", new_callable=AsyncMock, return_value=True) as mock_commit: - await committer._commit_tasks_batch(batch) + with patch.object(committer, "_call_committer", new_callable=AsyncMock, return_value=False): + result: typing.Final = await committer._commit_partitions({tp: [commit_task]}) - offsets: typing.Final = mock_commit.call_args[0][1] - # Only offset 10 is safe to commit; 11 was cancelled, 12 is beyond it - assert offsets[tp] == 11 # max safe offset (10) + 1 + assert result is False -async def test_committer_all_cancelled_tasks_skips_commit( - committer: KafkaBatchCommitter, mock_consumer: MockAIOKafkaConsumer +async def test_commit_partitions_returns_false_if_any_consumer_group_fails( + committer: KafkaBatchCommitter, ) -> None: - """If all tasks on a partition are cancelled, nothing is committed for that partition.""" - tp: typing.Final = TopicPartition(topic="t1", partition=0) - cancelled_task: typing.Final = MockAsyncioTask(cancelled=True) - commit_task: typing.Final = KafkaCommitTask( - asyncio_task=cancelled_task, # ty: ignore[invalid-argument-type] - offset=5, - consumer=mock_consumer, - topic_partition=tp, - ) + """If any consumer's commit slice fails, the overall return is False.""" + consumer_a: typing.Final = MockAIOKafkaConsumer() + consumer_b: typing.Final = MockAIOKafkaConsumer() + consumer_a.commit.side_effect = KafkaError("transient") # consumer_a fails - await committer._messages_queue.put(commit_task) - batch: typing.Final = [await committer._messages_queue.get()] + tp_a: typing.Final = TopicPartition(topic="t", partition=0) + tp_b: typing.Final = TopicPartition(topic="t", partition=1) + task_a: typing.Final = KafkaCommitTask( + asyncio_task=MockAsyncioTask(result="ok"), # ty: ignore[invalid-argument-type] + offset=1, + consumer=consumer_a, + topic_partition=tp_a, + ) + task_b: typing.Final = KafkaCommitTask( + asyncio_task=MockAsyncioTask(result="ok"), # ty: ignore[invalid-argument-type] + offset=2, + consumer=consumer_b, + topic_partition=tp_b, + ) + # Both go through the queue once so task_done() balance is maintained. + for t in (task_a, task_b): + await committer._messages_queue.put(t) + [await committer._messages_queue.get() for _ in range(2)] - with patch.object(committer, "_call_committer", new_callable=AsyncMock, return_value=True) as mock_commit: - await committer._commit_tasks_batch(batch) + result: typing.Final = await committer._commit_partitions({tp_a: [task_a], tp_b: [task_b]}) - offsets: typing.Final = mock_commit.call_args[0][1] - assert tp not in offsets + assert result is False + consumer_b.commit.assert_called_once() # b still committed independently -async def test_committer_cancelled_task_not_logged_as_error( +async def test_commit_partitions_cancelled_task_not_logged_as_error( committer: KafkaBatchCommitter, mock_consumer: MockAIOKafkaConsumer, caplog: pytest.LogCaptureFixture ) -> None: - """CancelledError from a task is expected during shutdown and must not be logged as an error.""" + """A cancelled task is not an error — it must not produce an error log line.""" caplog.set_level(logging.ERROR) - cancelled_task: typing.Final = MockAsyncioTask(cancelled=True) + cancelled: typing.Final = MockAsyncioTask(cancelled=True) + tp: typing.Final = TopicPartition(topic="t", partition=0) commit_task: typing.Final = KafkaCommitTask( - asyncio_task=cancelled_task, # ty: ignore[invalid-argument-type] + asyncio_task=cancelled, # ty: ignore[invalid-argument-type] offset=5, consumer=mock_consumer, - topic_partition=TopicPartition(topic="t1", partition=0), + topic_partition=tp, ) - await committer._messages_queue.put(commit_task) - batch: typing.Final = [await committer._messages_queue.get()] + await committer._messages_queue.get() - with patch.object(committer, "_call_committer", new_callable=AsyncMock, return_value=True): - await committer._commit_tasks_batch(batch) + await committer._commit_partitions({tp: [commit_task]}) assert "Task has finished with an exception" not in caplog.text -async def test_committer_commit_tasks_batch_returns_false_on_commit_failure( - committer: KafkaBatchCommitter, mock_consumer: MockAIOKafkaConsumer -) -> None: - """_commit_tasks_batch returns False when _call_committer fails.""" - task: typing.Final = MockAsyncioTask(result="ok") - commit_task: typing.Final = KafkaCommitTask( - asyncio_task=task, # ty: ignore[invalid-argument-type] - offset=1, - consumer=mock_consumer, - topic_partition=TopicPartition(topic="t", partition=0), - ) - await committer._messages_queue.put(commit_task) - batch: typing.Final = [await committer._messages_queue.get()] +# ---------- streaming loop end-to-end ---------- - with patch.object(committer, "_call_committer", new_callable=AsyncMock, return_value=False): - result: typing.Final = await committer._commit_tasks_batch(batch) - assert result is False +async def _drive_until(predicate: typing.Callable[[], bool], deadline_sec: float = 1.0, poll: float = 0.01) -> None: + """Yield to the event loop until ``predicate()`` returns True or ``deadline_sec`` elapses.""" + deadline: typing.Final = asyncio.get_event_loop().time() + deadline_sec + while asyncio.get_event_loop().time() < deadline: + if predicate(): + return + await asyncio.sleep(poll) + msg: typing.Final = "predicate did not become true in time" # pragma: no cover + raise AssertionError(msg) # pragma: no cover -async def test_committer_pipelines_fast_partition_ahead_of_slow_partition( - committer: KafkaBatchCommitter, mock_consumer: MockAIOKafkaConsumer -) -> None: - """A slow task on one partition must not block commits for already-done tasks on another.""" - fast_done_task: typing.Final = MockAsyncioTask(result="fast", done=True) - slow_event: typing.Final = asyncio.Event() +async def test_streaming_commits_when_batch_size_reached() -> None: + """When total pending crosses commit_batch_size, the loop commits the contiguous-done prefix.""" + consumer: typing.Final = MockAIOKafkaConsumer() + committer: typing.Final = KafkaBatchCommitter(commit_batch_timeout_sec=10.0, commit_batch_size=3) + committer.spawn() - async def slow_handler() -> str: - await slow_event.wait() - return "slow" + async def quick() -> None: + return None - slow_task: typing.Final = asyncio.create_task(slow_handler()) + tp: typing.Final = TopicPartition(topic="t", partition=0) + tasks: typing.Final = [asyncio.create_task(quick()) for _ in range(3)] + for index, task in enumerate(tasks): + await committer.send_task( + KafkaCommitTask( + asyncio_task=task, + offset=10 + index, + consumer=consumer, + topic_partition=tp, + ) + ) + + await _drive_until(lambda: consumer.commit.called) + consumer.commit.assert_called_once_with({tp: 13}) # offsets 10,11,12 → next-to-fetch 13 + await committer.close() - fast_partition: typing.Final = TopicPartition(topic="t", partition=0) - slow_partition: typing.Final = TopicPartition(topic="t", partition=1) - fast_commit_task: typing.Final = KafkaCommitTask( - asyncio_task=fast_done_task, # ty: ignore[invalid-argument-type] - offset=10, - consumer=mock_consumer, - topic_partition=fast_partition, +async def test_streaming_commits_on_timeout() -> None: + """When timeout fires before batch_size, the loop commits whatever's ready.""" + consumer: typing.Final = MockAIOKafkaConsumer() + committer: typing.Final = KafkaBatchCommitter(commit_batch_timeout_sec=0.05, commit_batch_size=100) + committer.spawn() + + async def quick() -> None: + return None + + tp: typing.Final = TopicPartition(topic="t", partition=0) + real_task: typing.Final = asyncio.create_task(quick()) + await committer.send_task( + KafkaCommitTask( + asyncio_task=real_task, + offset=42, + consumer=consumer, + topic_partition=tp, + ) ) - slow_commit_task: typing.Final = KafkaCommitTask( - asyncio_task=slow_task, - offset=20, - consumer=mock_consumer, - topic_partition=slow_partition, + + await _drive_until(lambda: consumer.commit.called, deadline_sec=2.0) + consumer.commit.assert_called_once_with({tp: 43}) + await committer.close() + + +async def test_streaming_commits_on_flush_event_without_stop() -> None: + """commit_all() must flush without shutting down the committer (for rebalance use).""" + consumer: typing.Final = MockAIOKafkaConsumer() + committer: typing.Final = KafkaBatchCommitter(commit_batch_timeout_sec=10.0, commit_batch_size=100) + committer.spawn() + + async def quick() -> None: + return None + + tp: typing.Final = TopicPartition(topic="t", partition=0) + real_task: typing.Final = asyncio.create_task(quick()) + await committer.send_task( + KafkaCommitTask( + asyncio_task=real_task, + offset=1, + consumer=consumer, + topic_partition=tp, + ) ) - for t in (fast_commit_task, slow_commit_task): - await committer._messages_queue.put(t) - batch: typing.Final = [await committer._messages_queue.get() for _ in range(2)] + await committer.commit_all() - commit_calls: list[dict[TopicPartition, int]] = [] + consumer.commit.assert_called_once_with({tp: 2}) + assert committer.is_healthy # still running after flush - async def record_commit(_: list[KafkaCommitTask], offsets: dict[TopicPartition, int]) -> bool: - commit_calls.append(dict(offsets)) - return True + await committer.close() - with patch.object(committer, "_call_committer", side_effect=record_commit): - commit_coro: typing.Final = asyncio.create_task(committer._commit_tasks_batch(batch)) - # Give the committer a chance to commit the fast partition before the slow task finishes. - await asyncio.sleep(0.05) - # First commit should already cover the fast partition only. - assert commit_calls == [{fast_partition: 11}] - slow_event.set() - await commit_coro - assert commit_calls == [{fast_partition: 11}, {slow_partition: 21}] +async def test_streaming_commits_on_close_flush() -> None: + """close() sets _stop_requested → flush triggers commit + shutdown.""" + consumer: typing.Final = MockAIOKafkaConsumer() + committer: typing.Final = KafkaBatchCommitter(commit_batch_timeout_sec=10.0, commit_batch_size=100) + committer.spawn() + async def quick() -> None: + return None -async def test_committer_pipelines_within_partition_advances_with_prefix( - committer: KafkaBatchCommitter, mock_consumer: MockAIOKafkaConsumer -) -> None: - """Within one partition, the contiguous-done prefix commits without waiting for the rest.""" - middle_event: typing.Final = asyncio.Event() - last_event: typing.Final = asyncio.Event() + tp: typing.Final = TopicPartition(topic="t", partition=0) + await committer.send_task( + KafkaCommitTask( + asyncio_task=asyncio.create_task(quick()), + offset=99, + consumer=consumer, + topic_partition=tp, + ) + ) + + await committer.close() + + consumer.commit.assert_called_once_with({tp: 100}) + assert not committer.is_healthy - async def gated(event: asyncio.Event) -> None: - await event.wait() - middle_task: typing.Final = asyncio.create_task(gated(middle_event)) - last_task: typing.Final = asyncio.create_task(gated(last_event)) +async def test_streaming_clears_flush_event_after_commit_all() -> None: + """After commit_all(), the flush event must be cleared so subsequent triggers work.""" + consumer: typing.Final = MockAIOKafkaConsumer() + committer: typing.Final = KafkaBatchCommitter(commit_batch_timeout_sec=10.0, commit_batch_size=100) + committer.spawn() + + async def quick() -> None: + return None tp: typing.Final = TopicPartition(topic="t", partition=0) - first: typing.Final = KafkaCommitTask( - asyncio_task=MockAsyncioTask(result="ok", done=True), # ty: ignore[invalid-argument-type] - offset=10, - consumer=mock_consumer, - topic_partition=tp, + await committer.send_task( + KafkaCommitTask( + asyncio_task=asyncio.create_task(quick()), + offset=1, + consumer=consumer, + topic_partition=tp, + ) ) - middle: typing.Final = KafkaCommitTask( - asyncio_task=middle_task, - offset=11, - consumer=mock_consumer, - topic_partition=tp, + await committer.commit_all() + assert not committer._flush_batch_event.is_set() + await committer.close() + + +async def test_streaming_handles_cancelled_error_in_loop(committer: KafkaBatchCommitter) -> None: + """Re-raise CancelledError from the loop without leaking pending tasks.""" + committer.spawn() + await asyncio.sleep(0.02) # let loop reach asyncio.wait + + assert committer._commit_task is not None + committer._commit_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await committer._commit_task + + assert not committer.is_healthy + + +async def test_streaming_cancel_with_active_timeout() -> None: + """Cancelling the loop while a timeout_task is ticking cleans up via finally.""" + consumer: typing.Final = MockAIOKafkaConsumer() + committer: typing.Final = KafkaBatchCommitter(commit_batch_timeout_sec=10.0, commit_batch_size=100) + committer.spawn() + + async def gated() -> None: + await asyncio.sleep(30) + + real_task: typing.Final = asyncio.create_task(gated()) + await committer.send_task( + KafkaCommitTask( + asyncio_task=real_task, + offset=10, + consumer=consumer, + topic_partition=TopicPartition(topic="t", partition=0), + ) ) - last: typing.Final = KafkaCommitTask( - asyncio_task=last_task, - offset=12, - consumer=mock_consumer, - topic_partition=tp, + await asyncio.sleep(0.02) # let the loop absorb and start timeout_task + + assert committer._commit_task is not None + committer._commit_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await committer._commit_task + + assert not committer.is_healthy + real_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await real_task + + +async def test_streaming_skips_cancelled_head_in_wait_set() -> None: + """A cancelled task at the head of pending is not added to the wait set (avoids busy-loop).""" + consumer: typing.Final = MockAIOKafkaConsumer() + committer: typing.Final = KafkaBatchCommitter(commit_batch_timeout_sec=0.05, commit_batch_size=100) + committer.spawn() + + cancelled_task: typing.Final = asyncio.get_event_loop().create_future() + cancelled_task.cancel() + + tp: typing.Final = TopicPartition(topic="t", partition=0) + await committer.send_task( + KafkaCommitTask( + asyncio_task=cancelled_task, # ty: ignore[invalid-argument-type] + offset=10, + consumer=consumer, + topic_partition=tp, + ) ) - for t in (first, middle, last): - await committer._messages_queue.put(t) - batch: typing.Final = [await committer._messages_queue.get() for _ in range(3)] + # Timeout fires → loop's _pending_head_tasks sees the cancelled head and skips it; the + # subsequent commit drops the cancelled task without advancing offsets. + await committer.close() + # No commit was issued for this partition — _map_offsets_per_partition produced an empty + # map (cancelled task), and _call_committer returns early when the offsets dict is empty. + consumer.commit.assert_not_called() - commit_calls: list[dict[TopicPartition, int]] = [] - async def record_commit(_: list[KafkaCommitTask], offsets: dict[TopicPartition, int]) -> bool: - commit_calls.append(dict(offsets)) - return True +async def test_streaming_idle_loop_blocks_without_busy_spin(committer: KafkaBatchCommitter) -> None: + """An idle committer with no pending must not busy-loop on already-done heads.""" + committer.spawn() + await asyncio.sleep(0.01) - with patch.object(committer, "_call_committer", side_effect=record_commit): - commit_coro: typing.Final = asyncio.create_task(committer._commit_tasks_batch(batch)) - await asyncio.sleep(0.05) - # Only the first task is done; the other two are blocked behind offset=11. - assert commit_calls == [{tp: 11}] - # Release offset=11 — now offset=11 plus offset=12 (still blocked) advance to 12; - # offset=12 still pending, so commit covers up to 11+1=12 only. - middle_event.set() - await asyncio.sleep(0.05) - assert commit_calls == [{tp: 11}, {tp: 12}] - # Finally release offset=12. - last_event.set() - await commit_coro + # No task work, no flush, no shutdown — the loop should be blocked on queue.get / flush. + assert committer.is_healthy - assert commit_calls == [{tp: 11}, {tp: 12}, {tp: 13}] +# ---------- pipelining (the streaming win) ---------- -async def test_committer_pipelines_returns_false_if_any_slice_fails( - committer: KafkaBatchCommitter, mock_consumer: MockAIOKafkaConsumer -) -> None: - """If any commit slice fails, _commit_tasks_batch returns False even if later slices succeed.""" - blocking_event: typing.Final = asyncio.Event() - async def gated() -> None: - await blocking_event.wait() +async def test_committer_absorbs_queue_during_slow_handler() -> None: + """A slow task on one partition must not stall queue absorption for other partitions. - fast_task: typing.Final = MockAsyncioTask(result="ok", done=True) - slow_task: typing.Final = asyncio.create_task(gated()) + Today this is the across-batch fix: with the old populate-then-commit loop the queue + would grow while one batch waits on a slow handler. + """ + consumer: typing.Final = MockAIOKafkaConsumer() + committer: typing.Final = KafkaBatchCommitter(commit_batch_timeout_sec=10.0, commit_batch_size=2) + committer.spawn() - tp_fast: typing.Final = TopicPartition(topic="t", partition=0) - tp_slow: typing.Final = TopicPartition(topic="t", partition=1) + slow_event: typing.Final = asyncio.Event() - batch: typing.Final = [ - KafkaCommitTask( - asyncio_task=fast_task, # ty: ignore[invalid-argument-type] - offset=1, - consumer=mock_consumer, - topic_partition=tp_fast, - ), + async def slow_handler() -> None: + await slow_event.wait() + + async def fast_handler() -> None: + return None + + slow_partition: typing.Final = TopicPartition(topic="t", partition=0) + fast_partition: typing.Final = TopicPartition(topic="t", partition=1) + + # 1) Send the slow task first — it goes into pending and stays not-done. + slow_task: typing.Final = asyncio.create_task(slow_handler()) + await committer.send_task( KafkaCommitTask( asyncio_task=slow_task, - offset=2, - consumer=mock_consumer, - topic_partition=tp_slow, - ), - ] - for t in batch: - await committer._messages_queue.put(t) - drained: typing.Final = [await committer._messages_queue.get() for _ in range(2)] + offset=1000, + consumer=consumer, + topic_partition=slow_partition, + ) + ) - call_results: typing.Final = iter([False, True]) + # 2) Send several fast tasks for a different partition. The streaming committer must + # absorb these into pending despite the slow task being in flight, then commit the + # fast partition's prefix once batch_size is reached. + fast_tasks: typing.Final[list[asyncio.Task[None]]] = [asyncio.create_task(fast_handler()) for _ in range(2)] + for index, t in enumerate(fast_tasks): + await committer.send_task( + KafkaCommitTask( + asyncio_task=t, + offset=10 + index, + consumer=consumer, + topic_partition=fast_partition, + ) + ) - async def fail_then_ok(_: list[KafkaCommitTask], __: dict[TopicPartition, int]) -> bool: - return next(call_results) + await _drive_until(lambda: consumer.commit.called, deadline_sec=1.0) + # The fast partition committed without waiting on the slow one. + consumer.commit.assert_called_with({fast_partition: 12}) - with patch.object(committer, "_call_committer", side_effect=fail_then_ok): - commit_coro: typing.Final = asyncio.create_task(committer._commit_tasks_batch(drained)) - await asyncio.sleep(0.05) - blocking_event.set() - result: typing.Final = await commit_coro + slow_event.set() + await committer.close() + # After close, the slow partition also commits. + assert consumer.commit.call_args_list[-1].args[0] == {slow_partition: 1001} - assert result is False +async def test_committer_streaming_drains_on_close() -> None: + """close() must commit every already-done task before exiting. -async def test_committer_commit_all_does_not_stop_loop(committer: KafkaBatchCommitter) -> None: - """commit_all() must flush without shutting down the committer (for rebalance use).""" + In-flight tasks are dropped (asyncio.wait_for inside close() bounds the wait + via _shutdown_timeout). + """ + consumer: typing.Final = MockAIOKafkaConsumer() + committer: typing.Final = KafkaBatchCommitter(commit_batch_timeout_sec=10.0, commit_batch_size=100) committer.spawn() - async def noop() -> None: - pass + async def quick() -> None: + return None - real_task: typing.Final = asyncio.create_task(noop()) - commit_task: typing.Final = KafkaCommitTask( - asyncio_task=real_task, - offset=1, - consumer=MockAIOKafkaConsumer(), - topic_partition=TopicPartition(topic="t", partition=0), - ) + tp: typing.Final = TopicPartition(topic="t", partition=0) + done_tasks: typing.Final = [asyncio.create_task(quick()) for _ in range(5)] + for index, t in enumerate(done_tasks): + await committer.send_task( + KafkaCommitTask( + asyncio_task=t, + offset=100 + index, + consumer=consumer, + topic_partition=tp, + ) + ) - with patch.object(committer, "_commit_tasks_batch", new_callable=AsyncMock) as mock_commit: + await committer.close() + # All five commits collapsed into one call (max offset 104 → next-to-fetch 105). + consumer.commit.assert_called_once_with({tp: 105}) - async def side_effect(batch: list[KafkaCommitTask]) -> bool: - for _ in batch: - committer._messages_queue.task_done() - return True - mock_commit.side_effect = side_effect - await committer.send_task(commit_task) - await committer.commit_all() +async def test_committer_streaming_handles_requeue_offset_order() -> None: + """Lazy offset sort tolerates re-queued tasks landing after higher-offset arrivals. - assert committer.is_healthy # still running after flush + Transient KafkaError re-queues a batch; meanwhile new same-partition tasks arrive + with higher offsets. The final commit must reflect the correct max offset. + """ + consumer: typing.Final = MockAIOKafkaConsumer() + # First commit attempt: transient KafkaError → re-queue. Second attempt: succeeds. + consumer.commit.side_effect = [KafkaError("transient"), None, None] + + committer: typing.Final = KafkaBatchCommitter(commit_batch_timeout_sec=10.0, commit_batch_size=2) + committer.spawn() + + async def quick() -> None: + return None + + tp: typing.Final = TopicPartition(topic="t", partition=0) + + # Send two tasks that trigger the (failing) first commit. + early_tasks: typing.Final[list[asyncio.Task[None]]] = [asyncio.create_task(quick()) for _ in range(2)] + for index, t in enumerate(early_tasks): + await committer.send_task( + KafkaCommitTask( + asyncio_task=t, + offset=100 + index, + consumer=consumer, + topic_partition=tp, + ) + ) + + # Wait until the failing commit attempt has occurred and the batch was re-queued. + await _drive_until(lambda: consumer.commit.call_count >= 1, deadline_sec=1.0) + + # Now send a task with a higher offset BEFORE the re-queued tasks land back in pending. + late_task: typing.Final = asyncio.create_task(quick()) + await committer.send_task( + KafkaCommitTask( + asyncio_task=late_task, + offset=200, # much higher than the re-queued 100/101 + consumer=consumer, + topic_partition=tp, + ) + ) + + await committer.close() + + # The final commit must reflect the max processed offset (200 → next-to-fetch 201) + # despite the requeued 100/101 arriving after offset 200 in queue order. + final_call: typing.Final = consumer.commit.call_args_list[-1] + assert final_call.args[0] == {tp: 201}