diff --git a/CLAUDE.md b/CLAUDE.md index 436498b..7c48fea 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -28,12 +28,13 @@ The library provides concurrent Kafka message processing for FastStream. Modules **`processing.py` — `KafkaConcurrentHandler`** The core engine. One handler is created per `initialize_concurrent_processing` call and stored in FastStream's `ContextRepo` under the key `"concurrent_processing"`. It is *not* a singleton — calling `stop_concurrent_processing` clears the context entry so a fresh handler can be initialised. The handler manages: - An `asyncio.Semaphore` for concurrency limiting (minimum: 1) -- In-flight task tracking via a counter (`_tracked_count`) + `asyncio.Event` (`_all_done_event`); the per-task done-callback (`_finish_task`) releases the semaphore, decrements the counter, and sets the event when it reaches zero. `wait_for_subtasks` awaits the event with a timeout -- Signal handlers (SIGTERM/SIGINT) that trigger graceful shutdown +- A `set[asyncio.Task]` (`_tracked_tasks`) holding in-flight user tasks; the per-task done-callback (`_finish_task`) releases the semaphore and removes the task from the set - A `KafkaBatchCommitter` for offset commits Key design: `handle_task()` fires-and-forgets the user coroutine as an asyncio task and enqueues a `KafkaCommitTask` on the committer. Offsets are not committed until the user task finishes (at-least-once semantics). +`stop()` cancels every in-flight tracked task, then awaits `committer.close()`. The committer treats cancelled tasks as a hard offset boundary (see `batch_committer.py`), so cancelled-and-after offsets stay uncommitted and get redelivered on restart. Total wall-clock for shutdown is bounded by the committer's own `shutdown_timeout_sec` (default 20 s) and is sub-second in normal conditions. The handler does *not* install signal handlers — shutdown is driven by the FastStream lifespan calling `stop_concurrent_processing`. + **`middleware.py` — FastStream middleware + lifecycle functions** - `KafkaConcurrentProcessingMiddleware`: FastStream `BaseMiddleware` subclass. Its `consume_scope` retrieves the handler from `self.context`. It passes through (a) FakeConsumer (TestKafkaBroker) and (b) any subscriber whose ack policy is not MANUAL (`kafka_message.committed is not None`). It refuses if `_enable_auto_commit=True` on the consumer. If the handler has been stopped, it logs a warning and skips the message (the offset stays uncommitted, so the message is redelivered on restart). - `initialize_concurrent_processing(context, ...)`: create and start a handler, store it in context. diff --git a/README.md b/README.md index f6252e8..eb36c54 100644 --- a/README.md +++ b/README.md @@ -14,8 +14,8 @@ By default FastStream processes Kafka messages sequentially — one message at a - Configurable concurrency limit (semaphore-based) - Batch offset committing per partition after each task completes - Rebalance-safe: pending offsets are flushed on partition revocation via `ConsumerRebalanceListener` -- Graceful shutdown: waits up to `shutdown_timeout_sec` (default 20 s) for in-flight tasks before exiting -- Signal handling (SIGTERM / SIGINT) triggers graceful shutdown +- Fast shutdown: cancels in-flight tasks; uncommitted offsets are redelivered on restart (at-least-once) +- Signal handling owned by your lifespan / process manager — this lib does not register SIGTERM/SIGINT handlers - Handler exceptions are logged but do not crash the consumer - Health check helper to probe handler status from a `ContextRepo` @@ -110,13 +110,13 @@ Create and start the concurrent processing handler; store it in FastStream's con | `concurrency_limit` | `10` | Max concurrent asyncio tasks (minimum: 1) | | `commit_batch_size` | `10` | Max messages per commit batch | | `commit_batch_timeout_sec` | `10.0` | Max seconds before flushing a batch | -| `shutdown_timeout_sec` | `20.0` | Max seconds to wait for the batch committer to flush AND for in-flight handlers to finish during graceful shutdown | +| `shutdown_timeout_sec` | `20.0` | Max seconds the batch committer waits for its background task to drain before forcing cancellation | Returns the `KafkaConcurrentHandler` instance. ### `stop_concurrent_processing(context)` -Flush pending commits, wait for in-flight tasks (up to `shutdown_timeout_sec`), then stop the handler. +Cancel all in-flight handler tasks, flush completed offsets via the committer, then stop the handler. Uncommitted offsets (from cancelled tasks or anything queued past a cancelled offset) are redelivered on restart — at-least-once. ### `is_kafka_handler_healthy(context)` @@ -150,7 +150,20 @@ modern_di_faststream.setup_di(app, container=container) # registered after 4. **Rebalance handling**: When Kafka revokes a partition, the `ConsumerRebalanceListener` (returned by `handler.create_rebalance_listener()`) calls `committer.commit_all()` to flush pending offsets before the partition is reassigned. This prevents in-flight messages from being redelivered to the new owner. -5. **Graceful shutdown**: `stop_concurrent_processing` flushes the committer, then awaits all in-flight tasks via an `asyncio.Event` (set when the in-flight counter reaches zero) bounded by `shutdown_timeout_sec` (default 20 s), then removes the signal handlers. +5. **Shutdown**: `stop_concurrent_processing` cancels every in-flight asyncio task, then awaits `committer.close()`. The committer treats cancelled tasks as a hard offset boundary — cancelled-and-after offsets stay uncommitted and get redelivered on restart. Total wall-clock is sub-second in normal conditions and bounded by `shutdown_timeout_sec` only as a safety net for stuck network commits. + +## Migration from < 0.x + +Previously, `stop_concurrent_processing` waited up to `2 × shutdown_timeout_sec` for in-flight handlers to drain to completion. The new behavior cancels them immediately. The at-least-once contract is unchanged — uncommitted offsets are redelivered on restart, the same way they always were when the handler crashed mid-task. + +| What changed | Old | New | +|---|---|---| +| In-flight handler tasks on stop | drained to completion | **cancelled** | +| `KafkaConcurrentHandler.wait_for_subtasks()` | public method | removed | +| `shutdown_timeout_sec` | applied separately to handler and committer | applied to committer only | +| Signal handler installation | installed automatically | removed — own them via your lifespan / process manager | + +If your handlers do non-idempotent work that's expensive to repeat, ensure your handlers are wrapped in `try/finally` so cleanup runs on `CancelledError`, or pin to the previous version of this library. To trigger shutdown on SIGTERM/SIGINT, your lifespan or main entry point must catch the signal and call `stop_concurrent_processing(broker.context)` — under uvicorn / AsgiFastStream this happens automatically through the lifespan `finally` block. ## Requirements diff --git a/faststream_concurrent_aiokafka/middleware.py b/faststream_concurrent_aiokafka/middleware.py index ddcefa0..0ef7388 100644 --- a/faststream_concurrent_aiokafka/middleware.py +++ b/faststream_concurrent_aiokafka/middleware.py @@ -1,3 +1,4 @@ +import asyncio import contextlib import dataclasses import logging @@ -105,6 +106,12 @@ async def consume_scope( # ty: ignore[invalid-method-override] # The user handler already fired; the offset stays uncommitted, so the message # will be redelivered on restart (at-least-once). logger.warning("Kafka middleware. Handler is shutting down, skipping message") + except asyncio.CancelledError: + # stop() cancelled this task while handle_task was awaiting send_task. Offset + # stays uncommitted → redelivered on restart. Propagate so FastStream's chain + # can run its own cleanup. + logger.warning("Kafka middleware. Task cancelled during shutdown") + raise return None @@ -127,7 +134,6 @@ async def initialize_concurrent_processing( shutdown_timeout_sec=shutdown_timeout_sec, ), concurrency_limit=concurrency_limit, - shutdown_timeout_sec=shutdown_timeout_sec, ) await concurrent_processing.start() context.set_global(_PROCESSING_CONTEXT_KEY, concurrent_processing) diff --git a/faststream_concurrent_aiokafka/processing.py b/faststream_concurrent_aiokafka/processing.py index 0612821..24344f2 100644 --- a/faststream_concurrent_aiokafka/processing.py +++ b/faststream_concurrent_aiokafka/processing.py @@ -1,7 +1,5 @@ import asyncio -import functools import logging -import signal import typing from faststream.kafka import ConsumerRecord, TopicPartition @@ -15,7 +13,6 @@ logger = logging.getLogger(__name__) -SIGNALS: typing.Final = (signal.SIGTERM, signal.SIGINT) DEFAULT_CONCURRENCY_LIMIT: typing.Final = 10 DEFAULT_SHUTDOWN_TIMEOUT_SEC: typing.Final = 20.0 @@ -25,42 +22,25 @@ def __init__( self, committer: KafkaBatchCommitter, concurrency_limit: int = DEFAULT_CONCURRENCY_LIMIT, - shutdown_timeout_sec: float = DEFAULT_SHUTDOWN_TIMEOUT_SEC, ) -> None: if concurrency_limit < 1: msg = f"concurrency_limit must be >= 1, got {concurrency_limit}" raise ValueError(msg) self._limiter = asyncio.Semaphore(concurrency_limit) - # Counter + Event replace the old _current_tasks set: shutdown waits on the event, - # which is set once every tracked task has fired its done-callback. - self._tracked_count: int = 0 - self._all_done_event: asyncio.Event = asyncio.Event() - self._all_done_event.set() # 0 tasks ⇒ "all done" is True + # Tracked only so stop() can cancel them. The committer is the source of truth for + # offset progress; this set just lets us reach in-flight tasks at shutdown. + self._tracked_tasks: set[asyncio.Task[typing.Any]] = set() self._is_running: bool = False self._committer: KafkaBatchCommitter = committer - self._stop_task: asyncio.Task[typing.Any] | None = None - self._shutdown_timeout_sec: float = shutdown_timeout_sec - - async def wait_for_subtasks(self) -> None: - logger.info("Kafka middleware. Gracefully waiting for tasks to end...") - try: - await asyncio.wait_for( - self._all_done_event.wait(), - timeout=self._shutdown_timeout_sec, - ) - except TimeoutError: - logger.exception("Kafka middleware. Whoops, some tasks haven't finished in graceful time, sorry") def _finish_task(self, task: asyncio.Task[typing.Any]) -> None: self._limiter.release() + self._tracked_tasks.discard(task) if not task.cancelled(): exc: typing.Final[BaseException | None] = task.exception() if exc: logger.error("Kafka middleware. Task has failed with the exception", exc_info=exc) - self._tracked_count -= 1 - if self._tracked_count == 0: - self._all_done_event.set() async def handle_task( self, @@ -70,13 +50,7 @@ async def handle_task( ) -> None: await self._limiter.acquire() task: typing.Final = asyncio.ensure_future(coroutine) - # Increment + clear before add_done_callback so the counter already reflects this - # task by the time _finish_task can run. add_done_callback always schedules via - # loop.call_soon (never synchronous), but the callback could fire on the very next - # tick — once we yield at the send_task await below — so the bookkeeping must be - # consistent before that point. - self._tracked_count += 1 - self._all_done_event.clear() + self._tracked_tasks.add(task) task.add_done_callback(self._finish_task) try: await self._committer.send_task( @@ -92,19 +66,6 @@ async def handle_task( await self.stop() raise - def _setup_signal_handlers(self) -> None: - loop: typing.Final = asyncio.get_running_loop() - for sig in SIGNALS: - loop.add_signal_handler( - sig, - functools.partial(self._signal_handler, sig), - ) - logger.debug(f"Kafka middleware. Registered handler for {sig.name}") - - def _signal_handler(self, sig: signal.Signals) -> None: - logger.info(f"Kafka middleware. Received signal {sig.name}, initiating graceful shutdown...") - self._stop_task = asyncio.create_task(self.stop()) - async def start(self) -> None: if self._is_running: return @@ -113,7 +74,6 @@ async def start(self) -> None: self._is_running = True self._committer.spawn() - self._setup_signal_handlers() logger.info("Kafka middleware is ready to process messages.") async def stop(self) -> None: @@ -122,15 +82,15 @@ async def stop(self) -> None: logger.info("Kafka middleware. Shutting down middleware handler") self._is_running = False + # Cancel in-flight user tasks. The committer treats cancelled tasks as a hard + # offset boundary (batch_committer._extract_ready_prefixes / _map_offsets_per_partition): + # cancelled-and-after offsets stay uncommitted and get redelivered on restart. + for task in list(self._tracked_tasks): + if not task.done(): + task.cancel() + await self._committer.close() - await self.wait_for_subtasks() - try: - loop = asyncio.get_running_loop() - for sig in SIGNALS: - loop.remove_signal_handler(sig) - except Exception: # noqa: BLE001 - logger.warning("Kafka middleware. Exception raised while removing signal handlers", exc_info=True) logger.info("Kafka middleware. Complete shutting down middleware handler") def create_rebalance_listener(self) -> ConsumerRebalanceListener: diff --git a/tests/test_concurrent_processing.py b/tests/test_concurrent_processing.py index 016585b..1b2edfb 100644 --- a/tests/test_concurrent_processing.py +++ b/tests/test_concurrent_processing.py @@ -2,9 +2,8 @@ import asyncio import contextlib import logging -import signal import typing -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock import pytest import pytest_asyncio @@ -47,16 +46,6 @@ def sample_record() -> MockConsumerRecord: return MockConsumerRecord() -def _track_external(handler: KafkaConcurrentHandler, task: asyncio.Task[typing.Any]) -> None: - """Register an externally-created task with the handler's count+event tracking. - - Mirrors the bookkeeping that handle_task does so wait_for_subtasks waits for it. - """ - handler._tracked_count += 1 - handler._all_done_event.clear() - task.add_done_callback(handler._finish_task) - - def test_concurrent_init_zero_concurrency_limit_raises() -> None: with pytest.raises(ValueError, match="concurrency_limit must be >= 1"): KafkaConcurrentHandler(committer=MockKafkaBatchCommitter(), concurrency_limit=0) # ty: ignore[invalid-argument-type] @@ -92,17 +81,14 @@ async def test_concurrent_failed_task_exception( assert "Task has failed with the exception" in caplog.text -async def test_concurrent_finish_task_decrements_and_sets_done_event(handler: KafkaConcurrentHandler) -> None: - mock_task: typing.Final = MagicMock() - mock_task.cancelled.return_value = False - mock_task.exception.return_value = None - handler._tracked_count = 1 - handler._all_done_event.clear() +async def test_concurrent_finish_task_discards_from_tracked_set(handler: KafkaConcurrentHandler) -> None: + real_task: typing.Final = asyncio.create_task(asyncio.sleep(0)) + await real_task + handler._tracked_tasks.add(real_task) - handler._finish_task(mock_task) + handler._finish_task(real_task) - assert handler._tracked_count == 0 - assert handler._all_done_event.is_set() + assert real_task not in handler._tracked_tasks async def test_concurrent_creates_task( @@ -112,8 +98,7 @@ async def coro() -> str: return "result" await handler.handle_task(coro(), sample_record, sample_message) # ty: ignore[invalid-argument-type] - assert handler._tracked_count == 1 - assert not handler._all_done_event.is_set() + assert len(handler._tracked_tasks) == 1 async def test_concurrent_task_passed_to_committer( @@ -172,33 +157,12 @@ async def test_concurrent_handles_committer_dead_error( assert handler._committer handler._committer.send_task.side_effect = CommitterIsDeadError("Dead") # ty: ignore[unresolved-attribute] - async def coro() -> str: - return "result" - with pytest.raises(CommitterIsDeadError): - await handler.handle_task(coro(), sample_record, sample_message) # ty: ignore[invalid-argument-type] + await handler.handle_task(asyncio.sleep(0), sample_record, sample_message) # ty: ignore[invalid-argument-type] assert not handler._is_running -async def test_concurrent_signal_handler_triggers_stop(handler: KafkaConcurrentHandler) -> None: - await handler.start() - - with patch.object(handler, "stop", new_callable=AsyncMock) as mock_stop: - handler._signal_handler(signal.SIGTERM) - await asyncio.sleep(0) - mock_stop.assert_called_once() - - -async def test_concurrent_signal_handler_logs_signal( - handler: KafkaConcurrentHandler, caplog: pytest.LogCaptureFixture -) -> None: - caplog.set_level(logging.INFO) - handler._signal_handler(signal.SIGINT) - assert "Received signal" in caplog.text - assert "SIGINT" in caplog.text - - async def test_concurrent_start_sets_running(handler: KafkaConcurrentHandler) -> None: await handler.start() assert handler.is_running @@ -238,24 +202,6 @@ async def test_concurrent_stop_closes_committer(handler_with_committer: KafkaCon handler._committer.close.assert_called_once() # ty: ignore[unresolved-attribute] -async def test_concurrent_stop_waits_for_subtasks(handler: KafkaConcurrentHandler) -> None: - await handler.start() - with patch.object(handler, "wait_for_subtasks", new_callable=AsyncMock) as mock_wait: - await handler.stop() - mock_wait.assert_called_once() - - -async def test_concurrent_stop_handles_handler_removal_error( - handler: KafkaConcurrentHandler, caplog: pytest.LogCaptureFixture -) -> None: - caplog.set_level(logging.WARNING) - await handler.start() - with patch("asyncio.get_running_loop", side_effect=Exception("Loop error")): - await handler.stop() - - assert "Exception raised" in caplog.text - - async def test_concurrent_stop_when_not_running( handler: KafkaConcurrentHandler, caplog: pytest.LogCaptureFixture ) -> None: @@ -264,109 +210,67 @@ async def test_concurrent_stop_when_not_running( assert "Shutting down" not in caplog.text -async def test_concurrent_waits_for_all_subtasks(handler: KafkaConcurrentHandler) -> None: - results: typing.Final = [] - expected_tasks_len: typing.Final = 2 - - async def task1() -> str: - await asyncio.sleep(0.01) - results.append(1) - return "task1" - - async def task2() -> str: - await asyncio.sleep(0.02) - results.append(2) - return "task2" - - _track_external(handler, asyncio.create_task(task1())) - _track_external(handler, asyncio.create_task(task2())) - await handler.wait_for_subtasks() - assert len(results) == expected_tasks_len - assert handler._tracked_count == 0 - - -async def test_concurrent_handles_task_exceptions( - handler: KafkaConcurrentHandler, caplog: pytest.LogCaptureFixture -) -> None: - caplog.set_level(logging.ERROR) - - async def failing_task() -> typing.Never: - msg: typing.Final = "Task failed" - raise ValueError(msg) +async def test_concurrent_stop_cancels_in_flight_tasks(handler: KafkaConcurrentHandler) -> None: + await handler.start() + sample_record: typing.Final = MockConsumerRecord() + sample_message: typing.Final = MockKafkaMessage() - failing: typing.Final = asyncio.create_task(failing_task()) - _track_external(handler, failing) - await handler.wait_for_subtasks() - assert failing.done() + started: typing.Final = asyncio.Event() + async def slow() -> None: + started.set() + await asyncio.sleep(60) -async def test_concurrent_wait_for_subtasks_drains_tasks_added_during_wait( - handler: KafkaConcurrentHandler, -) -> None: - handler._shutdown_timeout_sec = 1.0 - initial_done: typing.Final = asyncio.Event() - late_done: typing.Final = asyncio.Event() + await handler.handle_task(slow(), sample_record, sample_message) # ty: ignore[invalid-argument-type] + await started.wait() - async def initial() -> None: - await asyncio.sleep(0.02) - initial_done.set() + assert len(handler._tracked_tasks) == 1 + in_flight: typing.Final = next(iter(handler._tracked_tasks)) - async def late() -> None: - await asyncio.sleep(0.05) - late_done.set() + await handler.stop() - async def inject_during_wait() -> None: - await asyncio.sleep(0.01) - _track_external(handler, asyncio.create_task(late())) + # MockKafkaBatchCommitter.close() is an AsyncMock and returns instantly without + # giving the event loop a chance to deliver the cancellation. Await the task so + # the CancelledError propagates and observable state settles. + with contextlib.suppress(asyncio.CancelledError): + await in_flight + assert in_flight.cancelled() + assert handler._tracked_tasks == set() - _track_external(handler, asyncio.create_task(initial())) - injector: typing.Final = asyncio.create_task(inject_during_wait()) - await handler.wait_for_subtasks() - await injector +async def test_concurrent_stop_returns_quickly_with_slow_handlers(handler: KafkaConcurrentHandler) -> None: + """Cancelling in-flight tasks lets stop() return well under any per-task latency.""" + await handler.start() + sample_record: typing.Final = MockConsumerRecord() + sample_message: typing.Final = MockKafkaMessage() - assert initial_done.is_set() - assert late_done.is_set() - assert handler._tracked_count == 0 + async def slow() -> None: + await asyncio.sleep(60) + for _ in range(5): + await handler.handle_task(slow(), sample_record, sample_message) # ty: ignore[invalid-argument-type] -async def test_concurrent_logs_timeout(caplog: pytest.LogCaptureFixture) -> None: - caplog.set_level(logging.ERROR) - handler: typing.Final = KafkaConcurrentHandler( - committer=MockKafkaBatchCommitter(), # ty: ignore[invalid-argument-type] - shutdown_timeout_sec=0.1, - ) - - async def slow_task() -> None: - await asyncio.sleep(100) - - slow: typing.Final = asyncio.create_task(slow_task()) - _track_external(handler, slow) - await handler.wait_for_subtasks() - assert "haven't finished in graceful time" in caplog.text - slow.cancel() - with contextlib.suppress(asyncio.CancelledError): - await slow + # Yield so scheduled slow() coroutines reach their sleep before stop() cancels them. + await asyncio.sleep(0) + loop: typing.Final = asyncio.get_running_loop() + started: typing.Final = loop.time() + await handler.stop() + elapsed: typing.Final = loop.time() - started -async def test_handler_uses_shutdown_timeout_kwarg() -> None: - handler: typing.Final = KafkaConcurrentHandler( - committer=MockKafkaBatchCommitter(), # ty: ignore[invalid-argument-type] - shutdown_timeout_sec=7.5, - ) - assert handler._shutdown_timeout_sec == 7.5 + assert elapsed < 1.0, f"stop() took {elapsed:.3f}s with slow handlers" async def test_concurrent_finish_task_does_not_crash_on_cancelled_task( handler_with_limit: KafkaConcurrentHandler, ) -> None: task: typing.Final = asyncio.create_task(asyncio.sleep(10)) - _track_external(handler_with_limit, task) + handler_with_limit._tracked_tasks.add(task) + task.add_done_callback(handler_with_limit._finish_task) task.cancel() with contextlib.suppress(asyncio.CancelledError): await task - assert handler_with_limit._tracked_count == 0 - assert handler_with_limit._all_done_event.is_set() + assert task not in handler_with_limit._tracked_tasks async def test_concurrent_full_lifecycle() -> None: @@ -387,49 +291,14 @@ async def process_msg(msg_id: int) -> None: for i in range(5): await handler.handle_task(process_msg(i), record, msg) # ty: ignore[invalid-argument-type] - await handler.wait_for_subtasks() - await handler.stop() - - assert not handler.is_running - assert len(processed) > 0 - - -async def test_concurrent_message_processing() -> None: - target_value: typing.Final = 5 - handler: typing.Final = KafkaConcurrentHandler(committer=MockKafkaBatchCommitter(), concurrency_limit=target_value) # ty: ignore[invalid-argument-type] - await handler.start() - - start_times: typing.Final = [] - end_times: typing.Final = [] - - async def tracked_task(idx: int) -> None: - start_times.append((idx, asyncio.get_event_loop().time())) - await asyncio.sleep(0.05) - end_times.append((idx, asyncio.get_event_loop().time())) - - msg: typing.Final = MockKafkaMessage() - record: typing.Final = MockConsumerRecord() + # Let tasks complete naturally before stop, then assert lifecycle is clean. + if handler._tracked_tasks: + await asyncio.gather(*list(handler._tracked_tasks), return_exceptions=True) - for i in range(target_value): - await handler.handle_task(tracked_task(i), record, msg) # ty: ignore[invalid-argument-type] - - await handler.wait_for_subtasks() await handler.stop() - if len(start_times) == target_value and len(end_times) == target_value: - max_start: typing.Final = max(t for _, t in start_times) - min_end: typing.Final = min(t for _, t in end_times) - assert max_start < min_end - - -async def test_concurrent_signal_handling_integration() -> None: - handler: typing.Final = KafkaConcurrentHandler(committer=MockKafkaBatchCommitter()) # ty: ignore[invalid-argument-type] - await handler.start() - - handler._signal_handler(signal.SIGTERM) - assert handler._stop_task is not None - await handler._stop_task assert not handler.is_running + assert len(processed) == 5 def test_concurrent_create_rebalance_listener(handler: KafkaConcurrentHandler) -> None: diff --git a/tests/test_healthcheck.py b/tests/test_healthcheck.py index 8b27970..e5e4e90 100644 --- a/tests/test_healthcheck.py +++ b/tests/test_healthcheck.py @@ -12,7 +12,7 @@ async def test_healthy_when_handler_is_running() -> None: broker: typing.Final = KafkaBroker("localhost:9092") - async with TestKafkaBroker(broker) as test_broker: + async with TestKafkaBroker(broker, connect_only=False) as test_broker: await initialize_concurrent_processing(context=test_broker.context) try: assert is_kafka_handler_healthy(test_broker.context) is True @@ -22,13 +22,13 @@ async def test_healthy_when_handler_is_running() -> None: async def test_unhealthy_when_no_handler_in_context() -> None: broker: typing.Final = KafkaBroker("localhost:9092") - async with TestKafkaBroker(broker) as test_broker: + async with TestKafkaBroker(broker, connect_only=False) as test_broker: assert is_kafka_handler_healthy(test_broker.context) is False async def test_unhealthy_when_handler_stopped() -> None: broker: typing.Final = KafkaBroker("localhost:9092") - async with TestKafkaBroker(broker) as test_broker: + async with TestKafkaBroker(broker, connect_only=False) as test_broker: await initialize_concurrent_processing(context=test_broker.context) await stop_concurrent_processing(test_broker.context) assert is_kafka_handler_healthy(test_broker.context) is False @@ -36,7 +36,7 @@ async def test_unhealthy_when_handler_stopped() -> None: async def test_unhealthy_when_is_healthy_returns_false() -> None: broker: typing.Final = KafkaBroker("localhost:9092") - async with TestKafkaBroker(broker) as test_broker: + async with TestKafkaBroker(broker, connect_only=False) as test_broker: mock_handler: typing.Final = MagicMock() mock_handler.is_healthy = False test_broker.context.set_global("concurrent_processing", mock_handler) diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 882a6ad..af709a4 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -33,8 +33,8 @@ async def test_middleware_simple_message_processing(setup_broker: KafkaBroker) - async def handler(msg: typing.Any) -> None: processed_messages.append(msg) - async with TestKafkaBroker(setup_broker) as test_broker: - hdl: typing.Final = await initialize_concurrent_processing( + async with TestKafkaBroker(setup_broker, connect_only=False) as test_broker: + await initialize_concurrent_processing( context=test_broker.context, commit_batch_size=10, commit_batch_timeout_sec=5, @@ -43,7 +43,6 @@ async def handler(msg: typing.Any) -> None: try: await test_broker.publish({"id": 1, "data": "test"}, topic="test-topic") - await hdl.wait_for_subtasks() finally: await stop_concurrent_processing(test_broker.context) @@ -63,16 +62,15 @@ async def handler(msg: typing.Any) -> None: timestamps.append(("end", msg["id"], asyncio.get_event_loop().time())) processed.append(msg) - hdl: typing.Final = await initialize_concurrent_processing( + await initialize_concurrent_processing( context=setup_broker.context, commit_batch_size=10, commit_batch_timeout_sec=5, concurrency_limit=3 ) async def test(inner_broker: KafkaBroker) -> None: for i in range(expected_size): await inner_broker.publish({"id": i}, topic="parallel-topic") - await hdl.wait_for_subtasks() - async with TestKafkaBroker(setup_broker) as test_broker: + async with TestKafkaBroker(setup_broker, connect_only=False) as test_broker: await test(test_broker) # TestKafkaBroker uses FakeConsumer — middleware passes through directly (sequential) @@ -93,16 +91,15 @@ async def handler(msg: typing.Any) -> None: concurrent[0] -= 1 assert msg - hdl: typing.Final = await initialize_concurrent_processing( + await initialize_concurrent_processing( context=setup_broker.context, commit_batch_size=10, commit_batch_timeout_sec=5, concurrency_limit=2 ) async def test(inner_broker: KafkaBroker) -> None: for i in range(5): await inner_broker.publish({"id": i}, topic="limited-topic") - await hdl.wait_for_subtasks() - async with TestKafkaBroker(setup_broker) as test_broker: + async with TestKafkaBroker(setup_broker, connect_only=False) as test_broker: await test(test_broker) assert max_concurrent[0] <= concurrent_size, f"Concurrency limit exceeded: {max_concurrent[0]}" @@ -115,8 +112,8 @@ async def test_middleware_handler_context_instance_stable(setup_broker: KafkaBro async def handler(msg: typing.Any) -> None: processed.append(msg) - async with TestKafkaBroker(setup_broker) as test_broker: - hdl: typing.Final = await initialize_concurrent_processing( + async with TestKafkaBroker(setup_broker, connect_only=False) as test_broker: + await initialize_concurrent_processing( context=test_broker.context, commit_batch_size=10, commit_batch_timeout_sec=5, @@ -126,7 +123,6 @@ async def handler(msg: typing.Any) -> None: try: for i in range(3): await test_broker.publish({"id": i}, topic="stable-topic") - await hdl.wait_for_subtasks() finally: await stop_concurrent_processing(test_broker.context) @@ -135,7 +131,7 @@ async def handler(msg: typing.Any) -> None: async def test_middleware_initialize_start_failure_raises(setup_broker: KafkaBroker) -> None: with patch.object(KafkaConcurrentHandler, "start", side_effect=Exception("Start failed")): - async with TestKafkaBroker(setup_broker) as test_broker: + async with TestKafkaBroker(setup_broker, connect_only=False) as test_broker: with pytest.raises(Exception, match="Start failed"): await initialize_concurrent_processing( context=test_broker.context, @@ -149,7 +145,7 @@ async def test_middleware_initialize_skips_when_already_running( ) -> None: caplog.set_level(logging.WARNING) - async with TestKafkaBroker(setup_broker) as test_broker: + async with TestKafkaBroker(setup_broker, connect_only=False) as test_broker: await initialize_concurrent_processing( context=test_broker.context, commit_batch_size=10, @@ -180,7 +176,7 @@ async def test_middleware_shutting_down_skips_message( @setup_broker.subscriber("shutting-down-topic", group_id="shutting-down-group") async def handler(msg: typing.Any) -> None: ... - async with TestKafkaBroker(setup_broker) as test_broker: + async with TestKafkaBroker(setup_broker, connect_only=False) as test_broker: handler_instance: typing.Final = await initialize_concurrent_processing( context=test_broker.context, commit_batch_size=10, @@ -223,7 +219,7 @@ async def test_middleware_catches_committer_is_dead_during_race( @setup_broker.subscriber("dead-committer-topic", group_id="dead-committer-group") async def handler(msg: typing.Any) -> None: ... - async with TestKafkaBroker(setup_broker) as test_broker: + async with TestKafkaBroker(setup_broker, connect_only=False) as test_broker: handler_instance: typing.Final = await initialize_concurrent_processing( context=test_broker.context, commit_batch_size=10, @@ -256,12 +252,55 @@ def mock_get(key: str, default: typing.Any = None) -> typing.Any: await stop_concurrent_processing(test_broker.context) +async def test_middleware_logs_and_propagates_cancelled_error( + setup_broker: KafkaBroker, caplog: pytest.LogCaptureFixture +) -> None: + """CancelledError raised by handle_task is logged and re-raised. + + When stop() cancels a task while handle_task is awaiting send_task, the resulting + CancelledError must be logged and re-raised so FastStream's chain can clean up. + """ + caplog.set_level(logging.WARNING) + + @setup_broker.subscriber("cancel-topic", group_id="cancel-group") + async def handler(msg: typing.Any) -> None: ... + + async with TestKafkaBroker(setup_broker, connect_only=False) as test_broker: + handler_instance: typing.Final = await initialize_concurrent_processing( + context=test_broker.context, + ) + + original_handle_task: typing.Final = handler_instance.handle_task + + async def raising_handle_task(coro: typing.Any, *_args: typing.Any) -> None: + coro.close() + raise asyncio.CancelledError + + handler_instance.handle_task = raising_handle_task # ty: ignore[invalid-assignment] + + original_get: typing.Final = test_broker.context.get + + def mock_get(key: str, default: typing.Any = None) -> typing.Any: + if key == "message": + return MockKafkaMessage() + return original_get(key, default) + + test_broker.context.get = mock_get # ty: ignore[invalid-assignment] + + with pytest.raises(asyncio.CancelledError): + await test_broker.publish({"id": 1}, topic="cancel-topic") + assert "Task cancelled during shutdown" in caplog.text + + handler_instance.handle_task = original_handle_task # ty: ignore[invalid-assignment] + await stop_concurrent_processing(test_broker.context) + + async def test_middleware_no_kafka_message_with_batch_processing_raises(setup_broker: KafkaBroker) -> None: @setup_broker.subscriber("no-kafka-msg-topic", group_id="no-kafka-msg-group") async def handler(msg: typing.Any) -> None: ... - async with TestKafkaBroker(setup_broker) as test_broker: + async with TestKafkaBroker(setup_broker, connect_only=False) as test_broker: await initialize_concurrent_processing( context=test_broker.context, commit_batch_size=10, @@ -288,7 +327,7 @@ async def test_middleware_raises_if_auto_commit_enabled(setup_broker: KafkaBroke @setup_broker.subscriber("auto-commit-topic", group_id="auto-commit-group") async def handler(msg: typing.Any) -> None: ... - async with TestKafkaBroker(setup_broker) as test_broker: + async with TestKafkaBroker(setup_broker, connect_only=False) as test_broker: await initialize_concurrent_processing( context=test_broker.context, commit_batch_size=10, @@ -318,7 +357,7 @@ async def test_middleware_no_handler_in_context_raises(setup_broker: KafkaBroker @setup_broker.subscriber("no-handler-topic", group_id="no-handler-group") async def handler(msg: typing.Any) -> None: ... - async with TestKafkaBroker(setup_broker) as test_broker: + async with TestKafkaBroker(setup_broker, connect_only=False) as test_broker: # Override message with a MANUAL-ack mock so the middleware reaches the # is_running check (FakeConsumer and non-MANUAL messages pass through first). original_get: typing.Final = test_broker.context.get @@ -348,7 +387,7 @@ async def test_middleware_non_manual_ack_passes_through_without_concurrent_proce async def handler(msg: typing.Any) -> None: processed.append(msg) - async with TestKafkaBroker(setup_broker) as test_broker: + async with TestKafkaBroker(setup_broker, connect_only=False) as test_broker: original_get: typing.Final = test_broker.context.get def mock_get(key: str, default: typing.Any = None) -> typing.Any: @@ -374,7 +413,7 @@ async def test_middleware_batch_processing_has_committer(setup_broker: KafkaBrok async def handler(msg: typing.Any) -> None: processed.append(msg) - async with TestKafkaBroker(setup_broker) as test_broker: + async with TestKafkaBroker(setup_broker, connect_only=False) as test_broker: handler_instance: typing.Final = await initialize_concurrent_processing( context=test_broker.context, commit_batch_size=10, @@ -384,7 +423,6 @@ async def handler(msg: typing.Any) -> None: try: for i in range(expected_size): await test_broker.publish({"id": i}, topic="batch-topic") - await handler_instance.wait_for_subtasks() finally: await stop_concurrent_processing(test_broker.context) @@ -397,19 +435,18 @@ async def test_middleware_stop_without_start_is_noop( ) -> None: caplog.set_level(logging.WARNING) - async with TestKafkaBroker(setup_broker) as test_broker: + async with TestKafkaBroker(setup_broker, connect_only=False) as test_broker: await stop_concurrent_processing(test_broker.context) assert "Concurrent processing is not running" in caplog.text async def test_middleware_initialize_passes_shutdown_timeout(setup_broker: KafkaBroker) -> None: - """initialize_concurrent_processing forwards shutdown_timeout_sec to handler and committer.""" - async with TestKafkaBroker(setup_broker) as test_broker: + """initialize_concurrent_processing forwards shutdown_timeout_sec to the committer.""" + async with TestKafkaBroker(setup_broker, connect_only=False) as test_broker: handler: typing.Final = await initialize_concurrent_processing( context=test_broker.context, shutdown_timeout_sec=5.0 ) try: - assert handler._shutdown_timeout_sec == 5.0 assert handler._committer._shutdown_timeout == 5.0 finally: await stop_concurrent_processing(test_broker.context) @@ -417,7 +454,7 @@ async def test_middleware_initialize_passes_shutdown_timeout(setup_broker: Kafka async def test_middleware_stop_cleans_up_when_committer_dead(setup_broker: KafkaBroker) -> None: """If the committer task has died, stop_concurrent_processing must still tear down the handler.""" - async with TestKafkaBroker(setup_broker) as test_broker: + async with TestKafkaBroker(setup_broker, connect_only=False) as test_broker: handler: typing.Final = await initialize_concurrent_processing(context=test_broker.context) committer_task: typing.Final = handler._committer._commit_task @@ -445,7 +482,7 @@ async def handler(msg: typing.Any) -> None: msg = "Handler failed" raise ValueError(msg) - async with TestKafkaBroker(setup_broker) as test_broker: + async with TestKafkaBroker(setup_broker, connect_only=False) as test_broker: await initialize_concurrent_processing( context=test_broker.context, commit_batch_size=10, @@ -471,7 +508,7 @@ async def handler(msg: typing.Any) -> None: msg = "Failed" raise ValueError(msg) - async with TestKafkaBroker(setup_broker) as test_broker: + async with TestKafkaBroker(setup_broker, connect_only=False) as test_broker: await initialize_concurrent_processing( context=test_broker.context, commit_batch_size=10, @@ -494,7 +531,7 @@ async def handler(msg: typing.Any) -> None: async def test_middleware_start_stop_reinitialize(setup_broker: KafkaBroker) -> None: """Handler can be stopped and re-initialized; the second instance is fresh and healthy.""" - async with TestKafkaBroker(setup_broker) as test_broker: + async with TestKafkaBroker(setup_broker, connect_only=False) as test_broker: first_handler: typing.Final = await initialize_concurrent_processing( context=test_broker.context, concurrency_limit=5 ) @@ -520,7 +557,7 @@ async def test_middleware_general_exception_wrapped( @setup_broker.subscriber("general-error-topic", group_id="general-error-group") async def handler(msg: typing.Any) -> None: ... - async with TestKafkaBroker(setup_broker) as test_broker: + async with TestKafkaBroker(setup_broker, connect_only=False) as test_broker: await initialize_concurrent_processing( context=test_broker.context, commit_batch_size=10, @@ -545,8 +582,8 @@ async def test_middleware_fake_consumer_no_commit_error( async def handler(msg: typing.Any) -> None: pass - async with TestKafkaBroker(setup_broker) as test_broker: - hdl: typing.Final = await initialize_concurrent_processing( + async with TestKafkaBroker(setup_broker, connect_only=False) as test_broker: + await initialize_concurrent_processing( context=test_broker.context, commit_batch_size=10, commit_batch_timeout_sec=5, @@ -554,7 +591,6 @@ async def handler(msg: typing.Any) -> None: try: await test_broker.publish({"id": 1}, topic="fake-consumer-topic") - await hdl.wait_for_subtasks() finally: await stop_concurrent_processing(test_broker.context)