From 1e47e1067bf6e14e77a1bb23af83abe157674566 Mon Sep 17 00:00:00 2001 From: Artur Shiriev Date: Mon, 4 May 2026 08:49:07 +0300 Subject: [PATCH] scope cancellation watermarks per consumer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The watermark dict was keyed by TopicPartition only. When a single KafkaConcurrentHandler is shared across subscribers in different consumer groups subscribing to the same topic, a cancelled task on one group blocks commits on the other for the same partition. Re-key _cancellation_watermarks as dict[(id(consumer), TopicPartition)] and track a _partition_owner: dict[TopicPartition, int] inside the streaming loop so clear_cancellation_watermarks(partitions) can resolve which consumer's entry to drop on rebalance — listener API stays the same. Adds a focused regression test that fails under the old keying. Co-Authored-By: Claude Opus 4.7 --- .../batch_committer.py | 50 +++++++--- tests/test_kafka_committer.py | 95 ++++++++++++++----- 2 files changed, 108 insertions(+), 37 deletions(-) diff --git a/faststream_concurrent_aiokafka/batch_committer.py b/faststream_concurrent_aiokafka/batch_committer.py index 094d255..100c08f 100644 --- a/faststream_concurrent_aiokafka/batch_committer.py +++ b/faststream_concurrent_aiokafka/batch_committer.py @@ -84,11 +84,18 @@ def __init__( self._commit_batch_timeout_sec = commit_batch_timeout_sec self._commit_batch_size = commit_batch_size self._shutdown_timeout = shutdown_timeout_sec - # Per-partition floor for the smallest cancelled offset seen since the partition was - # last assigned. Once set, the committer will not advance Kafka's committed offset for - # that partition until clear_cancellation_watermarks() is called on rebalance — so the - # cancelled-and-after offsets get redelivered on restart (at-least-once). - self._cancellation_watermarks: dict[TopicPartition, int] = {} + # Per-(consumer, partition) floor for the smallest cancelled offset seen since the + # partition was last assigned to that consumer. Scoping by id(consumer) prevents a + # cancelled task on one consumer group from blocking commits on another group that + # happens to subscribe to the same (topic, partition). Once set, the committer will + # not advance Kafka's committed offset for that consumer/partition until + # clear_cancellation_watermarks() is called on rebalance — so the cancelled-and-after + # offsets get redelivered on restart (at-least-once). + self._cancellation_watermarks: dict[tuple[int, TopicPartition], int] = {} + # Most-recent consumer (by id()) that absorbed a task for each partition. Lets + # clear_cancellation_watermarks(partitions) resolve which consumer's watermark to + # drop on rebalance without changing the listener's API. + self._partition_owner: dict[TopicPartition, int] = {} def _on_user_task_done(self, _task: asyncio.Future[typing.Any]) -> None: """Done-callback target for user tasks; wakes the streaming loop.""" @@ -129,32 +136,35 @@ async def _call_committer( @staticmethod def _map_offsets_per_partition( + consumer_id: int, consumer_tasks: list[KafkaCommitTask], - watermarks: dict[TopicPartition, int], + watermarks: dict[tuple[int, TopicPartition], int], ) -> dict[TopicPartition, int]: # `watermarks` is mutated: any cancelled task seen here records (or lowers) the - # partition's watermark. Subsequent batches will see the watermark and skip - # advancing past it. Caller (the committer) is the watermark dict's owner. + # (consumer, partition) watermark. Subsequent batches for the same consumer will see + # it and skip advancing past it. Other consumers (different group, same partition) + # have their own keys and are unaffected. Caller (the committer) owns the dict. by_partition: dict[TopicPartition, list[KafkaCommitTask]] = {} for task in consumer_tasks: by_partition.setdefault(task.topic_partition, []).append(task) partitions_to_offsets: dict[TopicPartition, int] = {} for partition, tasks in by_partition.items(): + wm_key: tuple[int, TopicPartition] = (consumer_id, partition) max_offset: int | None = None for task in sorted(tasks, key=_OFFSET_KEY): if task.asyncio_task.cancelled(): # Earliest cancelled wins: a later batch may not see the earlier # cancellation, so without min() we could forget it and accidentally # advance past the boundary. - existing = watermarks.get(partition) + existing = watermarks.get(wm_key) if existing is None or task.offset < existing: - watermarks[partition] = task.offset + watermarks[wm_key] = task.offset break max_offset = task.offset if max_offset is None: continue - wm = watermarks.get(partition) + wm = watermarks.get(wm_key) if wm is not None and (max_offset + 1) > wm: # Advancing would jump past the cancelled boundary — skip this partition # until the watermark is cleared on rebalance. @@ -212,8 +222,11 @@ async def _commit_partitions(self, ready: dict[TopicPartition, list[KafkaCommitT results: typing.Final = await asyncio.gather( *( - self._call_committer(ct, self._map_offsets_per_partition(ct, self._cancellation_watermarks)) - for ct in consumers_tasks.values() + self._call_committer( + ct, + self._map_offsets_per_partition(consumer_id, ct, self._cancellation_watermarks), + ) + for consumer_id, ct in consumers_tasks.items() ) ) @@ -261,6 +274,7 @@ async def _streaming_iteration(self, state: "_StreamingState") -> None: new_ct = state.queue_get_task.result() self._track_user_task(new_ct) _insert_sorted(state.pending.setdefault(new_ct.topic_partition, []), new_ct) + self._partition_owner[new_ct.topic_partition] = id(new_ct.consumer) state.pending_count += 1 state.queue_get_task = asyncio.create_task(self._messages_queue.get()) if state.timeout_deadline is None: @@ -302,6 +316,7 @@ def _handle_flush_fired(self, state: "_StreamingState") -> None: break self._track_user_task(ct) _insert_sorted(state.pending.setdefault(ct.topic_partition, []), ct) + self._partition_owner[ct.topic_partition] = id(ct.consumer) state.pending_count += 1 if not state.queue_get_task.done(): state.queue_get_task.cancel() @@ -340,13 +355,18 @@ def clear_cancellation_watermarks(self, partitions: typing.Iterable[TopicPartiti """Forget cancellation watermarks for ``partitions`` (or all if ``None``). Called on partition revocation by the rebalance listener — the partition's - next assignment starts fresh, with no inherited "do not advance" floor. + next assignment starts fresh, with no inherited "do not advance" floor. The + consumer to clear is resolved via the per-partition owner tracked in the + streaming loop, so the listener's API stays partition-only. """ if partitions is None: self._cancellation_watermarks.clear() + self._partition_owner.clear() return for partition in partitions: - self._cancellation_watermarks.pop(partition, None) + owner = self._partition_owner.pop(partition, None) + if owner is not None: + self._cancellation_watermarks.pop((owner, partition), None) async def send_task(self, new_task: KafkaCommitTask) -> None: self._check_is_commit_task_running() diff --git a/tests/test_kafka_committer.py b/tests/test_kafka_committer.py index 9890db0..5d6cad7 100644 --- a/tests/test_kafka_committer.py +++ b/tests/test_kafka_committer.py @@ -300,7 +300,7 @@ def test_committer_map_offsets_skips_cancelled_tasks(mock_consumer: MockAIOKafka ), ] - offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(tasks, {}) + offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(id(mock_consumer), tasks, {}) # Only offset 10 is safe to commit; 11 was cancelled, 12 is beyond it. assert offsets[tp] == 11 # max safe offset (10) + 1 @@ -316,7 +316,7 @@ def test_committer_map_offsets_skips_partition_when_all_cancelled(mock_consumer: topic_partition=tp, ) - offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition([task], {}) + offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(id(mock_consumer), [task], {}) assert tp not in offsets @@ -334,10 +334,10 @@ def test_map_offsets_records_cancellation_watermark(mock_consumer: MockAIOKafkaC topic_partition=tp, ) - watermarks: dict[TopicPartition, int] = {} - KafkaBatchCommitter._map_offsets_per_partition([task], watermarks) + watermarks: dict[tuple[int, TopicPartition], int] = {} + KafkaBatchCommitter._map_offsets_per_partition(id(mock_consumer), [task], watermarks) - assert watermarks == {tp: cancelled_offset} + assert watermarks == {(id(mock_consumer), tp): cancelled_offset} def test_map_offsets_blocks_partition_when_watermark_present(mock_consumer: MockAIOKafkaConsumer) -> None: @@ -350,11 +350,11 @@ def test_map_offsets_blocks_partition_when_watermark_present(mock_consumer: Mock topic_partition=tp, ) - watermarks: dict[TopicPartition, int] = {tp: 11} - offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition([new_task], watermarks) + watermarks: dict[tuple[int, TopicPartition], int] = {(id(mock_consumer), tp): 11} + offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(id(mock_consumer), [new_task], watermarks) assert tp not in offsets - assert watermarks == {tp: 11} # unchanged + assert watermarks == {(id(mock_consumer), tp): 11} # unchanged def test_map_offsets_keeps_earliest_watermark(mock_consumer: MockAIOKafkaConsumer) -> None: @@ -367,10 +367,10 @@ def test_map_offsets_keeps_earliest_watermark(mock_consumer: MockAIOKafkaConsume topic_partition=tp, ) - watermarks: dict[TopicPartition, int] = {tp: 11} - KafkaBatchCommitter._map_offsets_per_partition([later_cancelled], watermarks) + watermarks: dict[tuple[int, TopicPartition], int] = {(id(mock_consumer), tp): 11} + KafkaBatchCommitter._map_offsets_per_partition(id(mock_consumer), [later_cancelled], watermarks) - assert watermarks == {tp: 11} + assert watermarks == {(id(mock_consumer), tp): 11} def test_map_offsets_commits_max_before_cancellation_records_watermark( @@ -395,31 +395,82 @@ def test_map_offsets_commits_max_before_cancellation_records_watermark( ), ] - watermarks: dict[TopicPartition, int] = {} - offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(tasks, watermarks) + watermarks: dict[tuple[int, TopicPartition], int] = {} + offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(id(mock_consumer), tasks, watermarks) assert offsets == {tp: successful_offset + 1} - assert watermarks == {tp: cancelled_offset} + assert watermarks == {(id(mock_consumer), tp): cancelled_offset} -def test_clear_cancellation_watermarks_specific_partitions(committer: KafkaBatchCommitter) -> None: +def test_map_offsets_watermark_isolated_per_consumer() -> None: + """A cancelled task on one consumer must not block commits on another consumer. + + Regression: previously the watermark dict was keyed by partition only, so a single + handler shared across consumer groups subscribing to the same (topic, partition) + would have one group's cancellation block the other group's commit. + """ + consumer_a: typing.Final = MockAIOKafkaConsumer(group_id="group-a") + consumer_b: typing.Final = MockAIOKafkaConsumer(group_id="group-b") + tp: typing.Final = TopicPartition(topic="shared", partition=0) + + cancelled_on_a: typing.Final = KafkaCommitTask( + asyncio_task=MockAsyncioTask(cancelled=True), # ty: ignore[invalid-argument-type] + offset=5, + consumer=consumer_a, + topic_partition=tp, + ) + success_on_b: typing.Final = KafkaCommitTask( + asyncio_task=MockAsyncioTask(done=True), # ty: ignore[invalid-argument-type] + offset=20, + consumer=consumer_b, + topic_partition=tp, + ) + + watermarks: dict[tuple[int, TopicPartition], int] = {} + KafkaBatchCommitter._map_offsets_per_partition(id(consumer_a), [cancelled_on_a], watermarks) + assert watermarks == {(id(consumer_a), tp): 5} + + offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition( + id(consumer_b), [success_on_b], watermarks + ) + assert offsets == {tp: 21}, "Consumer B's commit must not be blocked by consumer A's watermark" + # Consumer A's watermark stays intact for consumer A. + assert watermarks == {(id(consumer_a), tp): 5} + + +def test_clear_cancellation_watermarks_specific_partitions( + committer: KafkaBatchCommitter, mock_consumer: MockAIOKafkaConsumer +) -> None: tp_a: typing.Final = TopicPartition(topic="t", partition=0) tp_b: typing.Final = TopicPartition(topic="t", partition=1) - committer._cancellation_watermarks[tp_a] = 5 - committer._cancellation_watermarks[tp_b] = 7 + consumer_id: typing.Final = id(mock_consumer) + # Pre-seed owner so clear can resolve which consumer's watermark to drop. + committer._partition_owner[tp_a] = consumer_id + committer._partition_owner[tp_b] = consumer_id + committer._cancellation_watermarks[(consumer_id, tp_a)] = 5 + committer._cancellation_watermarks[(consumer_id, tp_b)] = 7 committer.clear_cancellation_watermarks([tp_a]) - assert committer._cancellation_watermarks == {tp_b: 7} + assert committer._cancellation_watermarks == {(consumer_id, tp_b): 7} + assert committer._partition_owner == {tp_b: consumer_id} -def test_clear_cancellation_watermarks_all_when_none(committer: KafkaBatchCommitter) -> None: - committer._cancellation_watermarks[TopicPartition(topic="t", partition=0)] = 5 - committer._cancellation_watermarks[TopicPartition(topic="t", partition=1)] = 7 +def test_clear_cancellation_watermarks_all_when_none( + committer: KafkaBatchCommitter, mock_consumer: MockAIOKafkaConsumer +) -> None: + consumer_id: typing.Final = id(mock_consumer) + tp_a: typing.Final = TopicPartition(topic="t", partition=0) + tp_b: typing.Final = TopicPartition(topic="t", partition=1) + committer._partition_owner[tp_a] = consumer_id + committer._partition_owner[tp_b] = consumer_id + committer._cancellation_watermarks[(consumer_id, tp_a)] = 5 + committer._cancellation_watermarks[(consumer_id, tp_b)] = 7 committer.clear_cancellation_watermarks() assert committer._cancellation_watermarks == {} + assert committer._partition_owner == {} def test_committer_map_offsets_advances_to_max_per_partition(mock_consumer: MockAIOKafkaConsumer) -> None: @@ -450,7 +501,7 @@ def test_committer_map_offsets_advances_to_max_per_partition(mock_consumer: Mock ) ) - offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(tasks, {}) + offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(id(mock_consumer), tasks, {}) assert offsets[TopicPartition(topic="t1", partition=0)] == first_offset + 10 + 1 assert offsets[TopicPartition(topic="t1", partition=partition)] == second_offset + 1