Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 44 additions & 6 deletions faststream_concurrent_aiokafka/batch_committer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ def __init__(
self._commit_batch_timeout_sec = commit_batch_timeout_sec
self._commit_batch_size = commit_batch_size
self._shutdown_timeout = shutdown_timeout_sec
# Per-partition floor for the smallest cancelled offset seen since the partition was
# last assigned. Once set, the committer will not advance Kafka's committed offset for
# that partition until clear_cancellation_watermarks() is called on rebalance — so the
# cancelled-and-after offsets get redelivered on restart (at-least-once).
self._cancellation_watermarks: dict[TopicPartition, int] = {}

def _on_user_task_done(self, _task: asyncio.Future[typing.Any]) -> None:
"""Done-callback target for user tasks; wakes the streaming loop."""
Expand Down Expand Up @@ -121,7 +126,13 @@ async def _call_committer(
return True

@staticmethod
def _map_offsets_per_partition(consumer_tasks: list[KafkaCommitTask]) -> dict[TopicPartition, int]:
def _map_offsets_per_partition(
consumer_tasks: list[KafkaCommitTask],
watermarks: dict[TopicPartition, int],
) -> dict[TopicPartition, int]:
# `watermarks` is mutated: any cancelled task seen here records (or lowers) the
# partition's watermark. Subsequent batches will see the watermark and skip
# advancing past it. Caller (the committer) is the watermark dict's owner.
by_partition: dict[TopicPartition, list[KafkaCommitTask]] = {}
for task in consumer_tasks:
by_partition.setdefault(task.topic_partition, []).append(task)
Expand All @@ -131,11 +142,23 @@ def _map_offsets_per_partition(consumer_tasks: list[KafkaCommitTask]) -> dict[To
max_offset: int | None = None
for task in sorted(tasks, key=_OFFSET_KEY):
if task.asyncio_task.cancelled():
break # stop committing at first cancelled task — message was not processed
# Earliest cancelled wins: a later batch may not see the earlier
# cancellation, so without min() we could forget it and accidentally
# advance past the boundary.
existing = watermarks.get(partition)
if existing is None or task.offset < existing:
watermarks[partition] = task.offset
break
max_offset = task.offset
if max_offset is not None:
# Kafka commits the *next* offset to fetch, so committed = processed_max + 1
partitions_to_offsets[partition] = max_offset + 1
if max_offset is None:
continue
wm = watermarks.get(partition)
if wm is not None and (max_offset + 1) > wm:
# Advancing would jump past the cancelled boundary — skip this partition
# until the watermark is cleared on rebalance.
continue
# Kafka commits the *next* offset to fetch, so committed = processed_max + 1
partitions_to_offsets[partition] = max_offset + 1
return partitions_to_offsets

@staticmethod
Expand Down Expand Up @@ -186,7 +209,10 @@ async def _commit_partitions(self, ready: dict[TopicPartition, list[KafkaCommitT
consumers_tasks.setdefault(id(task.consumer), []).append(task)

results: typing.Final = await asyncio.gather(
*(self._call_committer(ct, self._map_offsets_per_partition(ct)) for ct in consumers_tasks.values())
*(
self._call_committer(ct, self._map_offsets_per_partition(ct, self._cancellation_watermarks))
for ct in consumers_tasks.values()
)
)

for _ in flat:
Expand Down Expand Up @@ -308,6 +334,18 @@ async def commit_all(self) -> None:
self._flush_batch_event.set()
await self._messages_queue.join()

def clear_cancellation_watermarks(self, partitions: typing.Iterable[TopicPartition] | None = None) -> None:
"""Forget cancellation watermarks for ``partitions`` (or all if ``None``).

Called on partition revocation by the rebalance listener — the partition's
next assignment starts fresh, with no inherited "do not advance" floor.
"""
if partitions is None:
self._cancellation_watermarks.clear()
return
for partition in partitions:
self._cancellation_watermarks.pop(partition, None)

async def send_task(self, new_task: KafkaCommitTask) -> None:
self._check_is_commit_task_running()
await self._messages_queue.put(new_task)
Expand Down
11 changes: 10 additions & 1 deletion faststream_concurrent_aiokafka/rebalance.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import typing

from aiokafka import ConsumerRebalanceListener as BaseConsumerRebalanceListener

from faststream_concurrent_aiokafka.batch_committer import KafkaBatchCommitter


if typing.TYPE_CHECKING:
from faststream.kafka import TopicPartition


class ConsumerRebalanceListener(BaseConsumerRebalanceListener): # type: ignore[misc]
"""Commits all pending offsets when Kafka revokes partitions during rebalance.

Expand Down Expand Up @@ -32,5 +38,8 @@ def __init__(self, committer: KafkaBatchCommitter) -> None:
async def on_partitions_assigned(self, _assigned: object) -> None: # ty: ignore[invalid-method-override]
pass

async def on_partitions_revoked(self, _revoked: object) -> None: # ty: ignore[invalid-method-override]
async def on_partitions_revoked(self, revoked: object) -> None:
await self._committer.commit_all()
# The revoked partitions' next assignment (possibly to another consumer) starts
# fresh, so the cancellation floor — if any was set — must not carry over.
self._committer.clear_cancellation_watermarks(typing.cast("typing.Iterable[TopicPartition]", revoked))
1 change: 1 addition & 0 deletions tests/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self, *_args: object, **_kwargs: object) -> None:
self.close = AsyncMock()
self.spawn = Mock()
self.commit_all = AsyncMock()
self.clear_cancellation_watermarks = Mock()
self.notify_task_completed = Mock()
self._healthy = True

Expand Down
108 changes: 105 additions & 3 deletions tests/test_kafka_committer.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def test_committer_map_offsets_skips_cancelled_tasks(mock_consumer: MockAIOKafka
),
]

offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(tasks)
offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(tasks, {})
# Only offset 10 is safe to commit; 11 was cancelled, 12 is beyond it.
assert offsets[tp] == 11 # max safe offset (10) + 1

Expand All @@ -316,10 +316,112 @@ def test_committer_map_offsets_skips_partition_when_all_cancelled(mock_consumer:
topic_partition=tp,
)

offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition([task])
offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition([task], {})
assert tp not in offsets


# ---------- cancellation watermark ----------


def test_map_offsets_records_cancellation_watermark(mock_consumer: MockAIOKafkaConsumer) -> None:
"""A cancelled task records its offset as the partition's watermark."""
tp: typing.Final = TopicPartition(topic="t", partition=0)
cancelled_offset: typing.Final = 11
task: typing.Final = KafkaCommitTask(
asyncio_task=MockAsyncioTask(cancelled=True), # ty: ignore[invalid-argument-type]
offset=cancelled_offset,
consumer=mock_consumer,
topic_partition=tp,
)

watermarks: dict[TopicPartition, int] = {}
KafkaBatchCommitter._map_offsets_per_partition([task], watermarks)

assert watermarks == {tp: cancelled_offset}


def test_map_offsets_blocks_partition_when_watermark_present(mock_consumer: MockAIOKafkaConsumer) -> None:
"""A successful task whose offset would advance past the watermark is dropped."""
tp: typing.Final = TopicPartition(topic="t", partition=0)
new_task: typing.Final = KafkaCommitTask(
asyncio_task=MockAsyncioTask(done=True), # ty: ignore[invalid-argument-type]
offset=20,
consumer=mock_consumer,
topic_partition=tp,
)

watermarks: dict[TopicPartition, int] = {tp: 11}
offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition([new_task], watermarks)

assert tp not in offsets
assert watermarks == {tp: 11} # unchanged


def test_map_offsets_keeps_earliest_watermark(mock_consumer: MockAIOKafkaConsumer) -> None:
"""When a partition sees a second cancellation at a higher offset, the earlier wins."""
tp: typing.Final = TopicPartition(topic="t", partition=0)
later_cancelled: typing.Final = KafkaCommitTask(
asyncio_task=MockAsyncioTask(cancelled=True), # ty: ignore[invalid-argument-type]
offset=50,
consumer=mock_consumer,
topic_partition=tp,
)

watermarks: dict[TopicPartition, int] = {tp: 11}
KafkaBatchCommitter._map_offsets_per_partition([later_cancelled], watermarks)

assert watermarks == {tp: 11}


def test_map_offsets_commits_max_before_cancellation_records_watermark(
mock_consumer: MockAIOKafkaConsumer,
) -> None:
"""The pre-cancellation max is still committed in the same batch the watermark is recorded."""
tp: typing.Final = TopicPartition(topic="t", partition=0)
successful_offset: typing.Final = 9
cancelled_offset: typing.Final = 10
tasks: typing.Final = [
KafkaCommitTask(
asyncio_task=MockAsyncioTask(done=True), # ty: ignore[invalid-argument-type]
offset=successful_offset,
consumer=mock_consumer,
topic_partition=tp,
),
KafkaCommitTask(
asyncio_task=MockAsyncioTask(cancelled=True), # ty: ignore[invalid-argument-type]
offset=cancelled_offset,
consumer=mock_consumer,
topic_partition=tp,
),
]

watermarks: dict[TopicPartition, int] = {}
offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(tasks, watermarks)

assert offsets == {tp: successful_offset + 1}
assert watermarks == {tp: cancelled_offset}


def test_clear_cancellation_watermarks_specific_partitions(committer: KafkaBatchCommitter) -> None:
tp_a: typing.Final = TopicPartition(topic="t", partition=0)
tp_b: typing.Final = TopicPartition(topic="t", partition=1)
committer._cancellation_watermarks[tp_a] = 5
committer._cancellation_watermarks[tp_b] = 7

committer.clear_cancellation_watermarks([tp_a])

assert committer._cancellation_watermarks == {tp_b: 7}


def test_clear_cancellation_watermarks_all_when_none(committer: KafkaBatchCommitter) -> None:
committer._cancellation_watermarks[TopicPartition(topic="t", partition=0)] = 5
committer._cancellation_watermarks[TopicPartition(topic="t", partition=1)] = 7

committer.clear_cancellation_watermarks()

assert committer._cancellation_watermarks == {}


def test_committer_map_offsets_advances_to_max_per_partition(mock_consumer: MockAIOKafkaConsumer) -> None:
"""Offset advances to (max processed + 1) per partition."""
first_offset: typing.Final = 100
Expand Down Expand Up @@ -348,7 +450,7 @@ def test_committer_map_offsets_advances_to_max_per_partition(mock_consumer: Mock
)
)

offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(tasks)
offsets: typing.Final = KafkaBatchCommitter._map_offsets_per_partition(tasks, {})
assert offsets[TopicPartition(topic="t1", partition=0)] == first_offset + 10 + 1
assert offsets[TopicPartition(topic="t1", partition=partition)] == second_offset + 1

Expand Down
38 changes: 38 additions & 0 deletions tests/test_rebalance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import AsyncMock

import pytest
from faststream.kafka import TopicPartition

from faststream_concurrent_aiokafka.rebalance import ConsumerRebalanceListener
from tests.mocks import MockKafkaBatchCommitter
Expand Down Expand Up @@ -43,3 +44,40 @@ async def track_commit() -> None:

await listener.on_partitions_revoked(set())
assert flush_done, "commit_all was not awaited before returning"


async def test_rebalance_on_partitions_revoked_clears_watermarks(
listener: ConsumerRebalanceListener, committer: MockKafkaBatchCommitter
) -> None:
"""On revoke, the cancelled-offset watermarks for the revoked partitions must be cleared.

The next assignment of those partitions starts fresh.
"""
revoked: typing.Final = {TopicPartition(topic="t", partition=0), TopicPartition(topic="t", partition=1)}

await listener.on_partitions_revoked(revoked)

committer.clear_cancellation_watermarks.assert_called_once_with(revoked)


async def test_rebalance_clear_runs_after_commit_all(committer: MockKafkaBatchCommitter) -> None:
"""clear_cancellation_watermarks must run after commit_all.

Committing relies on the watermark to know which partitions to skip, so clearing first
would let an outgoing consumer commit past a cancelled boundary.
"""
order: typing.Final[list[str]] = []

async def track_commit_all() -> None:
order.append("commit_all")

def track_clear(_partitions: object) -> None:
order.append("clear")

committer.commit_all = AsyncMock(side_effect=track_commit_all)
committer.clear_cancellation_watermarks = track_clear # ty: ignore[invalid-assignment]
listener: typing.Final = ConsumerRebalanceListener(committer) # ty: ignore[invalid-argument-type]

await listener.on_partitions_revoked(set())

assert order == ["commit_all", "clear"]
Loading