diff --git a/auth_content/models/__init__.py b/auth_content/models/__init__.py
deleted file mode 100644
index 4d59541cb..000000000
--- a/auth_content/models/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from auth_content.models import permission_sets, users
diff --git a/auth_content/__init__.py b/cms/auth_content/__init__.py
similarity index 100%
rename from auth_content/__init__.py
rename to cms/auth_content/__init__.py
diff --git a/cms/auth_content/auth_utils.py b/cms/auth_content/auth_utils.py
new file mode 100644
index 000000000..cb54ffd23
--- /dev/null
+++ b/cms/auth_content/auth_utils.py
@@ -0,0 +1,298 @@
+import logging
+from collections.abc import Callable
+
+from django import forms
+
+from cms.auth_content.constants import WILDCARD_ID_VALUE
+from cms.dynamic_content import help_texts
+from metrics.data.models.core_models.supporting import (
+ Geography,
+ GeographyType,
+ Metric,
+ Topic,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def check_permissions_by_name(
+ permission_sets,
+ theme_name,
+ sub_theme_name,
+ topic_name,
+ metric_name,
+ geography_type,
+ geography_name,
+) -> bool:
+ """
+ This is a wrapper that converts permission resource names
+ into ids. It is only used to check CHART permissions.
+ """
+
+ logger.info("Entered check_permissions_by_name()")
+
+ theme_id, sub_theme_id, topic_id = Topic.objects.get_id_by_name(
+ theme_name, sub_theme_name, topic_name
+ )
+ metric_id = Metric.objects.get_id_by_name(metric_name)
+ geography_type_id = GeographyType.objects.get_id_by_name(geography_type)
+ geography_id = Geography.objects.get_id_by_name(geography_name)
+
+ # Be safe, just in case a NAME doesn't have an ID
+ if any(
+ value == -2
+ for value in (
+ theme_id,
+ sub_theme_id,
+ topic_id,
+ metric_id,
+ geography_type_id,
+ geography_id,
+ )
+ ):
+ return False
+
+ return check_permission_set(
+ permission_sets,
+ theme_id,
+ sub_theme_id,
+ topic_id,
+ metric_id,
+ geography_type_id,
+ geography_id,
+ )
+
+
+def check_permission_set(
+ permission_sets,
+ theme_id,
+ sub_theme_id,
+ topic_id,
+ metric_id,
+ geography_type,
+ geography_id,
+) -> bool:
+ """
+ This is a wrapper that only checks for global permissions, and
+ delegates further checks to our core permission checking function.
+ It is only used to check CHART permissions.
+
+ @param {dict} permission_sets which contains a permission_sets list, eg:
+ {
+ "permission_sets": [
+ {
+ "theme": {"id": "100", "name": "immunisation"},
+ "sub_theme": {"id": "200", "name": "childhood-vaccines"},
+ "topic": {"id": "-1", "name": "* (All)"},
+ "metric": {"id": "-1", "name": "* (All)"},
+ "geography_type": {"id": "300", "name": "Nation"},
+ "geography": {"id": "-1", "name": "* (All)"},
+ }
+ ],
+ "summary": {"has_global_access": False},
+ }
+ """
+
+ logger.info("Entered check_permission_set()")
+
+ if not isinstance(permission_sets, dict):
+ return False
+ if not isinstance(permission_sets.get("permission_sets"), list):
+ return False
+ if not isinstance(permission_sets.get("summary"), dict):
+ return False
+ if not isinstance(permission_sets.get("summary").get("has_global_access"), bool):
+ return False
+
+ if permission_sets.get("summary").get("has_global_access"):
+ return True
+
+ return check_permissions(
+ permission_sets.get("permission_sets"),
+ theme_id,
+ sub_theme_id,
+ topic_id,
+ metric_id,
+ geography_type,
+ geography_id,
+ )
+
+
+def check_permissions(
+ permission_sets,
+ theme_id,
+ sub_theme_id,
+ topic_id,
+ metric_id=None,
+ geography_type=None,
+ geography_id=None,
+) -> bool:
+ """
+ This is our core permission-checking function It is
+ used to check both PAGE & CHART permissions.
+
+ Metric- and geography-related permissions must be
+ evaluated separately (spec says).
+
+ @param {list} permission_sets, eg:
+ [
+ {
+ "theme": {"id": "100", "name": "immunisation"},
+ "sub_theme": {"id": "200", "name": "childhood-vaccines"},
+ "topic": {"id": "-1", "name": "* (All)"},
+ "metric": {"id": "-1", "name": "* (All)"},
+ "geography_type": {"id": "300", "name": "Nation"},
+ "geography": {"id": "-1", "name": "* (All)"},
+ }
+ ]
+ """
+
+ logger.info("Entered check_permissions()")
+
+ if not isinstance(permission_sets, list):
+ return False
+
+ for permission_set in permission_sets:
+ if geography_type and geography_id:
+ # CHART permissions
+ if check_metric_related_permissions(
+ permission_set, theme_id, sub_theme_id, topic_id, metric_id
+ ) and check_geography_permissions(
+ permission_set, geography_type, geography_id
+ ):
+ return True
+ else:
+ # PAGE permissions
+ if check_metric_related_permissions(
+ permission_set, theme_id, sub_theme_id, topic_id, metric_id
+ ):
+ return True
+
+ return False
+
+
+def check_metric_related_permissions(
+ permission_set,
+ theme_id,
+ sub_theme_id,
+ topic_id,
+ metric_id=None,
+) -> bool:
+ """
+ Make sure that every theme, sub_theme, topic and metric
+ match or have a wildcard at the end (only look at the
+ first 4 attributes of permission_set).
+
+ @param {dict} permission_set, eg:
+ {
+ "theme": {"id": "100", "name": "immunisation"},
+ "sub_theme": {"id": "200", "name": "childhood-vaccines"},
+ "topic": {"id": "-1", "name": "* (All)"},
+ "metric": {"id": "-1", "name": "* (All)"},
+ "geography_type": {"id": "300", "name": "Nation"},
+ "geography": {"id": "-1", "name": "* (All)"},
+ }
+ """
+
+ logger.info("Entered check_metric_related_permissions()")
+
+ if not isinstance(permission_set, dict):
+ return False
+
+ theme_id = str(theme_id)
+ sub_theme_id = str(sub_theme_id)
+ topic_id = str(topic_id)
+ metric_id = str(metric_id)
+
+ permission_theme_id = str(permission_set.get("theme", {}).get("id"))
+ permission_sub_theme_id = str(permission_set.get("sub_theme", {}).get("id"))
+ permission_topic_id = str(permission_set.get("topic", {}).get("id"))
+ permission_metric_id = str(permission_set.get("metric", {}).get("id"))
+
+ if permission_theme_id == WILDCARD_ID_VALUE:
+ return True
+
+ if permission_theme_id == theme_id and permission_sub_theme_id == WILDCARD_ID_VALUE:
+ return True
+
+ if (
+ permission_theme_id == theme_id
+ and permission_sub_theme_id == sub_theme_id
+ and permission_topic_id in {WILDCARD_ID_VALUE, topic_id}
+ ):
+ return True
+
+ if (
+ permission_theme_id == theme_id
+ and permission_sub_theme_id == sub_theme_id
+ and permission_topic_id == topic_id
+ and permission_metric_id in {WILDCARD_ID_VALUE, metric_id}
+ ):
+ return True
+
+ return False
+
+
+def check_geography_permissions(
+ permission_set,
+ geography_type=None,
+ geography_id=None,
+) -> bool:
+ """
+ Make sure that both geography_type and geography
+ match or have a wildcard at the end (only look at the
+ first 2 attributes of permission_set).
+
+ @param {dict} permission_set, eg:
+ {
+ "theme": {"id": "100", "name": "immunisation"},
+ "sub_theme": {"id": "200", "name": "childhood-vaccines"},
+ "topic": {"id": "-1", "name": "* (All)"},
+ "metric": {"id": "-1", "name": "* (All)"},
+ "geography_type": {"id": "300", "name": "Nation"},
+ "geography": {"id": "-1", "name": "* (All)"},
+ }
+ """
+
+ logger.info("Entered check_geography_permissions()")
+
+ if not isinstance(permission_set, dict):
+ return False
+
+ geography_type = str(geography_type)
+ geography_id = str(geography_id)
+
+ permission_geography_type = str(permission_set.get("geography_type", {}).get("id"))
+ permission_geography_id = str(permission_set.get("geography", {}).get("id"))
+
+ if permission_geography_type == WILDCARD_ID_VALUE:
+ return True
+
+ if permission_geography_type == geography_type and permission_geography_id in {
+ WILDCARD_ID_VALUE,
+ geography_id,
+ }:
+ return True
+
+ return False
+
+
+def _create_form_field(
+ field: dict[str, str | Callable | None], wildcard_id_value=None
+) -> forms.CharField:
+ choices = [
+ ("", field["field_choice_default"]),
+ ]
+
+ if field["field_choice_wildcard"]:
+ choices += [(wildcard_id_value, field["field_choice_wildcard"])]
+
+ if field["field_choice_callable"]:
+ choices += field["field_choice_callable"]()
+
+ return forms.CharField(
+ required=False,
+ label=field["field_label"],
+ widget=forms.Select(choices=choices),
+ help_text=help_texts.NON_PUBLIC_PAGE_REQUIRED,
+ )
diff --git a/auth_content/constants.py b/cms/auth_content/constants.py
similarity index 100%
rename from auth_content/constants.py
rename to cms/auth_content/constants.py
diff --git a/auth_content/migrations/0001_initial.py b/cms/auth_content/migrations/0001_initial.py
similarity index 100%
rename from auth_content/migrations/0001_initial.py
rename to cms/auth_content/migrations/0001_initial.py
diff --git a/auth_content/migrations/0002_alter_permissionset_geography_type_and_more.py b/cms/auth_content/migrations/0002_alter_permissionset_geography_type_and_more.py
similarity index 100%
rename from auth_content/migrations/0002_alter_permissionset_geography_type_and_more.py
rename to cms/auth_content/migrations/0002_alter_permissionset_geography_type_and_more.py
diff --git a/cms/auth_content/migrations/0003_permissionset_display_name_and_more.py b/cms/auth_content/migrations/0003_permissionset_display_name_and_more.py
new file mode 100644
index 000000000..1b66d2218
--- /dev/null
+++ b/cms/auth_content/migrations/0003_permissionset_display_name_and_more.py
@@ -0,0 +1,31 @@
+# Generated by Django 5.2.13 on 2026-05-22 08:25
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ("auth_content", "0002_alter_permissionset_geography_type_and_more"),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name="permissionset",
+ name="display_name",
+ field=models.CharField(
+ blank=True,
+ help_text="\nThis is an (optional) user readable name for the permission set. If not set, a default autogenerated name will be used.\n",
+ max_length=255,
+ null=True,
+ ),
+ ),
+ migrations.AddConstraint(
+ model_name="permissionset",
+ constraint=models.UniqueConstraint(
+ condition=models.Q(("display_name__isnull", False)),
+ fields=("display_name",),
+ name="unique_non_null_display_name",
+ ),
+ ),
+ ]
diff --git a/auth_content/migrations/__init__.py b/cms/auth_content/migrations/__init__.py
similarity index 100%
rename from auth_content/migrations/__init__.py
rename to cms/auth_content/migrations/__init__.py
diff --git a/cms/auth_content/models/__init__.py b/cms/auth_content/models/__init__.py
new file mode 100644
index 000000000..93449c3ce
--- /dev/null
+++ b/cms/auth_content/models/__init__.py
@@ -0,0 +1,2 @@
+from cms.auth_content.models import users
+from cms.auth_content.models import permission_sets
diff --git a/auth_content/models/permission_sets.py b/cms/auth_content/models/permission_sets.py
similarity index 87%
rename from auth_content/models/permission_sets.py
rename to cms/auth_content/models/permission_sets.py
index b3f70937e..ba5e2ca8b 100644
--- a/auth_content/models/permission_sets.py
+++ b/cms/auth_content/models/permission_sets.py
@@ -1,13 +1,13 @@
-from collections.abc import Callable
from itertools import starmap
-from django import forms
from django.core.exceptions import ValidationError
from django.db import models
-from wagtail.admin.forms import WagtailAdminModelForm
+from wagtail.admin.forms import WagtailAdminPageForm
from wagtail.admin.panels import FieldPanel, mark_safe
-from auth_content.constants import PERMISSION_SET_FIELDS, WILDCARD_ID_VALUE
+from cms.auth_content.auth_utils import _create_form_field
+from cms.auth_content.constants import PERMISSION_SET_FIELDS, WILDCARD_ID_VALUE
+from cms.dynamic_content import help_texts
from cms.metrics_interface.field_choices_callables import (
get_all_geography_names_and_codes,
get_all_geography_type_names_and_ids,
@@ -18,41 +18,14 @@
)
-def get_theme_child_map():
- """Returns an object of all parent to child mappings
- e.g.
- {
- infectious_disease: [vaccine_preventable, respiratory ....],
- extreme_event: [weather_alert, mortality_report...]
- ...
- }
-
- """
- return {}
-
-
-def _create_form_field(field: dict[str, str | Callable | None]) -> forms.CharField:
- choices = [
- ("", field["field_choice_default"]),
- ]
-
- if field["field_choice_wildcard"]:
- choices += [(WILDCARD_ID_VALUE, field["field_choice_wildcard"])]
-
- if field["field_choice_callable"]:
- choices += field["field_choice_callable"]()
-
- return forms.CharField(
- required=True, label=field["field_label"], widget=forms.Select(choices=choices)
- )
-
-
-class PermissionSetForm(WagtailAdminModelForm):
+class PermissionSetForm(WagtailAdminPageForm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
for field in PERMISSION_SET_FIELDS:
- self.fields[field["field_name"]] = _create_form_field(field)
+ self.fields[field["field_name"]] = _create_form_field(
+ field, WILDCARD_ID_VALUE
+ )
if self.instance and self.instance.pk:
self._initialize_dependent_fields()
@@ -110,6 +83,9 @@ def clean(self):
return cleaned_data
+ class Media:
+ js = ["js/permission_set.js"]
+
class PermissionSet(models.Model):
name = models.CharField(
@@ -118,6 +94,12 @@ class PermissionSet(models.Model):
editable=False,
help_text="Auto-generated display name",
)
+ display_name = models.CharField(
+ max_length=255,
+ blank=True,
+ null=True,
+ help_text=help_texts.PERMISSION_SET_DISPLAY_NAME,
+ )
theme = models.CharField(max_length=255, blank=False, default="")
sub_theme = models.CharField(max_length=255, blank=False, default="")
topic = models.CharField(max_length=255, blank=False, default="")
@@ -133,6 +115,7 @@ def permission_set_details(self):
return mark_safe("
".join(parts))
panels = [
+ FieldPanel("display_name"),
FieldPanel("theme"),
FieldPanel("sub_theme"),
FieldPanel("topic"),
@@ -153,7 +136,12 @@ class Meta:
"geography",
],
name="unique_permission_set",
- )
+ ),
+ models.UniqueConstraint(
+ fields=["display_name"],
+ condition=models.Q(display_name__isnull=False),
+ name="unique_non_null_display_name",
+ ),
]
def save(self, *args, **kwargs):
@@ -241,4 +229,4 @@ def _find_label_in_choices(choices: list[tuple], value: str) -> str:
)
def __str__(self):
- return self.name or f"Permission Set {self.id}"
+ return self.display_name or self.name or f"Permission Set {self.id}"
diff --git a/auth_content/models/users.py b/cms/auth_content/models/users.py
similarity index 100%
rename from auth_content/models/users.py
rename to cms/auth_content/models/users.py
diff --git a/auth_content/static/js/permission_set.js b/cms/auth_content/static/js/permission_set.js
similarity index 100%
rename from auth_content/static/js/permission_set.js
rename to cms/auth_content/static/js/permission_set.js
diff --git a/auth_content/wagtail_hooks.py b/cms/auth_content/wagtail_hooks.py
similarity index 77%
rename from auth_content/wagtail_hooks.py
rename to cms/auth_content/wagtail_hooks.py
index 9a60e2fd2..77852a1d1 100644
--- a/auth_content/wagtail_hooks.py
+++ b/cms/auth_content/wagtail_hooks.py
@@ -1,5 +1,3 @@
-from django.templatetags.static import static
-from django.utils.html import format_html
from wagtail import hooks
from wagtail.admin.viewsets.model import (
ModelPermissionPolicy,
@@ -7,8 +5,8 @@
ModelViewSetGroup,
)
-from auth_content.models.permission_sets import PermissionSet
-from auth_content.models.users import User
+from cms.auth_content.models.permission_sets import PermissionSet
+from cms.auth_content.models.users import User
class NoEditPermissionPolicy(ModelPermissionPolicy):
@@ -48,8 +46,3 @@ class AuthGroup(ModelViewSetGroup):
@hooks.register("register_admin_viewset")
def register_auth_viewset():
return AuthGroup()
-
-
-@hooks.register("insert_editor_js")
-def permission_set_js():
- return format_html('', static("js/permission_set.js"))
diff --git a/cms/dashboard/constants.py b/cms/dashboard/constants.py
new file mode 100644
index 000000000..6eed766d7
--- /dev/null
+++ b/cms/dashboard/constants.py
@@ -0,0 +1,35 @@
+from cms.metrics_interface.field_choices_callables import (
+ get_all_metric_names_and_ids,
+ get_all_theme_names_and_ids,
+)
+
+THEME_FIELDS = [
+ {
+ "field_name": "theme",
+ "field_label": "Theme",
+ "field_choice_default": "----------",
+ "field_choice_wildcard": None,
+ "field_choice_callable": get_all_theme_names_and_ids,
+ },
+ {
+ "field_name": "sub_theme",
+ "field_label": "Sub Theme",
+ "field_choice_default": "Select theme first",
+ "field_choice_wildcard": None,
+ "field_choice_callable": None,
+ },
+ {
+ "field_name": "topic",
+ "field_label": "Topic",
+ "field_choice_default": "Select sub-theme first",
+ "field_choice_wildcard": None,
+ "field_choice_callable": None,
+ },
+ {
+ "field_name": "metric",
+ "field_label": "Metric",
+ "field_choice_default": "Select topic first",
+ "field_choice_wildcard": None,
+ "field_choice_callable": get_all_metric_names_and_ids,
+ },
+]
diff --git a/cms/dashboard/static/js/classification_toggle.js b/cms/dashboard/static/js/classification_toggle.js
deleted file mode 100644
index 909e4a3af..000000000
--- a/cms/dashboard/static/js/classification_toggle.js
+++ /dev/null
@@ -1,28 +0,0 @@
-;(function () {
- function toggleClassification() {
- /*
- When the is_public box is checked, this will clear any selected page_classification,
- and disable the field. If the is_public box is then unchecked, it will re-enable the field
- */
- const isPublicCheckbox = document.querySelector('input[name="is_public"]')
- const classificationField = document.querySelector(
- 'select[name="page_classification"]',
- )
-
- if (!isPublicCheckbox || !classificationField) return
-
- if (isPublicCheckbox.checked) {
- classificationField.value = ""
- classificationField.disabled = true
- } else {
- classificationField.disabled = false
- }
- }
-
- document.addEventListener("DOMContentLoaded", toggleClassification)
- document.addEventListener("change", function (e) {
- if (e.target.name === "is_public") {
- toggleClassification()
- }
- })
-})()
diff --git a/cms/dashboard/static/js/toggle_available_fields_on_is_public.js b/cms/dashboard/static/js/toggle_available_fields_on_is_public.js
new file mode 100644
index 000000000..8ad1f004e
--- /dev/null
+++ b/cms/dashboard/static/js/toggle_available_fields_on_is_public.js
@@ -0,0 +1,276 @@
+;(function () {
+ "use strict"
+ let theme, subTheme, topic, metric, isPublicCheckbox;
+ let originalMetricOptions;
+
+ function toggleAvailableFields() {
+ /*
+ When the is_public box is checked, this will clear any selected page_classification,
+ and disable the field. If the is_public box is then unchecked, it will re-enable the field
+ */
+
+ const fields = {
+ classification: document.querySelector(
+ 'select[name="page_classification"]',
+ ),
+ theme: theme,
+ subTheme: subTheme,
+ topic: topic,
+ // metric: metric,
+ }
+
+ if (isPublicCheckbox.checked) {
+ Object.values(fields).forEach(disableField)
+ clearDropdown(fields.subTheme, "Select theme first")
+ clearDropdown(fields.topic, "Select sub-theme first")
+ // clearDropdown(fields.metric, "Select topic first")
+ restoreMetricOptions()
+ fields.theme.value = ""
+ } else {
+ if (!theme.value && !subTheme.value && !topic.value) {
+ clearDropdown(metric, "Select topic first")
+ }
+ Object.values(fields).forEach(enableField)
+ fields.classification.value="official_sensitive"
+ }
+ }
+
+ function restoreMetricOptions() {
+ clearDropdown(metric, "----------")
+ originalMetricOptions.forEach(option => {
+ if (option.text !== "Select topic first") {
+ metric.appendChild(option.cloneNode(true));
+ }
+ });
+ }
+
+ function disableField(field) {
+ field.disabled = true
+ }
+
+ function enableField(field) {
+ field.disabled = false
+ }
+
+ /**
+ * Generic function to fetch choices from the API
+ * @param {string} endpoint - The API endpoint (e.g., 'subthemes', 'topics')
+ * @param {string} dataItemId - The ID value to pass
+ * @returns {Promise} Array of choices [[id, name], ...]
+ */
+ async function fetchChoices(endpoint, dataItemId) {
+ try {
+ const url = `/api/data-hierarchy/${endpoint}/${dataItemId}`
+ const response = await fetch(url)
+
+ if (!response.ok) {
+ const errorData = await response.json()
+ console.error(`API error: ${errorData.error || "Unknown error"}`)
+ return []
+ }
+
+ const data = await response.json()
+ return data.choices || []
+ } catch (error) {
+ console.error(`Error fetching ${endpoint}:`, error)
+ return []
+ }
+ }
+
+ /**
+ * Generic function to populate a dropdown with choices
+ * @param {HTMLSelectElement} dropdown - The select element to populate
+ * @param {Array} choices - Array of [id, name] tuples
+ */
+ function populateDropdown(dropdown, choices, metrics = null) {
+ const currentValue = dropdown.value
+ dropdown.disabled = false
+ dropdown.innerHTML = ""
+
+ //dropdown empty
+ const nullOption = document.createElement("option")
+ nullOption.value = ""
+ nullOption.textContent = "--------"
+ dropdown.appendChild(nullOption)
+
+ choices.forEach(([id, name]) => {
+ const option = document.createElement("option")
+ option.value = id
+ option.textContent = name
+ dropdown.appendChild(option)
+ })
+
+ if (currentValue) {
+ dropdown.value = currentValue
+ }
+ }
+
+ function clearDropdown(dropdown, message = "Select parent first") {
+ dropdown.innerHTML = ""
+
+ const option = document.createElement("option")
+ option.value = ""
+ option.textContent = message
+ dropdown.appendChild(option)
+
+ dropdown.value = ""
+ }
+
+ /**
+ * Handle theme selection change
+ */
+ async function handleThemeChange() {
+ const themeValue = theme.value
+
+ // Clear all dependent dropdowns
+ if (!themeValue || themeValue === "") {
+ clearDropdown(subTheme, "Select theme first")
+ clearDropdown(topic, "Select sub-theme first")
+ clearDropdown(metric, "Select topic first");
+ return
+ }
+
+ clearDropdown(subTheme, "Select theme")
+ clearDropdown(topic, "Select sub-theme")
+ clearDropdown(metric, "Select topic first");
+
+ // Fetch and populate sub-themes
+ const choices = await fetchChoices("subthemes", themeValue)
+
+ if (choices.length > 0) {
+ populateDropdown(subTheme, choices)
+ } else {
+ clearDropdown(subTheme, "No sub-themes available")
+ }
+ }
+
+ /**
+ * Handle sub-theme selection change
+ */
+ async function handleSubThemeChange() {
+ const subThemeValue = subTheme.value
+
+ if (!subThemeValue || subThemeValue === "") {
+ // No sub-theme selected - clear children
+ clearDropdown(topic, "Select sub-theme first")
+ return
+ }
+
+ // Clear dependent dropdowns
+ clearDropdown(topic, "Select sub-theme")
+ clearDropdown(metric, "Select topic first");
+
+ // Fetch and populate topics
+ const choices = await fetchChoices("topics", subThemeValue)
+
+ if (choices.length > 0) {
+ populateDropdown(topic, choices)
+ } else {
+ clearDropdown(topic, "No topics available")
+ }
+ }
+
+ /**
+ * Handle topic selection change
+ */
+ async function handleTopicChange() {
+ const topicValue = topic.value;
+
+ if (!topicValue || topicValue === "") {
+ // No topic selected - clear metrics
+ clearDropdown(metric, "Select topic first");
+ return;
+ }
+
+ clearDropdown(metric, "--------");
+
+ // Fetch and populate metrics
+ const choices = await fetchChoices("metrics", topicValue);
+
+ if (choices.length > 0) {
+ populateDropdown(metric, choices, "* All metrics");
+ } else {
+ clearDropdown(metric, "No metrics available");
+ }
+ }
+
+ /**
+ * Initialize dropdowns for edit mode
+ * Loads the dropdown options based on saved values
+ */
+ async function initializeEditMode() {
+ // Store original values before we start manipulating dropdowns
+ const savedTheme = theme.value
+ const savedSubTheme = subTheme.value
+ const savedTopic = topic.value
+ const savedMetric = metric ? metric.value : undefined
+
+ // If theme has a value (not empty), load sub-themes
+ if (savedTheme && savedTheme !== "") {
+ const subThemeChoices = await fetchChoices("subthemes", savedTheme)
+ if (subThemeChoices.length > 0) {
+ populateDropdown(subTheme, subThemeChoices)
+ subTheme.value = savedSubTheme // Restore selection
+ }
+
+ // If sub-theme has a value, load topics
+ if (savedSubTheme && savedSubTheme !== "") {
+ const topicChoices = await fetchChoices("topics", savedSubTheme)
+ if (topicChoices.length > 0) {
+ populateDropdown(topic, topicChoices)
+ topic.value = savedTopic // Restore selection
+ }
+
+ if (savedTopic && savedTopic !== "") {
+ const metricChoices = await fetchChoices("metrics", savedTopic)
+ if (metricChoices.length > 0) {
+ populateDropdown(metric, metricChoices)
+ metric.value = savedMetric // Restore selection
+ }
+ }
+ }
+ }
+ }
+
+ function initialize() {
+ // Get dropdown elements
+
+ isPublicCheckbox = document.querySelector('input[name="is_public"]')
+ theme = document.querySelector('select[name="theme"]')
+ subTheme = document.querySelector('select[name="sub_theme"]')
+ topic = document.querySelector('select[name="topic"]')
+ metric = document.querySelector('select[name="metric"]')
+ // Take a copy of all available metrics so they can be restored if this becomes a public page
+ originalMetricOptions = Array.from(metric.options).map(option => option.cloneNode(true));
+
+ // Exit if not on page with themes and is_public toggle
+ if (!theme || !subTheme || !topic || !isPublicCheckbox) {
+ console.error("No theme dropdowns found on this page")
+ return
+ }
+
+ toggleAvailableFields()
+
+ // Add event listeners
+ theme.addEventListener("change", handleThemeChange)
+ subTheme.addEventListener("change", handleSubThemeChange)
+ topic.addEventListener("change", handleTopicChange);
+
+ const isEditMode = theme.value || subTheme.value || topic.value || metric.value
+
+ if (isEditMode) {
+ initializeEditMode()
+ } else {
+ clearDropdown(subTheme, "Select theme first")
+ clearDropdown(topic, "Select sub-theme first")
+ clearDropdown(metric, "Select topic first");
+ }
+ }
+
+ document.addEventListener("DOMContentLoaded", initialize)
+ document.addEventListener("change", function (e) {
+ if (e.target.name === "is_public") {
+ toggleAvailableFields()
+ }
+ })
+})()
diff --git a/cms/dashboard/viewsets.py b/cms/dashboard/viewsets.py
index 949c50d18..89b2c1ddc 100644
--- a/cms/dashboard/viewsets.py
+++ b/cms/dashboard/viewsets.py
@@ -1,3 +1,6 @@
+from itertools import chain
+
+from django.db.models import Exists, OuterRef, Q
from django.urls import path
from django.urls.resolvers import RoutePattern
from drf_spectacular.utils import extend_schema
@@ -6,11 +9,15 @@
from wagtail.api.v2.views import PagesAPIViewSet
from caching.private_api.decorators import cache_response
+from cms.auth_content.auth_utils import check_permissions
from cms.dashboard.serializers import CMSDraftPagesSerializer, ListablePageSerializer
+from cms.metrics_documentation.models.child import MetricsDocumentationChildEntry
+from cms.topic.models import TopicPage
@extend_schema(tags=["cms"])
class CMSPagesAPIViewSet(PagesAPIViewSet):
+ # This is the /pages (or proxy/pages env dependent endpoint)
permission_classes = []
base_serializer_class = ListablePageSerializer
listing_default_fields = PagesAPIViewSet.listing_default_fields + ["show_in_menus"]
@@ -39,7 +46,69 @@ def get_queryset(self):
"""
queryset = super().get_queryset()
- return queryset.specific()
+
+ req = self.request
+
+ if req.auth is None:
+ filtered_queryset = queryset.annotate(
+ is_public_topic_page=Exists(
+ TopicPage.objects.filter(
+ page_ptr_id=OuterRef("pk"),
+ is_public=True,
+ )
+ ),
+ is_public_metrics_doc_child_page=Exists(
+ MetricsDocumentationChildEntry.objects.filter(
+ page_ptr_id=OuterRef("pk"),
+ is_public=True,
+ )
+ ),
+ ).filter(
+ Q(is_public_topic_page=True)
+ | Q(is_public_metrics_doc_child_page=True)
+ | ~Q(
+ content_type__model__in=[
+ "topicpage",
+ "metricsdocumentationchildentry",
+ ]
+ )
+ )
+
+ else:
+ has_global_access = req.user.permission_sets["summary"]["has_global_access"]
+
+ if has_global_access:
+ filtered_queryset = queryset
+
+ else:
+ user_permissions = req.user.permission_sets
+ pages_to_check = chain(
+ ((page.id, page.topicpage) for page in queryset.type(TopicPage)),
+ (
+ (page.id, page.metricsdocumentationchildentry)
+ for page in queryset.type(MetricsDocumentationChildEntry)
+ ),
+ )
+ allowed_page_ids = [
+ page_id
+ for page_id, page in pages_to_check
+ if page.is_public
+ or check_permissions(
+ user_permissions,
+ page.theme,
+ page.sub_theme,
+ page.topic,
+ )
+ ]
+
+ public_pages = queryset.not_type(
+ TopicPage, MetricsDocumentationChildEntry
+ )
+ permitted_private_pages = queryset.filter(id__in=allowed_page_ids)
+
+ filtered_queryset = public_pages | permitted_private_pages
+
+ return filtered_queryset.specific()
@cache_response()
def listing_view(self, request: Request) -> Response:
diff --git a/cms/dynamic_content/help_texts.py b/cms/dynamic_content/help_texts.py
index 389f8ac2c..6d959eefb 100644
--- a/cms/dynamic_content/help_texts.py
+++ b/cms/dynamic_content/help_texts.py
@@ -626,6 +626,10 @@
The classification level of all data on this page (only applies to non-public pages). Defaults to `Official-Sensitive`.
"""
+NON_PUBLIC_PAGE_REQUIRED: str = """
+This field is required for a non-public page.
+"""
+
SECTION_FOOTER_BLOCKS: str = """
This is an optional footer for content sections to allow additional supporting information to be linked too. (E.g. a link to furhter information about how we define an outbreak)
"""
@@ -641,3 +645,6 @@
SECTION_FOOTER_LINK: str = """
This is a link component that allows the user to setup an internal or external link along with a short description of the link's content.
"""
+PERMISSION_SET_DISPLAY_NAME: str = """
+This is an (optional) user readable name for the permission set. If not set, a default autogenerated name will be used.
+"""
diff --git a/cms/metrics_documentation/data_migration/child_entries.py b/cms/metrics_documentation/data_migration/child_entries.py
index 8aa39c00f..3d08bc33a 100644
--- a/cms/metrics_documentation/data_migration/child_entries.py
+++ b/cms/metrics_documentation/data_migration/child_entries.py
@@ -37,10 +37,14 @@ def build_entry_from_row_data(*, row: tuple[str, ...]) -> dict[str, str | list[d
"""
title: str = row[0]
page_description: str = row[4]
- metric: str = row[1]
+ metric = row[7]
+ topic = row[1].split("_")[0]
sections: list[tuple[str, str]] = gather_sections_and_omit_if_needed(row=row)
return {
"title": title,
+ "topic": topic,
+ "theme": "test",
+ "sub_theme": "test",
"seo_title": f"{title} | UKHSA data dashboard",
"search_description": page_description,
"page_description": page_description,
diff --git a/cms/metrics_documentation/data_migration/source_data/metrics_definitions_migration_edit.xlsx b/cms/metrics_documentation/data_migration/source_data/metrics_definitions_migration_edit.xlsx
index 3efd453c9..6d71013d1 100644
Binary files a/cms/metrics_documentation/data_migration/source_data/metrics_definitions_migration_edit.xlsx and b/cms/metrics_documentation/data_migration/source_data/metrics_definitions_migration_edit.xlsx differ
diff --git a/cms/metrics_documentation/migrations/0016_metricsdocumentationchildentry_sub_theme_and_more.py b/cms/metrics_documentation/migrations/0016_metricsdocumentationchildentry_sub_theme_and_more.py
new file mode 100644
index 000000000..7c663a527
--- /dev/null
+++ b/cms/metrics_documentation/migrations/0016_metricsdocumentationchildentry_sub_theme_and_more.py
@@ -0,0 +1,26 @@
+# Generated by Django 5.2.13 on 2026-04-29 10:23
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ (
+ "metrics_documentation",
+ "0015_alter_metricsdocumentationchildentry_page_classification",
+ ),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name="metricsdocumentationchildentry",
+ name="sub_theme",
+ field=models.CharField(blank=True, default="", max_length=255, null=True),
+ ),
+ migrations.AddField(
+ model_name="metricsdocumentationchildentry",
+ name="theme",
+ field=models.CharField(blank=True, default="", max_length=255, null=True),
+ ),
+ ]
diff --git a/cms/metrics_documentation/migrations/0017_alter_metricsdocumentationchildentry_topic.py b/cms/metrics_documentation/migrations/0017_alter_metricsdocumentationchildentry_topic.py
new file mode 100644
index 000000000..5d23bdbd8
--- /dev/null
+++ b/cms/metrics_documentation/migrations/0017_alter_metricsdocumentationchildentry_topic.py
@@ -0,0 +1,21 @@
+# Generated by Django 5.2.13 on 2026-05-05 09:30
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ (
+ "metrics_documentation",
+ "0016_metricsdocumentationchildentry_sub_theme_and_more",
+ ),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name="metricsdocumentationchildentry",
+ name="topic",
+ field=models.CharField(blank=True, default="", max_length=255, null=True),
+ ),
+ ]
diff --git a/cms/metrics_documentation/models/child.py b/cms/metrics_documentation/models/child.py
index 9dada0e7d..618eff3fd 100644
--- a/cms/metrics_documentation/models/child.py
+++ b/cms/metrics_documentation/models/child.py
@@ -12,13 +12,14 @@
from wagtail.api import APIField
from wagtail.search import index
+from cms.auth_content.auth_utils import _create_form_field
+from cms.dashboard.constants import THEME_FIELDS
from cms.dashboard.models import DataClassificationLevels, UKHSAPage
from cms.dynamic_content import help_texts
from cms.dynamic_content.access import ALLOWABLE_BODY_CONTENT_TEXT_SECTION
from cms.dynamic_content.announcements import Announcement
from cms.metrics_interface.field_choices_callables import (
- get_a_list_of_all_topic_names,
- get_all_unique_metric_names,
+ get_all_metric_names_and_ids,
)
logger = logging.getLogger(__name__)
@@ -31,8 +32,35 @@ def __init__(self, topic: str, metric: str):
class MetricsDocumentationChildEntryAdminForm(WagtailAdminPageForm):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ for field in THEME_FIELDS:
+ self.fields[field["field_name"]] = _create_form_field(field)
+
+ if self.instance and self.instance.pk:
+ self._initialize_dependent_fields()
+
+ def _initialize_dependent_fields(self):
+ """Initialize choices for cascading dependent fields"""
+ dependent_fields = {
+ "sub_theme": ("Select theme first"),
+ "topic": ("Select sub-theme first"),
+ }
+
+ for field_name, (placeholder) in dependent_fields.items():
+ value = getattr(self.instance, field_name, None)
+ if value:
+ choices = self._get_field_choices(value, placeholder)
+ self.fields[field_name].widget.choices = choices
+
+ @staticmethod
+ def _get_field_choices(value, placeholder):
+ """Generate choices list based on field value"""
+ return [("", placeholder), (value, f"Loading... (ID: {value})")]
+
class Media:
- js = ["js/classification_toggle.js"]
+ js = ["js/toggle_available_fields_on_is_public.js"]
class MetricsDocumentationChildEntry(UKHSAPage):
@@ -51,24 +79,36 @@ class MetricsDocumentationChildEntry(UKHSAPage):
null=True,
blank=True,
)
- topic = models.CharField(
+
+ theme = models.CharField(
max_length=255,
+ blank=True,
default="",
+ null=True,
)
+ sub_theme = models.CharField(
+ max_length=255,
+ blank=True,
+ default="",
+ null=True,
+ )
+ topic = models.CharField(max_length=255, blank=True, default="", null=True)
body = ALLOWABLE_BODY_CONTENT_TEXT_SECTION
# Fields to index for searching within the CMS application.
search_fields = UKHSAPage.search_fields + [
- index.SearchField("metric"),
index.SearchField("body"),
]
# Content panels to render for editing within the CMS application.
content_panels = UKHSAPage.content_panels + [
FieldPanel("page_description"),
- FieldPanel("metric"),
FieldPanel("is_public"),
FieldPanel("page_classification"),
+ FieldPanel("theme"),
+ FieldPanel("sub_theme"),
+ FieldPanel("topic"),
+ FieldPanel("metric"),
FieldPanel("body"),
]
@@ -113,42 +153,7 @@ def __init__(self, *args, **kwargs):
load in the names dynamically from the metrics interface.
"""
super().__init__(*args, **kwargs)
- self._meta.get_field("metric").choices = get_all_unique_metric_names()
-
- def find_topic(self, *, topics: list[str]) -> str:
- """Finds the required topic from a list of strings based on the metric name.
-
- Args:
- topics: list of strings representing topic names.
-
- Returns:
- A string of the topic checked against the models metric value.
-
- Raises:
- `InvalidTopicForChosenMetricForChildEntry`: If the
- selected metric cannot be matched to a `Topic`
-
- """
- extracted_topic = self.metric.split("_")[0].lower()
- try:
- return next(topic for topic in topics if extracted_topic == topic.lower())
- except StopIteration as error:
- logger.info(
- "StopIteration Error: extracted topic not present in the topics list. %s",
- extracted_topic,
- )
- raise InvalidTopicForChosenMetricForChildEntryError(
- topic=extracted_topic, metric=self.metric
- ) from error
-
- def get_topic(self) -> str:
- """Finds the required topic name based on the selected metric name.
-
- Returns:
- a topic name as a string
- """
- topics = get_a_list_of_all_topic_names()
- return self.find_topic(topics=topics)
+ self._meta.get_field("metric").choices = get_all_metric_names_and_ids()
def save(self, *args, **kwargs):
"""Retrieves a topic based on the selected metric
@@ -156,12 +161,22 @@ def save(self, *args, **kwargs):
Notes:
This method will not be called when using `bulk_create()`
"""
- self.topic = self.get_topic()
super().save(*args, **kwargs)
@property
def metric_group(self) -> str:
- return self.metric.split("_")[1]
+ field = self._meta.get_field("metric")
+ choices = getattr(field, "choices", []) or []
+
+ display_name = next(
+ (item[1] for item in choices if item[0] == self.metric), None
+ )
+
+ if not display_name or "_" not in display_name:
+ return ""
+
+ parts = display_name.split("_")
+ return parts[1] if len(parts) > 1 else ""
def clean(self):
super().clean()
@@ -169,13 +184,29 @@ def clean(self):
# If is_public is true, automatically clear classification
if self.is_public:
self.page_classification = None
- # If not public page, classification must be chosen
+ self.theme = None
+ self.sub_theme = None
+ self.topic = None
+
+ # If not public page, non-public fields must be set
elif not self.page_classification:
raise ValidationError(
{
"page_classification": "Please select a classification level for this non-public page"
}
)
+ elif not self.theme:
+ raise ValidationError(
+ {"theme": "Please select a theme for this non-public page"}
+ )
+ elif not self.sub_theme:
+ raise ValidationError(
+ {"sub_theme": "Please select a subtheme for this non-public page"}
+ )
+ elif not self.topic:
+ raise ValidationError(
+ {"topic": "Please select a topic for this non-public page"}
+ )
class MetricsDocumentationChildPageAnnouncement(Announcement):
diff --git a/cms/topic/migrations/0032_topicpage_sub_theme_topicpage_theme_topicpage_topic.py b/cms/topic/migrations/0032_topicpage_sub_theme_topicpage_theme_topicpage_topic.py
new file mode 100644
index 000000000..ff9b0ef09
--- /dev/null
+++ b/cms/topic/migrations/0032_topicpage_sub_theme_topicpage_theme_topicpage_topic.py
@@ -0,0 +1,46 @@
+# Generated by Django 5.2.13 on 2026-04-29 10:23
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ("topic", "0031_alter_topicpage_page_classification"),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name="topicpage",
+ name="sub_theme",
+ field=models.CharField(
+ blank=True,
+ default="",
+ help_text="\nThe subtheme must be provided for a non-public page.\n",
+ max_length=255,
+ null=True,
+ ),
+ ),
+ migrations.AddField(
+ model_name="topicpage",
+ name="theme",
+ field=models.CharField(
+ blank=True,
+ default="",
+ help_text="\nThe theme must be provided for a non-public page.\n",
+ max_length=255,
+ null=True,
+ ),
+ ),
+ migrations.AddField(
+ model_name="topicpage",
+ name="topic",
+ field=models.CharField(
+ blank=True,
+ default="",
+ help_text="\nThe topic must be provided for a non-public page.\n",
+ max_length=255,
+ null=True,
+ ),
+ ),
+ ]
diff --git a/cms/topic/migrations/0033_alter_topicpage_sub_theme_alter_topicpage_theme_and_more.py b/cms/topic/migrations/0033_alter_topicpage_sub_theme_alter_topicpage_theme_and_more.py
new file mode 100644
index 000000000..bf95c455e
--- /dev/null
+++ b/cms/topic/migrations/0033_alter_topicpage_sub_theme_alter_topicpage_theme_and_more.py
@@ -0,0 +1,28 @@
+# Generated by Django 5.2.13 on 2026-04-29 10:54
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ("topic", "0032_topicpage_sub_theme_topicpage_theme_topicpage_topic"),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name="topicpage",
+ name="sub_theme",
+ field=models.CharField(blank=True, default="", max_length=255, null=True),
+ ),
+ migrations.AlterField(
+ model_name="topicpage",
+ name="theme",
+ field=models.CharField(blank=True, default="", max_length=255, null=True),
+ ),
+ migrations.AlterField(
+ model_name="topicpage",
+ name="topic",
+ field=models.CharField(blank=True, default="", max_length=255, null=True),
+ ),
+ ]
diff --git a/cms/topic/models.py b/cms/topic/models.py
index 9e687cc38..f39ecd3d9 100644
--- a/cms/topic/models.py
+++ b/cms/topic/models.py
@@ -14,6 +14,8 @@
from wagtail.fields import RichTextField
from wagtail.search import index
+from cms.auth_content.auth_utils import _create_form_field
+from cms.dashboard.constants import THEME_FIELDS
from cms.dashboard.enums import (
DEFAULT_RELATED_LINKS_LAYOUT_FIELD_LENGTH,
RelatedLinksLayoutEnum,
@@ -36,8 +38,35 @@
class TopicPageAdminForm(WagtailAdminPageForm):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ for field in THEME_FIELDS:
+ self.fields[field["field_name"]] = _create_form_field(field)
+
+ if self.instance and self.instance.pk:
+ self._initialize_dependent_fields()
+
+ def _initialize_dependent_fields(self):
+ """Initialize choices for cascading dependent fields"""
+ dependent_fields = {
+ "sub_theme": ("Select theme first"),
+ "topic": ("Select sub-theme first"),
+ }
+
+ for field_name, (placeholder) in dependent_fields.items():
+ value = getattr(self.instance, field_name, None)
+ if value:
+ choices = self._get_field_choices(value, placeholder)
+ self.fields[field_name].widget.choices = choices
+
+ @staticmethod
+ def _get_field_choices(value, placeholder):
+ """Generate choices list based on field value"""
+ return [("", placeholder), (value, f"Loading... (ID: {value})")]
+
class Media:
- js = ["js/classification_toggle.js"]
+ js = ["js/toggle_available_fields_on_is_public.js"]
class TopicPage(UKHSAPage):
@@ -66,6 +95,10 @@ class TopicPage(UKHSAPage):
null=True,
)
+ theme = models.CharField(max_length=255, blank=True, default="", null=True)
+ sub_theme = models.CharField(max_length=255, blank=True, default="", null=True)
+ topic = models.CharField(max_length=255, blank=True, default="", null=True)
+
related_links_layout = models.CharField(
verbose_name="Layout",
help_text=help_texts.RELATED_LINKS_LAYOUT_FIELD,
@@ -87,6 +120,9 @@ class TopicPage(UKHSAPage):
FieldPanel("enable_area_selector"),
FieldPanel("is_public"),
FieldPanel("page_classification"),
+ FieldPanel("theme"),
+ FieldPanel("sub_theme"),
+ FieldPanel("topic"),
FieldPanel("page_description"),
FieldPanel("body"),
]
@@ -230,16 +266,32 @@ def last_updated_at(self) -> datetime.datetime:
def clean(self):
super().clean()
- # If is_public is true, automatically clear classification
+ # If is_public is true, automatically clear non-public fields
if self.is_public:
self.page_classification = None
- # If not public page, classification must be chosen
+ self.theme = None
+ self.sub_theme = None
+ self.topic = None
+
+ # If not public page, non-public fields must be set
elif not self.page_classification:
raise ValidationError(
{
"page_classification": "Please select a classification level for this non-public page"
}
)
+ elif not self.theme:
+ raise ValidationError(
+ {"theme": "Please select a theme for this non-public page"}
+ )
+ elif not self.sub_theme:
+ raise ValidationError(
+ {"sub_theme": "Please select a sub theme for this non-public page"}
+ )
+ elif not self.topic:
+ raise ValidationError(
+ {"topic": "Please select a topic for this non-public page"}
+ )
class TopicPageRelatedLink(UKHSAPageRelatedLink):
diff --git a/common/auth/cognito_jwt/user_manager.py b/common/auth/cognito_jwt/user_manager.py
index 3a69470a4..a101168e0 100644
--- a/common/auth/cognito_jwt/user_manager.py
+++ b/common/auth/cognito_jwt/user_manager.py
@@ -17,6 +17,23 @@ def get_or_create_for_cognito(jwt_payload):
try:
username = jwt_payload["entraObjectId"]
permission_sets = jwt_payload["permissionSets"]
+
+ # DEBUGGING: Manual testing (just for now)
+ # username = "{YOUR_ENTRA_OBJECT_ID}"
+ # permission_sets = {
+ # "permission_sets": [
+ # {
+ # "theme": {"id": "100", "name": "immunisation"},
+ # "sub_theme": {"id": "200", "name": "childhood-vaccines"},
+ # "topic": {"id": "-1", "name": "* (All)"},
+ # "metric": {"id": "-1", "name": "* (All)"},
+ # "geography_type": {"id": "300", "name": "Nation"},
+ # "geography": {"id": "400", "name": "England"},
+ # }
+ # ],
+ # "summary": {"has_global_access": False},
+ # }
+
if not permission_sets:
logger.debug(
"Empty permissionSets in token for user: '%s'",
diff --git a/tests/unit/auth_content/__init__.py b/metrics/api/middleware/__init__.py
similarity index 100%
rename from tests/unit/auth_content/__init__.py
rename to metrics/api/middleware/__init__.py
diff --git a/metrics/api/middleware/sql_debug.py b/metrics/api/middleware/sql_debug.py
new file mode 100644
index 000000000..dc00d0762
--- /dev/null
+++ b/metrics/api/middleware/sql_debug.py
@@ -0,0 +1,24 @@
+from django.db import connection
+
+
+def _print_sql(execute, sql, params, many, context):
+ print(f"\n[SQL] {sql}")
+ if params:
+ print(f"[PARAMS] {params}")
+ return execute(sql, params, many, context)
+
+
+class SQLDebugMiddleware:
+ """
+ Middleware that prints the raw SQL and params for every DB query made
+ during a request/response cycle.
+
+ Only intended for local development — add to MIDDLEWARE in local.py.
+ """
+
+ def __init__(self, get_response):
+ self.get_response = get_response
+
+ def __call__(self, request):
+ with connection.execute_wrapper(_print_sql):
+ return self.get_response(request)
diff --git a/metrics/api/serializers/geographies.py b/metrics/api/serializers/geographies.py
index f19bdf6a3..d2c07d8d3 100644
--- a/metrics/api/serializers/geographies.py
+++ b/metrics/api/serializers/geographies.py
@@ -3,12 +3,12 @@
from django.db.models import QuerySet
from rest_framework import serializers
-from auth_content.constants import WILDCARD_ID_VALUE
from metrics.api.serializers import help_texts
from metrics.data.in_memory_models.geography_relationships.handlers import (
get_upstream_relationships_for_geography,
)
from metrics.data.managers.core_models.time_series import CoreTimeSeriesQuerySet
+from metrics.data.models.constants import PERMISSION_SET_WILDCARD_ID_VALUE
from metrics.data.models.core_models import (
CoreTimeSeries,
Geography,
@@ -234,7 +234,7 @@ def geography_manager(self):
@staticmethod
def validate_geography_type_id(value: str) -> str | int:
"""Validate geography_type_id is either wildcard or a valid integer"""
- if value == WILDCARD_ID_VALUE:
+ if value == PERMISSION_SET_WILDCARD_ID_VALUE:
return value
try:
@@ -253,8 +253,10 @@ def data(self) -> dict[str, list[list[str, str]]]:
geography_type_id = self.validated_data["geography_type_id"]
# Handle wildcard
- if geography_type_id == WILDCARD_ID_VALUE:
- return {"choices": [[WILDCARD_ID_VALUE, "* (All geographies)"]]}
+ if geography_type_id == PERMISSION_SET_WILDCARD_ID_VALUE:
+ return {
+ "choices": [[PERMISSION_SET_WILDCARD_ID_VALUE, "* (All geographies)"]]
+ }
parent_geography_type_id = int(geography_type_id)
geographies = (
diff --git a/metrics/api/serializers/permission_sets.py b/metrics/api/serializers/permission_sets.py
index 03c7e4413..f75cb6923 100644
--- a/metrics/api/serializers/permission_sets.py
+++ b/metrics/api/serializers/permission_sets.py
@@ -1,13 +1,13 @@
from django.db.models import QuerySet
from rest_framework import serializers
-from auth_content.constants import WILDCARD_ID_VALUE
+from metrics.data.models.constants import PERMISSION_SET_WILDCARD_ID_VALUE
from metrics.data.models.core_models.supporting import Metric, SubTheme, Topic
def _validate_input_id(value, field_name):
"""Validate theme_id is either wildcard or a valid integer"""
- if value == WILDCARD_ID_VALUE:
+ if value == PERMISSION_SET_WILDCARD_ID_VALUE:
return value
try:
@@ -46,8 +46,10 @@ def data(self) -> dict:
"""
theme_id = self.validated_data["theme_id"]
- if theme_id == WILDCARD_ID_VALUE:
- return {"choices": [[WILDCARD_ID_VALUE, "* (All sub-themes)"]]}
+ if theme_id == PERMISSION_SET_WILDCARD_ID_VALUE:
+ return {
+ "choices": [[PERMISSION_SET_WILDCARD_ID_VALUE, "* (All sub-themes)"]]
+ }
parent_theme_id = int(theme_id)
sub_theme_tuples = _queryset_to_id_name_tuples(
@@ -87,8 +89,8 @@ def data(self) -> dict:
"""
sub_theme_id = self.validated_data["sub_theme_id"]
- if sub_theme_id == WILDCARD_ID_VALUE:
- return {"choices": [[WILDCARD_ID_VALUE, "* (All topics)"]]}
+ if sub_theme_id == PERMISSION_SET_WILDCARD_ID_VALUE:
+ return {"choices": [[PERMISSION_SET_WILDCARD_ID_VALUE, "* (All topics)"]]}
parent_sub_theme_id = int(sub_theme_id)
topic_tuples = _queryset_to_id_name_tuples(
@@ -129,8 +131,8 @@ def data(self) -> dict:
"""
topic_id = self.validated_data["topic_id"]
- if topic_id == WILDCARD_ID_VALUE:
- return {"choices": [[WILDCARD_ID_VALUE, "* (All metrics)"]]}
+ if topic_id == PERMISSION_SET_WILDCARD_ID_VALUE:
+ return {"choices": [[PERMISSION_SET_WILDCARD_ID_VALUE, "* (All metrics)"]]}
parent_topic_id = int(topic_id)
metric_tuples = _queryset_to_id_name_tuples(
diff --git a/metrics/api/serializers/user.py b/metrics/api/serializers/user.py
index 66263373a..24f0dbcff 100644
--- a/metrics/api/serializers/user.py
+++ b/metrics/api/serializers/user.py
@@ -3,7 +3,7 @@
from django.db.models import QuerySet
from rest_framework import serializers
-from auth_content.models.users import User
+from cms.auth_content.models.users import User
from metrics.utils.permission_grouping import group_by_geography_type, group_by_theme
from metrics.utils.permission_hierarchy import (
build_permission_hierarchy,
diff --git a/metrics/api/settings/default.py b/metrics/api/settings/default.py
index 78af6c97d..32dbdd2b3 100644
--- a/metrics/api/settings/default.py
+++ b/metrics/api/settings/default.py
@@ -52,6 +52,7 @@
"metrics.api",
"cms.acknowledgement",
"cms.home",
+ "cms.auth_content",
"cms.topic",
"cms.topics_list",
"cms.dashboard",
@@ -78,7 +79,6 @@
"wagtail_trash",
"modelcluster",
"taggit",
- "auth_content",
]
MIDDLEWARE = [
diff --git a/metrics/api/settings/local.py b/metrics/api/settings/local.py
index 8d0cf9636..2a3f10dfa 100644
--- a/metrics/api/settings/local.py
+++ b/metrics/api/settings/local.py
@@ -29,6 +29,7 @@
MIDDLEWARE += [
"debug_toolbar.middleware.DebugToolbarMiddleware",
+ "metrics.api.middleware.sql_debug.SQLDebugMiddleware",
]
INTERNAL_IPS = ["127.0.0.1"]
diff --git a/metrics/api/views/permission_sets.py b/metrics/api/views/permission_sets.py
index 7a76eaf8f..91f94712e 100644
--- a/metrics/api/views/permission_sets.py
+++ b/metrics/api/views/permission_sets.py
@@ -42,7 +42,7 @@ class TopicsBySubThemeView(APIView):
permission_classes = []
def get(self, request, sub_theme_id, *args, **kwargs): # noqa: PLR6301
- """API endpoint to fetch sub-themes based on selected theme."""
+ """API endpoint to fetch topics based on selected sub-theme."""
serializer = TopicRequestSerializer(data={"sub_theme_id": sub_theme_id})
serializer.is_valid(raise_exception=True)
return Response(serializer.data())
@@ -59,7 +59,7 @@ class MetricsByTopicView(APIView):
permission_classes = []
def get(self, request, topic_id, *args, **kwargs): # noqa: PLR6301
- """API endpoint to fetch sub-themes based on selected theme."""
+ """API endpoint to fetch metrics based on selected topic."""
serializer = MetricRequestSerializer(data={"topic_id": topic_id})
serializer.is_valid(raise_exception=True)
return Response(serializer.data())
diff --git a/metrics/data/managers/core_models/geography.py b/metrics/data/managers/core_models/geography.py
index 5ff2ce8cb..046721983 100644
--- a/metrics/data/managers/core_models/geography.py
+++ b/metrics/data/managers/core_models/geography.py
@@ -47,6 +47,19 @@ def get_name_by_code(self, geography_code: str) -> str | None:
.first()
)
+ def get_id_by_name(self, geography_name: str) -> int:
+ """
+ Gets the geography ID for a given geography name.
+
+ Args:
+ geography_name: The name of the geography to look up
+
+ Returns:
+ The geography ID if found, or -2 otherwise
+ """
+ record = self.filter(name=geography_name).first()
+ return int(record.id) if record else -2
+
def get_all_geography_codes_by_geography_type(
self, geography_type_name: str
) -> Self:
@@ -167,6 +180,23 @@ def get_name_by_code(self, geography_code: int) -> str | None:
"""
return self.get_queryset().get_name_by_code(geography_code)
+ def get_id_by_name(self, geography_name: str) -> int:
+ """Gets the geography ID which matches the given geography name.
+
+ Args:
+ geography_name: The name of the geography to look up
+
+ Returns:
+ The geography ID if found, -2 otherwise
+
+ Examples:
+ >>> GeographyManager.get_id_by_name("England")
+ 6
+ >>> GeographyManager.get_id_by_name("Unknown geography")
+ -2
+ """
+ return self.get_queryset().get_id_by_name(geography_name)
+
def get_all_names(self) -> GeographyQuerySet:
"""Gets all available deduplicated geography names as a flat list queryset.
diff --git a/metrics/data/managers/core_models/geography_type.py b/metrics/data/managers/core_models/geography_type.py
index ae45b36b7..df85ccd5b 100644
--- a/metrics/data/managers/core_models/geography_type.py
+++ b/metrics/data/managers/core_models/geography_type.py
@@ -41,6 +41,19 @@ def get_name_by_id(self, geography_type_id: int) -> str | None:
"""
return self.filter(id=geography_type_id).values_list("name", flat=True).first()
+ def get_id_by_name(self, geography_type_name: str) -> int:
+ """
+ Gets the geography type ID for a given geography type name.
+
+ Args:
+ geography_type_name: The name of the geography type to look up
+
+ Returns:
+ The geography type ID if found, or -2 otherwise
+ """
+ record = self.filter(name=geography_type_name).first()
+ return int(record.id) if record else -2
+
def get_all_names_and_ids(self) -> models.QuerySet:
"""Gets all available geography_type names as a flat list queryset.
@@ -77,6 +90,23 @@ def get_name_by_id(self, geography_type_id: int) -> str | None:
"""
return self.get_queryset().get_name_by_id(geography_type_id)
+ def get_id_by_name(self, geography_type_name: str) -> int:
+ """Gets the geography type ID which matches the given geography type name.
+
+ Args:
+ geography_type_name: The name of the geography type to look up
+
+ Returns:
+ The geography type ID if found, -2 otherwise
+
+ Examples:
+ >>> GeographyTypeManager.get_id_by_name("Nation")
+ 5
+ >>> GeographyTypeManager.get_id_by_name("Unknown type")
+ -2
+ """
+ return self.get_queryset().get_id_by_name(geography_type_name)
+
def get_all_names(self) -> GeographyTypeQuerySet:
"""Gets all available geography_type names as a flat list queryset.
diff --git a/metrics/data/managers/core_models/metric.py b/metrics/data/managers/core_models/metric.py
index e9fcb961d..66a3ec60c 100644
--- a/metrics/data/managers/core_models/metric.py
+++ b/metrics/data/managers/core_models/metric.py
@@ -29,6 +29,19 @@ def get_name_by_id(self, metric_id: int) -> str | None:
"""
return self.filter(id=metric_id).values_list("name", flat=True).first()
+ def get_id_by_name(self, metric_name: str) -> int:
+ """
+ Gets the metric ID for a given metric name.
+
+ Args:
+ metric_name: The name of the metric to look up
+
+ Returns:
+ The metric ID if found, or -2 otherwise
+ """
+ record = self.filter(name=metric_name).first()
+ return int(record.id) if record else -2
+
def get_all_names(self) -> models.QuerySet:
"""Gets all available metric names as a flat list queryset.
@@ -146,6 +159,23 @@ def get_name_by_id(self, metric_id: int) -> str | None:
"""
return self.get_queryset().get_name_by_id(metric_id)
+ def get_id_by_name(self, metric_name: str) -> int:
+ """Gets the metric ID which matches the given metric name.
+
+ Args:
+ metric_name: The name of the metric to look up
+
+ Returns:
+ The metric ID if found, -2 otherwise
+
+ Examples:
+ >>> MetricManager.get_id_by_name("COVID-19_cases_countRollingMean")
+ 4
+ >>> MetricManager.get_id_by_name("Unknown metric")
+ -2
+ """
+ return self.get_queryset().get_id_by_name(metric_name)
+
def get_all_names(self) -> MetricQuerySet:
"""Gets all available metric names as a flat list queryset.
diff --git a/metrics/data/managers/core_models/time_series.py b/metrics/data/managers/core_models/time_series.py
index aed8df7f6..5b63aaca2 100644
--- a/metrics/data/managers/core_models/time_series.py
+++ b/metrics/data/managers/core_models/time_series.py
@@ -6,6 +6,7 @@
"""
import datetime
+import logging
from collections.abc import Iterable
from typing import Self
@@ -15,12 +16,13 @@
from metrics.api.permissions.fluent_permissions import (
is_public_data_only_enforced,
- validate_permissions_for_non_public,
)
from metrics.data.models import RBACPermission
ALLOWABLE_METRIC_VALUE_RANGE_TYPE = tuple[str | float | int, str | float | int]
+logger = logging.getLogger(__name__)
+
class CoreTimeSeriesQuerySet(models.QuerySet):
"""Custom queryset which can be used by the `CoreTimeSeriesManager`"""
@@ -171,8 +173,10 @@ def query_for_data(
stratum: str | None = None,
sex: str | None = None,
age: str | None = None,
+ theme: str,
+ sub_theme: str,
metric_value_ranges: list[tuple[str | float | int]] | None = None,
- restrict_to_public: bool = True,
+ permission_sets: dict,
) -> models.QuerySet:
"""Filters for a N-item list of dicts by the given params if `fields_to_export` is used.
@@ -212,14 +216,18 @@ def query_for_data(
Note that options are `M`, `F`, or `ALL`.
age: The age range to apply additional filtering to.
E.g. `0_4` would be used to capture the age of 0-4 years old
+ theme: The name of the theme being queried.
+ This is only used to determine permissions for
+ the non-public portion of the requested dataset.
+ sub_theme: The name of the sub theme being queried.
+ This is only used to determine permissions for
+ the non-public portion of the requested dataset.
metric_value_ranges: List of tuples whereby each
tuple represents a permissible metric value range.
i.e. to filter for all record with values
between 0 -> 80 AND 90 -> 100,
this can be provided as `[(0, 80), (90, 100)]`.
- restrict_to_public: Boolean switch to restrict the query
- to only return public records.
- If False, then non-public records will be included.
+ permission_sets: The JWT permissions extracted from the Cognito token.
Returns:
QuerySet: An ordered queryset from lowest -> highest
@@ -231,6 +239,9 @@ def query_for_data(
]>`
"""
+
+ logger.info("Entered query_for_data()")
+
queryset = self.filter(
metric__topic__name=topic,
metric__name=metric,
@@ -245,9 +256,30 @@ def query_for_data(
sex=sex,
age=age,
)
+ public_queryset = queryset.filter(is_public=True)
- if restrict_to_public:
- queryset = queryset.filter(is_public=True)
+ if permission_sets:
+ logger.info("Entered if permission_sets clause")
+
+ # WORKAROUND: Cos circular import error when at the top of the file
+ from cms.auth_content.auth_utils import check_permissions_by_name
+
+ if check_permissions_by_name(
+ permission_sets,
+ theme,
+ sub_theme,
+ topic,
+ metric,
+ geography_type,
+ geography,
+ ):
+ logger.info("Entered check_permissions_by_name() if clause")
+
+ queryset = public_queryset + queryset.filter(is_public=False)
+ else:
+ logger.info("Entered else permission_sets clause")
+
+ queryset = public_queryset
queryset = self._exclude_data_under_embargo(queryset=queryset)
queryset = self._filter_for_metric_value_ranges(
@@ -533,6 +565,7 @@ def query_for_data(
sub_theme: str = "",
metric_value_ranges: list[str | float | int] | None = None,
rbac_permissions: Iterable[RBACPermission] | None = None,
+ permission_sets: dict,
) -> CoreTimeSeriesQuerySet:
"""Filters for a 2-item object by the given params. Slices all values older than the `date_from`.
@@ -582,6 +615,7 @@ def query_for_data(
rbac_permissions: The RBAC permissions available
to the given request. This dictates whether the given
request is permitted access to non-public data or not.
+ permission_sets: The JWT permissions extracted from the Cognito token.
Notes:
If we have the following input `queryset`:
@@ -611,20 +645,12 @@ def query_for_data(
]>`
"""
- rbac_permissions: Iterable[RBACPermission] = rbac_permissions or []
- has_access_to_non_public_data: bool = validate_permissions_for_non_public(
- theme=theme,
- sub_theme=sub_theme,
- topic=topic,
- metric=metric,
- geography_type=geography_type,
- geography=geography,
- rbac_permissions=rbac_permissions,
- )
return self.get_queryset().query_for_data(
fields_to_export=fields_to_export,
field_to_order_by=field_to_order_by,
+ theme=theme,
+ sub_theme=sub_theme,
topic=topic,
metric=metric,
date_from=date_from,
@@ -635,7 +661,7 @@ def query_for_data(
sex=sex,
age=age,
metric_value_ranges=metric_value_ranges,
- restrict_to_public=not has_access_to_non_public_data,
+ permission_sets=permission_sets,
)
def query_for_superseded_data(
diff --git a/metrics/data/managers/core_models/topic.py b/metrics/data/managers/core_models/topic.py
index 00b6f0632..cd67e9f1a 100644
--- a/metrics/data/managers/core_models/topic.py
+++ b/metrics/data/managers/core_models/topic.py
@@ -40,6 +40,40 @@ def get_name_by_id(self, topic_id: int) -> str | None:
"""
return self.filter(id=topic_id).values_list("name", flat=True).first()
+ def get_id_by_name(
+ self, theme_name: str, sub_theme_name: str, topic_name: str
+ ) -> tuple[int, int, int]:
+ """
+ Gets the theme, sub-theme and topic IDs matching the given names.
+
+ Args:
+ theme_name: The name of the parent theme
+ sub_theme_name: The name of the parent sub-theme
+ topic_name: The name of the topic to look up
+
+ Returns:
+ A tuple of (theme_id, sub_theme_id, topic_id) if found,
+ or the tuple (-2, -2, -2) otherwise
+ """
+ record = self.filter(
+ sub_theme__theme__name=theme_name,
+ sub_theme__name=sub_theme_name,
+ name=topic_name,
+ ).first()
+
+ if record:
+ return (
+ int(record.sub_theme.theme_id),
+ int(record.sub_theme_id),
+ int(record.id),
+ )
+
+ return (
+ -2,
+ -2,
+ -2,
+ )
+
def get_all_unique_names(self) -> models.QuerySet:
"""Gets all available unique topic names as a flat list queryset.
@@ -113,6 +147,30 @@ def get_name_by_id(self, topic_id: int) -> str | None:
"""
return self.get_queryset().get_name_by_id(topic_id)
+ def get_id_by_name(
+ self, theme_name: str, sub_theme_name: str, topic_name: str
+ ) -> tuple[int, int, int]:
+ """Gets the theme, sub-theme and topic IDs matching the given names.
+
+ Args:
+ theme_name: The name of the parent theme
+ sub_theme_name: The name of the parent sub-theme
+ topic_name: The name of the topic to look up
+
+ Returns:
+ A tuple of (theme_id, sub_theme_id, topic_id) if found,
+ or (-2, -2, -2) if not found.
+
+ Examples:
+ >>> TopicManager.get_id_by_name("Infectious disease", "Respiratory", "COVID-19")
+ (1, 2, 3)
+ >>> TopicManager.get_id_by_name("Unknown", "Unknown", "Unknown")
+ (-2, -2, -2)
+ """
+ return self.get_queryset().get_id_by_name(
+ theme_name, sub_theme_name, topic_name
+ )
+
def get_all_names(self) -> TopicQuerySet:
"""Gets all available topic names as a flat list queryset.
diff --git a/metrics/data/managers/rbac_models/user.py b/metrics/data/managers/rbac_models/user.py
index f4a3194d1..b2967685f 100644
--- a/metrics/data/managers/rbac_models/user.py
+++ b/metrics/data/managers/rbac_models/user.py
@@ -9,7 +9,7 @@
from django.db import models
-from auth_content.models.permission_sets import PermissionSet
+from cms.auth_content.models.permission_sets import PermissionSet
class UserQuerySet(models.QuerySet):
diff --git a/metrics/data/models/constants.py b/metrics/data/models/constants.py
index d5fcfba29..2388368e7 100644
--- a/metrics/data/models/constants.py
+++ b/metrics/data/models/constants.py
@@ -5,3 +5,4 @@
METRIC_FREQUENCY_MAX_CHAR_CONSTRAINT: int = 1
METRIC_VALUE_MAX_DIGITS: int = 11
METRIC_VALUE_DECIMAL_PLACES: int = 4
+PERMISSION_SET_WILDCARD_ID_VALUE: str = "-1"
diff --git a/metrics/domain/models/charts/common.py b/metrics/domain/models/charts/common.py
index 317301450..7608749d3 100644
--- a/metrics/domain/models/charts/common.py
+++ b/metrics/domain/models/charts/common.py
@@ -1,3 +1,4 @@
+import logging
from collections.abc import Iterable
from decimal import Decimal
from typing import Literal
@@ -5,6 +6,8 @@
from pydantic.main import BaseModel
from rest_framework.request import Request
+logger = logging.getLogger(__name__)
+
class BaseChartRequestParams(BaseModel):
file_format: Literal["png", "svg", "jpg", "jpeg", "json", "csv"]
@@ -24,6 +27,16 @@ class BaseChartRequestParams(BaseModel):
class Config:
arbitrary_types_allowed = True
+ @property
+ def permission_sets(self) -> dict:
+ """Extract JWT permissions from the authenticated request"""
+
+ logger.info("Entered BaseChartRequestParams.permission_sets")
+
+ return getattr(self.request.user, "permission_sets", {})
+
@property
def rbac_permissions(self) -> Iterable["RBACPermission"]:
+ """TODO: RBAC-based permissions are legacy and will be removed in a future release"""
+
return getattr(self.request, "rbac_permissions", [])
diff --git a/metrics/interfaces/plots/access.py b/metrics/interfaces/plots/access.py
index 6e4eac34d..e8bcb0406 100644
--- a/metrics/interfaces/plots/access.py
+++ b/metrics/interfaces/plots/access.py
@@ -161,8 +161,12 @@ def get_queryset_from_core_model_manager(
plot_params["fields_to_export"].append("upper_confidence")
plot_params["fields_to_export"].append("lower_confidence")
+ logger.info("Entered access.py")
+
return self.core_model_manager.query_for_data(
- **plot_params, rbac_permissions=self.chart_request_params.rbac_permissions
+ **plot_params,
+ rbac_permissions=self.chart_request_params.rbac_permissions, # old permissions (remove)
+ permission_sets=self.chart_request_params.permission_sets, # new permissions
)
def build_plot_data_from_parameters_with_complete_queryset(
diff --git a/metrics/utils/permission_hierarchy.py b/metrics/utils/permission_hierarchy.py
index 53c9861a3..c5a182204 100644
--- a/metrics/utils/permission_hierarchy.py
+++ b/metrics/utils/permission_hierarchy.py
@@ -10,7 +10,7 @@
from django.db.models import QuerySet
-from auth_content.models.permission_sets import PermissionSet
+from cms.auth_content.models.permission_sets import PermissionSet
from metrics.data.models.core_models.supporting import (
Geography,
GeographyType,
diff --git a/pyproject.toml b/pyproject.toml
index 170a4e864..2ce7e729a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -225,6 +225,10 @@ ignore_imports = [
"metrics.api.urls_construction -> cms.snippets.views",
"metrics.api.urls_construction -> feedback.api.urls",
"cms.dynamic_content.elements -> validation.data_transfer_models.base",
+ "cms.auth_content.models.users -> metrics.data.managers.rbac_models.user", # Allow auth_content to be moved into CMS, consider refactor
+ "metrics.api.serializers.user -> cms.auth_content.models.users", # Allow auth_content to be moved into CMS, consider refactor
+ "metrics.data.managers.rbac_models.user -> cms.auth_content.models.permission_sets", # Allow auth_content to be moved into CMS, consider refactor
+ "metrics.utils.permission_hierarchy -> cms.auth_content.models.permission_sets", # Allow auth_content to be moved into CMS, consider refactor
]
[[tool.importlinter.contracts]]
@@ -237,7 +241,14 @@ layers = [
]
ignore_imports = [
"metrics.data.managers.core_models.time_series -> metrics.api.permissions.fluent_permissions",
- "metrics.data.managers.core_models.headline -> metrics.api.permissions.fluent_permissions"
+ "metrics.data.managers.core_models.headline -> metrics.api.permissions.fluent_permissions",
+ "metrics.data.managers.rbac_models.user -> cms.auth_content.models.permission_sets", # Allow auth_content to be moved into CMS, consider refactor
+ "cms.auth_content.models.permission_sets -> cms.metrics_interface.field_choices_callables", # Allow auth_content to be moved into CMS, consider refactor
+ "cms.metrics_interface.field_choices_callables -> cms.metrics_interface", # Allow auth_content to be moved into CMS, consider refactor
+ "cms.metrics_interface -> cms.metrics_interface.interface", # Allow auth_content to be moved into CMS, consider refactor
+ "cms.metrics_interface.interface -> metrics.domain.charts.colour_scheme", # Allow auth_content to be moved into CMS, consider refactor
+ "cms.metrics_interface.interface -> metrics.domain.charts.common_charts.plots.line_multi_coloured.properties", # Allow auth_content to be moved into CMS, consider refactor
+ "cms.metrics_interface.interface -> metrics.domain.common.utils", # Allow auth_content to be moved into CMS, consider refactor
]
[[tool.importlinter.contracts]]
diff --git a/tests/factories/auth_content/models/permission_sets.py b/tests/factories/auth_content/models/permission_sets.py
index 7f73e093f..633765392 100644
--- a/tests/factories/auth_content/models/permission_sets.py
+++ b/tests/factories/auth_content/models/permission_sets.py
@@ -1,6 +1,6 @@
import factory
-from auth_content.models.permission_sets import PermissionSet
+from cms.auth_content.models.permission_sets import PermissionSet
class PermissionSetFactory(factory.django.DjangoModelFactory):
diff --git a/tests/factories/auth_content/models/users.py b/tests/factories/auth_content/models/users.py
index df74effbd..ff76242f1 100644
--- a/tests/factories/auth_content/models/users.py
+++ b/tests/factories/auth_content/models/users.py
@@ -1,7 +1,7 @@
import factory
-from auth_content.models.permission_sets import PermissionSet
-from auth_content.models.users import User
+from cms.auth_content.models.permission_sets import PermissionSet
+from cms.auth_content.models.users import User
class UserFactory(factory.django.DjangoModelFactory):
diff --git a/tests/integration/cms/dashboard/__init__.py b/tests/integration/cms/dashboard/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/integration/cms/dashboard/test_viewsets.py b/tests/integration/cms/dashboard/test_viewsets.py
new file mode 100644
index 000000000..5a9ef4e6a
--- /dev/null
+++ b/tests/integration/cms/dashboard/test_viewsets.py
@@ -0,0 +1,219 @@
+import pytest
+from unittest.mock import MagicMock
+from django.test import RequestFactory
+from rest_framework.request import Request
+from wagtail.models import Page
+
+from cms.common.models import CommonPage
+from cms.dashboard.viewsets import CMSPagesAPIViewSet
+from cms.metrics_documentation.models.child import MetricsDocumentationChildEntry
+from cms.topic.models import TopicPage
+from metrics.data.models.core_models import Metric, Topic
+
+
+class MockPermissionSets(list):
+ def __init__(self, permissions, has_global_access=False):
+ super().__init__(permissions)
+ self._summary = {"has_global_access": has_global_access}
+
+ def __getitem__(self, key):
+ if key == "permission_sets":
+ return list(self)
+ if key == "summary":
+ return self._summary
+ return super().__getitem__(key)
+
+
+@pytest.mark.django_db
+class TestCMSPagesAPIViewSetPermissions:
+
+ @pytest.fixture
+ def setup_pages(self):
+ influenza_topic = Topic.objects.create(name="Influenza")
+ metric = Metric.objects.create(
+ name="influenza_headline_positivityLatest", topic=influenza_topic
+ )
+
+ covid_topic = Topic.objects.create(name="COVID-19")
+ private_metric = Metric.objects.create(
+ name="COVID-19_headline_cases_7DayTotals", topic=covid_topic
+ )
+ private_metric_two = Metric.objects.create(
+ name="COVID-19_headline_7DayAdmissionsChange", topic=covid_topic
+ )
+
+ home = Page.objects.get(id=2)
+
+ public_topic = TopicPage(
+ title="Public Topic",
+ page_description="test",
+ slug="public-topic",
+ is_public=True,
+ theme="1",
+ seo_title="public-topic",
+ )
+ home.add_child(instance=public_topic)
+
+ private_topic = TopicPage(
+ title="Private Topic",
+ page_description="test",
+ slug="private-topic",
+ is_public=False,
+ theme="1",
+ sub_theme="test",
+ topic="test",
+ page_classification="official_sensitive",
+ seo_title="private-topic",
+ )
+ home.add_child(instance=private_topic)
+
+ public_metrics = MetricsDocumentationChildEntry(
+ title="Public Metric",
+ page_description="test",
+ slug="public-metric",
+ metric=metric.pk,
+ is_public=True,
+ seo_title="public-metrics",
+ )
+ home.add_child(instance=public_metrics)
+
+ private_metrics = MetricsDocumentationChildEntry(
+ title="Private Metric",
+ page_description="test",
+ slug="private-metric",
+ theme="2",
+ sub_theme="test",
+ topic="test",
+ metric=private_metric.pk,
+ is_public=False,
+ seo_title="private-metrics",
+ )
+ home.add_child(instance=private_metrics)
+
+ private_metrics_two = MetricsDocumentationChildEntry(
+ title="Private Metric 2",
+ page_description="test",
+ slug="private-metric-two",
+ theme="1",
+ sub_theme="test",
+ topic="test",
+ metric=private_metric_two.pk,
+ is_public=False,
+ seo_title="private-metrics-two",
+ )
+ home.add_child(instance=private_metrics_two)
+
+ standard_page = CommonPage(
+ title="Standard", body="test", slug="standard", seo_title="standard-page"
+ )
+ home.add_child(instance=standard_page)
+
+ return {
+ "public_topic": public_topic,
+ "private_topic": private_topic,
+ "public_metrics": public_metrics,
+ "private_metrics": private_metrics,
+ "standard_page": standard_page,
+ }
+
+ def test_anonymous_user_access(self, setup_pages):
+ """
+ Given a request is made by an unauthenticated user
+ When the queryset is retrieved
+ Then only public pages are returned
+ """
+ # Given
+ rf = RequestFactory()
+ url = "/api/v2/pages/"
+ django_request = rf.get(url)
+
+ request = Request(django_request)
+
+ mock_user = MagicMock()
+ request.user = mock_user
+
+ view = CMSPagesAPIViewSet()
+ view.request = request
+
+ # When
+ result = view.get_queryset()
+
+ # Then
+ titles = [p.title for p in result]
+ assert "Public Topic" in titles
+ assert "Public Metric" in titles
+ assert "Standard" in titles
+ assert "Private Topic" not in titles
+ assert "Private Metric" not in titles
+ assert "Private Metric 2" not in titles
+
+ def test_global_access_user(self, setup_pages):
+ """
+ Given a request is made by an authenticated user with global access
+ When the queryset is retrieved
+ Then all pages are returned
+ """
+ # Given
+ rf = RequestFactory()
+ url = "/api/v2/pages/"
+ django_request = rf.get(url)
+
+ request = Request(django_request)
+
+ mock_user = MagicMock()
+ mock_user.permission_sets = MockPermissionSets(
+ [],
+ has_global_access=True,
+ )
+
+ request.user = mock_user
+ request.auth = "token"
+
+ view = CMSPagesAPIViewSet()
+ view.request = request
+
+ # When
+ result = view.get_queryset()
+
+ # Then
+ titles = [p.title for p in result]
+ assert "Public Topic" in titles
+ assert "Public Metric" in titles
+ assert "Standard" in titles
+ assert "Private Topic" in titles
+ assert "Private Metric" in titles
+ assert "Private Metric 2" in titles
+
+ def test_restricted_user_with_permission(self, setup_pages):
+ """
+ Given a request is made by an authenticated user with access to some private pages
+ When the queryset is retrieved
+ Then only the pages the user has access to are returned
+ """
+ # Given
+ rf = RequestFactory()
+ url = "/api/v2/pages/"
+ django_request = rf.get(url)
+
+ request = Request(django_request)
+
+ mock_user = MagicMock()
+ mock_user.permission_sets = MockPermissionSets(
+ [{"theme": {"id": "1"}, "sub_theme": {"id": "-1"}}],
+ has_global_access=False,
+ )
+
+ request.user = mock_user
+ request.auth = "token"
+
+ view = CMSPagesAPIViewSet()
+ view.request = request
+
+ # When
+ result = view.get_queryset()
+
+ # Then
+ titles = [p.title for p in result]
+ assert "Private Topic" in titles
+ assert "Private Metric 2" in titles
+ assert "Private Metric" not in titles
diff --git a/tests/integration/cms/dynamic_content/test_page_link_chooser.py b/tests/integration/cms/dynamic_content/test_page_link_chooser.py
index 3853a0067..12fb0681d 100644
--- a/tests/integration/cms/dynamic_content/test_page_link_chooser.py
+++ b/tests/integration/cms/dynamic_content/test_page_link_chooser.py
@@ -28,6 +28,9 @@ def test_page_chooser_returns_full_url(
path="abc",
depth=1,
title="abc",
+ theme="test",
+ sub_theme="test",
+ topic=1,
live=True,
seo_title="ABC",
)
diff --git a/tests/integration/cms/metrics_documentation/data_migration/test_operations.py b/tests/integration/cms/metrics_documentation/data_migration/test_operations.py
index 4e678a351..b8f464414 100644
--- a/tests/integration/cms/metrics_documentation/data_migration/test_operations.py
+++ b/tests/integration/cms/metrics_documentation/data_migration/test_operations.py
@@ -67,9 +67,12 @@ def test_removes_all_child_entries(self):
path="abc",
depth=1,
title="Test",
+ theme="test",
+ sub_theme="test",
slug="test",
page_description="xyz",
- metric=metric.name,
+ metric=metric.pk,
+ topic=metric.topic,
seo_title="Test",
)
assert MetricsDocumentationChildEntry.objects.exists()
@@ -146,45 +149,40 @@ def test_creates_correct_child_entries(self, dashboard_root_page: UKHSARootPage)
Then the correct child entries are created
for the corresponding `Metric` records
"""
+
# Given
- _seed_truncated_test_data_with_split_auth()
- healthcare_admission_metric = Metric.objects.get(
- name="RSV_healthcare_admissionRateByWeek"
+
+ entries = get_metrics_definitions()
+ assert entries, "No metric definitions found"
+
+ test_entry = entries[0]
+
+ topic = Topic.objects.create(name=test_entry["topic"])
+
+ metric = Metric.objects.create(
+ id=test_entry["metric"],
+ name=f"metric-{test_entry['metric']}",
+ topic=topic,
)
# When
create_metrics_documentation_parent_page_and_child_entries()
# Then
- healthcare_admission_rate_child_entry = (
- MetricsDocumentationChildEntry.objects.get(
- metric=healthcare_admission_metric.name
- )
- )
- assert healthcare_admission_rate_child_entry.metric_group == "healthcare"
- expected_title = "RSV healthcare admission rate by week"
- assert (
- healthcare_admission_rate_child_entry.slug
- == expected_title.lower().replace(" ", "-")
- )
- assert healthcare_admission_rate_child_entry.topic == "RSV"
- assert healthcare_admission_rate_child_entry.title == expected_title
- assert (
- healthcare_admission_rate_child_entry.seo_title
- == f"{expected_title} | UKHSA data dashboard"
- )
+ child_entry = MetricsDocumentationChildEntry.objects.get(metric=metric.pk)
+
+ expected_title = test_entry["title"]
+
+ assert child_entry.title == expected_title
+ assert child_entry.slug == expected_title.lower().replace(" ", "-")
+ assert child_entry.topic == test_entry["topic"]
+ assert child_entry.seo_title == test_entry["seo_title"]
- expected_page_description = (
- "This metric shows the rate per 100,000 people of the total number of people "
- "with confirmed RSV admitted to hospital "
- "(general admissions plus admissions to ICU and HDU) "
- "in the 7 days up to and including the date shown."
- )
assert (
- healthcare_admission_rate_child_entry.search_description
- == healthcare_admission_rate_child_entry.page_description
- == expected_page_description
+ child_entry.search_description
+ == child_entry.page_description
+ == test_entry["page_description"]
)
@pytest.mark.django_db
diff --git a/tests/integration/cms/metrics_documentation/models/test_child.py b/tests/integration/cms/metrics_documentation/models/test_child.py
index 024dc6b53..da59bd2a9 100644
--- a/tests/integration/cms/metrics_documentation/models/test_child.py
+++ b/tests/integration/cms/metrics_documentation/models/test_child.py
@@ -17,25 +17,31 @@ def test_metric_is_unique(self):
"""
# Given
metric_name = "influenza_headline_positivityLatest"
- Metric.objects.create(name=metric_name)
+ created_metric = Metric.objects.create(name=metric_name)
Topic.objects.create(name=metric_name.split("_")[0].title())
- _create_metrics_documentation_child_entry(metric_name=metric_name, path="doc_1")
+ _create_metrics_documentation_child_entry(
+ metric_name=metric_name, metric_id=created_metric.pk, path="doc_1"
+ )
# When / Then
with pytest.raises(ValidationError):
_create_metrics_documentation_child_entry(
- metric_name=metric_name, path="doc_2"
+ metric_name=metric_name, metric_id=created_metric.pk, path="doc_2"
)
def _create_metrics_documentation_child_entry(
metric_name: str,
+ metric_id: int,
path: str,
) -> MetricsDocumentationChildEntry:
MetricsDocumentationChildEntry.objects.create(
- metric=metric_name,
+ metric=metric_id,
title=metric_name,
+ theme="test",
+ sub_theme="test",
+ topic=1,
path=path,
depth=1,
slug=metric_name,
diff --git a/tests/integration/cms/topic/test_managers.py b/tests/integration/cms/topic/test_managers.py
index 0776a59a2..af36e4753 100644
--- a/tests/integration/cms/topic/test_managers.py
+++ b/tests/integration/cms/topic/test_managers.py
@@ -19,6 +19,9 @@ def test_get_live_pages(self):
path="abc",
depth=1,
title="abc",
+ theme="test",
+ topic=1,
+ sub_theme="test",
live=True,
seo_title="ABC",
)
@@ -26,6 +29,9 @@ def test_get_live_pages(self):
path="def",
depth=1,
title="def",
+ theme="test2",
+ topic=2,
+ sub_theme="test2",
live=False,
seo_title="DEF",
)
diff --git a/tests/integration/metrics/api/views/test_geographies.py b/tests/integration/metrics/api/views/test_geographies.py
index 31b15c627..685595f63 100644
--- a/tests/integration/metrics/api/views/test_geographies.py
+++ b/tests/integration/metrics/api/views/test_geographies.py
@@ -4,7 +4,7 @@
from rest_framework.response import Response
from rest_framework.test import APIClient
-from auth_content.constants import WILDCARD_ID_VALUE
+from cms.auth_content.constants import WILDCARD_ID_VALUE
from tests.factories.metrics.geography import GeographyFactory
from tests.factories.metrics.time_series import CoreTimeSeriesFactory
from validation.geography_code import UNITED_KINGDOM_GEOGRAPHY_CODE
diff --git a/tests/integration/metrics/api/views/test_permission_sets.py b/tests/integration/metrics/api/views/test_permission_sets.py
index 7acc8d51a..1bb4a9096 100644
--- a/tests/integration/metrics/api/views/test_permission_sets.py
+++ b/tests/integration/metrics/api/views/test_permission_sets.py
@@ -4,7 +4,7 @@
from rest_framework.response import Response
from rest_framework.test import APIClient
-from auth_content.constants import WILDCARD_ID_VALUE
+from cms.auth_content.constants import WILDCARD_ID_VALUE
from tests.factories.metrics.metric import MetricFactory
from tests.factories.metrics.sub_theme import SubThemeFactory
from tests.factories.metrics.topic import TopicFactory
diff --git a/tests/integration/metrics/api/views/test_user.py b/tests/integration/metrics/api/views/test_user.py
index 0742215d3..1264f058e 100644
--- a/tests/integration/metrics/api/views/test_user.py
+++ b/tests/integration/metrics/api/views/test_user.py
@@ -5,7 +5,7 @@
from rest_framework.response import Response
from rest_framework.test import APIClient
-from auth_content.constants import WILDCARD_ID_VALUE
+from cms.auth_content.constants import WILDCARD_ID_VALUE
from tests.factories.auth_content.models.permission_sets import PermissionSetFactory
from tests.factories.auth_content.models.users import UserFactory
from tests.factories.metrics.metric import MetricFactory
diff --git a/tests/integration/metrics/data/managers/rbac_models/test_user.py b/tests/integration/metrics/data/managers/rbac_models/test_user.py
index 8e085950f..7c75ec8ac 100644
--- a/tests/integration/metrics/data/managers/rbac_models/test_user.py
+++ b/tests/integration/metrics/data/managers/rbac_models/test_user.py
@@ -1,6 +1,6 @@
import pytest
-from auth_content.models.users import User
+from cms.auth_content.models.users import User
from tests.factories.auth_content.models.permission_sets import PermissionSetFactory
from tests.factories.auth_content.models.users import UserFactory
diff --git a/tests/integration/metrics/utils/test_permission_hierarchy.py b/tests/integration/metrics/utils/test_permission_hierarchy.py
index 62ee32dba..851860656 100644
--- a/tests/integration/metrics/utils/test_permission_hierarchy.py
+++ b/tests/integration/metrics/utils/test_permission_hierarchy.py
@@ -491,7 +491,7 @@ def test_single_permission_no_deduplication(self):
Then no deduplication occurs
"""
# Given
- from auth_content.models.users import User
+ from cms.auth_content.models.users import User
user = User.objects.create(user_id=uuid4())
perm = PermissionSetFactory.create_permission_set(
@@ -521,7 +521,7 @@ def test_removes_fully_subsumed_permission(self):
Then only permission A is returned
"""
# Given
- from auth_content.models.users import User
+ from cms.auth_content.models.users import User
user = User.objects.create(user_id=uuid4())
@@ -564,7 +564,7 @@ def test_keeps_independent_permissions(self):
Then both are kept
"""
# Given
- from auth_content.models.users import User
+ from cms.auth_content.models.users import User
user = User.objects.create(user_id=uuid4())
@@ -603,7 +603,7 @@ def test_complex_multi_level_deduplication(self):
Then correct deduplication occurs
"""
# Given
- from auth_content.models.users import User
+ from cms.auth_content.models.users import User
user = User.objects.create(user_id=uuid4())
@@ -658,7 +658,7 @@ def test_summary_contains_correct_statistics(self):
Then summary contains accurate statistics
"""
# Given
- from auth_content.models.users import User
+ from cms.auth_content.models.users import User
user = User.objects.create(user_id=uuid4())
@@ -696,7 +696,7 @@ def test_hierarchy_structure_is_correct(self):
Then each permission has correct structure
"""
# Given
- from auth_content.models.users import User
+ from cms.auth_content.models.users import User
user = User.objects.create(user_id=uuid4())
perm = PermissionSetFactory.create_permission_set(
@@ -743,7 +743,7 @@ def test_returns_normalized_permission_list(self):
Then returns list of NormalizedPermission objects
"""
# Given
- from auth_content.models.users import User
+ from cms.auth_content.models.users import User
user = User.objects.create(user_id=uuid4())
perm = PermissionSetFactory.create_permission_set(
@@ -772,7 +772,7 @@ def test_deduplicates_permissions(self):
Then subsumed permissions are removed
"""
# Given
- from auth_content.models.users import User
+ from cms.auth_content.models.users import User
user = User.objects.create(user_id=uuid4())
@@ -810,7 +810,7 @@ def test_empty_queryset_returns_empty_hierarchy(self):
Then returns empty hierarchy with zero counts
"""
# Given
- from auth_content.models.permission_sets import PermissionSet
+ from cms.auth_content.models.permission_sets import PermissionSet
# When
result = build_permission_hierarchy(PermissionSet.objects.none())
@@ -829,7 +829,7 @@ def test_handles_null_fields_gracefully(self):
Then handles gracefully without errors
"""
# Given
- from auth_content.models.users import User
+ from cms.auth_content.models.users import User
user = User.objects.create(user_id=uuid4())
perm = PermissionSetFactory.create(
@@ -859,7 +859,7 @@ def test_all_wildcards_at_different_levels(self):
Then correctly handles all wildcard combinations
"""
# Given
- from auth_content.models.users import User
+ from cms.auth_content.models.users import User
user = User.objects.create(user_id=uuid4())
perm1 = PermissionSetFactory.create_permission_set(
diff --git a/tests/unit/cms/auth_content/__init__.py b/tests/unit/cms/auth_content/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/unit/cms/auth_content/models/test_permission_sets.py b/tests/unit/cms/auth_content/models/test_permission_sets.py
new file mode 100644
index 000000000..1ca0eb2d0
--- /dev/null
+++ b/tests/unit/cms/auth_content/models/test_permission_sets.py
@@ -0,0 +1,165 @@
+import pytest
+from unittest.mock import MagicMock, patch
+
+from django.core.exceptions import ValidationError
+from cms.auth_content.models.permission_sets import PermissionSet, PermissionSetForm
+
+
+class TestPermissionSetForm:
+ MOCK_PERMISSION_SET_FIELDS = [
+ {"field_name": "theme", "field_label": "Theme"},
+ {"field_name": "sub_theme", "field_label": "Sub Theme"},
+ {"field_name": "topic", "field_label": "Topic"},
+ {"field_name": "metric", "field_label": "Metric"},
+ {"field_name": "geography", "field_label": "Geography"},
+ ]
+
+ def _make_form(self, instance=None, queryset_exists=False):
+ """
+ Instantiate PermissionSetForm with all Wagtail
+ internals patched.
+ """
+ with (
+ patch(
+ "wagtail.admin.panels.WagtailAdminPageForm.__init__", return_value=None
+ ),
+ patch(
+ "cms.auth_content.models.permission_sets.PERMISSION_SET_FIELDS",
+ self.MOCK_PERMISSION_SET_FIELDS,
+ ),
+ patch(
+ "cms.auth_content.models.permission_sets._create_form_field",
+ side_effect=lambda field, wildcard: MagicMock(name=field["field_name"]),
+ ),
+ ):
+ form = PermissionSetForm.__new__(PermissionSetForm)
+ form.fields = {}
+ form.instance = instance or MagicMock(pk=None)
+ default_data = {
+ "theme": 1,
+ "sub_theme": 2,
+ "topic": 3,
+ "metric": 4,
+ "geography_type": 5,
+ "geography": 6,
+ }
+ form.cleaned_data = default_data
+ form.__init__()
+
+ mock_qs = MagicMock()
+ mock_qs.exists.return_value = queryset_exists
+ mock_qs.exclude.return_value = mock_qs
+
+ form._mock_qs = mock_qs
+ return form
+
+ def test_init_sets_up_fields(self):
+ """
+ When a new form is instantiated
+ Then a form field is added to `fields` for each entry in `PERMISSION_SET_FIELDS`
+ """
+ form = self._make_form()
+
+ assert len(form.fields) == 5
+ assert "theme" in form.fields
+ assert "sub_theme" in form.fields
+ assert "topic" in form.fields
+ assert "metric" in form.fields
+ assert "geography" in form.fields
+
+ def test_initialize_dependent_fields(self):
+ """
+ Given a new form
+ When an instance has a pk value set
+ Then `_initialize_dependent_fields` is called
+ """
+ instance = MagicMock(pk=1, sub_theme=1, topic=2, metric=3, geography=4)
+ form = self._make_form(instance=instance)
+
+ assert form.fields["sub_theme"].widget.choices == [
+ ("", "Select theme first"),
+ (1, "Loading... (ID: 1)"),
+ ]
+ assert form.fields["topic"].widget.choices == [
+ ("", "Select sub-theme first"),
+ (2, "Loading... (ID: 2)"),
+ ]
+ assert form.fields["metric"].widget.choices == [
+ ("", "Select topic first"),
+ (3, "Loading... (ID: 3)"),
+ ]
+ assert form.fields["geography"].widget.choices == [
+ ("", "Select geography type first"),
+ (4, "Loading... (ID: 4)"),
+ ]
+
+ def test_get_field_choices(self):
+ """
+ When the static function `_get_field_choices` is called without a wildcard
+ Then the result returns the placeholder value
+ """
+ result = PermissionSetForm._get_field_choices("test", "placeholder", None)
+ assert result == [("", "placeholder"), ("test", "Loading... (ID: test)")]
+
+ def test_get_field_choices_wildcard_match(self):
+ """
+ When the static function `_get_field_choices` is called with a wildcard
+ Then the result returns the wildcard match
+ """
+ result = PermissionSetForm._get_field_choices("-1", "placeholder", "wildcard")
+ assert result == [("-1", "wildcard")]
+
+ @patch("cms.auth_content.models.permission_sets.PermissionSet.objects.filter")
+ def test_validation_error_raised_if_queryset_duplicated(
+ self, mock_query_filter: MagicMock
+ ):
+ """
+ Given a form is created with an existing queryset match
+ When `clean` is called
+ Then a `ValidationError` is raised
+ """
+ form = self._make_form(queryset_exists=True)
+ mock_query_filter.return_value = form._mock_qs
+
+ with pytest.raises(ValidationError) as e:
+ form.clean()
+
+ assert (
+ "A permission set with this exact combination already exists. Please modify your selection to create a unique permission set."
+ in str(e.value)
+ )
+
+ @patch("cms.auth_content.models.permission_sets.PermissionSet.objects.filter")
+ def test_returns_cleaned_data_when_no_duplicate_exists(
+ self, mock_query_filter: MagicMock
+ ):
+ """
+ Given a form is created without an existing queryset match
+ When `clean` is called
+ Then the cleaned data is returned
+ """
+ instance = MagicMock(pk=1)
+ form = self._make_form(instance=instance, queryset_exists=False)
+ mock_query_filter.return_value = form._mock_qs
+
+ result = form.clean()
+
+ assert result == form.cleaned_data
+
+
+class TestPermissionSet:
+ def test_get_choice_label(self):
+ """
+ Given a blank `PermissionSet`
+ When `_get_choice_label` is called with an unknown field and value
+ Then the unknown value is returned
+ """
+ test_permission_set = PermissionSet()
+ unknown_field = "unknown_field_type"
+ test_value = "12345"
+
+ # When
+ result = test_permission_set._get_choice_label(unknown_field, test_value)
+
+ # Then
+ assert result == test_value
diff --git a/tests/unit/cms/auth_content/test_auth_utils.py b/tests/unit/cms/auth_content/test_auth_utils.py
new file mode 100644
index 000000000..ba64ad56c
--- /dev/null
+++ b/tests/unit/cms/auth_content/test_auth_utils.py
@@ -0,0 +1,103 @@
+import pytest
+from unittest.mock import MagicMock
+from django import forms
+
+from cms.auth_content.auth_utils import _create_form_field
+
+
+class TestCreateFormField:
+ def test_create_form_field_basic(self):
+ """
+ Given no wildcard or callables in data
+ When `_create_form_field` is called
+ Then only the default choices are returned
+ """
+ field_data = {
+ "field_choice_default": "Select an option",
+ "field_choice_wildcard": None,
+ "field_choice_callable": None,
+ "field_label": "My Label",
+ }
+
+ result = _create_form_field(field_data)
+
+ assert isinstance(result, forms.CharField)
+ assert result.label == "My Label"
+ expected_choices = [("", "Select an option")]
+ assert result.widget.choices == expected_choices
+
+ def test_create_form_field_with_wildcard(self):
+ """
+ Given a wildcard in the data
+ When `_create_form_field` is called
+ Then the wildcard choice is added
+ """
+ field_data = {
+ "field_choice_default": "Default",
+ "field_choice_wildcard": "All Items",
+ "field_choice_callable": None,
+ "field_label": "Label",
+ }
+ wildcard_val = "-1"
+
+ result = _create_form_field(field_data, wildcard_id_value=wildcard_val)
+
+ expected_choices = [("", "Default"), ("-1", "All Items")]
+ assert result.widget.choices == expected_choices
+
+ def test_create_form_field_with_callable(self):
+ """
+ Given a callable in the data
+ When `_create_form_field` is called
+ Then the callable choice is added
+ """
+ mock_callable = MagicMock(return_value=[("1", "One"), ("2", "Two")])
+
+ field_data = {
+ "field_choice_default": "Default",
+ "field_choice_wildcard": None,
+ "field_choice_callable": mock_callable,
+ "field_label": "Label",
+ }
+
+ result = _create_form_field(field_data)
+
+ expected_choices = [("", "Default"), ("1", "One"), ("2", "Two")]
+ assert result.widget.choices == expected_choices
+ mock_callable.assert_called_once()
+
+ def test_create_form_field_all_features(self):
+ """
+ Given both a wildcard and callable are in the data
+ When `_create_form_field` is called
+ Then both the wildcard and callable choices are added
+ """
+ """Test combined default, wildcard, and callable choices"""
+ mock_callable = MagicMock(return_value=[("dynamic", "Dynamic")])
+ field_data = {
+ "field_choice_default": "Default",
+ "field_choice_wildcard": "Wildcard",
+ "field_choice_callable": mock_callable,
+ "field_label": "Label",
+ }
+
+ result = _create_form_field(field_data, wildcard_id_value="999")
+
+ expected_choices = [
+ ("", "Default"),
+ ("999", "Wildcard"),
+ ("dynamic", "Dynamic"),
+ ]
+ assert result.widget.choices == expected_choices
+
+ def test_create_form_field_missing_key_error(self):
+ """
+ Given data with missing keys
+ When `_create_form_field` is called
+ Then a key error exception is raised
+ """
+ """Test behavior if a required key is missing from the dict"""
+ field_data = {"field_choice_default": "Missing other keys"}
+
+ with pytest.raises(KeyError):
+ _create_form_field(field_data)
diff --git a/tests/unit/auth_content/test_wagtail_hooks.py b/tests/unit/cms/auth_content/test_wagtail_hooks.py
similarity index 58%
rename from tests/unit/auth_content/test_wagtail_hooks.py
rename to tests/unit/cms/auth_content/test_wagtail_hooks.py
index bf1c76cdb..0a2b19f04 100644
--- a/tests/unit/auth_content/test_wagtail_hooks.py
+++ b/tests/unit/cms/auth_content/test_wagtail_hooks.py
@@ -1,9 +1,23 @@
-from unittest.mock import MagicMock
+from unittest.mock import MagicMock, patch
from django.test import TestCase
from django.utils.safestring import SafeData
-from auth_content.models.permission_sets import PermissionSet
-from auth_content.wagtail_hooks import NoEditPermissionPolicy, PermissionSetViewSet
+from cms.auth_content.models.permission_sets import PermissionSet
+from cms.auth_content.wagtail_hooks import (
+ NoEditPermissionPolicy,
+ PermissionSetViewSet,
+ AuthGroup,
+ register_auth_viewset,
+)
+
+
+class TestWagtailHooks(TestCase):
+ def test_register_auth_viewset(self):
+ result = register_auth_viewset()
+ assert result.menu_label == AuthGroup.menu_label
+ assert result.menu_icon == AuthGroup.menu_icon
+ assert result.menu_order == AuthGroup.menu_order
+ assert len(result.items) == 2
class TestPermissionSetDetailsProperty(TestCase):
@@ -41,6 +55,14 @@ def setUp(self):
def test_change_permission_denied(self):
self.assertFalse(self.policy.user_has_permission(self.user, "change"))
+ @patch("wagtail.permission_policies.ModelPermissionPolicy.user_has_permission")
+ def test_user_has_permission_calls_super(self, spy_user_has_permissions: MagicMock):
+ spy_user_has_permissions.return_value = "parent_response"
+ result = self.policy.user_has_permission(self.user, "view")
+
+ spy_user_has_permissions.assert_called_once_with(self.user, "view")
+ assert result == "parent_response"
+
def test_change_permission_denied_for_instance(self):
self.assertFalse(
self.policy.user_has_permission_for_instance(
@@ -48,6 +70,22 @@ def test_change_permission_denied_for_instance(self):
)
)
+ @patch(
+ "wagtail.permission_policies.ModelPermissionPolicy.user_has_permission_for_instance"
+ )
+ def test_user_has_permission_for_instance_calls_super(
+ self, spy_user_has_permissions_for_instance: MagicMock
+ ):
+ spy_user_has_permissions_for_instance.return_value = "parent_response"
+ result = self.policy.user_has_permission_for_instance(
+ self.user, "view", self.instance
+ )
+
+ spy_user_has_permissions_for_instance.assert_called_once_with(
+ self.user, "view", self.instance
+ )
+ assert result == "parent_response"
+
class TestPermissionSetViewSet(TestCase):
diff --git a/tests/unit/cms/dashboard/test_viewsets.py b/tests/unit/cms/dashboard/test_viewsets.py
index 0620091aa..165ccb266 100644
--- a/tests/unit/cms/dashboard/test_viewsets.py
+++ b/tests/unit/cms/dashboard/test_viewsets.py
@@ -1,5 +1,155 @@
+import pytest
+
from cms.dashboard.serializers import CMSDraftPagesSerializer, ListablePageSerializer
-from cms.dashboard.viewsets import CMSDraftPagesViewSet, CMSPagesAPIViewSet
+from cms.dashboard.viewsets import (
+ CMSDraftPagesViewSet,
+ CMSPagesAPIViewSet,
+ check_permissions,
+)
+
+
+class TestCheckPermissions:
+ @pytest.mark.parametrize(
+ "user_permissions, theme_id, sub_theme_id, topic_id",
+ [
+ ([{"theme": {"id": "-1"}}], "10", "20", "30"),
+ ([{"theme": {"id": "10"}, "sub_theme": {"id": "-1"}}], "10", "20", "30"),
+ (
+ [
+ {
+ "theme": {"id": "10"},
+ "sub_theme": {"id": "20"},
+ "topic": {"id": "-1"},
+ }
+ ],
+ "10",
+ "20",
+ "30",
+ ),
+ (
+ [
+ {
+ "theme": {"id": "10"},
+ "sub_theme": {"id": "20"},
+ "topic": {"id": "30"},
+ }
+ ],
+ "10",
+ "20",
+ "30",
+ ),
+ (
+ [
+ {"theme": {"id": "5"}, "sub_theme": {"id": "-1"}},
+ {
+ "theme": {"id": "10"},
+ "sub_theme": {"id": "20"},
+ "topic": {"id": "30"},
+ },
+ ],
+ "10",
+ "20",
+ "30",
+ ),
+ ],
+ )
+ def test_check_permissions_valid_access(
+ self, user_permissions, theme_id, sub_theme_id, topic_id
+ ):
+ """
+ Given a permission set that does grant access to the provided ids
+ When the `check_permissions` function is called
+ Then the function returns true
+ """
+ assert (
+ check_permissions(user_permissions, theme_id, sub_theme_id, topic_id)
+ == True
+ )
+
+ @pytest.mark.parametrize(
+ "user_permissions, theme_id, sub_theme_id, topic_id",
+ [
+ ([{"theme": {"id": "99"}, "sub_theme": {"id": "-1"}}], "10", "20", "30"),
+ (
+ [
+ {
+ "theme": {"id": "10"},
+ "sub_theme": {"id": "99"},
+ "topic": {"id": "-1"},
+ }
+ ],
+ "10",
+ "20",
+ "30",
+ ),
+ (
+ [
+ {
+ "theme": {"id": "10"},
+ "sub_theme": {"id": "20"},
+ "topic": {"id": "99"},
+ }
+ ],
+ "10",
+ "20",
+ "30",
+ ),
+ ([], "10", "20", "30"),
+ ],
+ )
+ def test_check_permissions_invalid_access(
+ self, user_permissions, theme_id, sub_theme_id, topic_id
+ ):
+ """
+ Given a permission set that does not grant access to the provided ids
+ When the `check_permissions` function is called
+ Then the function returns false
+ """
+ assert (
+ check_permissions(user_permissions, theme_id, sub_theme_id, topic_id)
+ == False
+ )
+
+ @pytest.mark.parametrize(
+ "user_permissions, theme_id, sub_theme_id, topic_id",
+ [
+ ([{}], "10", "20", "30"),
+ (None, "10", "20", "30"),
+ ([{"sub_theme": {"id": "-1"}, "topic": {"id": "-1"}}], "10", "20", "30"),
+ (
+ [{"theme": {}, "sub_theme": {"id": "-1"}, "topic": {"id": "-1"}}],
+ "10",
+ "20",
+ "30",
+ ),
+ ([{"theme": {"id": "10"}, "topic": {"id": "-1"}}], "10", "20", "30"),
+ (
+ [{"theme": {"id": "10"}, "sub_theme": {}, "topic": {"id": "-1"}}],
+ "10",
+ "20",
+ "30",
+ ),
+ ([{"theme": {"id": "10"}, "sub_theme": {"id": "20"}}], "10", "20", "30"),
+ (
+ [{"theme": {"id": "10"}, "sub_theme": {"id": "20"}, "topic": {}}],
+ "10",
+ "20",
+ "30",
+ ),
+ ],
+ )
+ def test_check_permissions_with_missing_values(
+ self, user_permissions, theme_id, sub_theme_id, topic_id
+ ):
+ """
+ Given a permission set that is missing values
+ When the `check_permissions` function is called
+ Then the function returns false
+ """
+ assert (
+ check_permissions(user_permissions, theme_id, sub_theme_id, topic_id)
+ == False
+ )
class TestCMSDraftPagesViewSet:
diff --git a/tests/unit/cms/metrics_documentation/data_migration/test_child_entries.py b/tests/unit/cms/metrics_documentation/data_migration/test_child_entries.py
index 3d813e26b..39df8df1d 100644
--- a/tests/unit/cms/metrics_documentation/data_migration/test_child_entries.py
+++ b/tests/unit/cms/metrics_documentation/data_migration/test_child_entries.py
@@ -96,14 +96,18 @@ def test_returns_correct_dictionary(
"Fake page description",
"Fake methodology content",
"Fake caveats content",
+ 1,
)
spy_build_sections.return_value = []
expected_response = {
"title": "Fake title",
+ "topic": "Fake",
+ "theme": "test",
+ "sub_theme": "test",
"seo_title": "Fake title | UKHSA data dashboard",
"search_description": "Fake page description",
"page_description": "Fake page description",
- "metric": "Fake_metric_name",
+ "metric": 1,
"body": [],
}
@@ -132,6 +136,7 @@ def build_worksheet() -> Worksheet:
work_sheet["E2"] = "Fake page description"
work_sheet["F2"] = "Fake methodology content"
work_sheet["G2"] = "Fake caveats content"
+ work_sheet["H2"] = 1
return work_sheet
@@ -150,10 +155,13 @@ def test_delegates_calls_correctly(
expected_response = [
{
"title": "Fake title",
+ "topic": "Fake",
+ "theme": "test",
+ "sub_theme": "test",
"seo_title": "Fake title | UKHSA data dashboard",
"search_description": "Fake page description",
"page_description": "Fake page description",
- "metric": "Fake_metric_name",
+ "metric": 1,
"body": [
{
"type": "section",
diff --git a/tests/unit/cms/metrics_documentation/data_migration/test_operations.py b/tests/unit/cms/metrics_documentation/data_migration/test_operations.py
index b4741209a..274b78e5c 100644
--- a/tests/unit/cms/metrics_documentation/data_migration/test_operations.py
+++ b/tests/unit/cms/metrics_documentation/data_migration/test_operations.py
@@ -128,9 +128,11 @@ def test_log_recorded_when_metric_not_available_for_child_page(
create_metrics_documentation_parent_page_and_child_entries()
# Then
- expected_log = (
- f"Metrics Documentation Child Entry for {fake_metric} was not created. "
+ expected_log_part_one = "Metrics Documentation Child Entry for "
+ expected_log_part_two = (
+ " was not created. "
"Because the corresponding `Metric` was not created beforehand"
)
- assert expected_log in caplog.text
+ assert expected_log_part_one in caplog.text
+ assert expected_log_part_two in caplog.text
diff --git a/tests/unit/cms/metrics_documentation/models/test_child.py b/tests/unit/cms/metrics_documentation/models/test_child.py
index 08cc67a9b..8f2b60a90 100644
--- a/tests/unit/cms/metrics_documentation/models/test_child.py
+++ b/tests/unit/cms/metrics_documentation/models/test_child.py
@@ -1,12 +1,14 @@
from unittest import mock
+from unittest.mock import MagicMock, patch
from django.core.exceptions import ValidationError
import pytest
from wagtail.admin.panels import FieldPanel
from wagtail.api.conf import APIField
-from cms.metrics_documentation.models import child
+from cms.dashboard.constants import THEME_FIELDS
from cms.metrics_documentation.models.child import (
+ MetricsDocumentationChildEntryAdminForm,
InvalidTopicForChosenMetricForChildEntryError,
)
from tests.fakes.factories.cms.metrics_documentation_child_entry_factory import (
@@ -16,6 +18,167 @@
MODULE_PATH = "cms.metrics_documentation.models.child"
+class TestInvalidTopicForChosenMetricForChildEntryError:
+ def test_exception_has_expected_message(self):
+ actual = InvalidTopicForChosenMetricForChildEntryError(
+ "test_topic", "test_metric"
+ )
+ expected = "InvalidTopicForChosenMetricForChildEntryError('The `test_topic` is not available for selected metric of `test_metric`')"
+
+ assert expected == repr(actual)
+
+
+class TestMetricsDocumentationChildEntryAdminForm:
+ MOCK_THEME_FIELDS = [
+ {"field_name": "theme", "label": "Theme", "required": True},
+ {"field_name": "sub_theme", "label": "Sub Theme", "required": False},
+ {"field_name": "topic", "label": "Topic", "required": False},
+ ]
+
+ def _make_form(self, instance=None):
+ """
+ Instantiate MetricsDocumentationChildEntryAdminForm with all Wagtail
+ internals patched.
+ """
+ with (
+ patch(
+ "wagtail.admin.panels.WagtailAdminPageForm.__init__", return_value=None
+ ),
+ patch(
+ "cms.metrics_documentation.models.child.THEME_FIELDS",
+ self.MOCK_THEME_FIELDS,
+ ),
+ patch(
+ "cms.metrics_documentation.models.child._create_form_field",
+ side_effect=lambda field: MagicMock(name=field["field_name"]),
+ ),
+ ):
+ form = MetricsDocumentationChildEntryAdminForm.__new__(
+ MetricsDocumentationChildEntryAdminForm
+ )
+ form.fields = {}
+ form.instance = instance or MagicMock(pk=None)
+ form.__init__()
+ return form
+
+ def _make_form_with_instance(self, sub_theme=None, topic=None):
+ """
+ Instantiate MetricsDocumentationChildEntryAdminForm with all Wagtail
+ internals patched, and a mocked instance.
+ """
+ instance = MagicMock(pk=1)
+ instance.sub_theme = sub_theme
+ instance.topic = topic
+
+ form = self._make_form(instance=instance)
+
+ for field_name in ("sub_theme", "topic"):
+ mock_widget = MagicMock()
+ mock_widget.choices = []
+ form.fields[field_name] = MagicMock(widget=mock_widget)
+
+ return form
+
+ def test_creates_field_for_every_theme_field(self):
+ """
+ When a new form is instantiated
+ Then a form field is added to `fields` for each entry in `THEME_FIELDS`.
+ """
+ form = self._make_form()
+
+ assert len(form.fields) == 3
+ assert "theme" in form.fields
+ assert "sub_theme" in form.fields
+ assert "topic" in form.fields
+
+ @mock.patch("cms.metrics_documentation.models.child._create_form_field")
+ @mock.patch("wagtail.admin.panels.WagtailAdminPageForm.__init__")
+ def test_field_creation_uses_create_form_field_helper(
+ self, spy_init_admin_form: mock.MagicMock, spy_create_form_field: mock.MagicMock
+ ):
+ """
+ Given a new form is created
+ When init is called on the form
+ Then `_create_form_field` is called once per `THEME_FIELDS` entry.
+ """
+ form = MetricsDocumentationChildEntryAdminForm.__new__(
+ MetricsDocumentationChildEntryAdminForm
+ )
+ form.fields = {}
+ form.instance = MagicMock(pk=None)
+ form.__init__()
+
+ assert spy_create_form_field.call_count == len(THEME_FIELDS)
+ spy_create_form_field.assert_any_call(THEME_FIELDS[0])
+ spy_create_form_field.assert_any_call(THEME_FIELDS[1])
+ spy_create_form_field.assert_any_call(THEME_FIELDS[2])
+ spy_create_form_field.assert_any_call(THEME_FIELDS[3])
+
+ def test_initialize_dependent_fields_called_when_instance_has_pk(self):
+ """
+ Given a new form
+ When an instance has a pk value set
+ Then `_initialize_dependent_fields` is called
+ """
+ instance = MagicMock(pk=42)
+
+ with patch.object(
+ MetricsDocumentationChildEntryAdminForm,
+ "_initialize_dependent_fields",
+ ) as mock_init_deps:
+ self._make_form(instance=instance)
+
+ mock_init_deps.assert_called_once()
+
+ def test_initialize_dependent_fields_not_called_when_instance_has_no_pk(self):
+ """
+ Given a new form
+ When an instance does not have a pk value set
+ Then `_initialize_dependent_fields` is not called
+ """
+ instance = MagicMock(pk=None)
+
+ with patch.object(
+ MetricsDocumentationChildEntryAdminForm,
+ "_initialize_dependent_fields",
+ ) as mock_init_deps:
+ self._make_form(instance=instance)
+
+ mock_init_deps.assert_not_called()
+
+ def test_both_fields_updated_when_both_have_values(self):
+ """
+ Given a new form with a sub_theme and topic
+ When `_initialize_dependent_fields` is called
+ Then both sub_theme and topic choices are set
+ """
+ form = self._make_form_with_instance(sub_theme=3, topic=7)
+
+ form._initialize_dependent_fields()
+
+ assert form.fields["sub_theme"].widget.choices == [
+ ("", "Select theme first"),
+ (3, "Loading... (ID: 3)"),
+ ]
+ assert form.fields["topic"].widget.choices == [
+ ("", "Select sub-theme first"),
+ (7, "Loading... (ID: 7)"),
+ ]
+
+ def test_skips_field_when_value_is_none(self):
+ """
+ Given a new form with no sub_theme or topic
+ When `_initialize_dependent_fields` is called
+ Then widget choices are left untouched.
+ """
+ form = self._make_form_with_instance(sub_theme=None, topic=None)
+ original_choices = form.fields["sub_theme"].widget.choices
+
+ form._initialize_dependent_fields()
+
+ assert form.fields["sub_theme"].widget.choices == original_choices
+
+
class TestMetricsDocumentationChildEntry:
@pytest.mark.parametrize(
"expected_api_field",
@@ -32,10 +195,10 @@ class TestMetricsDocumentationChildEntry:
"page_classification",
],
)
- @mock.patch(f"{MODULE_PATH}.get_all_unique_metric_names")
+ @mock.patch(f"{MODULE_PATH}.get_all_metric_names_and_ids")
def test_has_correct_api_fields(
self,
- mock_get_all_unique_metric_names: mock.MagicMock(),
+ mock_get_all_metric_names_and_ids: mock.MagicMock(),
expected_api_field: str,
):
"""
@@ -63,10 +226,10 @@ def test_has_correct_api_fields(
"body",
],
)
- @mock.patch(f"{MODULE_PATH}.get_all_unique_metric_names")
+ @mock.patch(f"{MODULE_PATH}.get_all_metric_names_and_ids")
def test_has_the_correct_content_panels(
self,
- mock_get_all_unique_metric_names: mock.MagicMock(),
+ mock_get_all_metric_names_and_ids: mock.MagicMock(),
expected_content_panel_name: str,
):
"""
@@ -89,154 +252,207 @@ def test_has_the_correct_content_panels(
fake_metrics_documentation_child_entry_page, expected_content_panel_name
)
- @mock.patch(f"{MODULE_PATH}.get_a_list_of_all_topic_names")
- @mock.patch.object(child.MetricsDocumentationChildEntry, "find_topic")
- @mock.patch(f"{MODULE_PATH}.get_all_unique_metric_names")
- def test_get_topic_delegates_calls_correctly(
+ @mock.patch(f"{MODULE_PATH}.get_all_metric_names_and_ids")
+ @pytest.mark.parametrize(
+ "metric_id, metric_group",
+ [
+ (1, "cases"),
+ (2, "headline"),
+ (3, "vaccinations"),
+ (4, "deaths"),
+ ],
+ )
+ def test_metric_group_returns_expected_string(
self,
- mock_get_all_unique_metric_names: mock.MagicMock,
- spy_find_topic: mock.MagicMock,
- spy_get_a_list_of_all_topic_names: mock.MagicMock,
+ get_all_metric_names_and_ids: mock.MagicMock,
+ metric_id: int,
+ metric_group: str,
):
"""
Given a blank `MetricsDocumentationChildEntryPage` model.
- When `get_topic()` is called.
- Then the `get_a_list_of_all_topic_names()` method and `find_topic()`
- methods are called.
+ When a metric id is supplied to the `metric` property.
+ Then the metric_group will be correctly extracted from the string.
"""
# Given
- fake_topics = ["COVID-19", "Influenza"]
+ get_all_metric_names_and_ids.return_value = [
+ (1, "COVID-19_cases_rateRollingMean"),
+ (2, "COVID-19_headline_vaccines_autumn23Total"),
+ (3, "COVID-19_vaccinations_autumn22_uptakeByDay"),
+ (4, "COVID-19_deaths_ONSByWeek"),
+ ]
fake_metrics_documentation_child_entry_page = (
FakeMetricsDocumentationChildEntryFactory.build_page_from_template()
)
- fake_metrics_documentation_child_entry_page.metric = (
- "COVID-19_cases_rateRollingMean"
- )
# When
- spy_get_a_list_of_all_topic_names.return_value = fake_topics
- fake_metrics_documentation_child_entry_page.get_topic()
+ fake_metrics_documentation_child_entry_page.metric = metric_id
# Then
- spy_get_a_list_of_all_topic_names.assert_called_once()
- spy_find_topic.assert_called_once()
+ assert fake_metrics_documentation_child_entry_page.metric_group == metric_group
- @mock.patch(f"{MODULE_PATH}.get_a_list_of_all_topic_names")
- @mock.patch(f"{MODULE_PATH}.get_all_unique_metric_names")
+ @mock.patch(f"{MODULE_PATH}.get_all_metric_names_and_ids")
@pytest.mark.parametrize(
- "metric_name, metric_group",
- [
- ("COVID-19_cases_rateRollingMean", "cases"),
- ("COVID-19_headline_vaccines_autumn23Total", "headline"),
- ("COVID-19_vaccinations_autumn22_uptakeByDay", "vaccinations"),
- ("COVID-19_deaths_ONSByWeek", "deaths"),
- ],
+ "metric_id",
+ [1, 2, 3, 4, 5],
)
- def test_metric_group_returns_expected_string(
- self,
- mock_get_all_unique_metric_names: mock.MagicMock,
- mock_get_all_topic_names: mock.MagicMock,
- metric_name: str,
- metric_group: str,
+ def test_metric_group_returns_emptry_string_with_missing_values(
+ self, get_all_metric_names_and_ids: mock.MagicMock, metric_id: int
):
"""
Given a blank `MetricsDocumentationChildEntryPage` model.
- When a metric name is supplied to the `metric` property.
- Then the metric_group will be correctly extracted from the string.
+ When a metric id is supplied to the `metric` property with invalid choices returned.
+ Then the metric_group will return an empty string.
"""
# Given
+ get_all_metric_names_and_ids.return_value = [
+ (1, "COVID-19casesrateRollingMean"),
+ (2, "COVID-19_"),
+ (3, ""),
+ (4, None),
+ ]
fake_metrics_documentation_child_entry_page = (
FakeMetricsDocumentationChildEntryFactory.build_page_from_template()
)
# When
- fake_metrics_documentation_child_entry_page.metric = metric_name
+ fake_metrics_documentation_child_entry_page.metric = metric_id
# Then
- assert fake_metrics_documentation_child_entry_page.metric_group == metric_group
+ assert fake_metrics_documentation_child_entry_page.metric_group == ""
- @mock.patch(f"{MODULE_PATH}.get_all_unique_metric_names")
- @mock.patch(f"{MODULE_PATH}.get_a_list_of_all_topic_names")
- @pytest.mark.parametrize(
- "selected_metric, extracted_topic",
- [
- ("COVID-19_cases_rateRollingMean", "COVID-19"),
- ("influenza_headline_ICUHDUadmissionRatePercentChange", "Influenza"),
- ("hMPV_testing_positivityByWeek", "hMPV"),
- ("parainfluenza_headline_positivityLatest", "Parainfluenza"),
- ("rhinovirus_headline_positivityLatest", "Rhinovirus"),
- ("RSV_headline_admissionRateLatest", "RSV"),
- ("adenovirus_headline_positivityLatest", "Adenovirus"),
- ],
+ @mock.patch(f"{MODULE_PATH}.get_all_metric_names_and_ids")
+ def test_metric_group_returns_emptry_string_with_empty_metrics(
+ self, get_all_metric_names_and_ids: mock.MagicMock
+ ):
+ """
+ Given a blank `MetricsDocumentationChildEntryPage` model.
+ When a metric id is supplied to the `metric` property with no choices returned.
+ Then the metric_group will return an empty string.
+ """
+ # Given
+ get_all_metric_names_and_ids.return_value = []
+ fake_metrics_documentation_child_entry_page = (
+ FakeMetricsDocumentationChildEntryFactory.build_page_from_template()
+ )
+
+ # When
+ fake_metrics_documentation_child_entry_page.metric = 1
+
+ # Then
+ assert fake_metrics_documentation_child_entry_page.metric_group == ""
+
+ @mock.patch(
+ "cms.dashboard.models.UKHSAPage._raise_error_if_seo_title_tag_not_provided",
+ return_value=None,
)
- def test_find_topic_returns_expected_topic_name(
+ @mock.patch(
+ "cms.dashboard.models.UKHSAPage._raise_error_if_slug_not_unique",
+ return_value=None,
+ )
+ @mock.patch(f"{MODULE_PATH}.get_all_metric_names_and_ids")
+ def test_public_error_raised_if_invalid_classification(
self,
- spy_get_a_list_of_all_topic_names: mock.MagicMock(),
- mock_get_all_unique_metric_names: mock.MagicMock(),
- selected_metric: str,
- extracted_topic: str,
+ mock_get_all_metric_names_and_ids: mock.MagicMock(),
+ mock_slug_raise_error,
+ mock_seo_title_raise_error,
):
"""
- Given a blank `MetricsDocumentationChildEntryPage` model
- a list of topics and a metric name.
- When the `find_topic()` method is called.
- Then the expected topic name will be matched from the list
- using the metric name.
+ Given is_public is False (i.e the page is a non public page).
+ When no page classification is given.
+ Then a `ValidationError` is raised.
"""
# Given
fake_metrics_documentation_child_entry_page = (
FakeMetricsDocumentationChildEntryFactory.build_page_from_template()
)
- fake_topics = [
- "COVID-19",
- "Influenza",
- "RSV",
- "hMPV",
- "Parainfluenza",
- "Rhinovirus",
- "Adenovirus",
- ]
- fake_metrics_documentation_child_entry_page.metric = selected_metric
- # When
- return_topic = fake_metrics_documentation_child_entry_page.find_topic(
- topics=fake_topics
+ fake_metrics_documentation_child_entry_page.is_public = False
+ fake_metrics_documentation_child_entry_page.page_classification = None
+ fake_metrics_documentation_child_entry_page.theme = "test"
+ fake_metrics_documentation_child_entry_page.sub_theme = "test"
+ fake_metrics_documentation_child_entry_page.topic = "test"
+
+ # When/Then
+ with pytest.raises(ValidationError) as e:
+ fake_metrics_documentation_child_entry_page.clean()
+
+ assert "Please select a classification level for this non-public page" in str(
+ e.value
)
- # Then
- assert return_topic == extracted_topic
+ @mock.patch(
+ "cms.dashboard.models.UKHSAPage._raise_error_if_seo_title_tag_not_provided",
+ return_value=None,
+ )
+ @mock.patch(
+ "cms.dashboard.models.UKHSAPage._raise_error_if_slug_not_unique",
+ return_value=None,
+ )
+ @mock.patch(f"{MODULE_PATH}.get_all_metric_names_and_ids")
+ def test_public_error_raised_if_invalid_theme(
+ self,
+ mock_get_all_metric_names_and_ids: mock.MagicMock(),
+ mock_slug_raise_error,
+ mock_seo_title_raise_error,
+ ):
+ """
+ Given is_public is False (i.e the page is a non public page).
+ When no theme is given.
+ Then a `ValidationError` is raised.
+ """
+ # Given
+ fake_metrics_documentation_child_entry_page = (
+ FakeMetricsDocumentationChildEntryFactory.build_page_from_template()
+ )
- @mock.patch(f"{MODULE_PATH}.get_all_unique_metric_names")
- @mock.patch(f"{MODULE_PATH}.get_a_list_of_all_topic_names")
- def test_find_topic_raises_error(
+ fake_metrics_documentation_child_entry_page.is_public = False
+ fake_metrics_documentation_child_entry_page.page_classification = "test"
+ fake_metrics_documentation_child_entry_page.theme = None
+ fake_metrics_documentation_child_entry_page.sub_theme = "test"
+ fake_metrics_documentation_child_entry_page.topic = "test"
+
+ # When/Then
+ with pytest.raises(ValidationError) as e:
+ fake_metrics_documentation_child_entry_page.clean()
+
+ assert "Please select a theme for this non-public page" in str(e.value)
+
+ @mock.patch(
+ "cms.dashboard.models.UKHSAPage._raise_error_if_seo_title_tag_not_provided",
+ return_value=None,
+ )
+ @mock.patch(
+ "cms.dashboard.models.UKHSAPage._raise_error_if_slug_not_unique",
+ return_value=None,
+ )
+ @mock.patch(f"{MODULE_PATH}.get_all_metric_names_and_ids")
+ def test_public_error_raised_if_invalid_sub_theme(
self,
- mock_get_a_list_of_all_topic_names: mock.MagicMock(),
- mock_get_all_unique_metric_names: mock.MagicMock(),
+ mock_get_all_metric_names_and_ids: mock.MagicMock(),
+ mock_slug_raise_error,
+ mock_seo_title_raise_error,
):
"""
- Given a metric name that does not include a valid topic.
- When the `find_topic()` method is called with a list of topics.
- Then an `InvalidTopicForChosenMetricForChildEntryError` is raised.
+ Given is_public is False (i.e the page is a non public page).
+ When no sub theme is given.
+ Then a `ValidationError` is raised.
"""
# Given
- fake_invalid_metric = "invalid_metric_contains_no_topic"
- fake_topics = [
- "COVID-19",
- "Influenza",
- "RSV",
- "hMPV",
- "Parainfluenza",
- "Rhinovirus",
- "Adenovirus",
- ]
fake_metrics_documentation_child_entry_page = (
FakeMetricsDocumentationChildEntryFactory.build_page_from_template()
)
- fake_metrics_documentation_child_entry_page.metric = fake_invalid_metric
- # When / Then
- with pytest.raises(InvalidTopicForChosenMetricForChildEntryError):
- fake_metrics_documentation_child_entry_page.find_topic(topics=fake_topics)
+ fake_metrics_documentation_child_entry_page.is_public = False
+ fake_metrics_documentation_child_entry_page.page_classification = "test"
+ fake_metrics_documentation_child_entry_page.theme = "None"
+ fake_metrics_documentation_child_entry_page.sub_theme = None
+ fake_metrics_documentation_child_entry_page.topic = "test"
+
+ # When/Then
+ with pytest.raises(ValidationError) as e:
+ fake_metrics_documentation_child_entry_page.clean()
+
+ assert "Please select a subtheme for this non-public page" in str(e.value)
@mock.patch(
"cms.dashboard.models.UKHSAPage._raise_error_if_seo_title_tag_not_provided",
@@ -246,18 +462,16 @@ def test_find_topic_raises_error(
"cms.dashboard.models.UKHSAPage._raise_error_if_slug_not_unique",
return_value=None,
)
- @mock.patch(f"{MODULE_PATH}.get_all_unique_metric_names")
- @mock.patch(f"{MODULE_PATH}.get_a_list_of_all_topic_names")
- def test_public_error_raised_if_invalid_classification(
+ @mock.patch(f"{MODULE_PATH}.get_all_metric_names_and_ids")
+ def test_public_error_raised_if_invalid_topic(
self,
- mock_get_a_list_of_all_topic_names: mock.MagicMock(),
- mock_get_all_unique_metric_names: mock.MagicMock(),
+ mock_get_all_metric_names_and_ids: mock.MagicMock(),
mock_slug_raise_error,
mock_seo_title_raise_error,
):
"""
Given is_public is False (i.e the page is a non public page).
- When no page classification is given.
+ When no topic is given.
Then a `ValidationError` is raised.
"""
# Given
@@ -266,12 +480,17 @@ def test_public_error_raised_if_invalid_classification(
)
fake_metrics_documentation_child_entry_page.is_public = False
- fake_metrics_documentation_child_entry_page.page_classification = None
+ fake_metrics_documentation_child_entry_page.page_classification = "test"
+ fake_metrics_documentation_child_entry_page.theme = "test"
+ fake_metrics_documentation_child_entry_page.sub_theme = "test"
+ fake_metrics_documentation_child_entry_page.topic = None
# When/Then
- with pytest.raises(ValidationError):
+ with pytest.raises(ValidationError) as e:
fake_metrics_documentation_child_entry_page.clean()
+ assert "Please select a topic for this non-public page" in str(e.value)
+
@mock.patch(
"cms.dashboard.models.UKHSAPage._raise_error_if_seo_title_tag_not_provided",
return_value=None,
@@ -280,12 +499,10 @@ def test_public_error_raised_if_invalid_classification(
"cms.dashboard.models.UKHSAPage._raise_error_if_slug_not_unique",
return_value=None,
)
- @mock.patch(f"{MODULE_PATH}.get_all_unique_metric_names")
- @mock.patch(f"{MODULE_PATH}.get_a_list_of_all_topic_names")
+ @mock.patch(f"{MODULE_PATH}.get_all_metric_names_and_ids")
def test_public_page_clears_page_classification(
self,
- mock_get_a_list_of_all_topic_names: mock.MagicMock(),
- mock_get_all_unique_metric_names: mock.MagicMock(),
+ mock_get_all_metric_names_and_ids: mock.MagicMock(),
mock_slug_raise_error,
mock_seo_title_raise_error,
):
@@ -316,12 +533,10 @@ def test_public_page_clears_page_classification(
"cms.dashboard.models.UKHSAPage._raise_error_if_slug_not_unique",
return_value=None,
)
- @mock.patch(f"{MODULE_PATH}.get_all_unique_metric_names")
- @mock.patch(f"{MODULE_PATH}.get_a_list_of_all_topic_names")
+ @mock.patch(f"{MODULE_PATH}.get_all_metric_names_and_ids")
def test_non_public_page_doesnt_clean_page_classification(
self,
- mock_get_a_list_of_all_topic_names: mock.MagicMock(),
- mock_get_all_unique_metric_names: mock.MagicMock(),
+ mock_get_all_metric_names_and_ids: mock.MagicMock(),
mock_slug_raise_error,
mock_seo_title_raise_error,
):
@@ -337,6 +552,9 @@ def test_non_public_page_doesnt_clean_page_classification(
fake_metrics_documentation_child_entry_page.is_public = False
fake_metrics_documentation_child_entry_page.page_classification = "official"
+ fake_metrics_documentation_child_entry_page.theme = "infectious_disease"
+ fake_metrics_documentation_child_entry_page.sub_theme = "respiratory"
+ fake_metrics_documentation_child_entry_page.topic = "COVID-19"
# When
fake_metrics_documentation_child_entry_page.clean()
diff --git a/tests/unit/cms/topic/test_models.py b/tests/unit/cms/topic/test_models.py
index 9b0c1f156..d6458a5ae 100644
--- a/tests/unit/cms/topic/test_models.py
+++ b/tests/unit/cms/topic/test_models.py
@@ -5,7 +5,7 @@
from django.core.exceptions import ValidationError
-from cms.topic.models import TopicPage
+from cms.topic.models import TopicPage, TopicPageAdminForm
from metrics.domain.charts.colour_scheme import RGBAChartLineColours
from metrics.domain.charts.common_charts.plots.line_multi_coloured.properties import (
@@ -16,6 +16,108 @@
from wagtail.search.index import SearchField
+class TestTopicPageAdminForm:
+ MOCK_THEME_FIELDS = [
+ {"field_name": "theme", "label": "Theme", "required": True},
+ {"field_name": "sub_theme", "label": "Sub Theme", "required": False},
+ ]
+
+ def _make_form(self, instance=None):
+ """
+ Instantiate TopicPageAdminForm with all Wagtail
+ internals patched.
+ """
+ with (
+ mock.patch(
+ "wagtail.admin.panels.WagtailAdminPageForm.__init__", return_value=None
+ ),
+ mock.patch("cms.topic.models.THEME_FIELDS", self.MOCK_THEME_FIELDS),
+ mock.patch(
+ "cms.topic.models._create_form_field",
+ side_effect=lambda field: mock.MagicMock(name=field["field_name"]),
+ ),
+ ):
+ form = TopicPageAdminForm.__new__(TopicPageAdminForm)
+ form.fields = {}
+ form.instance = instance or mock.MagicMock(pk=None)
+ form.__init__()
+ return form
+
+ def test_theme_fields_are_added_on_init(self):
+ """
+ When a new form is instantieated
+ Then a form field is added to `fields` for each entry in `THEME_FIELDS`.
+ """
+ form = self._make_form()
+
+ assert len(form.fields) == 2
+ assert "theme" in form.fields
+ assert "sub_theme" in form.fields
+
+ def test_dependent_fields_initialised_for_saved_instance(self):
+ """
+ Given a new form
+ When an instance has a pk value set
+ Then `_initialize_dependent_fields` is called
+ """
+ with mock.patch.object(
+ TopicPageAdminForm, "_initialize_dependent_fields"
+ ) as init_fields_mock:
+ self._make_form(instance=mock.MagicMock(pk=1))
+ init_fields_mock.assert_called_once()
+
+ def test_dependent_fields_not_initialised_for_new_instance(self):
+ """
+ Given a new form
+ When an instance does not have a pk value set
+ Then `_initialize_dependent_fields` is not called
+ """
+ with mock.patch.object(
+ TopicPageAdminForm, "_initialize_dependent_fields"
+ ) as init_fields_mock:
+ self._make_form(instance=mock.MagicMock(pk=None))
+ init_fields_mock.assert_not_called()
+
+ def test_widget_choices_set_when_sub_theme_has_value(self):
+ """
+ Given a new form with a sub_theme
+ When `_initialize_dependent_fields` is called
+ Then the sub_theme choices are set
+ """
+ instance = mock.MagicMock(pk=1, sub_theme=5, topic=None)
+ form = self._make_form(instance=instance)
+ mock_widget = mock.MagicMock()
+ form.fields["sub_theme"] = mock.MagicMock(widget=mock_widget)
+ form.fields["topic"] = mock.MagicMock(widget=mock.MagicMock())
+
+ form._initialize_dependent_fields()
+
+ assert mock_widget.choices == [
+ ("", "Select theme first"),
+ (5, "Loading... (ID: 5)"),
+ ]
+
+ def test_widget_choices_not_set_when_value_is_none(self):
+ """
+ Given a new form with no sub_theme or topic
+ When `_initialize_dependent_fields` is called
+ Then widget choices are left untouched.
+ """
+ instance = mock.MagicMock(pk=1, sub_theme=None, topic=None)
+ form = self._make_form(instance=instance)
+ mock_widget = mock.MagicMock(choices=[])
+ form.fields["sub_theme"] = mock.MagicMock(widget=mock_widget)
+ form.fields["topic"] = mock.MagicMock(widget=mock.MagicMock(choices=[]))
+
+ form._initialize_dependent_fields()
+
+ assert mock_widget.choices == []
+
+ def test_get_field_choices_returns_correct_structure(self):
+ result = TopicPageAdminForm._get_field_choices(42, "Select theme first")
+ assert result == [("", "Select theme first"), (42, "Loading... (ID: 42)")]
+
+
class TestTopicPage:
@pytest.mark.parametrize(
"expected_search_field",
@@ -714,11 +816,117 @@ def test_public_error_raised_if_invalid_classification(
fake_covid_topic_page.is_public = False
fake_covid_topic_page.page_classification = None
+ fake_covid_topic_page.theme = "test"
+ fake_covid_topic_page.sub_theme = "test"
+ fake_covid_topic_page.topic = "test"
+
+ # When/Then
+ with pytest.raises(ValidationError) as e:
+ fake_covid_topic_page.clean()
+
+ assert "Please select a classification level for this non-public page" in str(
+ e.value
+ )
+
+ @mock.patch(
+ "cms.dashboard.models.UKHSAPage._raise_error_if_seo_title_tag_not_provided",
+ return_value=None,
+ )
+ @mock.patch(
+ "cms.dashboard.models.UKHSAPage._raise_error_if_slug_not_unique",
+ return_value=None,
+ )
+ def test_public_error_raised_if_invalid_theme(
+ self,
+ mock_slug_raise_error,
+ mock_seo_title_raise_error,
+ ):
+ """
+ Given is_public is False (i.e the page is a non public page).
+ When no page theme is given.
+ Then a `ValidationError` is raised.
+ """
+ # Given
+ fake_covid_topic_page = FakeTopicPageFactory.build_covid_19_page_from_template()
+
+ fake_covid_topic_page.is_public = False
+ fake_covid_topic_page.page_classification = "test"
+ fake_covid_topic_page.theme = None
+ fake_covid_topic_page.sub_theme = "test"
+ fake_covid_topic_page.topic = "test"
# When/Then
- with pytest.raises(ValidationError):
+ with pytest.raises(ValidationError) as e:
fake_covid_topic_page.clean()
+ assert "Please select a theme for this non-public page" in str(e.value)
+
+ @mock.patch(
+ "cms.dashboard.models.UKHSAPage._raise_error_if_seo_title_tag_not_provided",
+ return_value=None,
+ )
+ @mock.patch(
+ "cms.dashboard.models.UKHSAPage._raise_error_if_slug_not_unique",
+ return_value=None,
+ )
+ def test_public_error_raised_if_invalid_sub_theme(
+ self,
+ mock_slug_raise_error,
+ mock_seo_title_raise_error,
+ ):
+ """
+ Given is_public is False (i.e the page is a non public page).
+ When no page sub theme is given.
+ Then a `ValidationError` is raised.
+ """
+ # Given
+ fake_covid_topic_page = FakeTopicPageFactory.build_covid_19_page_from_template()
+
+ fake_covid_topic_page.is_public = False
+ fake_covid_topic_page.page_classification = "test"
+ fake_covid_topic_page.theme = "test"
+ fake_covid_topic_page.sub_theme = None
+ fake_covid_topic_page.topic = "test"
+
+ # When/Then
+ with pytest.raises(ValidationError) as e:
+ fake_covid_topic_page.clean()
+
+ assert "Please select a sub theme for this non-public page" in str(e.value)
+
+ @mock.patch(
+ "cms.dashboard.models.UKHSAPage._raise_error_if_seo_title_tag_not_provided",
+ return_value=None,
+ )
+ @mock.patch(
+ "cms.dashboard.models.UKHSAPage._raise_error_if_slug_not_unique",
+ return_value=None,
+ )
+ def test_public_error_raised_if_invalid_topic(
+ self,
+ mock_slug_raise_error,
+ mock_seo_title_raise_error,
+ ):
+ """
+ Given is_public is False (i.e the page is a non public page).
+ When no page topic is given.
+ Then a `ValidationError` is raised.
+ """
+ # Given
+ fake_covid_topic_page = FakeTopicPageFactory.build_covid_19_page_from_template()
+
+ fake_covid_topic_page.is_public = False
+ fake_covid_topic_page.page_classification = "test"
+ fake_covid_topic_page.theme = "test"
+ fake_covid_topic_page.sub_theme = "test"
+ fake_covid_topic_page.topic = None
+
+ # When/Then
+ with pytest.raises(ValidationError) as e:
+ fake_covid_topic_page.clean()
+
+ assert "Please select a topic for this non-public page" in str(e.value)
+
@mock.patch(
"cms.dashboard.models.UKHSAPage._raise_error_if_seo_title_tag_not_provided",
return_value=None,
@@ -772,6 +980,9 @@ def test_non_public_page_doesnt_clean_page_classification(
fake_covid_topic_page.is_public = False
fake_covid_topic_page.page_classification = "official"
+ fake_covid_topic_page.theme = "infectious_disease"
+ fake_covid_topic_page.sub_theme = "respiratory"
+ fake_covid_topic_page.topic = "COVID-19"
# When
fake_covid_topic_page.clean()
diff --git a/tests/unit/metrics/api/serializers/test_geographies.py b/tests/unit/metrics/api/serializers/test_geographies.py
index fa9d1583d..108a7567e 100644
--- a/tests/unit/metrics/api/serializers/test_geographies.py
+++ b/tests/unit/metrics/api/serializers/test_geographies.py
@@ -4,7 +4,7 @@
from rest_framework.exceptions import ValidationError
-from auth_content.constants import WILDCARD_ID_VALUE
+from cms.auth_content.constants import WILDCARD_ID_VALUE
from metrics.data.models.core_models.supporting import Geography
from validation.geography_code import UNITED_KINGDOM_GEOGRAPHY_CODE
from metrics.api.serializers.geographies import (
diff --git a/tests/unit/metrics/api/serializers/test_permission_sets.py b/tests/unit/metrics/api/serializers/test_permission_sets.py
index 62ebd2045..7d6e95627 100644
--- a/tests/unit/metrics/api/serializers/test_permission_sets.py
+++ b/tests/unit/metrics/api/serializers/test_permission_sets.py
@@ -3,7 +3,7 @@
import pytest
from rest_framework import serializers as drf_serializers
-from auth_content.constants import WILDCARD_ID_VALUE
+from cms.auth_content.constants import WILDCARD_ID_VALUE
from metrics.api.serializers.permission_sets import (
MetricRequestSerializer,
PermissionSetResponseSerializer,