Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 22 additions & 28 deletions src/backend/apps/consumer/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@
import asyncio
from pathlib import Path
from PIL import Image, ImageDraw, ImageFont
from apps.consumer.rabbitmq import RabbitMQTokenConnection
from apps.consumer.rabbitmq import RabbitMQTokenConnection
from .db import get_all_from_db, load_index_from_db
from timezonefinder import TimezoneFinder
from contextlib import asynccontextmanager
from aiormq.exceptions import ChannelInvalidStateError
from asgiref.sync import sync_to_async
from apps.webcam.models import Region, RegionHighway, Webcam
from apps.consumer.models import ImageIndex
from apps.shared.status import get_recent_timestamps, calculate_camera_status
from apps.shared.status import calculate_camera_status
from botocore.config import Config
from django.contrib.gis.geos import Point
from django.db import close_old_connections, connection
Expand Down Expand Up @@ -63,7 +65,6 @@
QUEUE_MAX_BYTES = int(os.getenv("RABBITMQ_QUEUE_MAX_BYTES", "209715200"))
EXCHANGE_NAME = os.getenv("RABBITMQ_EXCHANGE_NAME")
CAMERA_CACHE_REFRESH_SECONDS = int(os.getenv("CAMERA_CACHE_REFRESH_SECONDS", "60"))

RABBITMQ_HEARTBEAT = int(os.getenv("RABBITMQ_HEARTBEAT", "60"))
RABBITMQ_TIMEOUT = int(os.getenv("RABBITMQ_TIMEOUT", "30"))
RABBITMQ_RECONNECT_INTERVAL = int(os.getenv("RABBITMQ_RECONNECT_INTERVAL", "5"))
Expand Down Expand Up @@ -100,32 +101,23 @@

tz_pst = 'America/Vancouver'

async def on_reconnect(conn):
def on_reconnect():
logger.info("RabbitMQ connection re-established")
global last_activity
last_activity = time.time()

async def on_close(conn, exc=None):
logger.warning(f"RabbitMQ connection closed: {exc}")

async def on_channel_close(ch, exc=None):
logger.warning(f"RabbitMQ channel closed: {exc}")


async def setup_rabbitmq(rb_url: str, name: str):
connection = await aio_pika.connect_robust(
rb_url,
heartbeat=RABBITMQ_HEARTBEAT,
timeout=RABBITMQ_TIMEOUT,
reconnect_interval=RABBITMQ_RECONNECT_INTERVAL,
fail_fast=False,
)
logger.info(f"RabbitMQ connection established for {name}.")
connection.reconnect_callbacks.add(on_reconnect)
connection.close_callbacks.add(on_close)

async def setup_rabbitmq(host: str, port: int, name: str):
rabbitmq = RabbitMQTokenConnection()
connection = await rabbitmq.connect(host=host, port=port)
logger.info("RabbitMQ connection created.")
channel = await connection.channel()
logger.info(f"RabbitMQ channel created for {name}.")
logger.info("RabbitMQ channel created.")
channel.close_callbacks.add(on_channel_close)

exchange = await channel.declare_exchange(
Expand Down Expand Up @@ -161,12 +153,12 @@ async def consume_queue(queue, name: str):
except Exception as e:
logger.error(f"Error processing message from {name}: {e}")

async def consume_from(rb_url: str, name: str):
async def consume_from(host: str, port: str, name: str):
while not stop_event.is_set():
connection = None

try:
connection, queue = await setup_rabbitmq(rb_url, name)
connection, queue = await setup_rabbitmq(host, int(port), name)

logger.info(f"Starting message consumption from {name}...")
await consume_queue(queue, name)
Expand Down Expand Up @@ -223,22 +215,24 @@ async def run_consumer():
Launch consumers for Gold and GoldDR in parallel.
Each consumer listens to its own RabbitMQ instance.
"""
gold_url = os.getenv("RABBITMQ_URL_GOLD")
golddr_url = os.getenv("RABBITMQ_URL_GOLDDR")
gold_host = os.getenv("RABBITMQ_HOST_GOLD")
gold_port = os.getenv("RABBITMQ_PORT_GOLD")
golddr_host = os.getenv("RABBITMQ_HOST_GOLDDR")
golddr_port = os.getenv("RABBITMQ_PORT_GOLDDR")

if not gold_url and not golddr_url:
if not gold_host and not golddr_host:
raise RuntimeError("No RabbitMQ URLs configured. At least one is required.")

tasks = []

if gold_url:
if gold_host:
logger.info("Starting GOLD consumer...")
tasks.append(asyncio.create_task(consume_from(gold_url, "GOLD")))
tasks.append(asyncio.create_task(consume_from(gold_host, gold_port, "GOLD")))
# pass

if golddr_url:
if golddr_host:
logger.info("Starting GOLDDR consumer...")
tasks.append(asyncio.create_task(consume_from(golddr_url, "GOLDDR")))
tasks.append(asyncio.create_task(consume_from(golddr_host, golddr_port, "GOLDDR")))
# pass

logger.info("All configured RabbitMQ consumers started.")
Expand Down Expand Up @@ -754,4 +748,4 @@ class ConnectionMetrics:
connect_time: float
last_activity: float
reconnect_count: int = 0
messages_processed: int = 0
messages_processed: int = 0
77 changes: 77 additions & 0 deletions src/backend/apps/consumer/rabbitmq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import httpx
import os
import aio_pika
import logging
from datetime import datetime, timedelta, timezone
from aio_pika.connection import make_url

RABBITMQ_HEARTBEAT = int(os.getenv("RABBITMQ_HEARTBEAT", "60"))
RABBITMQ_TIMEOUT = int(os.getenv("RABBITMQ_TIMEOUT", "30"))
RABBITMQ_RECONNECT_INTERVAL = int(os.getenv("RABBITMQ_RECONNECT_INTERVAL", "5"))
RABBITMQ_HOST = os.getenv("RABBITMQ_HOST")
RABBITMQ_PORT = int(os.getenv("RABBITMQ_PORT", "5672"))
RABBITMQ_VHOST = os.getenv("RABBITMQ_VHOST")
OAUTH2_TOKEN_URL = os.getenv("OAUTH2_TOKEN_URL")
OAUTH2_CLIENT_ID = os.getenv("OAUTH2_CLIENT_ID")
OAUTH2_SCOPE = os.getenv("OAUTH2_SCOPE", "")
OAUTH2_DRIVEBC_RABBITMQ_USERNAME = os.getenv("OAUTH2_DRIVEBC_RABBITMQ_USERNAME")
OAUTH2_DRIVEBC_RABBITMQ_PASSWORD = os.getenv("OAUTH2_DRIVEBC_RABBITMQ_PASSWORD")


logger = logging.getLogger(__name__)

class RabbitMQTokenConnection:
def __init__(self):
self._connection = None
self._token = None
self._token_expiry = None

async def _fetch_token(self) -> str:
now = datetime.now(timezone.utc)
if self._token and self._token_expiry and now < self._token_expiry:
return self._token

payload = {
"grant_type": "password",
"client_id": OAUTH2_CLIENT_ID,
"username": OAUTH2_DRIVEBC_RABBITMQ_USERNAME,
"password": OAUTH2_DRIVEBC_RABBITMQ_PASSWORD,
}

async with httpx.AsyncClient() as client:
response = await client.post(OAUTH2_TOKEN_URL, data=payload)
response.raise_for_status()
data = response.json()

self._token = data["access_token"]
expires_in = data.get("expires_in", 300)
self._token_expiry = now + timedelta(seconds=expires_in - 60)
return self._token

async def connect(self, host: str, port: int) -> aio_pika.RobustConnection:
token = await self._fetch_token()

self._connection = await aio_pika.connect_robust(
host=host,
port=port,
virtualhost=RABBITMQ_VHOST,
login="",
password=token,
heartbeat=RABBITMQ_HEARTBEAT,
timeout=RABBITMQ_TIMEOUT,
reconnect_interval=RABBITMQ_RECONNECT_INTERVAL,
)


# Re-fetch token on every reconnect attempt
self._connection.reconnect_callbacks.add(self._on_reconnect)
return self._connection

async def _on_reconnect(self, connection):
"""Called by aio_pika before each reconnect β€” refresh token."""
logger.info("RabbitMQ reconnecting β€” refreshing OAuth2 token...")
try:
token = await self._fetch_token()
connection.password = token # inject fresh token
except Exception as e:
logger.error(f"Failed to refresh RabbitMQ token: {e}")
166 changes: 166 additions & 0 deletions src/backend/apps/consumer/tests/test_rabbitmq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import httpx
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch
from datetime import datetime, timedelta, timezone


class TestRabbitMQTokenConnection(IsolatedAsyncioTestCase):

def setUp(self):
from apps.consumer.rabbitmq import RabbitMQTokenConnection
self.conn = RabbitMQTokenConnection()

# _fetch_token
async def test_fetch_token_returns_new_token(self):
"""Fetches and caches a fresh token when none exists."""
mock_response = MagicMock()
mock_response.json.return_value = {"access_token": "test-token-123", "expires_in": 300}
mock_response.raise_for_status = MagicMock()

with patch("apps.consumer.rabbitmq.httpx.AsyncClient") as mock_client:
mock_client.return_value.__aenter__.return_value.post = AsyncMock(return_value=mock_response)
token = await self.conn._fetch_token()

self.assertEqual(token, "test-token-123")
self.assertEqual(self.conn._token, "test-token-123")

async def test_fetch_token_uses_cached_token(self):
"""Returns cached token when it has not expired yet."""
self.conn._token = "cached-token"
self.conn._token_expiry = datetime.now(timezone.utc) + timedelta(seconds=120)

with patch("apps.consumer.rabbitmq.httpx.AsyncClient") as mock_client:
token = await self.conn._fetch_token()
mock_client.assert_not_called()

self.assertEqual(token, "cached-token")

async def test_fetch_token_refreshes_expired_token(self):
"""Fetches a new token when the cached one has expired."""
self.conn._token = "old-token"
self.conn._token_expiry = datetime.now(timezone.utc) - timedelta(seconds=10)

mock_response = MagicMock()
mock_response.json.return_value = {"access_token": "new-token-456", "expires_in": 300}
mock_response.raise_for_status = MagicMock()

with patch("apps.consumer.rabbitmq.httpx.AsyncClient") as mock_client:
mock_client.return_value.__aenter__.return_value.post = AsyncMock(return_value=mock_response)
token = await self.conn._fetch_token()

self.assertEqual(token, "new-token-456")

async def test_fetch_token_expiry_set_correctly(self):
"""Token expiry is set 60 seconds before actual expiry."""
mock_response = MagicMock()
mock_response.json.return_value = {"access_token": "token", "expires_in": 300}
mock_response.raise_for_status = MagicMock()

before = datetime.now(timezone.utc)
with patch("apps.consumer.rabbitmq.httpx.AsyncClient") as mock_client:
mock_client.return_value.__aenter__.return_value.post = AsyncMock(return_value=mock_response)
await self.conn._fetch_token()
after = datetime.now(timezone.utc)

expected_min = before + timedelta(seconds=240) # 300 - 60
expected_max = after + timedelta(seconds=240)
self.assertGreaterEqual(self.conn._token_expiry, expected_min)
self.assertLessEqual(self.conn._token_expiry, expected_max)

async def test_fetch_token_uses_default_expires_in(self):
"""Falls back to 300s expiry when expires_in is missing from response."""
mock_response = MagicMock()
mock_response.json.return_value = {"access_token": "token"} # no expires_in
mock_response.raise_for_status = MagicMock()

before = datetime.now(timezone.utc)
with patch("apps.consumer.rabbitmq.httpx.AsyncClient") as mock_client:
mock_client.return_value.__aenter__.return_value.post = AsyncMock(return_value=mock_response)
await self.conn._fetch_token()

expected = before + timedelta(seconds=240) # default 300 - 60
self.assertGreaterEqual(self.conn._token_expiry, expected)

async def test_fetch_token_raises_on_http_error(self):
"""Propagates HTTP errors from the token endpoint."""
with patch("apps.consumer.rabbitmq.httpx.AsyncClient") as mock_client:
mock_post = AsyncMock(side_effect=httpx.HTTPStatusError(
"401", request=MagicMock(), response=MagicMock()
))
mock_client.return_value.__aenter__.return_value.post = mock_post

with self.assertRaises(httpx.HTTPStatusError):
await self.conn._fetch_token()

# connect
async def test_connect_returns_connection(self):
"""connect() fetches token and returns a RobustConnection."""
self.conn._token = "valid-token"
self.conn._token_expiry = datetime.now(timezone.utc) + timedelta(seconds=120)

mock_connection = MagicMock()
mock_connection.reconnect_callbacks = MagicMock()
mock_connection.reconnect_callbacks.add = MagicMock()

with patch("apps.consumer.rabbitmq.aio_pika.connect_robust", new=AsyncMock(return_value=mock_connection)):
result = await self.conn.connect()

self.assertEqual(result, mock_connection)
self.assertEqual(self.conn._connection, mock_connection)

async def test_connect_registers_reconnect_callback(self):
"""connect() registers _on_reconnect as a reconnect callback."""
self.conn._token = "valid-token"
self.conn._token_expiry = datetime.now(timezone.utc) + timedelta(seconds=120)

mock_connection = MagicMock()
mock_connection.reconnect_callbacks = MagicMock()
mock_connection.reconnect_callbacks.add = MagicMock()

with patch("apps.consumer.rabbitmq.aio_pika.connect_robust", new=AsyncMock(return_value=mock_connection)):
await self.conn.connect()

mock_connection.reconnect_callbacks.add.assert_called_once_with(self.conn._on_reconnect)

async def test_connect_passes_correct_credentials(self):
"""connect() uses empty login and token as password."""
self.conn._token = "my-oauth-token"
self.conn._token_expiry = datetime.now(timezone.utc) + timedelta(seconds=120)

mock_connection = MagicMock()
mock_connection.reconnect_callbacks.add = MagicMock()

with patch("apps.consumer.rabbitmq.aio_pika.connect_robust", new=AsyncMock(return_value=mock_connection)) as mock_connect:
await self.conn.connect()
_, kwargs = mock_connect.call_args
self.assertEqual(kwargs["login"], "")
self.assertEqual(kwargs["password"], "my-oauth-token")

# _on_reconnect
async def test_on_reconnect_refreshes_token(self):
"""_on_reconnect injects a fresh token into the connection."""
mock_connection = MagicMock()
self.conn._token = "old-token"
self.conn._token_expiry = datetime.now(timezone.utc) - timedelta(seconds=10)

mock_response = MagicMock()
mock_response.json.return_value = {"access_token": "refreshed-token", "expires_in": 300}
mock_response.raise_for_status = MagicMock()

with patch("apps.consumer.rabbitmq.httpx.AsyncClient") as mock_client:
mock_client.return_value.__aenter__.return_value.post = AsyncMock(return_value=mock_response)
await self.conn._on_reconnect(mock_connection)

self.assertEqual(mock_connection.password, "refreshed-token")

async def test_on_reconnect_logs_error_on_failure(self):
"""_on_reconnect logs an error if token refresh fails."""
mock_connection = MagicMock()

with patch("apps.consumer.rabbitmq.httpx.AsyncClient") as mock_client:
mock_client.return_value.__aenter__.return_value.post = AsyncMock(
side_effect=Exception("network error")
)
with patch("apps.consumer.rabbitmq.logger") as mock_logger:
await self.conn._on_reconnect(mock_connection)
mock_logger.error.assert_called_once()
Loading