diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..01a59e0 --- /dev/null +++ b/Makefile @@ -0,0 +1,14 @@ +.DEFAULT: + @echo "No such command (or you pass two or many targets to ). List of possible commands: make help" + +.DEFAULT_GOAL := help + +##@ Local development + +.PHONY: help +help: ## Show this help + @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m \033[0m\n"} /^[a-zA-Z_-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m %s\033[0m\n\n", substr($$0, 5) } ' $(MAKEFILE_LIST) + +.PHONY: clear_rabbit +clear_rabbit: ## Clear RabbitMQ data volume and restart container + @docker stop taskiq_aio_pika_rabbitmq && docker rm taskiq_aio_pika_rabbitmq && docker volume rm taskiq-aio-pika_rabbitmq_data && docker compose up -d diff --git a/README.md b/README.md index 39ebbee..a9b0db1 100644 --- a/README.md +++ b/README.md @@ -9,13 +9,14 @@ This library provides you with aio-pika broker for taskiq. Features: - Supports delayed messages using dead-letter queues or RabbitMQ delayed message exchange plugin. - Supports message priorities. +- Supports multiple queues and custom routing. Usage example: ```python from taskiq_aio_pika import AioPikaBroker -broker = AioPikaBroker() +broker = AioPikaBroker(...) @broker.task async def test() -> None: @@ -32,7 +33,7 @@ To send delayed message, you have to specify delay label. You can do it with `ta In this type of delay we are using additional queue with `expiration` parameter. After declared time message will be deleted from `delay` queue and sent to the main queue. For example: ```python -broker = AioPikaBroker() +broker = AioPikaBroker(...) @broker.task(delay=3) async def delayed_task() -> int: @@ -86,13 +87,12 @@ async def main(): ## Priorities You can define priorities for messages using `priority` label. Messages with higher priorities are delivered faster. -But to use priorities you need to define `max_priority` of the main queue, by passing `max_priority` parameter in broker's init. This parameter sets maximum priority for the queue and declares it as the priority queue. Before doing so please read the [documentation](https://www.rabbitmq.com/priority.html#behaviour) about what downsides you get by using prioritized queues. ```python -broker = AioPikaBroker(max_priority=10) +broker = AioPikaBroker(...) # We can define default priority for tasks. @broker.task(priority=2) @@ -111,42 +111,43 @@ async def main(): await prio_task.kicker().with_labels(priority=None).kiq() ``` -## Configuration +## Custom Queue and Exchange arguments -AioPikaBroker parameters: - -* `url` - url to rabbitmq. If None, "amqp://guest:guest@localhost:5672" is used. -* `result_backend` - custom result backend. -* `task_id_generator` - custom task_id genertaor. -* `exchange_name` - name of exchange that used to send messages. -* `exchange_type` - type of the exchange. Used only if `declare_exchange` is True. -* `queue_name` - queue that used to get incoming messages. -* `routing_key` - that used to bind that queue to the exchange. -* `declare_exchange` - whether you want to declare new exchange if it doesn't exist. -* `max_priority` - maximum priority for messages. -* `delay_queue_name` - custom delay queue name. This queue is used to deliver messages with delays. -* `dead_letter_queue_name` - custom dead letter queue name. - This queue is used to receive negatively acknowledged messages from the main queue. -* `qos` - number of messages that worker can prefetch. -* `declare_queues` - whether you want to declare queues even on client side. May be useful for message persistence. -* `declare_queues_kwargs` - see [Custom Queue Arguments](#custom-queue-arguments) for more details. - -## Custom Queue Arguments - -You can pass custom arguments to the underlying RabbitMQ queue declaration by using the `declare_queues_kwargs` parameter of `AioPikaBroker`. If you want to set specific queue arguments (such as RabbitMQ extensions or custom behaviors), provide them in the `arguments` dictionary inside `declare_queues_kwargs`. +You can pass custom arguments to the underlying RabbitMQ queues and exchange declaration by using the `Queue`/`Exchange` classes from `taskiq_aio_pika`. If you used `faststream` before you are probably familiar with this concept. These arguments will be merged with the default arguments used by the broker (such as dead-lettering and priority settings). If there are any conflicts, the values you provide will take precedence over the broker's defaults. Example: ```python +from taskiq_aio_pika import AioPikaBroker, Queue, QueueType, Exchange +from aio_pika.abc import ExchangeType + broker = AioPikaBroker( - declare_queues_kwargs={ - "arguments": { - "x-message-ttl": 60000, # Set message TTL to 60 seconds - "x-queue-type": "quorum", # Use quorum queue type - } - } + exchange=Exchange( + name="custom_exchange", + type=ExchangeType.TOPIC, + declare=True, + durable=True, + auto_delete=False, + ) + task_queues=[ + Queue( + name="custom_queue", + type=QueueType.CLASSIC, + declare=True, + durable=True, + max_priority=10, + routing_key="custom_queue", + ) + ] ) ``` This will ensure that the queue is created with your custom arguments, in addition to the broker's defaults. + + +## Multiqueue support + +You can define multiple queues for your tasks. Each queue can have its own routing key and other settings. And your workers can listen to multiple queues (or specific queue) as well. + +You can check [multiqueue usage example](./examples/topic_with_two_queues.py) in examples folder for more details. diff --git a/docker-compose.yaml b/docker-compose.yaml index e8d5dc3..acca9dc 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -2,6 +2,7 @@ services: rabbitmq: container_name: taskiq_aio_pika_rabbitmq image: heidiks/rabbitmq-delayed-message-exchange:latest + # image: rabbitmq:3.13.7-management # rabbit with management UI for debugging environment: RABBITMQ_DEFAULT_USER: "guest" RABBITMQ_DEFAULT_PASS: "guest" @@ -14,4 +15,13 @@ services: ports: - "5672:5672" - "15672:15672" - - "61613:61613" + volumes: + - rabbitmq_data:/var/lib/rabbitmq + redis: + container_name: taskiq_aio_pika_redis + image: redis:latest + ports: + - "6379:6379" + +volumes: + rabbitmq_data: diff --git a/examples/basic.py b/examples/basic.py new file mode 100644 index 0000000..7da8b4f --- /dev/null +++ b/examples/basic.py @@ -0,0 +1,40 @@ +""" +Basic example of using Taskiq with AioPika broker. + +How to run: + 1. Run worker: taskiq worker examples.basic:broker -w 1 + 2. Run broker: uv run examples/basic.py +""" + +import asyncio + +from taskiq_redis import RedisAsyncResultBackend + +from taskiq_aio_pika import AioPikaBroker + +broker = AioPikaBroker( + "amqp://guest:guest@localhost:5672/", +).with_result_backend(RedisAsyncResultBackend("redis://localhost:6379/0")) + + +@broker.task +async def add_one(value: int) -> int: + return value + 1 + + +async def main() -> None: + await broker.startup() + # Send the task to the broker. + task = await add_one.kiq(1) + # Wait for the result. + result = await task.wait_result(timeout=2) + print(f"Task execution took: {result.execution_time} seconds.") + if not result.is_err: + print(f"Returned value: {result.return_value}") + else: + print("Error found while executing task.") + await broker.shutdown() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/delayed_task.py b/examples/delayed_task.py new file mode 100644 index 0000000..c91f056 --- /dev/null +++ b/examples/delayed_task.py @@ -0,0 +1,41 @@ +""" +Example of delayed task execution using Taskiq with AioPika broker. + +How to run: + 1. Run worker: taskiq worker examples.delayed_task:broker -w 1 + 2. Run broker: uv run examples/delayed_task.py +""" + +import asyncio + +from taskiq_redis import RedisAsyncResultBackend + +from taskiq_aio_pika import AioPikaBroker + +broker = AioPikaBroker( + "amqp://guest:guest@localhost:5672/", +).with_result_backend(RedisAsyncResultBackend("redis://localhost:6379/0")) + + +@broker.task +async def add_one(value: int) -> int: + return value + 1 + + +async def main() -> None: + await broker.startup() + # Send the task to the broker. + task = await add_one.kicker().with_labels(delay=2).kiq(1) + print("Task sent with 2 seconds delay.") + # Wait for the result. + result = await task.wait_result(timeout=3) + print(f"Task execution took: {result.execution_time} seconds.") + if not result.is_err: + print(f"Returned value: {result.return_value}") + else: + print("Error found while executing task.") + await broker.shutdown() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/topic_with_two_queues.py b/examples/topic_with_two_queues.py new file mode 100644 index 0000000..3ba3d82 --- /dev/null +++ b/examples/topic_with_two_queues.py @@ -0,0 +1,96 @@ +""" +Example with two queues for different workers and one topic exchange. + +It can be useful when you want to have two worker + +How to run: + 1. Run worker for queue_1: taskiq worker examples.topic_with_two_queues:get_broker_for_queue_1 -w 1 + 2. Run worker for queue_2: taskiq worker examples.topic_with_two_queues:get_broker_for_queue_2 -w 1 + 3. Run broker to send a task: uv run examples/topic_with_two_queues.py --queue 1 + 4. Optionally run broker to send a task to other queue: uv run examples/topic_with_two_queues.py --queue 2 +""" + +import argparse +import asyncio +import uuid + +from aio_pika.abc import ExchangeType +from taskiq_redis import RedisAsyncResultBackend + +from taskiq_aio_pika import AioPikaBroker, Exchange, Queue, QueueType + +broker = AioPikaBroker( + "amqp://guest:guest@localhost:5672/", + exchange=Exchange( + name="topic_exchange", + type=ExchangeType.TOPIC, + ), + delay_queue=Queue( + name="taskiq.delay", + routing_key="queue1", + ), # send delayed messages to queue1 +).with_result_backend(RedisAsyncResultBackend("redis://localhost:6379/0")) + + +@broker.task +async def add_one(value: int) -> int: + return value + 1 + + +queue_1 = Queue( + name="queue1", + type=QueueType.CLASSIC, + durable=False, +) +queue_2 = Queue( + name="queue2", + type=QueueType.CLASSIC, + durable=False, +) + + +def get_broker_for_queue_1() -> AioPikaBroker: + print("This broker will listen to queue1") + return broker.with_queue(queue_1) + + +def get_broker_for_queue_2() -> AioPikaBroker: + print("This broker will listen to queue2") + return broker.with_queue(queue_2) + + +async def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--queue", + choices=["1", "2"], + required=True, + help="Queue to send the task to.", + ) + args = parser.parse_args() + + queue_name = queue_1.name if args.queue == "1" else queue_2.name + + broker.with_queues( + queue_1, + queue_2, + ) # declare both queues to know about them during publishing + await broker.startup() + + task = ( + await add_one.kicker() + .with_labels(queue_name=queue_name) # or it can be routing_key from queue_1 + .with_task_id(uuid.uuid4().hex) + .kiq(2) + ) + result = await task.wait_result(timeout=2) + print(f"Task execution took: {result.execution_time} seconds.") + if not result.is_err: + print(f"Returned value: {result.return_value}") + else: + print("Error found while executing task.") + await broker.shutdown() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 66e9e0c..4520b90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,8 +30,9 @@ repository = "https://github.com/taskiq-python/taskiq-aio-pika" keywords = ["taskiq", "tasks", "distributed", "async", "aio-pika"] requires-python = ">=3.10,<4" dependencies = [ - "taskiq>=0.11.20,<1", + "taskiq>=0.12.0,<1", "aio-pika>=9.0.0", + "aiostream>=0.7.1", ] [dependency-groups] @@ -48,6 +49,10 @@ dev = [ "coverage>=7.11.3", "pytest-xdist[psutil]>=3.8.0", "anyio>=4.11.0", + {include-group = "examples"}, +] +examples = [ + "taskiq-redis>=1.1.2", ] [tool.mypy] @@ -130,6 +135,12 @@ line-length = 88 "SLF001", # Private member accessed "S311", # Standard pseudo-random generators are not suitable for security/cryptographic purposes "D101", # Missing docstring in public class + "D102", # Missing docstring in public method + "E501", # Line too long +] +"examples/*" = [ + "D", # missing docstrings + "T201", # print found ] [tool.ruff.lint.pydocstyle] diff --git a/taskiq_aio_pika/__init__.py b/taskiq_aio_pika/__init__.py index 6bab083..b05b517 100644 --- a/taskiq_aio_pika/__init__.py +++ b/taskiq_aio_pika/__init__.py @@ -3,7 +3,14 @@ from importlib.metadata import version from taskiq_aio_pika.broker import AioPikaBroker +from taskiq_aio_pika.exchange import Exchange +from taskiq_aio_pika.queue import Queue, QueueType __version__ = version("taskiq-aio-pika") -__all__ = ["AioPikaBroker"] +__all__ = [ + "AioPikaBroker", + "Exchange", + "Queue", + "QueueType", +] diff --git a/taskiq_aio_pika/broker.py b/taskiq_aio_pika/broker.py index e940f46..f910c0a 100644 --- a/taskiq_aio_pika/broker.py +++ b/taskiq_aio_pika/broker.py @@ -4,9 +4,22 @@ from logging import getLogger from typing import Any, TypeVar +import aiormq from aio_pika import DeliveryMode, ExchangeType, Message, connect_robust from aio_pika.abc import AbstractChannel, AbstractQueue, AbstractRobustConnection +from aiostream import stream +from pamqp.common import FieldTable from taskiq import AckableMessage, AsyncBroker, AsyncResultBackend, BrokerMessage +from typing_extensions import Self + +from taskiq_aio_pika.exceptions import ( + ExchangeNotDeclaredError, + IncorrectRoutingKeyError, + NoStartupError, + QueueNotDeclaredError, +) +from taskiq_aio_pika.exchange import Exchange +from taskiq_aio_pika.queue import Queue _T = TypeVar("_T") @@ -43,18 +56,14 @@ def __init__( task_id_generator: Callable[[], str] | None = None, qos: int = 10, loop: asyncio.AbstractEventLoop | None = None, - exchange_name: str = "taskiq", - queue_name: str = "taskiq", - dead_letter_queue_name: str | None = None, - delay_queue_name: str | None = None, - declare_exchange: bool = True, - declare_queues: bool = True, - routing_key: str = "#", - exchange_type: ExchangeType = ExchangeType.TOPIC, - max_priority: int | None = None, + exchange: Exchange | None = None, + task_queues: list[Queue] | None = None, + dead_letter_queue: Queue | None = None, + delay_queue: Queue | None = None, delayed_message_exchange_plugin: bool = False, - declare_exchange_kwargs: dict[Any, Any] | None = None, - declare_queues_kwargs: dict[Any, Any] | None = None, + delayed_message_exchange: Exchange | None = None, + label_for_routing: str = "queue_name", + label_for_priority: str = "priority", **connection_kwargs: Any, ) -> None: """ @@ -63,28 +72,20 @@ def __init__( :param url: url to rabbitmq. If None, the default "amqp://guest:guest@localhost:5672" is used. :param result_backend: custom result backend. - :param task_id_generator: custom task_id genertaor. :param qos: number of messages that worker can prefetch. :param loop: specific even loop. - :param exchange_name: name of exchange that used to send messages. - :param queue_name: queue that used to get incoming messages. - :param dead_letter_queue_name: custom name for dead-letter queue. - by default it set to {queue_name}.dead_letter. - :param delay_queue_name: custom name for queue that used to - deliver messages with delays. - :param declare_exchange: whether you want to declare new exchange - if it doesn't exist. - :param declare_queues: whether you want to declare queues even on - client side. May be useful for message persistence. - :param routing_key: that used to bind that queue to the exchange. - :param exchange_type: type of the exchange. - Used only if `declare_exchange` is True. - :param max_priority: maximum priority value for messages. + :param exchange: parameters of exchange that used to send messages. + :param task_queues: parameters of queues + that will be used to get incoming messages. + :param dead_letter_queue: parameters of dead-letter queue. + :param delay_queue: parameters of queue for simple delay implementation. :param delayed_message_exchange_plugin: turn on or disable delayed-message-exchange rabbitmq plugin. - :param declare_exchange_kwargs: additional from AbstractChannel.declare_exchange - :param declare_queues_kwargs: additional from AbstractChannel.declare_queue + :param delayed_message_exchange: parameters of exchange + that used to send messages with delay. + :param label_for_routing: label name to use for routing key selection. + :param label_for_priority: label name to use for message priority. :param connection_kwargs: additional keyword arguments, for connect_robust method of aio-pika. """ @@ -93,27 +94,24 @@ def __init__( self.url = url self._loop = loop self._conn_kwargs = connection_kwargs - self._exchange_name = exchange_name - self._exchange_type = exchange_type + self._exchange = exchange or Exchange() self._qos = qos - self._declare_exchange = declare_exchange - self._declare_exchange_kwargs = declare_exchange_kwargs or {} - self._declare_queues = declare_queues - self._declare_queues_kwargs = declare_queues_kwargs or {} - self._queue_name = queue_name - self._routing_key = routing_key - self._max_priority = max_priority - self._delayed_message_exchange_plugin = delayed_message_exchange_plugin + self._task_queues = task_queues or [] + self._dead_letter_queue = dead_letter_queue or Queue(name="taskiq.dead_letter") - self._dead_letter_queue_name = f"{queue_name}.dead_letter" - if dead_letter_queue_name: - self._dead_letter_queue_name = dead_letter_queue_name + self._label_for_routing = label_for_routing + self._label_for_priority = label_for_priority - self._delay_queue_name = f"{queue_name}.delay" - if delay_queue_name: - self._delay_queue_name = delay_queue_name + self._delay_queue = delay_queue or Queue(name="taskiq.delay") - self._delay_plugin_exchange_name = f"{exchange_name}.plugin_delay" + self._delayed_message_exchange_plugin = delayed_message_exchange_plugin + if self._delayed_message_exchange_plugin: + self._delayed_message_exchange = delayed_message_exchange or Exchange( + name=f"{self._exchange.name}.plugin_delay", + type=ExchangeType.X_DELAYED_MESSAGE, + arguments={"x-delayed-type": "direct"}, + declare=True, + ) self.read_conn: AbstractRobustConnection | None = None self.write_conn: AbstractRobustConnection | None = None @@ -138,43 +136,110 @@ async def startup(self) -> None: ) self.read_channel = await self.read_conn.channel() - if self._declare_exchange: - await self.write_channel.declare_exchange( - self._exchange_name, - type=self._exchange_type, - **self._declare_exchange_kwargs, + await self._declare_exchanges() + await self._declare_queues(self.write_channel) + + async def _declare_exchanges( + self, + ) -> None: + """ + Declare all exchanges. + + :param channel: channel to use for declaration. + :raises NoStartupError: if startup wasn't called. + """ + if self.write_channel is None: + raise NoStartupError( + "Write channel is not initialized. Please call startup first.", ) - if self._delayed_message_exchange_plugin: + if self._exchange.declare: await self.write_channel.declare_exchange( - self._delay_plugin_exchange_name, - type=ExchangeType.X_DELAYED_MESSAGE, - arguments={ - "x-delayed-type": "direct", - }, + name=self._exchange.name, + type=self._exchange.type, + durable=self._exchange.durable, + auto_delete=self._exchange.auto_delete, + internal=self._exchange.internal, + passive=self._exchange.passive, + arguments=self._exchange.arguments, + timeout=self._exchange.timeout, ) + else: + try: + await self.write_channel.get_exchange( + name=self._exchange.name, + ensure=True, + ) + except aiormq.exceptions.ChannelNotFoundEntity as error: + raise ExchangeNotDeclaredError( + f"Exchange '{self._exchange.name}' " + f"was not declared and does not exist.", + ) from error - if self._declare_queues: - await self.declare_queues(self.write_channel) - - async def shutdown(self) -> None: - """Close all connections on shutdown.""" - await super().shutdown() - if self.write_channel: - await self.write_channel.close() - if self.read_channel: - await self.read_channel.close() - if self.write_conn: - await self.write_conn.close() - if self.read_conn: - await self.read_conn.close() + if self._delayed_message_exchange_plugin: + if self._delayed_message_exchange.declare: + await self.write_channel.declare_exchange( + name=self._delayed_message_exchange.name, + type=self._delayed_message_exchange.type, + durable=self._delayed_message_exchange.durable, + auto_delete=self._delayed_message_exchange.auto_delete, + internal=self._delayed_message_exchange.internal, + passive=self._delayed_message_exchange.passive, + arguments=self._delayed_message_exchange.arguments, + timeout=self._delayed_message_exchange.timeout, + ) + else: + try: + await self.write_channel.get_exchange( + name=self._delayed_message_exchange.name, + ensure=True, + ) + except aiormq.exceptions.ChannelNotFoundEntity as error: + raise ExchangeNotDeclaredError( + f"Exchange '{self._delayed_message_exchange.name}' " + f"was not declared and does not exist.", + ) from error + + async def _declare_dead_letter_queue( + self, + channel: AbstractChannel, + ) -> None: + if self._dead_letter_queue.declare: + dead_letter_queue_arguments = self._dead_letter_queue.arguments.copy() + if self._dead_letter_queue.max_priority is not None: + dead_letter_queue_arguments["x-max-priority"] = ( + self._dead_letter_queue.max_priority + ) + dead_letter_queue_arguments["x-queue-type"] = ( + self._dead_letter_queue.type.value + ) + await channel.declare_queue( + name=self._dead_letter_queue.name, + durable=self._dead_letter_queue.durable, + exclusive=self._dead_letter_queue.exclusive, + passive=self._dead_letter_queue.passive, + auto_delete=self._dead_letter_queue.auto_delete, + arguments=dead_letter_queue_arguments, + timeout=self._dead_letter_queue.timeout, + ) + else: + try: + await channel.get_queue( + name=self._dead_letter_queue.name, + ensure=True, + ) + except aiormq.exceptions.ChannelNotFoundEntity as error: + raise QueueNotDeclaredError( + f"Dead-letter queue '{self._dead_letter_queue.name}' " + f"was not declared and does not exist.", + ) from error - async def declare_queues( + async def _declare_queues( self, channel: AbstractChannel, - ) -> AbstractQueue: + ) -> list[tuple[AbstractQueue, dict[str, Any]]]: """ - This function is used to declare queues. + Declare all queues. It's useful since aio-pika have automatic recover mechanism, which works only if @@ -184,49 +249,95 @@ async def declare_queues( :param channel: channel to used for declaration. :return: main queue instance. """ - await channel.declare_queue( - self._dead_letter_queue_name, - **self._declare_queues_kwargs, - ) - args: dict[str, Any] = { + await self._declare_dead_letter_queue(channel) + declared_queues = [] + queue_default_arguments: FieldTable = { "x-dead-letter-exchange": "", - "x-dead-letter-routing-key": self._dead_letter_queue_name, + "x-dead-letter-routing-key": ( + self._dead_letter_queue.routing_key or self._dead_letter_queue.name + ), } - if self._max_priority is not None: - args["x-max-priority"] = self._max_priority - queue = await channel.declare_queue( - self._queue_name, - **{ - **self._declare_queues_kwargs, - "arguments": { - **self._declare_queues_kwargs.get("arguments", {}), - **args, - }, - }, - ) - if self._delayed_message_exchange_plugin: - await queue.bind( - exchange=self._delay_plugin_exchange_name, - routing_key=self._routing_key, + if not self._task_queues: # add default queue if user didn't provide any + self._task_queues.append(Queue()) + + queues = self._task_queues.copy() + if not self._delayed_message_exchange_plugin: + queues.append(self._delay_queue) + + for queue in filter(lambda queue: queue.declare, queues): + per_queue_arguments: FieldTable = queue_default_arguments.copy() + if queue.max_priority is not None: + per_queue_arguments["x-max-priority"] = queue.max_priority + per_queue_arguments["x-queue-type"] = queue.type.value + if queue.name == self._delay_queue.name: + per_queue_arguments["x-dead-letter-exchange"] = self._exchange.name + per_queue_arguments["x-dead-letter-routing-key"] = ( + self._delay_queue.routing_key + or queues[0].routing_key + or queues[0].name + ) + per_queue_arguments.update( + queue.arguments if queue.arguments is not None else {}, ) - else: - await channel.declare_queue( - self._delay_queue_name, - **{ - **self._declare_queues_kwargs, - "arguments": { - **self._declare_queues_kwargs.get("arguments", {}), - "x-dead-letter-exchange": "", - "x-dead-letter-routing-key": self._queue_name, - }, - }, + declared_queue = await channel.declare_queue( + name=queue.name, + durable=queue.durable, + exclusive=queue.exclusive, + passive=queue.passive, + auto_delete=queue.auto_delete, + arguments=per_queue_arguments, + timeout=queue.timeout, + ) + logger.debug( + "Bind queue to exchange with routing key '%s'", + queue.routing_key or queue.name, ) + if queue.name != self._delay_queue.name: + await declared_queue.bind( + exchange=self._exchange.name, + routing_key=queue.routing_key or queue.name, + arguments=queue.bind_arguments, + timeout=queue.bind_timeout, + ) + if self._delayed_message_exchange_plugin: + await declared_queue.bind( + exchange=self._delayed_message_exchange.name, + routing_key=queue.routing_key or queue.name, + ) + declared_queues.append((declared_queue, queue.consumer_arguments)) - await queue.bind( - exchange=self._exchange_name, - routing_key=self._routing_key, - ) - return queue + for queue in filter(lambda queue: not queue.declare, queues): + try: + existing_queue = await channel.get_queue( + name=queue.name, + ensure=True, + ) + except aiormq.exceptions.ChannelNotFoundEntity as error: + raise QueueNotDeclaredError( + f"Queue '{queue.name}' was not declared and does not exist.", + ) from error + declared_queues.append((existing_queue, queue.consumer_arguments)) + return declared_queues + + def with_queue(self, queue: Queue) -> Self: + """ + Add new queue to the broker. + + :param queue: queue to add. + :return: self. + """ + self._task_queues.append(queue) + return self + + def with_queues(self, *queues: Queue) -> Self: + """ + Replace existing queues with new ones. + + :param queues: queues to add. + :return: self. + """ + self._task_queues = list(queues) + return self async def kick(self, message: BrokerMessage) -> None: """ @@ -239,56 +350,62 @@ async def kick(self, message: BrokerMessage) -> None: in headers. And message's routing key is the same as the task_name. - :raises ValueError: if startup wasn't called. + :raises NoStartupError: if startup wasn't called. + :raises IncorrectRoutingKeyError: if routing key is incorrect. :param message: message to send. """ if self.write_channel is None: - raise ValueError("Please run startup before kicking.") - - message_base_params: dict[str, Any] = { - "body": message.message, - "headers": { + raise NoStartupError("Please run startup before kicking.") + priority = parse_val(int, message.labels.get(self._label_for_priority)) + rmq_message = Message( + body=message.message, + headers={ "task_id": message.task_id, "task_name": message.task_name, **message.labels, }, - "delivery_mode": DeliveryMode.PERSISTENT, - "priority": parse_val( - int, - message.labels.get("priority"), - ), - } + delivery_mode=DeliveryMode.PERSISTENT, + priority=priority, + ) + delay = parse_val(float, message.labels.get("delay")) - delay: float | None = parse_val(float, message.labels.get("delay")) - rmq_message: Message = Message(**message_base_params) + if len(self._task_queues) == 1: + routing_key_name = ( + self._task_queues[0].routing_key or self._task_queues[0].name + ) + else: + routing_key_name = ( + parse_val( + str, + message.labels.get(self._label_for_routing), + ) + or "" + ) + if self._exchange.type == ExchangeType.DIRECT and routing_key_name not in { + queue.routing_key or queue.name for queue in self._task_queues + }: + raise IncorrectRoutingKeyError( + f"Routing key '{routing_key_name}' is not valid. " + f"Check routing keys and queue names in broker queues.", + ) if delay is None: exchange = await self.write_channel.get_exchange( - self._exchange_name, + self._exchange.name, ensure=False, ) - - routing_key = message.task_name - - # Because direct exchange uses exact routing key for routing - if self._exchange_type == ExchangeType.DIRECT: - routing_key = self._routing_key - - await exchange.publish(rmq_message, routing_key=routing_key) + await exchange.publish(rmq_message, routing_key=routing_key_name) elif self._delayed_message_exchange_plugin: rmq_message.headers["x-delay"] = int(delay * 1000) exchange = await self.write_channel.get_exchange( - self._delay_plugin_exchange_name, - ) - await exchange.publish( - rmq_message, - routing_key=self._routing_key, + self._delayed_message_exchange.name, ) + await exchange.publish(rmq_message, routing_key=routing_key_name) else: rmq_message.expiration = timedelta(seconds=delay) await self.write_channel.default_exchange.publish( rmq_message, - routing_key=self._delay_queue_name, + routing_key=self._delay_queue.routing_key or self._delay_queue.name, ) async def listen(self) -> AsyncGenerator[AckableMessage, None]: @@ -298,16 +415,49 @@ async def listen(self) -> AsyncGenerator[AckableMessage, None]: This function listens to queue and yields every new message. + :raises NoStartupError: if startup wasn't called. :yields: parsed broker message. - :raises ValueError: if startup wasn't called. """ if self.read_channel is None: - raise ValueError("Call startup before starting listening.") + raise NoStartupError("Call startup before starting listening.") await self.read_channel.set_qos(prefetch_count=self._qos) - queue = await self.declare_queues(self.read_channel) - async with queue.iterator() as iterator: - async for message in iterator: - yield AckableMessage( - data=message.body, - ack=message.ack, - ) + queue_with_consumer_args_list = await self._declare_queues(self.read_channel) + + async def body( + queue: AbstractQueue, + consumer_args: dict[str, Any], + ) -> AsyncGenerator[AckableMessage, None]: + try: + async with queue.iterator(**consumer_args) as iterator: + async for message in iterator: + yield AckableMessage( + data=message.body, + ack=message.ack, + ) + except (RuntimeError, asyncio.CancelledError): + # Suppress errors during iterator cleanup if channel is being closed + logger.info("Queue iterator closed during shutdown") + + combine = stream.merge( + *[ + body(queue, consumer_args) + for queue, consumer_args in queue_with_consumer_args_list + if queue.name != self._delay_queue.name + ], + ) + + async with combine.stream() as streamer: + async for message in streamer: + yield message + + async def shutdown(self) -> None: + """Close all connections on shutdown.""" + await super().shutdown() + if self.write_channel: + await self.write_channel.close() + if self.read_channel: + await self.read_channel.close() + if self.write_conn: + await self.write_conn.close() + if self.read_conn: + await self.read_conn.close() diff --git a/taskiq_aio_pika/exceptions.py b/taskiq_aio_pika/exceptions.py new file mode 100644 index 0000000..c804a08 --- /dev/null +++ b/taskiq_aio_pika/exceptions.py @@ -0,0 +1,18 @@ +class BaseAioPikaBrokerError(Exception): + """Base exception for AioPika broker errors.""" + + +class NoStartupError(BaseAioPikaBrokerError): + """Raised when an operation is attempted before the broker has started.""" + + +class IncorrectRoutingKeyError(BaseAioPikaBrokerError): + """Raised when a message is received with an incorrect routing key.""" + + +class QueueNotDeclaredError(BaseAioPikaBrokerError): + """Raised when attempting to use a queue that has not been declared.""" + + +class ExchangeNotDeclaredError(BaseAioPikaBrokerError): + """Raised when attempting to use an exchange that has not been declared.""" diff --git a/taskiq_aio_pika/exchange.py b/taskiq_aio_pika/exchange.py new file mode 100644 index 0000000..b2f69ca --- /dev/null +++ b/taskiq_aio_pika/exchange.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass, field + +from aio_pika import ExchangeType +from pamqp.common import FieldTable + + +@dataclass(frozen=True) +class Exchange: + """ + Represents a RabbitMQ exchange configuration. + + Attributes: + name: The name of the exchange. + type: The type of the exchange (topic, direct, fanout, headers). + internal: Whether the exchange is internal. + passive: Whether to check if the exchange exists without creating it. + auto_delete: Whether the exchange should be auto-deleted. + declare: Whether to declare the exchange on startup. + durable: Whether the exchange should survive broker restarts. + arguments: Additional arguments for the exchange declaration. + timeout: Timeout for exchange declaration. + """ + + name: str = "taskiq" + type: ExchangeType = ExchangeType.TOPIC + internal: bool = False + passive: bool = False + auto_delete: bool = False + declare: bool = True + durable: bool = True + arguments: FieldTable = field(default_factory=dict) + timeout: int | float | None = None diff --git a/taskiq_aio_pika/queue.py b/taskiq_aio_pika/queue.py new file mode 100644 index 0000000..e3beafc --- /dev/null +++ b/taskiq_aio_pika/queue.py @@ -0,0 +1,58 @@ +import enum +from dataclasses import dataclass, field + +from pamqp.common import FieldTable + + +class QueueType(str, enum.Enum): + """Enum representing different types of RabbitMQ queues.""" + + QUORUM = "quorum" + CLASSIC = "classic" + STREAM = "stream" + + +@dataclass(frozen=True) +class Queue: + """ + Represents a RabbitMQ queue configuration. + + Attributes: + name: The name of the queue. + type: The type of the queue (quorum, classic, stream). + declare: Whether to declare the queue on startup. + durable: Whether the queue should survive broker restarts. + exclusive: Whether the queue is exclusive to the connection. + passive: Whether to check if the queue exists without creating it. + auto_delete: Whether the queue should be auto-deleted. + max_priority: The maximum priority for the queue. + arguments: Additional arguments for the queue declaration. + timeout: Timeout for queue declaration. + routing_key: The routing key for the queue. + bind_arguments: Arguments for binding the queue. + bind_timeout: Timeout for binding the queue. + consumer_arguments: Arguments for the consumer. + """ + + declare: bool = True + + # will be passed as arguments + type: QueueType = QueueType.QUORUM + + # from declare_queue arguments + name: str = "taskiq" + durable: bool = True + exclusive: bool = False + passive: bool = False + auto_delete: bool = False + max_priority: int | None = None + arguments: FieldTable = field(default_factory=dict) + timeout: int | float | None = None + + # will be used during binding to tasks exchange + routing_key: str | None = None + bind_arguments: FieldTable = field(default_factory=dict) + bind_timeout: int | float | None = None + + # will be used during message consumption + consumer_arguments: FieldTable = field(default_factory=dict) diff --git a/tests/conftest.py b/tests/conftest.py index 78762e0..e164b75 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,8 @@ from aio_pika.abc import AbstractChannel, AbstractRobustConnection from taskiq_aio_pika.broker import AioPikaBroker +from taskiq_aio_pika.exchange import Exchange +from taskiq_aio_pika.queue import Queue, QueueType @pytest.fixture(scope="session") @@ -40,17 +42,7 @@ def queue_name() -> str: :return: random queue name. """ - return uuid4().hex - - -@pytest.fixture -def routing_key() -> str: - """ - Generated routing key. - - :return: random routing key. - """ - return uuid4().hex + return uuid4().hex + "_queue" @pytest.fixture @@ -60,7 +52,7 @@ def delay_queue_name() -> str: :return: random exchange name. """ - return uuid4().hex + return uuid4().hex + "_delay_queue" @pytest.fixture @@ -70,7 +62,7 @@ def dead_queue_name() -> str: :return: random exchange name. """ - return uuid4().hex + return uuid4().hex + "_dlx_queue" @pytest.fixture @@ -80,7 +72,7 @@ def exchange_name() -> str: :return: random exchange name. """ - return uuid4().hex + return uuid4().hex + "_exchange" @pytest.fixture @@ -154,28 +146,29 @@ async def broker( exchange_name: str, test_channel: Channel, ) -> AsyncGenerator[AioPikaBroker, None]: - """ - Yields new broker instance. - - This function is used to - create broker, run startup, - and shutdown after test. - - :param amqp_url: current rabbitmq connection string. - :param test_channel: amqp channel for tests. - :param queue_name: test queue name. - :param delay_queue_name: test delay queue name. - :param dead_queue_name: test dead letter queue name. - :param exchange_name: test exchange name. - :yield: broker. - """ broker = AioPikaBroker( url=amqp_url, - declare_exchange=True, - exchange_name=exchange_name, - dead_letter_queue_name=dead_queue_name, - delay_queue_name=delay_queue_name, - queue_name=queue_name, + exchange=Exchange( + name=exchange_name, + declare=True, + ), + dead_letter_queue=Queue( + name=dead_queue_name, + declare=True, + type=QueueType.CLASSIC, + ), + delay_queue=Queue( + name=delay_queue_name, + declare=True, + type=QueueType.CLASSIC, + ), + task_queues=[ + Queue( + name=queue_name, + declare=True, + type=QueueType.CLASSIC, + ), + ], ) broker.is_worker_process = True @@ -196,36 +189,29 @@ async def broker( async def broker_with_delayed_message_plugin( amqp_url: str, queue_name: str, - delay_queue_name: str, dead_queue_name: str, exchange_name: str, - routing_key: str, test_channel: Channel, ) -> AsyncGenerator[AioPikaBroker, None]: - """ - Yields new broker instance. - - This function is used to - create broker, run startup, - and shutdown after test. - - :param amqp_url: current rabbitmq connection string. - :param test_channel: amqp channel for tests. - :param queue_name: test queue name. - :param delay_queue_name: test delay queue name. - :param dead_queue_name: test dead letter queue name. - :param exchange_name: test exchange name. - :param routing_key: routing_key. - :yield: broker. - """ broker = AioPikaBroker( url=amqp_url, - declare_exchange=True, - exchange_name=exchange_name, - dead_letter_queue_name=dead_queue_name, - queue_name=queue_name, + exchange=Exchange( + name=exchange_name, + declare=True, + ), + dead_letter_queue=Queue( + name=dead_queue_name, + declare=True, + type=QueueType.CLASSIC, + ), + task_queues=[ + Queue( + name=queue_name, + declare=True, + type=QueueType.CLASSIC, + ), + ], delayed_message_exchange_plugin=True, - routing_key=routing_key, ) broker.is_worker_process = True @@ -237,6 +223,6 @@ async def broker_with_delayed_message_plugin( await _cleanup_amqp_resources( amqp_url, - [exchange_name, broker._delay_plugin_exchange_name], - [queue_name, delay_queue_name, dead_queue_name], + [exchange_name, broker._delayed_message_exchange.name], + [queue_name, dead_queue_name], ) diff --git a/tests/test_broker.py b/tests/test_broker.py deleted file mode 100644 index ff79ddc..0000000 --- a/tests/test_broker.py +++ /dev/null @@ -1,262 +0,0 @@ -import asyncio -import uuid - -import pytest -from aio_pika import Channel, ExchangeType, Message -from aio_pika.exceptions import QueueEmpty -from taskiq import AckableMessage, BrokerMessage -from taskiq.utils import maybe_awaitable - -from taskiq_aio_pika.broker import AioPikaBroker - - -async def get_first_task(broker: AioPikaBroker) -> AckableMessage: - """ - Get first message from the queue. - - :param broker: async message broker. - :return: first message from listen method - """ - async for message in broker.listen(): - return message - return None # type: ignore - - -async def test_kick_success(broker: AioPikaBroker) -> None: - """ - Test that messages are published and read correctly. - - We kick the message and then try to listen to the queue, - and check that message we got is the same as we sent. - - :param broker: current broker. - """ - task_id = uuid.uuid4().hex - task_name = uuid.uuid4().hex - - sent = BrokerMessage( - task_id=task_id, - task_name=task_name, - message=b"my_msg", - labels={ - "label1": "val1", - }, - ) - - await broker.kick(sent) - - message = await asyncio.wait_for(get_first_task(broker), timeout=0.4) - - assert message.data == sent.message - await maybe_awaitable(message.ack()) - - -async def test_startup( - broker: AioPikaBroker, - test_channel: Channel, - queue_name: str, - exchange_name: str, -) -> None: - """ - Test startup event. - - In this test we delete the exchange and the queue, - call startup method, and ensure that queue and exchange - exist. - - :param broker: current broker. - :param test_channel: test channel. - :param queue_name: name of the queue. - :param exchange_name: name of the test exchange. - """ - queue = await test_channel.get_queue(queue_name) - exchange = await test_channel.get_exchange(exchange_name) - await queue.delete() - await exchange.delete() - broker._declare_exchange = True - - await broker.startup() - - await test_channel.get_queue(queue_name, ensure=True) - await test_channel.get_exchange(exchange_name, ensure=True) - - -async def test_listen( - broker: AioPikaBroker, - test_channel: Channel, - exchange_name: str, -) -> None: - """ - Test that message are read correctly. - - Tests that broker listens to the queue - correctly and listen can be iterated. - - :param broker: current broker. - :param test_channel: amqp channel. - :param exchange_name: main exchange name. - """ - exchange = await test_channel.get_exchange(exchange_name) - await exchange.publish( - Message( - b"test_message", - headers={ - "task_id": "test_id", - "task_name": "task_name", - "label1": "label_val", - }, - ), - routing_key="task_name", - ) - - message = await asyncio.wait_for(get_first_task(broker), timeout=0.4) - - assert message.data == b"test_message" - await maybe_awaitable(message.ack()) - - -async def test_wrong_format( - broker: AioPikaBroker, - queue_name: str, - test_channel: Channel, -) -> None: - """ - Tests that messages with wrong format are still received. - - :param broker: aio-pika broker. - :param queue_name: test queue name. - :param test_channel: test channel. - """ - queue = await test_channel.get_queue(queue_name) - await test_channel.default_exchange.publish( - Message(b"wrong"), - routing_key=queue_name, - ) - - message = await asyncio.wait_for(get_first_task(broker), 0.4) - - assert message.data == b"wrong" - await maybe_awaitable(message.ack()) - - with pytest.raises(QueueEmpty): - await queue.get() - - -async def test_delayed_message( - broker: AioPikaBroker, - test_channel: Channel, - queue_name: str, - delay_queue_name: str, -) -> None: - """ - Test that delayed messages are delivered correctly. - - This test send message with delay label, - checks that this message appears in delay queue. - After that it waits specified delay period and - checks that message was transferred to the main queue. - - :param broker: current broker. - :param test_channel: amqp channel for tests. - :param queue_name: test queue name. - :param delay_queue_name: name of the test queue for delayed messages. - """ - delay_queue = await test_channel.get_queue(delay_queue_name) - main_queue = await test_channel.get_queue(queue_name) - broker_msg = BrokerMessage( - task_id="1", - task_name="name", - message=b"message", - labels={"delay": "2"}, - ) - await broker.kick(broker_msg) - - # We check that message appears in delay queue. - delay_msg = await delay_queue.get() - await delay_msg.nack(requeue=True) - - # After we wait the delay message must appear in - # the main queue. - await asyncio.sleep(2) - - # Check that it disappear. - with pytest.raises(QueueEmpty): - await delay_queue.get(no_ack=True) - - # Check that we can get the message. - await main_queue.get() - - -async def test_delayed_message_with_plugin( - broker_with_delayed_message_plugin: AioPikaBroker, - test_channel: Channel, - queue_name: str, -) -> None: - """Test that we can send delayed messages with plugin. - - :param broker_with_delayed_message_plugin: broker with - turned on plugin integration. - :param test_channel: amqp channel for tests. - :param queue_name: test queue name. - """ - main_queue = await test_channel.get_queue(queue_name) - broker_msg = BrokerMessage( - task_id="1", - task_name="name", - message=b"message", - labels={"delay": "2"}, - ) - - await broker_with_delayed_message_plugin.kick(broker_msg) - with pytest.raises(QueueEmpty): - await main_queue.get(no_ack=True) - - await asyncio.sleep(2) - - assert await main_queue.get() - - -async def test_direct_kick( - broker: AioPikaBroker, - test_channel: Channel, - queue_name: str, - exchange_name: str, -) -> None: - """ - Test that messages are published and read correctly. - - We kick the message and then try to listen to the queue, - and check that message we got is the same as we sent. - """ - queue = await test_channel.get_queue(queue_name) - exchange = await test_channel.get_exchange(exchange_name) - await queue.delete() - await exchange.delete() - - broker._declare_exchange = True - broker._exchange_type = ExchangeType.DIRECT - broker._routing_key = "direct_routing_key" - - await broker.startup() - - await test_channel.get_queue(queue_name, ensure=True) - await test_channel.get_exchange(exchange_name, ensure=True) - - task_id = uuid.uuid4().hex - task_name = uuid.uuid4().hex - - sent = BrokerMessage( - task_id=task_id, - task_name=task_name, - message=b"my_msg", - labels={ - "label1": "val1", - }, - ) - - await broker.kick(sent) - - message = await asyncio.wait_for(get_first_task(broker), timeout=0.4) - - assert message.data == sent.message - await maybe_awaitable(message.ack()) diff --git a/tests/test_delay.py b/tests/test_delay.py new file mode 100644 index 0000000..2c8b62f --- /dev/null +++ b/tests/test_delay.py @@ -0,0 +1,40 @@ +import asyncio + +import pytest +from aio_pika import Channel +from aio_pika.exceptions import QueueEmpty +from taskiq import BrokerMessage + +from taskiq_aio_pika import AioPikaBroker + + +async def test_when_delayed_message_queue_exists__then_send_with_delay_must_work( + broker: AioPikaBroker, + test_channel: Channel, + queue_name: str, + delay_queue_name: str, +) -> None: + delay_queue = await test_channel.get_queue(delay_queue_name) + main_queue = await test_channel.get_queue(queue_name) + broker_msg = BrokerMessage( + task_id="1", + task_name="name", + message=b"message", + labels={"delay": "2"}, + ) + await broker.kick(broker_msg) + + # We check that message appears in delay queue. + delay_msg = await delay_queue.get() + await delay_msg.nack(requeue=True) + + # After we wait the delay message must appear in + # the main queue. + await asyncio.sleep(2) + + # Check that it disappear. + with pytest.raises(QueueEmpty): + await delay_queue.get(no_ack=True) + + # Check that we can get the message. + await main_queue.get() diff --git a/tests/test_delay_with_plugin.py b/tests/test_delay_with_plugin.py new file mode 100644 index 0000000..9206450 --- /dev/null +++ b/tests/test_delay_with_plugin.py @@ -0,0 +1,30 @@ +import asyncio + +import pytest +from aio_pika import Channel +from aio_pika.exceptions import QueueEmpty +from taskiq import BrokerMessage + +from taskiq_aio_pika import AioPikaBroker + + +async def test_when_delayed_message_plugin_enabled__then_send_with_delay_must_work( + broker_with_delayed_message_plugin: AioPikaBroker, + test_channel: Channel, + queue_name: str, +) -> None: + # given + main_queue = await test_channel.get_queue(queue_name) + broker_msg = BrokerMessage( + task_id="1", + task_name="name", + message=b"message", + labels={"delay": "2"}, + ) + + # when & then + await broker_with_delayed_message_plugin.kick(broker_msg) + with pytest.raises(QueueEmpty): + await main_queue.get(no_ack=True) + await asyncio.sleep(2) + assert await main_queue.get() diff --git a/tests/test_routing.py b/tests/test_routing.py new file mode 100644 index 0000000..0bbff9e --- /dev/null +++ b/tests/test_routing.py @@ -0,0 +1,231 @@ +import asyncio +import uuid +from collections.abc import AsyncGenerator + +import aio_pika +import pytest +from aio_pika import Channel, Message +from aio_pika.abc import ExchangeType +from taskiq import BrokerMessage +from taskiq.utils import maybe_awaitable + +from taskiq_aio_pika import AioPikaBroker, Queue +from taskiq_aio_pika.exchange import Exchange +from tests.conftest import _cleanup_amqp_resources +from tests.utils import get_first_task + + +class TestRouting: + broker: AioPikaBroker | None = None + + @pytest.fixture(autouse=True) + async def cleanup_class_broker(self, amqp_url: str) -> AsyncGenerator[None, None]: + yield + if self.broker is not None: + await self.broker.shutdown() + await _cleanup_amqp_resources( + amqp_url, + [self.broker._exchange.name], + [queue.name for queue in self.broker._task_queues] + + [ + self.broker._dead_letter_queue.name, + self.broker._delay_queue.name, + ], + ) + + async def test_when_message_has_wrong_format__then_message_still_can_be_received( + self, + broker: AioPikaBroker, + queue_name: str, + test_channel: Channel, + ) -> None: + # given & when + await test_channel.default_exchange.publish( + Message(b"wrong"), + routing_key=queue_name, + ) + + # then + message = await asyncio.wait_for(get_first_task(broker), 0.4) + assert message.data == b"wrong" + + async def test_when_broker_has_only_default_settings__then_task_can_be_passed( + self, + amqp_url: str, + ) -> None: + # given + self.broker = AioPikaBroker( + url=amqp_url, + ) + self.broker.is_worker_process = True + await self.broker.startup() + task_id = uuid.uuid4().hex + message = BrokerMessage( + task_id=task_id, + task_name="task_name", + message=b"my_msg", + labels={ + "label1": "val1", + }, + ) + + # when + await self.broker.kick(message) + + # then + received_message = await asyncio.wait_for( + get_first_task(self.broker), + timeout=0.4, + ) + assert received_message.data == message.message + await maybe_awaitable(received_message.ack()) + + async def test_when_broker_has_two_queues_and_default_exchange__then_task_should_be_put_only_to_right_queue( + self, + amqp_url: str, + test_channel: Channel, + ) -> None: + # given + self.broker = AioPikaBroker( + url=amqp_url, + task_queues=[ + Queue( + name="queue1", + declare=True, + ), + Queue( + name="queue2", + declare=True, + ), + ], + ) + self.broker.is_worker_process = True + await self.broker.startup() + + queue_1 = await test_channel.get_queue("queue1") + queue_2 = await test_channel.get_queue("queue2") + + message_to_queue_1 = BrokerMessage( + task_id=uuid.uuid4().hex, + task_name="task_name", + message=b"my_msg", + labels={ + "queue_name": "queue1", + }, + ) + + # when + await self.broker.kick(message_to_queue_1) + + # then + received_message = await queue_1.get() + assert received_message is not None + await received_message.nack(requeue=True) + + with pytest.raises(aio_pika.exceptions.QueueEmpty): + received_message = await queue_2.get() + assert received_message is None + + async def test_when_queue_bind_with_pattern_and_exchange_type_topic__when_message_published_in_right_queue( + self, + amqp_url: str, + test_channel: Channel, + ) -> None: + # given + self.broker = AioPikaBroker( + url=amqp_url, + exchange=Exchange( + name="test_topic_exchange", + type=ExchangeType.TOPIC, + declare=True, + ), + task_queues=[ + Queue( + name="service.task.queue1", + declare=True, + routing_key="service.task.*", + ), + Queue( + name="service.task.queue2", + declare=True, + routing_key="service.task.only_specific_task", + ), + ], + ) + self.broker.is_worker_process = True + await self.broker.startup() + + queue_1 = await test_channel.get_queue("service.task.queue1") + queue_2 = await test_channel.get_queue("service.task.queue2") + + message_with_specific_task = BrokerMessage( + task_id=uuid.uuid4().hex, + task_name="task_name", + message=b"my_msg", + labels={"queue_name": "service.task.only_specific_task"}, + ) + + # when + await self.broker.kick(message_with_specific_task) + + # then + received_message_1 = await queue_1.get() + assert ( + received_message_1 is not None + ), "Message was not routed to queue, but should be by pattern" + await received_message_1.ack() + + received_message_2 = await queue_2.get() + assert ( + received_message_2 is not None + ), "Message was not routed to queue, but should be by specific name" + await received_message_2.ack() + + async def test_when_exchange_fanout__when_message_published_in_all_queues( + self, + amqp_url: str, + test_channel: Channel, + ) -> None: + # given + self.broker = AioPikaBroker( + url=amqp_url, + exchange=Exchange( + name="test_topic_exchange", + type=ExchangeType.FANOUT, + declare=True, + ), + task_queues=[ + Queue( + name="service.task.queue1", + declare=True, + ), + Queue( + name="service.task.queue2", + declare=True, + ), + ], + ) + self.broker.is_worker_process = True + await self.broker.startup() + + queue_1 = await test_channel.get_queue("service.task.queue1") + queue_2 = await test_channel.get_queue("service.task.queue2") + + message_for_all_queues = BrokerMessage( + task_id=uuid.uuid4().hex, + task_name="task_name", + message=b"my_msg", + labels={"queue_name": "service.task.only_specific_task"}, + ) + + # when + await self.broker.kick(message_for_all_queues) + + # then + received_message_1 = await queue_1.get() + assert received_message_1 is not None + await received_message_1.ack() + + received_message_2 = await queue_2.get() + assert received_message_2 is not None + await received_message_2.ack() diff --git a/tests/test_startup.py b/tests/test_startup.py new file mode 100644 index 0000000..d6d937d --- /dev/null +++ b/tests/test_startup.py @@ -0,0 +1,229 @@ +import uuid +from collections.abc import AsyncGenerator + +import aiormq +import pytest +from aio_pika import Channel + +from taskiq_aio_pika import AioPikaBroker +from taskiq_aio_pika.exceptions import ExchangeNotDeclaredError, QueueNotDeclaredError +from taskiq_aio_pika.exchange import Exchange +from taskiq_aio_pika.queue import Queue, QueueType +from tests.conftest import _cleanup_amqp_resources + + +class TestStartup: + broker: AioPikaBroker | None = None + + @pytest.fixture(autouse=True) + async def cleanup_class_broker(self, amqp_url: str) -> AsyncGenerator[None, None]: + yield + if self.broker is not None: + await self.broker.shutdown() + await _cleanup_amqp_resources( + amqp_url, + [self.broker._exchange.name], + [queue.name for queue in self.broker._task_queues] + + [ + self.broker._dead_letter_queue.name, + self.broker._delay_queue.name, + ], + ) + + async def test_when_declare_flag_passed_to_queue__broker_declare_queue_on_startup( + self, + amqp_url: str, + test_channel: Channel, + exchange_name: str, + ) -> None: + # given + self.broker = AioPikaBroker( + url=amqp_url, + exchange=Exchange( + name=exchange_name, + declare=True, + durable=False, + ), + ) + + # when + self.broker.with_queue( + Queue( + name="declared_queue", + declare=True, + ), + ) + await self.broker.startup() + + # then + queue = await test_channel.get_queue("declared_queue", ensure=True) + assert queue.name == "declared_queue" + assert not queue.durable + assert not queue.exclusive + assert not queue.auto_delete + assert queue.passive + assert queue.arguments is None + + async def test_when_declare_flag_not_passed_to_queue__broker_does_not_declare_queue_on_startup( + self, + amqp_url: str, + test_channel: Channel, + exchange_name: str, + delay_queue_name: str, + ) -> None: + # given + self.broker = AioPikaBroker( + url=amqp_url, + exchange=Exchange( + name=exchange_name, + declare=True, + durable=False, + ), + delay_queue=Queue( + name=delay_queue_name, + type=QueueType.CLASSIC, + declare=True, + durable=False, + ), + ) + not_declared_queue_name = "not_declared_queue" + uuid.uuid4().hex + + # when & then + self.broker.with_queues( + Queue( + name=not_declared_queue_name, + declare=False, + ), + ) + with pytest.raises( + QueueNotDeclaredError, + match=f"Queue '{not_declared_queue_name}' was not declared and does not exist.", + ): + await self.broker.startup() + + async def test_when_exchange_declare_flag_false__broker_does_not_declare_exchange_on_startup( + self, + amqp_url: str, + exchange_name: str, + ) -> None: + # given + self.broker = AioPikaBroker( + url=amqp_url, + exchange=Exchange( + name=exchange_name, + declare=False, + ), + ) + + # when & then + with pytest.raises( + ExchangeNotDeclaredError, + match=f"Exchange '{exchange_name}' was not declared and does not exist.", + ): + await self.broker.startup() + + async def test_when_exchange_declare_flag_true__broker_declares_exchange_on_startup( + self, + amqp_url: str, + test_channel: Channel, + exchange_name: str, + ) -> None: + # given + self.broker = AioPikaBroker( + url=amqp_url, + exchange=Exchange( + name=exchange_name, + declare=True, + ), + ) + + # when + await self.broker.startup() + + # then + exchange = await test_channel.get_exchange(exchange_name, ensure=True) + assert ( + exchange.name == exchange_name + ) # should be more checks for arguments here + + async def test_when_delayed_message_exchange_plugin_enabled__broker_declares_exchange_on_startup( + self, + amqp_url: str, + test_channel: Channel, + exchange_name: str, + ) -> None: + # given + self.broker = AioPikaBroker( + url=amqp_url, + exchange=Exchange( + name=exchange_name, + declare=True, + ), + delayed_message_exchange_plugin=True, + ) + + # when + await self.broker.startup() + + # then + exchange = await test_channel.get_exchange( + f"{exchange_name}.plugin_delay", + ensure=True, + ) + assert ( + exchange.name == f"{exchange_name}.plugin_delay" + ) # should be more checks for arguments here + + async def test_when_delayed_message_exchange_plugin_disabled__broker_does_not_declare_exchange_on_startup( + self, + amqp_url: str, + exchange_name: str, + test_channel: Channel, + ) -> None: + # given + self.broker = AioPikaBroker( + url=amqp_url, + exchange=Exchange( + name=exchange_name, + declare=True, + ), + delayed_message_exchange_plugin=False, + ) + + # when + await self.broker.startup() + + # then + with pytest.raises( + aiormq.exceptions.ChannelNotFoundEntity, + ): + await test_channel.get_exchange( + f"{exchange_name}.plugin_delay", + ensure=True, + ) + + async def test_when_delayed_message_exchange_plugin_enabled_and_custom_exchange_not_declared__broker_raise_error( + self, + amqp_url: str, + exchange_name: str, + ) -> None: + # given + delayed_message_exchange_name = "custom_delay_exchange" + uuid.uuid4().hex + self.broker = AioPikaBroker( + url=amqp_url, + exchange=Exchange( + name=exchange_name, + declare=True, + ), + delayed_message_exchange_plugin=True, + delayed_message_exchange=Exchange( + name=delayed_message_exchange_name, + declare=False, + ), + ) + # when & then + with pytest.raises( + ExchangeNotDeclaredError, + match=f"Exchange '{delayed_message_exchange_name}' was not declared and does not exist.", + ): + await self.broker.startup() diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..3bbd64d --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,15 @@ +from taskiq import AckableMessage + +from taskiq_aio_pika.broker import AioPikaBroker + + +async def get_first_task(broker: AioPikaBroker) -> AckableMessage: + """ + Get first message from the queue. + + :param broker: async message broker. + :return: first message from listen method + """ + async for message in broker.listen(): + return message + return None # type: ignore diff --git a/uv.lock b/uv.lock index ce95e8a..30218d4 100644 --- a/uv.lock +++ b/uv.lock @@ -171,6 +171,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, ] +[[package]] +name = "aiostream" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8b/65/b9b69695702b76a878c9879f2ee80cefce75bc5cb864fc100460bc1c5380/aiostream-0.7.1.tar.gz", hash = "sha256:272aaa0d8f83beb906f5aa9022bb59046bb7a103fa3770f807c31f918595acf6", size = 44059, upload-time = "2025-10-13T20:02:06.961Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/a0/d7c6ca304140f3f49987d710e15bc164248924a35d8cdfac2f6e87fca041/aiostream-0.7.1-py3-none-any.whl", hash = "sha256:ea8739e9158ee6a606b3feedf3762721c3507344e540d09a10984c5e88a13b37", size = 41416, upload-time = "2025-10-13T20:02:05.535Z" }, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -564,18 +576,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" }, ] -[[package]] -name = "importlib-metadata" -version = "8.7.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "zipp" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/76/66/650a33bd90f786193e4de4b3ad86ea60b53c89b669a5c7be931fac31cdb0/importlib_metadata-8.7.0.tar.gz", hash = "sha256:d13b81ad223b890aa16c5471f2ac3056cf76c5f10f82d6f9292f0b415f389000", size = 56641, upload-time = "2025-04-27T15:29:01.736Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/20/b0/36bd937216ec521246249be3bf9855081de4c5e06a0c9b4219dbeda50373/importlib_metadata-8.7.0-py3-none-any.whl", hash = "sha256:e5dd1551894c77868a30651cef00984d50e1002d06942a7101d34870c5f02afd", size = 27656, upload-time = "2025-04-27T15:29:00.214Z" }, -] - [[package]] name = "iniconfig" version = "2.3.0" @@ -1206,15 +1206,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/84/25/d9db8be44e205a124f6c98bc0324b2bb149b7431c53877fc6d1038dddaf5/pytokens-0.3.0-py3-none-any.whl", hash = "sha256:95b2b5eaf832e469d141a378872480ede3f251a5a5041b8ec6e581d3ac71bbf3", size = 12195, upload-time = "2025-11-05T13:36:33.183Z" }, ] -[[package]] -name = "pytz" -version = "2025.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f8/bf/abbd3cdfb8fbc7fb3d4d38d320f2441b1e7cbe29be4f23797b4a2b5d8aac/pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3", size = 320884, upload-time = "2025-03-25T02:25:00.538Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00", size = 509225, upload-time = "2025-03-25T02:24:58.468Z" }, -] - [[package]] name = "pyyaml" version = "6.0.3" @@ -1279,6 +1270,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" }, ] +[[package]] +name = "redis" +version = "6.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "async-timeout", marker = "python_full_version < '3.11.3'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0d/d6/e8b92798a5bd67d659d51a18170e91c16ac3b59738d91894651ee255ed49/redis-6.4.0.tar.gz", hash = "sha256:b01bc7282b8444e28ec36b261df5375183bb47a07eb9c603f284e89cbc5ef010", size = 4647399, upload-time = "2025-08-07T08:10:11.441Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e8/02/89e2ed7e85db6c93dfa9e8f691c5087df4e3551ab39081a4d7c6d1f90e05/redis-6.4.0-py3-none-any.whl", hash = "sha256:f0544fa9604264e9464cdf4814e7d4830f74b165d52f2a330a760a88dd248b7f", size = 279847, upload-time = "2025-08-07T08:10:09.84Z" }, +] + [[package]] name = "ruff" version = "0.14.5" @@ -1316,23 +1319,21 @@ wheels = [ [[package]] name = "taskiq" -version = "0.11.20" +version = "0.12.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, { name = "anyio" }, - { name = "importlib-metadata" }, { name = "izulu" }, { name = "packaging" }, { name = "pycron" }, { name = "pydantic" }, - { name = "pytz" }, { name = "taskiq-dependencies" }, - { name = "typing-extensions" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2c/80/58b1aee0934d88b5558bfe013cc6b34f684985ec44b6952a294e96048876/taskiq-0.11.20.tar.gz", hash = "sha256:598fa9b03eafd4cc9521158917382f184ba613e5c71544b9b60934c1cc055f51", size = 55964, upload-time = "2025-10-30T13:15:15.38Z" } +sdist = { url = "https://files.pythonhosted.org/packages/67/9b/bb9b3ab5024051e80013170950bf5acb7729636918bab0ceb91a3900c815/taskiq-0.12.0.tar.gz", hash = "sha256:722d64b8176affb146635c7ac356d3e44efa446a3dc7a373694c0eb8852672b6", size = 60099, upload-time = "2025-11-26T18:38:54.618Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f8/70/b5f91b22e7b6ca8b467f7a933fd7e1d4b5016186720729bb6a40bfa2251a/taskiq-0.11.20-py3-none-any.whl", hash = "sha256:29521bbd580af1b98052d4c24a33f6d4a6060bba73dcae4ce9c69c63b937cf87", size = 81873, upload-time = "2025-10-30T13:15:14.311Z" }, + { url = "https://files.pythonhosted.org/packages/37/88/e0bb05fcca198d313a50c5c461711a6f36d3a8d29010b65960796bdb03cd/taskiq-0.12.0-py3-none-any.whl", hash = "sha256:8fea577bbf72ceabd77338f643510c64787521f43e7907e605ebeead660c1a74", size = 90388, upload-time = "2025-11-26T18:38:53.4Z" }, ] [[package]] @@ -1341,6 +1342,7 @@ version = "0.0.0" source = { editable = "." } dependencies = [ { name = "aio-pika" }, + { name = "aiostream" }, { name = "taskiq" }, ] @@ -1355,12 +1357,17 @@ dev = [ { name = "pytest-cov" }, { name = "pytest-xdist", extra = ["psutil"] }, { name = "ruff" }, + { name = "taskiq-redis" }, +] +examples = [ + { name = "taskiq-redis" }, ] [package.metadata] requires-dist = [ { name = "aio-pika", specifier = ">=9.0.0" }, - { name = "taskiq", specifier = ">=0.11.20,<1" }, + { name = "aiostream", specifier = ">=0.7.1" }, + { name = "taskiq", specifier = ">=0.12.0,<1" }, ] [package.metadata.requires-dev] @@ -1374,7 +1381,9 @@ dev = [ { name = "pytest-cov", specifier = ">=7.0.0" }, { name = "pytest-xdist", extras = ["psutil"], specifier = ">=3.8.0" }, { name = "ruff", specifier = ">=0.14.5" }, + { name = "taskiq-redis", specifier = ">=1.1.2" }, ] +examples = [{ name = "taskiq-redis", specifier = ">=1.1.2" }] [[package]] name = "taskiq-dependencies" @@ -1385,6 +1394,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/99/6d/4a012f2de002c2e93273f5e7d3e3feea02f7fdbb7b75ca2ca1dd10703091/taskiq_dependencies-1.5.7-py3-none-any.whl", hash = "sha256:6fcee5d159bdb035ef915d4d848826169b6f06fe57cc2297a39b62ea3e76036f", size = 13801, upload-time = "2025-02-26T22:07:38.622Z" }, ] +[[package]] +name = "taskiq-redis" +version = "1.1.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "redis" }, + { name = "taskiq" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d2/0b/5006792cbdf6e78abaab45407968d96e2a1ffd0238cd41b13a62b33880cf/taskiq_redis-1.1.2.tar.gz", hash = "sha256:f878a047abc1a0fa0ddda6e4b063d742f030858d5a4409de96261a00a992c9c9", size = 16225, upload-time = "2025-10-07T10:46:13.666Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/20/9bf19b05eb7cb275a04871f3c38bf088c108579c4c19ea2a1358d2a81334/taskiq_redis-1.1.2-py3-none-any.whl", hash = "sha256:3b69468ff99da33243314c84acfa137a4a014abad10dd44e62195bc5e4ae9947", size = 20424, upload-time = "2025-10-07T10:46:12.779Z" }, +] + [[package]] name = "tomli" version = "2.3.0" @@ -1595,12 +1617,3 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/b7/503c98092fb3b344a179579f55814b613c1fbb1c23b3ec14a7b008a66a6e/yarl-1.22.0-cp314-cp314t-win_arm64.whl", hash = "sha256:9f6d73c1436b934e3f01df1e1b21ff765cd1d28c77dfb9ace207f746d4610ee1", size = 85171, upload-time = "2025-10-06T14:12:16.935Z" }, { url = "https://files.pythonhosted.org/packages/73/ae/b48f95715333080afb75a4504487cbe142cae1268afc482d06692d605ae6/yarl-1.22.0-py3-none-any.whl", hash = "sha256:1380560bdba02b6b6c90de54133c81c9f2a453dee9912fe58c1dcced1edb7cff", size = 46814, upload-time = "2025-10-06T14:12:53.872Z" }, ] - -[[package]] -name = "zipp" -version = "3.23.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e3/02/0f2892c661036d50ede074e376733dca2ae7c6eb617489437771209d4180/zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166", size = 25547, upload-time = "2025-06-08T17:06:39.4Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276, upload-time = "2025-06-08T17:06:38.034Z" }, -]