Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 35 additions & 15 deletions faststream_concurrent_aiokafka/batch_committer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,18 @@ def __init__(
self._commit_batch_timeout_sec = commit_batch_timeout_sec
self._commit_batch_size = commit_batch_size
self._shutdown_timeout = shutdown_timeout_sec
# Per-partition floor for the smallest cancelled offset seen since the partition was
# last assigned. Once set, the committer will not advance Kafka's committed offset for
# that partition until clear_cancellation_watermarks() is called on rebalance — so the
# cancelled-and-after offsets get redelivered on restart (at-least-once).
self._cancellation_watermarks: dict[TopicPartition, int] = {}
# Per-(consumer, partition) floor for the smallest cancelled offset seen since the
# partition was last assigned to that consumer. Scoping by id(consumer) prevents a
# cancelled task on one consumer group from blocking commits on another group that
# happens to subscribe to the same (topic, partition). Once set, the committer will
# not advance Kafka's committed offset for that consumer/partition until
# clear_cancellation_watermarks() is called on rebalance — so the cancelled-and-after
# offsets get redelivered on restart (at-least-once).
self._cancellation_watermarks: dict[tuple[int, TopicPartition], int] = {}
# Most-recent consumer (by id()) that absorbed a task for each partition. Lets
# clear_cancellation_watermarks(partitions) resolve which consumer's watermark to
# drop on rebalance without changing the listener's API.
self._partition_owner: dict[TopicPartition, int] = {}

def _on_user_task_done(self, _task: asyncio.Future[typing.Any]) -> None:
"""Done-callback target for user tasks; wakes the streaming loop."""
Expand Down Expand Up @@ -129,32 +136,35 @@ async def _call_committer(

@staticmethod
def _map_offsets_per_partition(
consumer_id: int,
consumer_tasks: list[KafkaCommitTask],
watermarks: dict[TopicPartition, int],
watermarks: dict[tuple[int, TopicPartition], int],
) -> dict[TopicPartition, int]:
# `watermarks` is mutated: any cancelled task seen here records (or lowers) the
# partition's watermark. Subsequent batches will see the watermark and skip
# advancing past it. Caller (the committer) is the watermark dict's owner.
# (consumer, partition) watermark. Subsequent batches for the same consumer will see
# it and skip advancing past it. Other consumers (different group, same partition)
# have their own keys and are unaffected. Caller (the committer) owns the dict.
by_partition: dict[TopicPartition, list[KafkaCommitTask]] = {}
for task in consumer_tasks:
by_partition.setdefault(task.topic_partition, []).append(task)

partitions_to_offsets: dict[TopicPartition, int] = {}
for partition, tasks in by_partition.items():
wm_key: tuple[int, TopicPartition] = (consumer_id, partition)
max_offset: int | None = None
for task in sorted(tasks, key=_OFFSET_KEY):
if task.asyncio_task.cancelled():
# Earliest cancelled wins: a later batch may not see the earlier
# cancellation, so without min() we could forget it and accidentally
# advance past the boundary.
existing = watermarks.get(partition)
existing = watermarks.get(wm_key)
if existing is None or task.offset < existing:
watermarks[partition] = task.offset
watermarks[wm_key] = task.offset
break
max_offset = task.offset
if max_offset is None:
continue
wm = watermarks.get(partition)
wm = watermarks.get(wm_key)
if wm is not None and (max_offset + 1) > wm:
# Advancing would jump past the cancelled boundary — skip this partition
# until the watermark is cleared on rebalance.
Expand Down Expand Up @@ -212,8 +222,11 @@ async def _commit_partitions(self, ready: dict[TopicPartition, list[KafkaCommitT

results: typing.Final = await asyncio.gather(
*(
self._call_committer(ct, self._map_offsets_per_partition(ct, self._cancellation_watermarks))
for ct in consumers_tasks.values()
self._call_committer(
ct,
self._map_offsets_per_partition(consumer_id, ct, self._cancellation_watermarks),
)
for consumer_id, ct in consumers_tasks.items()
)
)

Expand Down Expand Up @@ -261,6 +274,7 @@ async def _streaming_iteration(self, state: "_StreamingState") -> None:
new_ct = state.queue_get_task.result()
self._track_user_task(new_ct)
_insert_sorted(state.pending.setdefault(new_ct.topic_partition, []), new_ct)
self._partition_owner[new_ct.topic_partition] = id(new_ct.consumer)
state.pending_count += 1
state.queue_get_task = asyncio.create_task(self._messages_queue.get())
if state.timeout_deadline is None:
Expand Down Expand Up @@ -302,6 +316,7 @@ def _handle_flush_fired(self, state: "_StreamingState") -> None:
break
self._track_user_task(ct)
_insert_sorted(state.pending.setdefault(ct.topic_partition, []), ct)
self._partition_owner[ct.topic_partition] = id(ct.consumer)
state.pending_count += 1
if not state.queue_get_task.done():
state.queue_get_task.cancel()
Expand Down Expand Up @@ -340,13 +355,18 @@ def clear_cancellation_watermarks(self, partitions: typing.Iterable[TopicPartiti
"""Forget cancellation watermarks for ``partitions`` (or all if ``None``).

Called on partition revocation by the rebalance listener — the partition's
next assignment starts fresh, with no inherited "do not advance" floor.
next assignment starts fresh, with no inherited "do not advance" floor. The
consumer to clear is resolved via the per-partition owner tracked in the
streaming loop, so the listener's API stays partition-only.
"""
if partitions is None:
self._cancellation_watermarks.clear()
self._partition_owner.clear()
return
for partition in partitions:
self._cancellation_watermarks.pop(partition, None)
owner = self._partition_owner.pop(partition, None)
if owner is not None:
self._cancellation_watermarks.pop((owner, partition), None)

async def send_task(self, new_task: KafkaCommitTask) -> None:
self._check_is_commit_task_running()
Expand Down
95 changes: 73 additions & 22 deletions tests/test_kafka_committer.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def test_committer_map_offsets_skips_cancelled_tasks(mock_consumer: MockAIOKafka
),
]

offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(tasks, {})
offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(id(mock_consumer), tasks, {})
# Only offset 10 is safe to commit; 11 was cancelled, 12 is beyond it.
assert offsets[tp] == 11 # max safe offset (10) + 1

Expand All @@ -316,7 +316,7 @@ def test_committer_map_offsets_skips_partition_when_all_cancelled(mock_consumer:
topic_partition=tp,
)

offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition([task], {})
offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(id(mock_consumer), [task], {})
assert tp not in offsets


Expand All @@ -334,10 +334,10 @@ def test_map_offsets_records_cancellation_watermark(mock_consumer: MockAIOKafkaC
topic_partition=tp,
)

watermarks: dict[TopicPartition, int] = {}
KafkaBatchCommitter._map_offsets_per_partition([task], watermarks)
watermarks: dict[tuple[int, TopicPartition], int] = {}
KafkaBatchCommitter._map_offsets_per_partition(id(mock_consumer), [task], watermarks)

assert watermarks == {tp: cancelled_offset}
assert watermarks == {(id(mock_consumer), tp): cancelled_offset}


def test_map_offsets_blocks_partition_when_watermark_present(mock_consumer: MockAIOKafkaConsumer) -> None:
Expand All @@ -350,11 +350,11 @@ def test_map_offsets_blocks_partition_when_watermark_present(mock_consumer: Mock
topic_partition=tp,
)

watermarks: dict[TopicPartition, int] = {tp: 11}
offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition([new_task], watermarks)
watermarks: dict[tuple[int, TopicPartition], int] = {(id(mock_consumer), tp): 11}
offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(id(mock_consumer), [new_task], watermarks)

assert tp not in offsets
assert watermarks == {tp: 11} # unchanged
assert watermarks == {(id(mock_consumer), tp): 11} # unchanged


def test_map_offsets_keeps_earliest_watermark(mock_consumer: MockAIOKafkaConsumer) -> None:
Expand All @@ -367,10 +367,10 @@ def test_map_offsets_keeps_earliest_watermark(mock_consumer: MockAIOKafkaConsume
topic_partition=tp,
)

watermarks: dict[TopicPartition, int] = {tp: 11}
KafkaBatchCommitter._map_offsets_per_partition([later_cancelled], watermarks)
watermarks: dict[tuple[int, TopicPartition], int] = {(id(mock_consumer), tp): 11}
KafkaBatchCommitter._map_offsets_per_partition(id(mock_consumer), [later_cancelled], watermarks)

assert watermarks == {tp: 11}
assert watermarks == {(id(mock_consumer), tp): 11}


def test_map_offsets_commits_max_before_cancellation_records_watermark(
Expand All @@ -395,31 +395,82 @@ def test_map_offsets_commits_max_before_cancellation_records_watermark(
),
]

watermarks: dict[TopicPartition, int] = {}
offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(tasks, watermarks)
watermarks: dict[tuple[int, TopicPartition], int] = {}
offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(id(mock_consumer), tasks, watermarks)

assert offsets == {tp: successful_offset + 1}
assert watermarks == {tp: cancelled_offset}
assert watermarks == {(id(mock_consumer), tp): cancelled_offset}


def test_clear_cancellation_watermarks_specific_partitions(committer: KafkaBatchCommitter) -> None:
def test_map_offsets_watermark_isolated_per_consumer() -> None:
"""A cancelled task on one consumer must not block commits on another consumer.

Regression: previously the watermark dict was keyed by partition only, so a single
handler shared across consumer groups subscribing to the same (topic, partition)
would have one group's cancellation block the other group's commit.
"""
consumer_a: typing.Final = MockAIOKafkaConsumer(group_id="group-a")
consumer_b: typing.Final = MockAIOKafkaConsumer(group_id="group-b")
tp: typing.Final = TopicPartition(topic="shared", partition=0)

cancelled_on_a: typing.Final = KafkaCommitTask(
asyncio_task=MockAsyncioTask(cancelled=True), # ty: ignore[invalid-argument-type]
offset=5,
consumer=consumer_a,
topic_partition=tp,
)
success_on_b: typing.Final = KafkaCommitTask(
asyncio_task=MockAsyncioTask(done=True), # ty: ignore[invalid-argument-type]
offset=20,
consumer=consumer_b,
topic_partition=tp,
)

watermarks: dict[tuple[int, TopicPartition], int] = {}
KafkaBatchCommitter._map_offsets_per_partition(id(consumer_a), [cancelled_on_a], watermarks)
assert watermarks == {(id(consumer_a), tp): 5}

offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(
id(consumer_b), [success_on_b], watermarks
)
assert offsets == {tp: 21}, "Consumer B's commit must not be blocked by consumer A's watermark"
# Consumer A's watermark stays intact for consumer A.
assert watermarks == {(id(consumer_a), tp): 5}


def test_clear_cancellation_watermarks_specific_partitions(
committer: KafkaBatchCommitter, mock_consumer: MockAIOKafkaConsumer
) -> None:
tp_a: typing.Final = TopicPartition(topic="t", partition=0)
tp_b: typing.Final = TopicPartition(topic="t", partition=1)
committer._cancellation_watermarks[tp_a] = 5
committer._cancellation_watermarks[tp_b] = 7
consumer_id: typing.Final = id(mock_consumer)
# Pre-seed owner so clear can resolve which consumer's watermark to drop.
committer._partition_owner[tp_a] = consumer_id
committer._partition_owner[tp_b] = consumer_id
committer._cancellation_watermarks[(consumer_id, tp_a)] = 5
committer._cancellation_watermarks[(consumer_id, tp_b)] = 7

committer.clear_cancellation_watermarks([tp_a])

assert committer._cancellation_watermarks == {tp_b: 7}
assert committer._cancellation_watermarks == {(consumer_id, tp_b): 7}
assert committer._partition_owner == {tp_b: consumer_id}


def test_clear_cancellation_watermarks_all_when_none(committer: KafkaBatchCommitter) -> None:
committer._cancellation_watermarks[TopicPartition(topic="t", partition=0)] = 5
committer._cancellation_watermarks[TopicPartition(topic="t", partition=1)] = 7
def test_clear_cancellation_watermarks_all_when_none(
committer: KafkaBatchCommitter, mock_consumer: MockAIOKafkaConsumer
) -> None:
consumer_id: typing.Final = id(mock_consumer)
tp_a: typing.Final = TopicPartition(topic="t", partition=0)
tp_b: typing.Final = TopicPartition(topic="t", partition=1)
committer._partition_owner[tp_a] = consumer_id
committer._partition_owner[tp_b] = consumer_id
committer._cancellation_watermarks[(consumer_id, tp_a)] = 5
committer._cancellation_watermarks[(consumer_id, tp_b)] = 7

committer.clear_cancellation_watermarks()

assert committer._cancellation_watermarks == {}
assert committer._partition_owner == {}


def test_committer_map_offsets_advances_to_max_per_partition(mock_consumer: MockAIOKafkaConsumer) -> None:
Expand Down Expand Up @@ -450,7 +501,7 @@ def test_committer_map_offsets_advances_to_max_per_partition(mock_consumer: Mock
)
)

offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(tasks, {})
offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(id(mock_consumer), tasks, {})
assert offsets[TopicPartition(topic="t1", partition=0)] == first_offset + 10 + 1
assert offsets[TopicPartition(topic="t1", partition=partition)] == second_offset + 1

Expand Down
Loading