From 0616ef6dc28ad376c053d40139e8972fe762f419 Mon Sep 17 00:00:00 2001 From: Artur Shiriev Date: Sun, 3 May 2026 23:04:16 +0300 Subject: [PATCH] add cancellation watermark to close at-least-once gap on cancelled tasks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Today _extract_ready_prefixes drops a cancelled task and everything after from per-partition pending state, and _map_offsets_per_partition stops the offset advance at the cancellation. That keeps cancelled-and-after offsets redeliverable on restart — but only because cancellation never actually occurs mid-stream today (it's gated by shutdown, after which no new tasks are absorbed). If a future change ever allowed mid-stream cancellation, a new task arriving for the same partition with a higher offset would slip past the boundary: the cancelled task is gone from pending, _map_offsets_per_partition has no memory of the cancellation, and the new task's offset would be committed, silently skipping the cancelled-and-after window. This change adds a per-partition cancellation watermark on the committer. When _map_offsets_per_partition sees a cancelled task at offset N for partition P, it records watermarks[P] = N (keeping the earliest if multiple batches see cancellations). On every subsequent batch the watermark blocks that partition from advancing — the partition's pending still drains for task_done() balance, but no commit is issued for it until the watermark is cleared. The rebalance listener clears the watermark for revoked partitions after commit_all() runs, so the next assignment starts fresh. Trace: tasks 9 (✓), 10 (✗), 11 (✓), 12 (✓) all in pending → first commit produces {tp: 10}, sets wm[tp] = 10. Second batch sees task 13 (✓), but 13+1 > 10, so {tp} is dropped. On restart, fetch from 10 — re-process 10, 11, 12, 13. At-least-once preserved. Co-Authored-By: Claude Opus 4.7 --- .../batch_committer.py | 50 +++++++- faststream_concurrent_aiokafka/rebalance.py | 11 +- tests/mocks.py | 1 + tests/test_kafka_committer.py | 108 +++++++++++++++++- tests/test_rebalance.py | 38 ++++++ 5 files changed, 198 insertions(+), 10 deletions(-) diff --git a/faststream_concurrent_aiokafka/batch_committer.py b/faststream_concurrent_aiokafka/batch_committer.py index 538ce24..8846c5a 100644 --- a/faststream_concurrent_aiokafka/batch_committer.py +++ b/faststream_concurrent_aiokafka/batch_committer.py @@ -84,6 +84,11 @@ def __init__( self._commit_batch_timeout_sec = commit_batch_timeout_sec self._commit_batch_size = commit_batch_size self._shutdown_timeout = shutdown_timeout_sec + # Per-partition floor for the smallest cancelled offset seen since the partition was + # last assigned. Once set, the committer will not advance Kafka's committed offset for + # that partition until clear_cancellation_watermarks() is called on rebalance — so the + # cancelled-and-after offsets get redelivered on restart (at-least-once). + self._cancellation_watermarks: dict[TopicPartition, int] = {} def _on_user_task_done(self, _task: asyncio.Future[typing.Any]) -> None: """Done-callback target for user tasks; wakes the streaming loop.""" @@ -121,7 +126,13 @@ async def _call_committer( return True @staticmethod - def _map_offsets_per_partition(consumer_tasks: list[KafkaCommitTask]) -> dict[TopicPartition, int]: + def _map_offsets_per_partition( + consumer_tasks: list[KafkaCommitTask], + watermarks: dict[TopicPartition, int], + ) -> dict[TopicPartition, int]: + # `watermarks` is mutated: any cancelled task seen here records (or lowers) the + # partition's watermark. Subsequent batches will see the watermark and skip + # advancing past it. Caller (the committer) is the watermark dict's owner. by_partition: dict[TopicPartition, list[KafkaCommitTask]] = {} for task in consumer_tasks: by_partition.setdefault(task.topic_partition, []).append(task) @@ -131,11 +142,23 @@ def _map_offsets_per_partition(consumer_tasks: list[KafkaCommitTask]) -> dict[To max_offset: int | None = None for task in sorted(tasks, key=_OFFSET_KEY): if task.asyncio_task.cancelled(): - break # stop committing at first cancelled task — message was not processed + # Earliest cancelled wins: a later batch may not see the earlier + # cancellation, so without min() we could forget it and accidentally + # advance past the boundary. + existing = watermarks.get(partition) + if existing is None or task.offset < existing: + watermarks[partition] = task.offset + break max_offset = task.offset - if max_offset is not None: - # Kafka commits the *next* offset to fetch, so committed = processed_max + 1 - partitions_to_offsets[partition] = max_offset + 1 + if max_offset is None: + continue + wm = watermarks.get(partition) + if wm is not None and (max_offset + 1) > wm: + # Advancing would jump past the cancelled boundary — skip this partition + # until the watermark is cleared on rebalance. + continue + # Kafka commits the *next* offset to fetch, so committed = processed_max + 1 + partitions_to_offsets[partition] = max_offset + 1 return partitions_to_offsets @staticmethod @@ -186,7 +209,10 @@ async def _commit_partitions(self, ready: dict[TopicPartition, list[KafkaCommitT consumers_tasks.setdefault(id(task.consumer), []).append(task) results: typing.Final = await asyncio.gather( - *(self._call_committer(ct, self._map_offsets_per_partition(ct)) for ct in consumers_tasks.values()) + *( + self._call_committer(ct, self._map_offsets_per_partition(ct, self._cancellation_watermarks)) + for ct in consumers_tasks.values() + ) ) for _ in flat: @@ -308,6 +334,18 @@ async def commit_all(self) -> None: self._flush_batch_event.set() await self._messages_queue.join() + def clear_cancellation_watermarks(self, partitions: typing.Iterable[TopicPartition] | None = None) -> None: + """Forget cancellation watermarks for ``partitions`` (or all if ``None``). + + Called on partition revocation by the rebalance listener — the partition's + next assignment starts fresh, with no inherited "do not advance" floor. + """ + if partitions is None: + self._cancellation_watermarks.clear() + return + for partition in partitions: + self._cancellation_watermarks.pop(partition, None) + async def send_task(self, new_task: KafkaCommitTask) -> None: self._check_is_commit_task_running() await self._messages_queue.put(new_task) diff --git a/faststream_concurrent_aiokafka/rebalance.py b/faststream_concurrent_aiokafka/rebalance.py index ff86792..b5d2290 100644 --- a/faststream_concurrent_aiokafka/rebalance.py +++ b/faststream_concurrent_aiokafka/rebalance.py @@ -1,8 +1,14 @@ +import typing + from aiokafka import ConsumerRebalanceListener as BaseConsumerRebalanceListener from faststream_concurrent_aiokafka.batch_committer import KafkaBatchCommitter +if typing.TYPE_CHECKING: + from faststream.kafka import TopicPartition + + class ConsumerRebalanceListener(BaseConsumerRebalanceListener): # type: ignore[misc] """Commits all pending offsets when Kafka revokes partitions during rebalance. @@ -32,5 +38,8 @@ def __init__(self, committer: KafkaBatchCommitter) -> None: async def on_partitions_assigned(self, _assigned: object) -> None: # ty: ignore[invalid-method-override] pass - async def on_partitions_revoked(self, _revoked: object) -> None: # ty: ignore[invalid-method-override] + async def on_partitions_revoked(self, revoked: object) -> None: await self._committer.commit_all() + # The revoked partitions' next assignment (possibly to another consumer) starts + # fresh, so the cancellation floor — if any was set — must not carry over. + self._committer.clear_cancellation_watermarks(typing.cast("typing.Iterable[TopicPartition]", revoked)) diff --git a/tests/mocks.py b/tests/mocks.py index 0251d93..b6e6ea0 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -57,6 +57,7 @@ def __init__(self, *_args: object, **_kwargs: object) -> None: self.close = AsyncMock() self.spawn = Mock() self.commit_all = AsyncMock() + self.clear_cancellation_watermarks = Mock() self.notify_task_completed = Mock() self._healthy = True diff --git a/tests/test_kafka_committer.py b/tests/test_kafka_committer.py index b0ebccf..9890db0 100644 --- a/tests/test_kafka_committer.py +++ b/tests/test_kafka_committer.py @@ -300,7 +300,7 @@ def test_committer_map_offsets_skips_cancelled_tasks(mock_consumer: MockAIOKafka ), ] - offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(tasks) + offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(tasks, {}) # Only offset 10 is safe to commit; 11 was cancelled, 12 is beyond it. assert offsets[tp] == 11 # max safe offset (10) + 1 @@ -316,10 +316,112 @@ def test_committer_map_offsets_skips_partition_when_all_cancelled(mock_consumer: topic_partition=tp, ) - offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition([task]) + offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition([task], {}) assert tp not in offsets +# ---------- cancellation watermark ---------- + + +def test_map_offsets_records_cancellation_watermark(mock_consumer: MockAIOKafkaConsumer) -> None: + """A cancelled task records its offset as the partition's watermark.""" + tp: typing.Final = TopicPartition(topic="t", partition=0) + cancelled_offset: typing.Final = 11 + task: typing.Final = KafkaCommitTask( + asyncio_task=MockAsyncioTask(cancelled=True), # ty: ignore[invalid-argument-type] + offset=cancelled_offset, + consumer=mock_consumer, + topic_partition=tp, + ) + + watermarks: dict[TopicPartition, int] = {} + KafkaBatchCommitter._map_offsets_per_partition([task], watermarks) + + assert watermarks == {tp: cancelled_offset} + + +def test_map_offsets_blocks_partition_when_watermark_present(mock_consumer: MockAIOKafkaConsumer) -> None: + """A successful task whose offset would advance past the watermark is dropped.""" + tp: typing.Final = TopicPartition(topic="t", partition=0) + new_task: typing.Final = KafkaCommitTask( + asyncio_task=MockAsyncioTask(done=True), # ty: ignore[invalid-argument-type] + offset=20, + consumer=mock_consumer, + topic_partition=tp, + ) + + watermarks: dict[TopicPartition, int] = {tp: 11} + offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition([new_task], watermarks) + + assert tp not in offsets + assert watermarks == {tp: 11} # unchanged + + +def test_map_offsets_keeps_earliest_watermark(mock_consumer: MockAIOKafkaConsumer) -> None: + """When a partition sees a second cancellation at a higher offset, the earlier wins.""" + tp: typing.Final = TopicPartition(topic="t", partition=0) + later_cancelled: typing.Final = KafkaCommitTask( + asyncio_task=MockAsyncioTask(cancelled=True), # ty: ignore[invalid-argument-type] + offset=50, + consumer=mock_consumer, + topic_partition=tp, + ) + + watermarks: dict[TopicPartition, int] = {tp: 11} + KafkaBatchCommitter._map_offsets_per_partition([later_cancelled], watermarks) + + assert watermarks == {tp: 11} + + +def test_map_offsets_commits_max_before_cancellation_records_watermark( + mock_consumer: MockAIOKafkaConsumer, +) -> None: + """The pre-cancellation max is still committed in the same batch the watermark is recorded.""" + tp: typing.Final = TopicPartition(topic="t", partition=0) + successful_offset: typing.Final = 9 + cancelled_offset: typing.Final = 10 + tasks: typing.Final = [ + KafkaCommitTask( + asyncio_task=MockAsyncioTask(done=True), # ty: ignore[invalid-argument-type] + offset=successful_offset, + consumer=mock_consumer, + topic_partition=tp, + ), + KafkaCommitTask( + asyncio_task=MockAsyncioTask(cancelled=True), # ty: ignore[invalid-argument-type] + offset=cancelled_offset, + consumer=mock_consumer, + topic_partition=tp, + ), + ] + + watermarks: dict[TopicPartition, int] = {} + offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(tasks, watermarks) + + assert offsets == {tp: successful_offset + 1} + assert watermarks == {tp: cancelled_offset} + + +def test_clear_cancellation_watermarks_specific_partitions(committer: KafkaBatchCommitter) -> None: + tp_a: typing.Final = TopicPartition(topic="t", partition=0) + tp_b: typing.Final = TopicPartition(topic="t", partition=1) + committer._cancellation_watermarks[tp_a] = 5 + committer._cancellation_watermarks[tp_b] = 7 + + committer.clear_cancellation_watermarks([tp_a]) + + assert committer._cancellation_watermarks == {tp_b: 7} + + +def test_clear_cancellation_watermarks_all_when_none(committer: KafkaBatchCommitter) -> None: + committer._cancellation_watermarks[TopicPartition(topic="t", partition=0)] = 5 + committer._cancellation_watermarks[TopicPartition(topic="t", partition=1)] = 7 + + committer.clear_cancellation_watermarks() + + assert committer._cancellation_watermarks == {} + + def test_committer_map_offsets_advances_to_max_per_partition(mock_consumer: MockAIOKafkaConsumer) -> None: """Offset advances to (max processed + 1) per partition.""" first_offset: typing.Final = 100 @@ -348,7 +450,7 @@ def test_committer_map_offsets_advances_to_max_per_partition(mock_consumer: Mock ) ) - offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(tasks) + offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(tasks, {}) assert offsets[TopicPartition(topic="t1", partition=0)] == first_offset + 10 + 1 assert offsets[TopicPartition(topic="t1", partition=partition)] == second_offset + 1 diff --git a/tests/test_rebalance.py b/tests/test_rebalance.py index 81f1dd6..f7db3fa 100644 --- a/tests/test_rebalance.py +++ b/tests/test_rebalance.py @@ -2,6 +2,7 @@ from unittest.mock import AsyncMock import pytest +from faststream.kafka import TopicPartition from faststream_concurrent_aiokafka.rebalance import ConsumerRebalanceListener from tests.mocks import MockKafkaBatchCommitter @@ -43,3 +44,40 @@ async def track_commit() -> None: await listener.on_partitions_revoked(set()) assert flush_done, "commit_all was not awaited before returning" + + +async def test_rebalance_on_partitions_revoked_clears_watermarks( + listener: ConsumerRebalanceListener, committer: MockKafkaBatchCommitter +) -> None: + """On revoke, the cancelled-offset watermarks for the revoked partitions must be cleared. + + The next assignment of those partitions starts fresh. + """ + revoked: typing.Final = {TopicPartition(topic="t", partition=0), TopicPartition(topic="t", partition=1)} + + await listener.on_partitions_revoked(revoked) + + committer.clear_cancellation_watermarks.assert_called_once_with(revoked) + + +async def test_rebalance_clear_runs_after_commit_all(committer: MockKafkaBatchCommitter) -> None: + """clear_cancellation_watermarks must run after commit_all. + + Committing relies on the watermark to know which partitions to skip, so clearing first + would let an outgoing consumer commit past a cancelled boundary. + """ + order: typing.Final[list[str]] = [] + + async def track_commit_all() -> None: + order.append("commit_all") + + def track_clear(_partitions: object) -> None: + order.append("clear") + + committer.commit_all = AsyncMock(side_effect=track_commit_all) + committer.clear_cancellation_watermarks = track_clear # ty: ignore[invalid-assignment] + listener: typing.Final = ConsumerRebalanceListener(committer) # ty: ignore[invalid-argument-type] + + await listener.on_partitions_revoked(set()) + + assert order == ["commit_all", "clear"]