diff --git a/src/common/core/utils.py b/src/common/core/utils.py index a7e1167d..301edcd6 100644 --- a/src/common/core/utils.py +++ b/src/common/core/utils.py @@ -4,7 +4,7 @@ import random from functools import lru_cache from itertools import cycle -from typing import Iterator, Literal, NotRequired, TypedDict, TypeVar +from typing import Iterator, Literal, NotRequired, TypedDict, TypeVar, get_args from django.conf import settings from django.contrib.auth import get_user_model @@ -130,6 +130,14 @@ def get_file_contents(file_path: str) -> str | None: return None +@lru_cache() +def is_database_replica_setup() -> bool: + """Checks if any database replica is set up""" + return any( + name for name in connections if name.startswith(get_args(ReplicaNamePrefix)) + ) + + def using_database_replica( manager: ManagerType, replica_prefix: ReplicaNamePrefix = "replica_", diff --git a/tests/unit/common/core/test_utils.py b/tests/unit/common/core/test_utils.py index 5aefc105..777c2862 100644 --- a/tests/unit/common/core/test_utils.py +++ b/tests/unit/common/core/test_utils.py @@ -15,6 +15,7 @@ get_version_info, get_versions_from_manifest, has_email_provider, + is_database_replica_setup, is_enterprise, is_oss, is_saas, @@ -204,6 +205,37 @@ def test_get_version__invalid_file_contents__returns_unknown( assert result == "unknown" +@pytest.mark.parametrize( + ["database_names", "expected"], + [ + ({"default"}, False), + ({"default", "another_database_with_'replica'_in_its_name"}, False), + ({"default", "task_processor"}, False), + ({"default", "replica_1"}, True), + ({"default", "replica_1", "replica_2"}, True), + ({"default", "cross_region_replica_1"}, True), + ({"default", "replica_1", "cross_region_replica_1"}, True), + ], +) +def test_is_database_replica_setup__tells_whether_any_replica_is_present( + database_names: list[str], + expected: bool, + mocker: MockerFixture, +) -> None: + # Given + is_database_replica_setup.cache_clear() + mocker.patch( + "common.core.utils.connections", + {name: connections["default"] for name in database_names}, + ) + + # When + result = is_database_replica_setup() + + # Then + assert result is expected + + @pytest.mark.django_db(databases="__all__") def test_using_database_replica__no_replicas__points_to_default( django_assert_num_queries: DjangoAssertNumQueries,