diff --git a/src/backend/apps/consumer/processor.py b/src/backend/apps/consumer/processor.py index 64b9c212f..3781f8086 100644 --- a/src/backend/apps/consumer/processor.py +++ b/src/backend/apps/consumer/processor.py @@ -16,6 +16,8 @@ import asyncio from pathlib import Path from PIL import Image, ImageDraw, ImageFont +from apps.consumer.rabbitmq import RabbitMQTokenConnection +from apps.consumer.rabbitmq import RabbitMQTokenConnection from .db import get_all_from_db, load_index_from_db from timezonefinder import TimezoneFinder from contextlib import asynccontextmanager @@ -23,7 +25,7 @@ from asgiref.sync import sync_to_async from apps.webcam.models import Region, RegionHighway, Webcam from apps.consumer.models import ImageIndex -from apps.shared.status import get_recent_timestamps, calculate_camera_status +from apps.shared.status import calculate_camera_status from botocore.config import Config from django.contrib.gis.geos import Point from django.db import close_old_connections, connection @@ -63,7 +65,6 @@ QUEUE_MAX_BYTES = int(os.getenv("RABBITMQ_QUEUE_MAX_BYTES", "209715200")) EXCHANGE_NAME = os.getenv("RABBITMQ_EXCHANGE_NAME") CAMERA_CACHE_REFRESH_SECONDS = int(os.getenv("CAMERA_CACHE_REFRESH_SECONDS", "60")) - RABBITMQ_HEARTBEAT = int(os.getenv("RABBITMQ_HEARTBEAT", "60")) RABBITMQ_TIMEOUT = int(os.getenv("RABBITMQ_TIMEOUT", "30")) RABBITMQ_RECONNECT_INTERVAL = int(os.getenv("RABBITMQ_RECONNECT_INTERVAL", "5")) @@ -100,32 +101,23 @@ tz_pst = 'America/Vancouver' -async def on_reconnect(conn): +def on_reconnect(): logger.info("RabbitMQ connection re-established") global last_activity last_activity = time.time() - + async def on_close(conn, exc=None): logger.warning(f"RabbitMQ connection closed: {exc}") async def on_channel_close(ch, exc=None): logger.warning(f"RabbitMQ channel closed: {exc}") - -async def setup_rabbitmq(rb_url: str, name: str): - connection = await aio_pika.connect_robust( - rb_url, - heartbeat=RABBITMQ_HEARTBEAT, - timeout=RABBITMQ_TIMEOUT, - reconnect_interval=RABBITMQ_RECONNECT_INTERVAL, - fail_fast=False, - ) - logger.info(f"RabbitMQ connection established for {name}.") - connection.reconnect_callbacks.add(on_reconnect) - connection.close_callbacks.add(on_close) - +async def setup_rabbitmq(host: str, port: int, name: str): + rabbitmq = RabbitMQTokenConnection() + connection = await rabbitmq.connect(host=host, port=port) + logger.info("RabbitMQ connection created.") channel = await connection.channel() - logger.info(f"RabbitMQ channel created for {name}.") + logger.info("RabbitMQ channel created.") channel.close_callbacks.add(on_channel_close) exchange = await channel.declare_exchange( @@ -161,12 +153,12 @@ async def consume_queue(queue, name: str): except Exception as e: logger.error(f"Error processing message from {name}: {e}") -async def consume_from(rb_url: str, name: str): +async def consume_from(host: str, port: str, name: str): while not stop_event.is_set(): connection = None try: - connection, queue = await setup_rabbitmq(rb_url, name) + connection, queue = await setup_rabbitmq(host, int(port), name) logger.info(f"Starting message consumption from {name}...") await consume_queue(queue, name) @@ -223,22 +215,24 @@ async def run_consumer(): Launch consumers for Gold and GoldDR in parallel. Each consumer listens to its own RabbitMQ instance. """ - gold_url = os.getenv("RABBITMQ_URL_GOLD") - golddr_url = os.getenv("RABBITMQ_URL_GOLDDR") + gold_host = os.getenv("RABBITMQ_HOST_GOLD") + gold_port = os.getenv("RABBITMQ_PORT_GOLD") + golddr_host = os.getenv("RABBITMQ_HOST_GOLDDR") + golddr_port = os.getenv("RABBITMQ_PORT_GOLDDR") - if not gold_url and not golddr_url: + if not gold_host and not golddr_host: raise RuntimeError("No RabbitMQ URLs configured. At least one is required.") tasks = [] - if gold_url: + if gold_host: logger.info("Starting GOLD consumer...") - tasks.append(asyncio.create_task(consume_from(gold_url, "GOLD"))) + tasks.append(asyncio.create_task(consume_from(gold_host, gold_port, "GOLD"))) # pass - if golddr_url: + if golddr_host: logger.info("Starting GOLDDR consumer...") - tasks.append(asyncio.create_task(consume_from(golddr_url, "GOLDDR"))) + tasks.append(asyncio.create_task(consume_from(golddr_host, golddr_port, "GOLDDR"))) # pass logger.info("All configured RabbitMQ consumers started.") @@ -754,4 +748,4 @@ class ConnectionMetrics: connect_time: float last_activity: float reconnect_count: int = 0 - messages_processed: int = 0 + messages_processed: int = 0 \ No newline at end of file diff --git a/src/backend/apps/consumer/rabbitmq.py b/src/backend/apps/consumer/rabbitmq.py new file mode 100644 index 000000000..30a2880f4 --- /dev/null +++ b/src/backend/apps/consumer/rabbitmq.py @@ -0,0 +1,77 @@ +import httpx +import os +import aio_pika +import logging +from datetime import datetime, timedelta, timezone +from aio_pika.connection import make_url + +RABBITMQ_HEARTBEAT = int(os.getenv("RABBITMQ_HEARTBEAT", "60")) +RABBITMQ_TIMEOUT = int(os.getenv("RABBITMQ_TIMEOUT", "30")) +RABBITMQ_RECONNECT_INTERVAL = int(os.getenv("RABBITMQ_RECONNECT_INTERVAL", "5")) +RABBITMQ_HOST = os.getenv("RABBITMQ_HOST") +RABBITMQ_PORT = int(os.getenv("RABBITMQ_PORT", "5672")) +RABBITMQ_VHOST = os.getenv("RABBITMQ_VHOST") +OAUTH2_TOKEN_URL = os.getenv("OAUTH2_TOKEN_URL") +OAUTH2_CLIENT_ID = os.getenv("OAUTH2_CLIENT_ID") +OAUTH2_SCOPE = os.getenv("OAUTH2_SCOPE", "") +OAUTH2_DRIVEBC_RABBITMQ_USERNAME = os.getenv("OAUTH2_DRIVEBC_RABBITMQ_USERNAME") +OAUTH2_DRIVEBC_RABBITMQ_PASSWORD = os.getenv("OAUTH2_DRIVEBC_RABBITMQ_PASSWORD") + + +logger = logging.getLogger(__name__) + +class RabbitMQTokenConnection: + def __init__(self): + self._connection = None + self._token = None + self._token_expiry = None + + async def _fetch_token(self) -> str: + now = datetime.now(timezone.utc) + if self._token and self._token_expiry and now < self._token_expiry: + return self._token + + payload = { + "grant_type": "password", + "client_id": OAUTH2_CLIENT_ID, + "username": OAUTH2_DRIVEBC_RABBITMQ_USERNAME, + "password": OAUTH2_DRIVEBC_RABBITMQ_PASSWORD, + } + + async with httpx.AsyncClient() as client: + response = await client.post(OAUTH2_TOKEN_URL, data=payload) + response.raise_for_status() + data = response.json() + + self._token = data["access_token"] + expires_in = data.get("expires_in", 300) + self._token_expiry = now + timedelta(seconds=expires_in - 60) + return self._token + + async def connect(self, host: str, port: int) -> aio_pika.RobustConnection: + token = await self._fetch_token() + + self._connection = await aio_pika.connect_robust( + host=host, + port=port, + virtualhost=RABBITMQ_VHOST, + login="", + password=token, + heartbeat=RABBITMQ_HEARTBEAT, + timeout=RABBITMQ_TIMEOUT, + reconnect_interval=RABBITMQ_RECONNECT_INTERVAL, + ) + + + # Re-fetch token on every reconnect attempt + self._connection.reconnect_callbacks.add(self._on_reconnect) + return self._connection + + async def _on_reconnect(self, connection): + """Called by aio_pika before each reconnect — refresh token.""" + logger.info("RabbitMQ reconnecting — refreshing OAuth2 token...") + try: + token = await self._fetch_token() + connection.password = token # inject fresh token + except Exception as e: + logger.error(f"Failed to refresh RabbitMQ token: {e}") diff --git a/src/backend/apps/consumer/tests/test_rabbitmq.py b/src/backend/apps/consumer/tests/test_rabbitmq.py new file mode 100644 index 000000000..9ced9fbf0 --- /dev/null +++ b/src/backend/apps/consumer/tests/test_rabbitmq.py @@ -0,0 +1,166 @@ +import httpx +from unittest import IsolatedAsyncioTestCase +from unittest.mock import AsyncMock, MagicMock, patch +from datetime import datetime, timedelta, timezone + + +class TestRabbitMQTokenConnection(IsolatedAsyncioTestCase): + + def setUp(self): + from apps.consumer.rabbitmq import RabbitMQTokenConnection + self.conn = RabbitMQTokenConnection() + + # _fetch_token + async def test_fetch_token_returns_new_token(self): + """Fetches and caches a fresh token when none exists.""" + mock_response = MagicMock() + mock_response.json.return_value = {"access_token": "test-token-123", "expires_in": 300} + mock_response.raise_for_status = MagicMock() + + with patch("apps.consumer.rabbitmq.httpx.AsyncClient") as mock_client: + mock_client.return_value.__aenter__.return_value.post = AsyncMock(return_value=mock_response) + token = await self.conn._fetch_token() + + self.assertEqual(token, "test-token-123") + self.assertEqual(self.conn._token, "test-token-123") + + async def test_fetch_token_uses_cached_token(self): + """Returns cached token when it has not expired yet.""" + self.conn._token = "cached-token" + self.conn._token_expiry = datetime.now(timezone.utc) + timedelta(seconds=120) + + with patch("apps.consumer.rabbitmq.httpx.AsyncClient") as mock_client: + token = await self.conn._fetch_token() + mock_client.assert_not_called() + + self.assertEqual(token, "cached-token") + + async def test_fetch_token_refreshes_expired_token(self): + """Fetches a new token when the cached one has expired.""" + self.conn._token = "old-token" + self.conn._token_expiry = datetime.now(timezone.utc) - timedelta(seconds=10) + + mock_response = MagicMock() + mock_response.json.return_value = {"access_token": "new-token-456", "expires_in": 300} + mock_response.raise_for_status = MagicMock() + + with patch("apps.consumer.rabbitmq.httpx.AsyncClient") as mock_client: + mock_client.return_value.__aenter__.return_value.post = AsyncMock(return_value=mock_response) + token = await self.conn._fetch_token() + + self.assertEqual(token, "new-token-456") + + async def test_fetch_token_expiry_set_correctly(self): + """Token expiry is set 60 seconds before actual expiry.""" + mock_response = MagicMock() + mock_response.json.return_value = {"access_token": "token", "expires_in": 300} + mock_response.raise_for_status = MagicMock() + + before = datetime.now(timezone.utc) + with patch("apps.consumer.rabbitmq.httpx.AsyncClient") as mock_client: + mock_client.return_value.__aenter__.return_value.post = AsyncMock(return_value=mock_response) + await self.conn._fetch_token() + after = datetime.now(timezone.utc) + + expected_min = before + timedelta(seconds=240) # 300 - 60 + expected_max = after + timedelta(seconds=240) + self.assertGreaterEqual(self.conn._token_expiry, expected_min) + self.assertLessEqual(self.conn._token_expiry, expected_max) + + async def test_fetch_token_uses_default_expires_in(self): + """Falls back to 300s expiry when expires_in is missing from response.""" + mock_response = MagicMock() + mock_response.json.return_value = {"access_token": "token"} # no expires_in + mock_response.raise_for_status = MagicMock() + + before = datetime.now(timezone.utc) + with patch("apps.consumer.rabbitmq.httpx.AsyncClient") as mock_client: + mock_client.return_value.__aenter__.return_value.post = AsyncMock(return_value=mock_response) + await self.conn._fetch_token() + + expected = before + timedelta(seconds=240) # default 300 - 60 + self.assertGreaterEqual(self.conn._token_expiry, expected) + + async def test_fetch_token_raises_on_http_error(self): + """Propagates HTTP errors from the token endpoint.""" + with patch("apps.consumer.rabbitmq.httpx.AsyncClient") as mock_client: + mock_post = AsyncMock(side_effect=httpx.HTTPStatusError( + "401", request=MagicMock(), response=MagicMock() + )) + mock_client.return_value.__aenter__.return_value.post = mock_post + + with self.assertRaises(httpx.HTTPStatusError): + await self.conn._fetch_token() + + # connect + async def test_connect_returns_connection(self): + """connect() fetches token and returns a RobustConnection.""" + self.conn._token = "valid-token" + self.conn._token_expiry = datetime.now(timezone.utc) + timedelta(seconds=120) + + mock_connection = MagicMock() + mock_connection.reconnect_callbacks = MagicMock() + mock_connection.reconnect_callbacks.add = MagicMock() + + with patch("apps.consumer.rabbitmq.aio_pika.connect_robust", new=AsyncMock(return_value=mock_connection)): + result = await self.conn.connect() + + self.assertEqual(result, mock_connection) + self.assertEqual(self.conn._connection, mock_connection) + + async def test_connect_registers_reconnect_callback(self): + """connect() registers _on_reconnect as a reconnect callback.""" + self.conn._token = "valid-token" + self.conn._token_expiry = datetime.now(timezone.utc) + timedelta(seconds=120) + + mock_connection = MagicMock() + mock_connection.reconnect_callbacks = MagicMock() + mock_connection.reconnect_callbacks.add = MagicMock() + + with patch("apps.consumer.rabbitmq.aio_pika.connect_robust", new=AsyncMock(return_value=mock_connection)): + await self.conn.connect() + + mock_connection.reconnect_callbacks.add.assert_called_once_with(self.conn._on_reconnect) + + async def test_connect_passes_correct_credentials(self): + """connect() uses empty login and token as password.""" + self.conn._token = "my-oauth-token" + self.conn._token_expiry = datetime.now(timezone.utc) + timedelta(seconds=120) + + mock_connection = MagicMock() + mock_connection.reconnect_callbacks.add = MagicMock() + + with patch("apps.consumer.rabbitmq.aio_pika.connect_robust", new=AsyncMock(return_value=mock_connection)) as mock_connect: + await self.conn.connect() + _, kwargs = mock_connect.call_args + self.assertEqual(kwargs["login"], "") + self.assertEqual(kwargs["password"], "my-oauth-token") + + # _on_reconnect + async def test_on_reconnect_refreshes_token(self): + """_on_reconnect injects a fresh token into the connection.""" + mock_connection = MagicMock() + self.conn._token = "old-token" + self.conn._token_expiry = datetime.now(timezone.utc) - timedelta(seconds=10) + + mock_response = MagicMock() + mock_response.json.return_value = {"access_token": "refreshed-token", "expires_in": 300} + mock_response.raise_for_status = MagicMock() + + with patch("apps.consumer.rabbitmq.httpx.AsyncClient") as mock_client: + mock_client.return_value.__aenter__.return_value.post = AsyncMock(return_value=mock_response) + await self.conn._on_reconnect(mock_connection) + + self.assertEqual(mock_connection.password, "refreshed-token") + + async def test_on_reconnect_logs_error_on_failure(self): + """_on_reconnect logs an error if token refresh fails.""" + mock_connection = MagicMock() + + with patch("apps.consumer.rabbitmq.httpx.AsyncClient") as mock_client: + mock_client.return_value.__aenter__.return_value.post = AsyncMock( + side_effect=Exception("network error") + ) + with patch("apps.consumer.rabbitmq.logger") as mock_logger: + await self.conn._on_reconnect(mock_connection) + mock_logger.error.assert_called_once() \ No newline at end of file