Skip to content

Commit 9f218d2

Browse files
committed
fix: use QpidRobustConnection only when using AMQPMixin.is_qpid
1 parent 7a63f06 commit 9f218d2

4 files changed

Lines changed: 54 additions & 35 deletions

File tree

icij-worker/icij_worker/event_publisher/amqp.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,27 @@
44
from contextlib import AsyncExitStack
55
from functools import cached_property
66

7-
from aio_pika import (
8-
Exchange as AioPikaExchange,
9-
RobustChannel,
10-
connect_robust,
11-
)
7+
from aio_pika import Exchange as AioPikaExchange, RobustChannel
128
from aio_pika.abc import AbstractRobustConnection
139

1410
from icij_common.logging_utils import LogWithNameMixin
1511
from icij_worker import ManagerEvent
1612
from . import EventPublisher
1713
from ..routing_strategy import Routing
18-
from ..utils.amqp import AMQPMixin, RobustConnection
14+
from ..utils.amqp import AMQPMixin
1915

2016

2117
class AMQPPublisher(AMQPMixin, EventPublisher, LogWithNameMixin):
2218
def __init__(
2319
self,
24-
logger: Optional[logging.Logger] = None,
20+
logger: logging.Logger | None = None,
2521
*,
2622
broker_url: str,
2723
connection_timeout_s: float = 1.0,
2824
reconnection_wait_s: float = 5.0,
2925
is_qpid: bool = False,
3026
app_id: str | None = None,
31-
connection: Optional[AbstractRobustConnection] = None,
27+
connection: AbstractRobustConnection | None = None,
3228
):
3329
super().__init__(
3430
broker_url,
@@ -42,8 +38,8 @@ def __init__(
4238
self._app_id = app_id
4339
self._broker_url = broker_url
4440
self._connection_ = connection
45-
self._channel_: Optional[RobustChannel] = None
46-
self._manager_evt_x: Optional[AioPikaExchange] = None
41+
self._channel_: RobustChannel | None = None
42+
self._manager_evt_x: AioPikaExchange | None = None
4743
self._connection_timeout_s = connection_timeout_s
4844
self._reconnection_wait_s = reconnection_wait_s
4945
self._exit_stack = AsyncExitStack()
@@ -73,13 +69,8 @@ async def _publish_event(self, event: ManagerEvent):
7369

7470
async def _connection_workflow(self):
7571
self.debug("creating connection...")
76-
if self._connection_ is None:
77-
self._connection_ = await connect_robust(
78-
self._broker_url,
79-
timeout=self._connection_timeout_s,
80-
reconnect_interval=self._reconnection_wait_s,
81-
connection_class=RobustConnection,
82-
)
72+
if self._connection is None:
73+
await self._connect()
8374
await self._exit_stack.enter_async_context(self._connection)
8475
self.debug("creating channel...")
8576
self._channel_ = await self._connection.channel(

icij-worker/icij_worker/task_manager/amqp.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from functools import cached_property
66
from typing import TypeVar, cast
77

8-
from aio_pika import connect_robust
98
from aio_pika.abc import AbstractExchange, AbstractQueueIterator
109
from aiormq import DeliveryError
1110
from pydantic import Field
@@ -32,7 +31,6 @@
3231
AMQPConfigMixin,
3332
AMQPManagementClient,
3433
AMQPMixin,
35-
RobustConnection,
3634
amqp_task_group_policy,
3735
health_policy,
3836
)
@@ -217,12 +215,8 @@ async def shutdown_workers(self):
217215
async def _connection_workflow(self):
218216
await self._exit_stack.enter_async_context(self._management_client)
219217
logger.debug("creating connection...")
220-
self._connection_ = await connect_robust(
221-
self._broker_url,
222-
timeout=self._connection_timeout_s,
223-
reconnect_interval=self._reconnection_wait_s,
224-
connection_class=RobustConnection,
225-
)
218+
if self._connection is None:
219+
await self._connect()
226220
await self._exit_stack.enter_async_context(self._connection)
227221
logger.debug("creating channel...")
228222
self._channel_ = await self._connection.channel(

icij-worker/icij_worker/utils/amqp.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
from aio_pika import (
1515
DeliveryMode,
1616
Message as AioPikaMessage,
17-
RobustChannel as RobustChannel_,
18-
RobustConnection as RobustConnection_,
17+
RobustChannel,
18+
RobustConnection,
19+
connect_robust,
1920
)
2021
from aio_pika.abc import (
2122
AbstractExchange,
@@ -117,8 +118,7 @@ def broker_url(self) -> str:
117118
if amqp_userinfo:
118119
amqp_userinfo += "@"
119120
amqp_authority = (
120-
f"{amqp_userinfo or ''}{self.rabbitmq_host}"
121-
f"{f':{self.rabbitmq_port}' or ''}"
121+
f"{amqp_userinfo or ''}{self.rabbitmq_host}{f':{self.rabbitmq_port}' or ''}"
122122
)
123123
amqp_uri = f"amqp://{amqp_authority}"
124124
if self.rabbitmq_vhost is not None:
@@ -216,7 +216,7 @@ def channel(self) -> AbstractRobustChannel:
216216
return self._channel
217217

218218
@property
219-
def connection(self) -> AbstractRobustChannel:
219+
def connection(self) -> AbstractRobustConnection:
220220
return self._connection
221221

222222
@classmethod
@@ -268,6 +268,15 @@ def health_routing(cls) -> Routing:
268268
queue_name=AMQP_HEALTH_QUEUE,
269269
)
270270

271+
async def _connect(self):
272+
connection_class = QpidRobustConnection if self._is_qpid else RobustConnection
273+
self._connection_ = await connect_robust(
274+
self._broker_url,
275+
timeout=self._connection_timeout_s,
276+
reconnect_interval=self._reconnection_wait_s,
277+
connection_class=connection_class,
278+
)
279+
271280
async def _get_queue_iterator(
272281
self,
273282
routing: Routing,
@@ -431,7 +440,7 @@ def health_policy(routing: Routing) -> AMQPPolicy:
431440
)
432441

433442

434-
class RobustChannel(RobustChannel_):
443+
class QpidRobustChannel(RobustChannel):
435444
async def __close_callback(self, _: Any, exc: BaseException) -> None:
436445
# pylint: disable=unused-private-member
437446
timeout_exc = parse_consumer_timeout(exc)
@@ -440,8 +449,8 @@ async def __close_callback(self, _: Any, exc: BaseException) -> None:
440449
raise timeout_exc from exc
441450

442451

443-
class RobustConnection(RobustConnection_):
444-
CHANNEL_CLASS: type[RobustChannel] = RobustChannel
452+
class QpidRobustConnection(RobustConnection):
453+
CHANNEL_CLASS: type[RobustChannel] = QpidRobustChannel
445454

446455
# Defined async context manager attributes to be able to enter and exit this
447456
# in ExitStack

icij-worker/tests/utils/test_amqp.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,17 @@
44

55
import pytest
66
from aio_pika import Message, connect_robust
7-
from aio_pika.abc import AbstractRobustQueue
7+
from aio_pika.abc import AbstractRobustConnection, AbstractRobustQueue
88

99
from icij_worker.task_storage.postgres.postgres import logger
1010
from icij_worker.utils.amqp import (
1111
AMQPManagementClient,
1212
AMQPMixin,
1313
AMQPPolicy,
1414
ApplyTo,
15-
parse_consumer_timeout,
15+
QpidRobustConnection,
1616
RobustConnection,
17+
parse_consumer_timeout,
1718
worker_events_policy,
1819
)
1920

@@ -104,3 +105,27 @@ async def test_worker_events_policy():
104105
assert policy == expected
105106
worker_queue_name = "WORKER_EVENT-some-service"
106107
assert re.match(policy.pattern, worker_queue_name)
108+
109+
110+
@pytest.mark.parametrize(
111+
"is_qpid_,expected_type", [(True, QpidRobustConnection), (False, RobustConnection)]
112+
)
113+
async def test_should_handle_qpid_when_creating_connection(
114+
is_qpid_, expected_type: type[AbstractRobustConnection], rabbit_mq: str
115+
):
116+
# Given
117+
class SomeClass(AMQPMixin):
118+
def __init__(self, broker_url: str, *, is_qpid: bool):
119+
super().__init__(broker_url=broker_url, is_qpid=is_qpid)
120+
121+
async def connect(self):
122+
await self._connect()
123+
124+
# When
125+
instance = SomeClass(rabbit_mq, is_qpid=is_qpid_)
126+
await instance.connect()
127+
128+
# Then
129+
assert (
130+
type(instance.connection) is expected_type # pylint: disable=unidiomatic-typecheck
131+
)

0 commit comments

Comments
 (0)