diff --git a/auth_content/models/__init__.py b/auth_content/models/__init__.py deleted file mode 100644 index 4d59541cb..000000000 --- a/auth_content/models/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from auth_content.models import permission_sets, users diff --git a/auth_content/__init__.py b/cms/auth_content/__init__.py similarity index 100% rename from auth_content/__init__.py rename to cms/auth_content/__init__.py diff --git a/cms/auth_content/auth_utils.py b/cms/auth_content/auth_utils.py new file mode 100644 index 000000000..cb54ffd23 --- /dev/null +++ b/cms/auth_content/auth_utils.py @@ -0,0 +1,298 @@ +import logging +from collections.abc import Callable + +from django import forms + +from cms.auth_content.constants import WILDCARD_ID_VALUE +from cms.dynamic_content import help_texts +from metrics.data.models.core_models.supporting import ( + Geography, + GeographyType, + Metric, + Topic, +) + +logger = logging.getLogger(__name__) + + +def check_permissions_by_name( + permission_sets, + theme_name, + sub_theme_name, + topic_name, + metric_name, + geography_type, + geography_name, +) -> bool: + """ + This is a wrapper that converts permission resource names + into ids. It is only used to check CHART permissions. + """ + + logger.info("Entered check_permissions_by_name()") + + theme_id, sub_theme_id, topic_id = Topic.objects.get_id_by_name( + theme_name, sub_theme_name, topic_name + ) + metric_id = Metric.objects.get_id_by_name(metric_name) + geography_type_id = GeographyType.objects.get_id_by_name(geography_type) + geography_id = Geography.objects.get_id_by_name(geography_name) + + # Be safe, just in case a NAME doesn't have an ID + if any( + value == -2 + for value in ( + theme_id, + sub_theme_id, + topic_id, + metric_id, + geography_type_id, + geography_id, + ) + ): + return False + + return check_permission_set( + permission_sets, + theme_id, + sub_theme_id, + topic_id, + metric_id, + geography_type_id, + geography_id, + ) + + +def check_permission_set( + permission_sets, + theme_id, + sub_theme_id, + topic_id, + metric_id, + geography_type, + geography_id, +) -> bool: + """ + This is a wrapper that only checks for global permissions, and + delegates further checks to our core permission checking function. + It is only used to check CHART permissions. + + @param {dict} permission_sets which contains a permission_sets list, eg: + { + "permission_sets": [ + { + "theme": {"id": "100", "name": "immunisation"}, + "sub_theme": {"id": "200", "name": "childhood-vaccines"}, + "topic": {"id": "-1", "name": "* (All)"}, + "metric": {"id": "-1", "name": "* (All)"}, + "geography_type": {"id": "300", "name": "Nation"}, + "geography": {"id": "-1", "name": "* (All)"}, + } + ], + "summary": {"has_global_access": False}, + } + """ + + logger.info("Entered check_permission_set()") + + if not isinstance(permission_sets, dict): + return False + if not isinstance(permission_sets.get("permission_sets"), list): + return False + if not isinstance(permission_sets.get("summary"), dict): + return False + if not isinstance(permission_sets.get("summary").get("has_global_access"), bool): + return False + + if permission_sets.get("summary").get("has_global_access"): + return True + + return check_permissions( + permission_sets.get("permission_sets"), + theme_id, + sub_theme_id, + topic_id, + metric_id, + geography_type, + geography_id, + ) + + +def check_permissions( + permission_sets, + theme_id, + sub_theme_id, + topic_id, + metric_id=None, + geography_type=None, + geography_id=None, +) -> bool: + """ + This is our core permission-checking function It is + used to check both PAGE & CHART permissions. + + Metric- and geography-related permissions must be + evaluated separately (spec says). + + @param {list} permission_sets, eg: + [ + { + "theme": {"id": "100", "name": "immunisation"}, + "sub_theme": {"id": "200", "name": "childhood-vaccines"}, + "topic": {"id": "-1", "name": "* (All)"}, + "metric": {"id": "-1", "name": "* (All)"}, + "geography_type": {"id": "300", "name": "Nation"}, + "geography": {"id": "-1", "name": "* (All)"}, + } + ] + """ + + logger.info("Entered check_permissions()") + + if not isinstance(permission_sets, list): + return False + + for permission_set in permission_sets: + if geography_type and geography_id: + # CHART permissions + if check_metric_related_permissions( + permission_set, theme_id, sub_theme_id, topic_id, metric_id + ) and check_geography_permissions( + permission_set, geography_type, geography_id + ): + return True + else: + # PAGE permissions + if check_metric_related_permissions( + permission_set, theme_id, sub_theme_id, topic_id, metric_id + ): + return True + + return False + + +def check_metric_related_permissions( + permission_set, + theme_id, + sub_theme_id, + topic_id, + metric_id=None, +) -> bool: + """ + Make sure that every theme, sub_theme, topic and metric + match or have a wildcard at the end (only look at the + first 4 attributes of permission_set). + + @param {dict} permission_set, eg: + { + "theme": {"id": "100", "name": "immunisation"}, + "sub_theme": {"id": "200", "name": "childhood-vaccines"}, + "topic": {"id": "-1", "name": "* (All)"}, + "metric": {"id": "-1", "name": "* (All)"}, + "geography_type": {"id": "300", "name": "Nation"}, + "geography": {"id": "-1", "name": "* (All)"}, + } + """ + + logger.info("Entered check_metric_related_permissions()") + + if not isinstance(permission_set, dict): + return False + + theme_id = str(theme_id) + sub_theme_id = str(sub_theme_id) + topic_id = str(topic_id) + metric_id = str(metric_id) + + permission_theme_id = str(permission_set.get("theme", {}).get("id")) + permission_sub_theme_id = str(permission_set.get("sub_theme", {}).get("id")) + permission_topic_id = str(permission_set.get("topic", {}).get("id")) + permission_metric_id = str(permission_set.get("metric", {}).get("id")) + + if permission_theme_id == WILDCARD_ID_VALUE: + return True + + if permission_theme_id == theme_id and permission_sub_theme_id == WILDCARD_ID_VALUE: + return True + + if ( + permission_theme_id == theme_id + and permission_sub_theme_id == sub_theme_id + and permission_topic_id in {WILDCARD_ID_VALUE, topic_id} + ): + return True + + if ( + permission_theme_id == theme_id + and permission_sub_theme_id == sub_theme_id + and permission_topic_id == topic_id + and permission_metric_id in {WILDCARD_ID_VALUE, metric_id} + ): + return True + + return False + + +def check_geography_permissions( + permission_set, + geography_type=None, + geography_id=None, +) -> bool: + """ + Make sure that both geography_type and geography + match or have a wildcard at the end (only look at the + first 2 attributes of permission_set). + + @param {dict} permission_set, eg: + { + "theme": {"id": "100", "name": "immunisation"}, + "sub_theme": {"id": "200", "name": "childhood-vaccines"}, + "topic": {"id": "-1", "name": "* (All)"}, + "metric": {"id": "-1", "name": "* (All)"}, + "geography_type": {"id": "300", "name": "Nation"}, + "geography": {"id": "-1", "name": "* (All)"}, + } + """ + + logger.info("Entered check_geography_permissions()") + + if not isinstance(permission_set, dict): + return False + + geography_type = str(geography_type) + geography_id = str(geography_id) + + permission_geography_type = str(permission_set.get("geography_type", {}).get("id")) + permission_geography_id = str(permission_set.get("geography", {}).get("id")) + + if permission_geography_type == WILDCARD_ID_VALUE: + return True + + if permission_geography_type == geography_type and permission_geography_id in { + WILDCARD_ID_VALUE, + geography_id, + }: + return True + + return False + + +def _create_form_field( + field: dict[str, str | Callable | None], wildcard_id_value=None +) -> forms.CharField: + choices = [ + ("", field["field_choice_default"]), + ] + + if field["field_choice_wildcard"]: + choices += [(wildcard_id_value, field["field_choice_wildcard"])] + + if field["field_choice_callable"]: + choices += field["field_choice_callable"]() + + return forms.CharField( + required=False, + label=field["field_label"], + widget=forms.Select(choices=choices), + help_text=help_texts.NON_PUBLIC_PAGE_REQUIRED, + ) diff --git a/auth_content/constants.py b/cms/auth_content/constants.py similarity index 100% rename from auth_content/constants.py rename to cms/auth_content/constants.py diff --git a/auth_content/migrations/0001_initial.py b/cms/auth_content/migrations/0001_initial.py similarity index 100% rename from auth_content/migrations/0001_initial.py rename to cms/auth_content/migrations/0001_initial.py diff --git a/auth_content/migrations/0002_alter_permissionset_geography_type_and_more.py b/cms/auth_content/migrations/0002_alter_permissionset_geography_type_and_more.py similarity index 100% rename from auth_content/migrations/0002_alter_permissionset_geography_type_and_more.py rename to cms/auth_content/migrations/0002_alter_permissionset_geography_type_and_more.py diff --git a/cms/auth_content/migrations/0003_permissionset_display_name_and_more.py b/cms/auth_content/migrations/0003_permissionset_display_name_and_more.py new file mode 100644 index 000000000..1b66d2218 --- /dev/null +++ b/cms/auth_content/migrations/0003_permissionset_display_name_and_more.py @@ -0,0 +1,31 @@ +# Generated by Django 5.2.13 on 2026-05-22 08:25 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("auth_content", "0002_alter_permissionset_geography_type_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="permissionset", + name="display_name", + field=models.CharField( + blank=True, + help_text="\nThis is an (optional) user readable name for the permission set. If not set, a default autogenerated name will be used.\n", + max_length=255, + null=True, + ), + ), + migrations.AddConstraint( + model_name="permissionset", + constraint=models.UniqueConstraint( + condition=models.Q(("display_name__isnull", False)), + fields=("display_name",), + name="unique_non_null_display_name", + ), + ), + ] diff --git a/auth_content/migrations/__init__.py b/cms/auth_content/migrations/__init__.py similarity index 100% rename from auth_content/migrations/__init__.py rename to cms/auth_content/migrations/__init__.py diff --git a/cms/auth_content/models/__init__.py b/cms/auth_content/models/__init__.py new file mode 100644 index 000000000..93449c3ce --- /dev/null +++ b/cms/auth_content/models/__init__.py @@ -0,0 +1,2 @@ +from cms.auth_content.models import users +from cms.auth_content.models import permission_sets diff --git a/auth_content/models/permission_sets.py b/cms/auth_content/models/permission_sets.py similarity index 87% rename from auth_content/models/permission_sets.py rename to cms/auth_content/models/permission_sets.py index b3f70937e..ba5e2ca8b 100644 --- a/auth_content/models/permission_sets.py +++ b/cms/auth_content/models/permission_sets.py @@ -1,13 +1,13 @@ -from collections.abc import Callable from itertools import starmap -from django import forms from django.core.exceptions import ValidationError from django.db import models -from wagtail.admin.forms import WagtailAdminModelForm +from wagtail.admin.forms import WagtailAdminPageForm from wagtail.admin.panels import FieldPanel, mark_safe -from auth_content.constants import PERMISSION_SET_FIELDS, WILDCARD_ID_VALUE +from cms.auth_content.auth_utils import _create_form_field +from cms.auth_content.constants import PERMISSION_SET_FIELDS, WILDCARD_ID_VALUE +from cms.dynamic_content import help_texts from cms.metrics_interface.field_choices_callables import ( get_all_geography_names_and_codes, get_all_geography_type_names_and_ids, @@ -18,41 +18,14 @@ ) -def get_theme_child_map(): - """Returns an object of all parent to child mappings - e.g. - { - infectious_disease: [vaccine_preventable, respiratory ....], - extreme_event: [weather_alert, mortality_report...] - ... - } - - """ - return {} - - -def _create_form_field(field: dict[str, str | Callable | None]) -> forms.CharField: - choices = [ - ("", field["field_choice_default"]), - ] - - if field["field_choice_wildcard"]: - choices += [(WILDCARD_ID_VALUE, field["field_choice_wildcard"])] - - if field["field_choice_callable"]: - choices += field["field_choice_callable"]() - - return forms.CharField( - required=True, label=field["field_label"], widget=forms.Select(choices=choices) - ) - - -class PermissionSetForm(WagtailAdminModelForm): +class PermissionSetForm(WagtailAdminPageForm): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) for field in PERMISSION_SET_FIELDS: - self.fields[field["field_name"]] = _create_form_field(field) + self.fields[field["field_name"]] = _create_form_field( + field, WILDCARD_ID_VALUE + ) if self.instance and self.instance.pk: self._initialize_dependent_fields() @@ -110,6 +83,9 @@ def clean(self): return cleaned_data + class Media: + js = ["js/permission_set.js"] + class PermissionSet(models.Model): name = models.CharField( @@ -118,6 +94,12 @@ class PermissionSet(models.Model): editable=False, help_text="Auto-generated display name", ) + display_name = models.CharField( + max_length=255, + blank=True, + null=True, + help_text=help_texts.PERMISSION_SET_DISPLAY_NAME, + ) theme = models.CharField(max_length=255, blank=False, default="") sub_theme = models.CharField(max_length=255, blank=False, default="") topic = models.CharField(max_length=255, blank=False, default="") @@ -133,6 +115,7 @@ def permission_set_details(self): return mark_safe("
".join(parts)) panels = [ + FieldPanel("display_name"), FieldPanel("theme"), FieldPanel("sub_theme"), FieldPanel("topic"), @@ -153,7 +136,12 @@ class Meta: "geography", ], name="unique_permission_set", - ) + ), + models.UniqueConstraint( + fields=["display_name"], + condition=models.Q(display_name__isnull=False), + name="unique_non_null_display_name", + ), ] def save(self, *args, **kwargs): @@ -241,4 +229,4 @@ def _find_label_in_choices(choices: list[tuple], value: str) -> str: ) def __str__(self): - return self.name or f"Permission Set {self.id}" + return self.display_name or self.name or f"Permission Set {self.id}" diff --git a/auth_content/models/users.py b/cms/auth_content/models/users.py similarity index 100% rename from auth_content/models/users.py rename to cms/auth_content/models/users.py diff --git a/auth_content/static/js/permission_set.js b/cms/auth_content/static/js/permission_set.js similarity index 100% rename from auth_content/static/js/permission_set.js rename to cms/auth_content/static/js/permission_set.js diff --git a/auth_content/wagtail_hooks.py b/cms/auth_content/wagtail_hooks.py similarity index 77% rename from auth_content/wagtail_hooks.py rename to cms/auth_content/wagtail_hooks.py index 9a60e2fd2..77852a1d1 100644 --- a/auth_content/wagtail_hooks.py +++ b/cms/auth_content/wagtail_hooks.py @@ -1,5 +1,3 @@ -from django.templatetags.static import static -from django.utils.html import format_html from wagtail import hooks from wagtail.admin.viewsets.model import ( ModelPermissionPolicy, @@ -7,8 +5,8 @@ ModelViewSetGroup, ) -from auth_content.models.permission_sets import PermissionSet -from auth_content.models.users import User +from cms.auth_content.models.permission_sets import PermissionSet +from cms.auth_content.models.users import User class NoEditPermissionPolicy(ModelPermissionPolicy): @@ -48,8 +46,3 @@ class AuthGroup(ModelViewSetGroup): @hooks.register("register_admin_viewset") def register_auth_viewset(): return AuthGroup() - - -@hooks.register("insert_editor_js") -def permission_set_js(): - return format_html('', static("js/permission_set.js")) diff --git a/cms/dashboard/constants.py b/cms/dashboard/constants.py new file mode 100644 index 000000000..6eed766d7 --- /dev/null +++ b/cms/dashboard/constants.py @@ -0,0 +1,35 @@ +from cms.metrics_interface.field_choices_callables import ( + get_all_metric_names_and_ids, + get_all_theme_names_and_ids, +) + +THEME_FIELDS = [ + { + "field_name": "theme", + "field_label": "Theme", + "field_choice_default": "----------", + "field_choice_wildcard": None, + "field_choice_callable": get_all_theme_names_and_ids, + }, + { + "field_name": "sub_theme", + "field_label": "Sub Theme", + "field_choice_default": "Select theme first", + "field_choice_wildcard": None, + "field_choice_callable": None, + }, + { + "field_name": "topic", + "field_label": "Topic", + "field_choice_default": "Select sub-theme first", + "field_choice_wildcard": None, + "field_choice_callable": None, + }, + { + "field_name": "metric", + "field_label": "Metric", + "field_choice_default": "Select topic first", + "field_choice_wildcard": None, + "field_choice_callable": get_all_metric_names_and_ids, + }, +] diff --git a/cms/dashboard/static/js/classification_toggle.js b/cms/dashboard/static/js/classification_toggle.js deleted file mode 100644 index 909e4a3af..000000000 --- a/cms/dashboard/static/js/classification_toggle.js +++ /dev/null @@ -1,28 +0,0 @@ -;(function () { - function toggleClassification() { - /* - When the is_public box is checked, this will clear any selected page_classification, - and disable the field. If the is_public box is then unchecked, it will re-enable the field - */ - const isPublicCheckbox = document.querySelector('input[name="is_public"]') - const classificationField = document.querySelector( - 'select[name="page_classification"]', - ) - - if (!isPublicCheckbox || !classificationField) return - - if (isPublicCheckbox.checked) { - classificationField.value = "" - classificationField.disabled = true - } else { - classificationField.disabled = false - } - } - - document.addEventListener("DOMContentLoaded", toggleClassification) - document.addEventListener("change", function (e) { - if (e.target.name === "is_public") { - toggleClassification() - } - }) -})() diff --git a/cms/dashboard/static/js/toggle_available_fields_on_is_public.js b/cms/dashboard/static/js/toggle_available_fields_on_is_public.js new file mode 100644 index 000000000..8ad1f004e --- /dev/null +++ b/cms/dashboard/static/js/toggle_available_fields_on_is_public.js @@ -0,0 +1,276 @@ +;(function () { + "use strict" + let theme, subTheme, topic, metric, isPublicCheckbox; + let originalMetricOptions; + + function toggleAvailableFields() { + /* + When the is_public box is checked, this will clear any selected page_classification, + and disable the field. If the is_public box is then unchecked, it will re-enable the field + */ + + const fields = { + classification: document.querySelector( + 'select[name="page_classification"]', + ), + theme: theme, + subTheme: subTheme, + topic: topic, + // metric: metric, + } + + if (isPublicCheckbox.checked) { + Object.values(fields).forEach(disableField) + clearDropdown(fields.subTheme, "Select theme first") + clearDropdown(fields.topic, "Select sub-theme first") + // clearDropdown(fields.metric, "Select topic first") + restoreMetricOptions() + fields.theme.value = "" + } else { + if (!theme.value && !subTheme.value && !topic.value) { + clearDropdown(metric, "Select topic first") + } + Object.values(fields).forEach(enableField) + fields.classification.value="official_sensitive" + } + } + + function restoreMetricOptions() { + clearDropdown(metric, "----------") + originalMetricOptions.forEach(option => { + if (option.text !== "Select topic first") { + metric.appendChild(option.cloneNode(true)); + } + }); + } + + function disableField(field) { + field.disabled = true + } + + function enableField(field) { + field.disabled = false + } + + /** + * Generic function to fetch choices from the API + * @param {string} endpoint - The API endpoint (e.g., 'subthemes', 'topics') + * @param {string} dataItemId - The ID value to pass + * @returns {Promise} Array of choices [[id, name], ...] + */ + async function fetchChoices(endpoint, dataItemId) { + try { + const url = `/api/data-hierarchy/${endpoint}/${dataItemId}` + const response = await fetch(url) + + if (!response.ok) { + const errorData = await response.json() + console.error(`API error: ${errorData.error || "Unknown error"}`) + return [] + } + + const data = await response.json() + return data.choices || [] + } catch (error) { + console.error(`Error fetching ${endpoint}:`, error) + return [] + } + } + + /** + * Generic function to populate a dropdown with choices + * @param {HTMLSelectElement} dropdown - The select element to populate + * @param {Array} choices - Array of [id, name] tuples + */ + function populateDropdown(dropdown, choices, metrics = null) { + const currentValue = dropdown.value + dropdown.disabled = false + dropdown.innerHTML = "" + + //dropdown empty + const nullOption = document.createElement("option") + nullOption.value = "" + nullOption.textContent = "--------" + dropdown.appendChild(nullOption) + + choices.forEach(([id, name]) => { + const option = document.createElement("option") + option.value = id + option.textContent = name + dropdown.appendChild(option) + }) + + if (currentValue) { + dropdown.value = currentValue + } + } + + function clearDropdown(dropdown, message = "Select parent first") { + dropdown.innerHTML = "" + + const option = document.createElement("option") + option.value = "" + option.textContent = message + dropdown.appendChild(option) + + dropdown.value = "" + } + + /** + * Handle theme selection change + */ + async function handleThemeChange() { + const themeValue = theme.value + + // Clear all dependent dropdowns + if (!themeValue || themeValue === "") { + clearDropdown(subTheme, "Select theme first") + clearDropdown(topic, "Select sub-theme first") + clearDropdown(metric, "Select topic first"); + return + } + + clearDropdown(subTheme, "Select theme") + clearDropdown(topic, "Select sub-theme") + clearDropdown(metric, "Select topic first"); + + // Fetch and populate sub-themes + const choices = await fetchChoices("subthemes", themeValue) + + if (choices.length > 0) { + populateDropdown(subTheme, choices) + } else { + clearDropdown(subTheme, "No sub-themes available") + } + } + + /** + * Handle sub-theme selection change + */ + async function handleSubThemeChange() { + const subThemeValue = subTheme.value + + if (!subThemeValue || subThemeValue === "") { + // No sub-theme selected - clear children + clearDropdown(topic, "Select sub-theme first") + return + } + + // Clear dependent dropdowns + clearDropdown(topic, "Select sub-theme") + clearDropdown(metric, "Select topic first"); + + // Fetch and populate topics + const choices = await fetchChoices("topics", subThemeValue) + + if (choices.length > 0) { + populateDropdown(topic, choices) + } else { + clearDropdown(topic, "No topics available") + } + } + + /** + * Handle topic selection change + */ + async function handleTopicChange() { + const topicValue = topic.value; + + if (!topicValue || topicValue === "") { + // No topic selected - clear metrics + clearDropdown(metric, "Select topic first"); + return; + } + + clearDropdown(metric, "--------"); + + // Fetch and populate metrics + const choices = await fetchChoices("metrics", topicValue); + + if (choices.length > 0) { + populateDropdown(metric, choices, "* All metrics"); + } else { + clearDropdown(metric, "No metrics available"); + } + } + + /** + * Initialize dropdowns for edit mode + * Loads the dropdown options based on saved values + */ + async function initializeEditMode() { + // Store original values before we start manipulating dropdowns + const savedTheme = theme.value + const savedSubTheme = subTheme.value + const savedTopic = topic.value + const savedMetric = metric ? metric.value : undefined + + // If theme has a value (not empty), load sub-themes + if (savedTheme && savedTheme !== "") { + const subThemeChoices = await fetchChoices("subthemes", savedTheme) + if (subThemeChoices.length > 0) { + populateDropdown(subTheme, subThemeChoices) + subTheme.value = savedSubTheme // Restore selection + } + + // If sub-theme has a value, load topics + if (savedSubTheme && savedSubTheme !== "") { + const topicChoices = await fetchChoices("topics", savedSubTheme) + if (topicChoices.length > 0) { + populateDropdown(topic, topicChoices) + topic.value = savedTopic // Restore selection + } + + if (savedTopic && savedTopic !== "") { + const metricChoices = await fetchChoices("metrics", savedTopic) + if (metricChoices.length > 0) { + populateDropdown(metric, metricChoices) + metric.value = savedMetric // Restore selection + } + } + } + } + } + + function initialize() { + // Get dropdown elements + + isPublicCheckbox = document.querySelector('input[name="is_public"]') + theme = document.querySelector('select[name="theme"]') + subTheme = document.querySelector('select[name="sub_theme"]') + topic = document.querySelector('select[name="topic"]') + metric = document.querySelector('select[name="metric"]') + // Take a copy of all available metrics so they can be restored if this becomes a public page + originalMetricOptions = Array.from(metric.options).map(option => option.cloneNode(true)); + + // Exit if not on page with themes and is_public toggle + if (!theme || !subTheme || !topic || !isPublicCheckbox) { + console.error("No theme dropdowns found on this page") + return + } + + toggleAvailableFields() + + // Add event listeners + theme.addEventListener("change", handleThemeChange) + subTheme.addEventListener("change", handleSubThemeChange) + topic.addEventListener("change", handleTopicChange); + + const isEditMode = theme.value || subTheme.value || topic.value || metric.value + + if (isEditMode) { + initializeEditMode() + } else { + clearDropdown(subTheme, "Select theme first") + clearDropdown(topic, "Select sub-theme first") + clearDropdown(metric, "Select topic first"); + } + } + + document.addEventListener("DOMContentLoaded", initialize) + document.addEventListener("change", function (e) { + if (e.target.name === "is_public") { + toggleAvailableFields() + } + }) +})() diff --git a/cms/dashboard/viewsets.py b/cms/dashboard/viewsets.py index 949c50d18..89b2c1ddc 100644 --- a/cms/dashboard/viewsets.py +++ b/cms/dashboard/viewsets.py @@ -1,3 +1,6 @@ +from itertools import chain + +from django.db.models import Exists, OuterRef, Q from django.urls import path from django.urls.resolvers import RoutePattern from drf_spectacular.utils import extend_schema @@ -6,11 +9,15 @@ from wagtail.api.v2.views import PagesAPIViewSet from caching.private_api.decorators import cache_response +from cms.auth_content.auth_utils import check_permissions from cms.dashboard.serializers import CMSDraftPagesSerializer, ListablePageSerializer +from cms.metrics_documentation.models.child import MetricsDocumentationChildEntry +from cms.topic.models import TopicPage @extend_schema(tags=["cms"]) class CMSPagesAPIViewSet(PagesAPIViewSet): + # This is the /pages (or proxy/pages env dependent endpoint) permission_classes = [] base_serializer_class = ListablePageSerializer listing_default_fields = PagesAPIViewSet.listing_default_fields + ["show_in_menus"] @@ -39,7 +46,69 @@ def get_queryset(self): """ queryset = super().get_queryset() - return queryset.specific() + + req = self.request + + if req.auth is None: + filtered_queryset = queryset.annotate( + is_public_topic_page=Exists( + TopicPage.objects.filter( + page_ptr_id=OuterRef("pk"), + is_public=True, + ) + ), + is_public_metrics_doc_child_page=Exists( + MetricsDocumentationChildEntry.objects.filter( + page_ptr_id=OuterRef("pk"), + is_public=True, + ) + ), + ).filter( + Q(is_public_topic_page=True) + | Q(is_public_metrics_doc_child_page=True) + | ~Q( + content_type__model__in=[ + "topicpage", + "metricsdocumentationchildentry", + ] + ) + ) + + else: + has_global_access = req.user.permission_sets["summary"]["has_global_access"] + + if has_global_access: + filtered_queryset = queryset + + else: + user_permissions = req.user.permission_sets + pages_to_check = chain( + ((page.id, page.topicpage) for page in queryset.type(TopicPage)), + ( + (page.id, page.metricsdocumentationchildentry) + for page in queryset.type(MetricsDocumentationChildEntry) + ), + ) + allowed_page_ids = [ + page_id + for page_id, page in pages_to_check + if page.is_public + or check_permissions( + user_permissions, + page.theme, + page.sub_theme, + page.topic, + ) + ] + + public_pages = queryset.not_type( + TopicPage, MetricsDocumentationChildEntry + ) + permitted_private_pages = queryset.filter(id__in=allowed_page_ids) + + filtered_queryset = public_pages | permitted_private_pages + + return filtered_queryset.specific() @cache_response() def listing_view(self, request: Request) -> Response: diff --git a/cms/dynamic_content/help_texts.py b/cms/dynamic_content/help_texts.py index 389f8ac2c..6d959eefb 100644 --- a/cms/dynamic_content/help_texts.py +++ b/cms/dynamic_content/help_texts.py @@ -626,6 +626,10 @@ The classification level of all data on this page (only applies to non-public pages). Defaults to `Official-Sensitive`. """ +NON_PUBLIC_PAGE_REQUIRED: str = """ +This field is required for a non-public page. +""" + SECTION_FOOTER_BLOCKS: str = """ This is an optional footer for content sections to allow additional supporting information to be linked too. (E.g. a link to furhter information about how we define an outbreak) """ @@ -641,3 +645,6 @@ SECTION_FOOTER_LINK: str = """ This is a link component that allows the user to setup an internal or external link along with a short description of the link's content. """ +PERMISSION_SET_DISPLAY_NAME: str = """ +This is an (optional) user readable name for the permission set. If not set, a default autogenerated name will be used. +""" diff --git a/cms/metrics_documentation/data_migration/child_entries.py b/cms/metrics_documentation/data_migration/child_entries.py index 8aa39c00f..3d08bc33a 100644 --- a/cms/metrics_documentation/data_migration/child_entries.py +++ b/cms/metrics_documentation/data_migration/child_entries.py @@ -37,10 +37,14 @@ def build_entry_from_row_data(*, row: tuple[str, ...]) -> dict[str, str | list[d """ title: str = row[0] page_description: str = row[4] - metric: str = row[1] + metric = row[7] + topic = row[1].split("_")[0] sections: list[tuple[str, str]] = gather_sections_and_omit_if_needed(row=row) return { "title": title, + "topic": topic, + "theme": "test", + "sub_theme": "test", "seo_title": f"{title} | UKHSA data dashboard", "search_description": page_description, "page_description": page_description, diff --git a/cms/metrics_documentation/data_migration/source_data/metrics_definitions_migration_edit.xlsx b/cms/metrics_documentation/data_migration/source_data/metrics_definitions_migration_edit.xlsx index 3efd453c9..6d71013d1 100644 Binary files a/cms/metrics_documentation/data_migration/source_data/metrics_definitions_migration_edit.xlsx and b/cms/metrics_documentation/data_migration/source_data/metrics_definitions_migration_edit.xlsx differ diff --git a/cms/metrics_documentation/migrations/0016_metricsdocumentationchildentry_sub_theme_and_more.py b/cms/metrics_documentation/migrations/0016_metricsdocumentationchildentry_sub_theme_and_more.py new file mode 100644 index 000000000..7c663a527 --- /dev/null +++ b/cms/metrics_documentation/migrations/0016_metricsdocumentationchildentry_sub_theme_and_more.py @@ -0,0 +1,26 @@ +# Generated by Django 5.2.13 on 2026-04-29 10:23 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ( + "metrics_documentation", + "0015_alter_metricsdocumentationchildentry_page_classification", + ), + ] + + operations = [ + migrations.AddField( + model_name="metricsdocumentationchildentry", + name="sub_theme", + field=models.CharField(blank=True, default="", max_length=255, null=True), + ), + migrations.AddField( + model_name="metricsdocumentationchildentry", + name="theme", + field=models.CharField(blank=True, default="", max_length=255, null=True), + ), + ] diff --git a/cms/metrics_documentation/migrations/0017_alter_metricsdocumentationchildentry_topic.py b/cms/metrics_documentation/migrations/0017_alter_metricsdocumentationchildentry_topic.py new file mode 100644 index 000000000..5d23bdbd8 --- /dev/null +++ b/cms/metrics_documentation/migrations/0017_alter_metricsdocumentationchildentry_topic.py @@ -0,0 +1,21 @@ +# Generated by Django 5.2.13 on 2026-05-05 09:30 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ( + "metrics_documentation", + "0016_metricsdocumentationchildentry_sub_theme_and_more", + ), + ] + + operations = [ + migrations.AlterField( + model_name="metricsdocumentationchildentry", + name="topic", + field=models.CharField(blank=True, default="", max_length=255, null=True), + ), + ] diff --git a/cms/metrics_documentation/models/child.py b/cms/metrics_documentation/models/child.py index 9dada0e7d..618eff3fd 100644 --- a/cms/metrics_documentation/models/child.py +++ b/cms/metrics_documentation/models/child.py @@ -12,13 +12,14 @@ from wagtail.api import APIField from wagtail.search import index +from cms.auth_content.auth_utils import _create_form_field +from cms.dashboard.constants import THEME_FIELDS from cms.dashboard.models import DataClassificationLevels, UKHSAPage from cms.dynamic_content import help_texts from cms.dynamic_content.access import ALLOWABLE_BODY_CONTENT_TEXT_SECTION from cms.dynamic_content.announcements import Announcement from cms.metrics_interface.field_choices_callables import ( - get_a_list_of_all_topic_names, - get_all_unique_metric_names, + get_all_metric_names_and_ids, ) logger = logging.getLogger(__name__) @@ -31,8 +32,35 @@ def __init__(self, topic: str, metric: str): class MetricsDocumentationChildEntryAdminForm(WagtailAdminPageForm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + for field in THEME_FIELDS: + self.fields[field["field_name"]] = _create_form_field(field) + + if self.instance and self.instance.pk: + self._initialize_dependent_fields() + + def _initialize_dependent_fields(self): + """Initialize choices for cascading dependent fields""" + dependent_fields = { + "sub_theme": ("Select theme first"), + "topic": ("Select sub-theme first"), + } + + for field_name, (placeholder) in dependent_fields.items(): + value = getattr(self.instance, field_name, None) + if value: + choices = self._get_field_choices(value, placeholder) + self.fields[field_name].widget.choices = choices + + @staticmethod + def _get_field_choices(value, placeholder): + """Generate choices list based on field value""" + return [("", placeholder), (value, f"Loading... (ID: {value})")] + class Media: - js = ["js/classification_toggle.js"] + js = ["js/toggle_available_fields_on_is_public.js"] class MetricsDocumentationChildEntry(UKHSAPage): @@ -51,24 +79,36 @@ class MetricsDocumentationChildEntry(UKHSAPage): null=True, blank=True, ) - topic = models.CharField( + + theme = models.CharField( max_length=255, + blank=True, default="", + null=True, ) + sub_theme = models.CharField( + max_length=255, + blank=True, + default="", + null=True, + ) + topic = models.CharField(max_length=255, blank=True, default="", null=True) body = ALLOWABLE_BODY_CONTENT_TEXT_SECTION # Fields to index for searching within the CMS application. search_fields = UKHSAPage.search_fields + [ - index.SearchField("metric"), index.SearchField("body"), ] # Content panels to render for editing within the CMS application. content_panels = UKHSAPage.content_panels + [ FieldPanel("page_description"), - FieldPanel("metric"), FieldPanel("is_public"), FieldPanel("page_classification"), + FieldPanel("theme"), + FieldPanel("sub_theme"), + FieldPanel("topic"), + FieldPanel("metric"), FieldPanel("body"), ] @@ -113,42 +153,7 @@ def __init__(self, *args, **kwargs): load in the names dynamically from the metrics interface. """ super().__init__(*args, **kwargs) - self._meta.get_field("metric").choices = get_all_unique_metric_names() - - def find_topic(self, *, topics: list[str]) -> str: - """Finds the required topic from a list of strings based on the metric name. - - Args: - topics: list of strings representing topic names. - - Returns: - A string of the topic checked against the models metric value. - - Raises: - `InvalidTopicForChosenMetricForChildEntry`: If the - selected metric cannot be matched to a `Topic` - - """ - extracted_topic = self.metric.split("_")[0].lower() - try: - return next(topic for topic in topics if extracted_topic == topic.lower()) - except StopIteration as error: - logger.info( - "StopIteration Error: extracted topic not present in the topics list. %s", - extracted_topic, - ) - raise InvalidTopicForChosenMetricForChildEntryError( - topic=extracted_topic, metric=self.metric - ) from error - - def get_topic(self) -> str: - """Finds the required topic name based on the selected metric name. - - Returns: - a topic name as a string - """ - topics = get_a_list_of_all_topic_names() - return self.find_topic(topics=topics) + self._meta.get_field("metric").choices = get_all_metric_names_and_ids() def save(self, *args, **kwargs): """Retrieves a topic based on the selected metric @@ -156,12 +161,22 @@ def save(self, *args, **kwargs): Notes: This method will not be called when using `bulk_create()` """ - self.topic = self.get_topic() super().save(*args, **kwargs) @property def metric_group(self) -> str: - return self.metric.split("_")[1] + field = self._meta.get_field("metric") + choices = getattr(field, "choices", []) or [] + + display_name = next( + (item[1] for item in choices if item[0] == self.metric), None + ) + + if not display_name or "_" not in display_name: + return "" + + parts = display_name.split("_") + return parts[1] if len(parts) > 1 else "" def clean(self): super().clean() @@ -169,13 +184,29 @@ def clean(self): # If is_public is true, automatically clear classification if self.is_public: self.page_classification = None - # If not public page, classification must be chosen + self.theme = None + self.sub_theme = None + self.topic = None + + # If not public page, non-public fields must be set elif not self.page_classification: raise ValidationError( { "page_classification": "Please select a classification level for this non-public page" } ) + elif not self.theme: + raise ValidationError( + {"theme": "Please select a theme for this non-public page"} + ) + elif not self.sub_theme: + raise ValidationError( + {"sub_theme": "Please select a subtheme for this non-public page"} + ) + elif not self.topic: + raise ValidationError( + {"topic": "Please select a topic for this non-public page"} + ) class MetricsDocumentationChildPageAnnouncement(Announcement): diff --git a/cms/topic/migrations/0032_topicpage_sub_theme_topicpage_theme_topicpage_topic.py b/cms/topic/migrations/0032_topicpage_sub_theme_topicpage_theme_topicpage_topic.py new file mode 100644 index 000000000..ff9b0ef09 --- /dev/null +++ b/cms/topic/migrations/0032_topicpage_sub_theme_topicpage_theme_topicpage_topic.py @@ -0,0 +1,46 @@ +# Generated by Django 5.2.13 on 2026-04-29 10:23 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("topic", "0031_alter_topicpage_page_classification"), + ] + + operations = [ + migrations.AddField( + model_name="topicpage", + name="sub_theme", + field=models.CharField( + blank=True, + default="", + help_text="\nThe subtheme must be provided for a non-public page.\n", + max_length=255, + null=True, + ), + ), + migrations.AddField( + model_name="topicpage", + name="theme", + field=models.CharField( + blank=True, + default="", + help_text="\nThe theme must be provided for a non-public page.\n", + max_length=255, + null=True, + ), + ), + migrations.AddField( + model_name="topicpage", + name="topic", + field=models.CharField( + blank=True, + default="", + help_text="\nThe topic must be provided for a non-public page.\n", + max_length=255, + null=True, + ), + ), + ] diff --git a/cms/topic/migrations/0033_alter_topicpage_sub_theme_alter_topicpage_theme_and_more.py b/cms/topic/migrations/0033_alter_topicpage_sub_theme_alter_topicpage_theme_and_more.py new file mode 100644 index 000000000..bf95c455e --- /dev/null +++ b/cms/topic/migrations/0033_alter_topicpage_sub_theme_alter_topicpage_theme_and_more.py @@ -0,0 +1,28 @@ +# Generated by Django 5.2.13 on 2026-04-29 10:54 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("topic", "0032_topicpage_sub_theme_topicpage_theme_topicpage_topic"), + ] + + operations = [ + migrations.AlterField( + model_name="topicpage", + name="sub_theme", + field=models.CharField(blank=True, default="", max_length=255, null=True), + ), + migrations.AlterField( + model_name="topicpage", + name="theme", + field=models.CharField(blank=True, default="", max_length=255, null=True), + ), + migrations.AlterField( + model_name="topicpage", + name="topic", + field=models.CharField(blank=True, default="", max_length=255, null=True), + ), + ] diff --git a/cms/topic/models.py b/cms/topic/models.py index 9e687cc38..f39ecd3d9 100644 --- a/cms/topic/models.py +++ b/cms/topic/models.py @@ -14,6 +14,8 @@ from wagtail.fields import RichTextField from wagtail.search import index +from cms.auth_content.auth_utils import _create_form_field +from cms.dashboard.constants import THEME_FIELDS from cms.dashboard.enums import ( DEFAULT_RELATED_LINKS_LAYOUT_FIELD_LENGTH, RelatedLinksLayoutEnum, @@ -36,8 +38,35 @@ class TopicPageAdminForm(WagtailAdminPageForm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + for field in THEME_FIELDS: + self.fields[field["field_name"]] = _create_form_field(field) + + if self.instance and self.instance.pk: + self._initialize_dependent_fields() + + def _initialize_dependent_fields(self): + """Initialize choices for cascading dependent fields""" + dependent_fields = { + "sub_theme": ("Select theme first"), + "topic": ("Select sub-theme first"), + } + + for field_name, (placeholder) in dependent_fields.items(): + value = getattr(self.instance, field_name, None) + if value: + choices = self._get_field_choices(value, placeholder) + self.fields[field_name].widget.choices = choices + + @staticmethod + def _get_field_choices(value, placeholder): + """Generate choices list based on field value""" + return [("", placeholder), (value, f"Loading... (ID: {value})")] + class Media: - js = ["js/classification_toggle.js"] + js = ["js/toggle_available_fields_on_is_public.js"] class TopicPage(UKHSAPage): @@ -66,6 +95,10 @@ class TopicPage(UKHSAPage): null=True, ) + theme = models.CharField(max_length=255, blank=True, default="", null=True) + sub_theme = models.CharField(max_length=255, blank=True, default="", null=True) + topic = models.CharField(max_length=255, blank=True, default="", null=True) + related_links_layout = models.CharField( verbose_name="Layout", help_text=help_texts.RELATED_LINKS_LAYOUT_FIELD, @@ -87,6 +120,9 @@ class TopicPage(UKHSAPage): FieldPanel("enable_area_selector"), FieldPanel("is_public"), FieldPanel("page_classification"), + FieldPanel("theme"), + FieldPanel("sub_theme"), + FieldPanel("topic"), FieldPanel("page_description"), FieldPanel("body"), ] @@ -230,16 +266,32 @@ def last_updated_at(self) -> datetime.datetime: def clean(self): super().clean() - # If is_public is true, automatically clear classification + # If is_public is true, automatically clear non-public fields if self.is_public: self.page_classification = None - # If not public page, classification must be chosen + self.theme = None + self.sub_theme = None + self.topic = None + + # If not public page, non-public fields must be set elif not self.page_classification: raise ValidationError( { "page_classification": "Please select a classification level for this non-public page" } ) + elif not self.theme: + raise ValidationError( + {"theme": "Please select a theme for this non-public page"} + ) + elif not self.sub_theme: + raise ValidationError( + {"sub_theme": "Please select a sub theme for this non-public page"} + ) + elif not self.topic: + raise ValidationError( + {"topic": "Please select a topic for this non-public page"} + ) class TopicPageRelatedLink(UKHSAPageRelatedLink): diff --git a/common/auth/cognito_jwt/user_manager.py b/common/auth/cognito_jwt/user_manager.py index 3a69470a4..a101168e0 100644 --- a/common/auth/cognito_jwt/user_manager.py +++ b/common/auth/cognito_jwt/user_manager.py @@ -17,6 +17,23 @@ def get_or_create_for_cognito(jwt_payload): try: username = jwt_payload["entraObjectId"] permission_sets = jwt_payload["permissionSets"] + + # DEBUGGING: Manual testing (just for now) + # username = "{YOUR_ENTRA_OBJECT_ID}" + # permission_sets = { + # "permission_sets": [ + # { + # "theme": {"id": "100", "name": "immunisation"}, + # "sub_theme": {"id": "200", "name": "childhood-vaccines"}, + # "topic": {"id": "-1", "name": "* (All)"}, + # "metric": {"id": "-1", "name": "* (All)"}, + # "geography_type": {"id": "300", "name": "Nation"}, + # "geography": {"id": "400", "name": "England"}, + # } + # ], + # "summary": {"has_global_access": False}, + # } + if not permission_sets: logger.debug( "Empty permissionSets in token for user: '%s'", diff --git a/tests/unit/auth_content/__init__.py b/metrics/api/middleware/__init__.py similarity index 100% rename from tests/unit/auth_content/__init__.py rename to metrics/api/middleware/__init__.py diff --git a/metrics/api/middleware/sql_debug.py b/metrics/api/middleware/sql_debug.py new file mode 100644 index 000000000..dc00d0762 --- /dev/null +++ b/metrics/api/middleware/sql_debug.py @@ -0,0 +1,24 @@ +from django.db import connection + + +def _print_sql(execute, sql, params, many, context): + print(f"\n[SQL] {sql}") + if params: + print(f"[PARAMS] {params}") + return execute(sql, params, many, context) + + +class SQLDebugMiddleware: + """ + Middleware that prints the raw SQL and params for every DB query made + during a request/response cycle. + + Only intended for local development — add to MIDDLEWARE in local.py. + """ + + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + with connection.execute_wrapper(_print_sql): + return self.get_response(request) diff --git a/metrics/api/serializers/geographies.py b/metrics/api/serializers/geographies.py index f19bdf6a3..d2c07d8d3 100644 --- a/metrics/api/serializers/geographies.py +++ b/metrics/api/serializers/geographies.py @@ -3,12 +3,12 @@ from django.db.models import QuerySet from rest_framework import serializers -from auth_content.constants import WILDCARD_ID_VALUE from metrics.api.serializers import help_texts from metrics.data.in_memory_models.geography_relationships.handlers import ( get_upstream_relationships_for_geography, ) from metrics.data.managers.core_models.time_series import CoreTimeSeriesQuerySet +from metrics.data.models.constants import PERMISSION_SET_WILDCARD_ID_VALUE from metrics.data.models.core_models import ( CoreTimeSeries, Geography, @@ -234,7 +234,7 @@ def geography_manager(self): @staticmethod def validate_geography_type_id(value: str) -> str | int: """Validate geography_type_id is either wildcard or a valid integer""" - if value == WILDCARD_ID_VALUE: + if value == PERMISSION_SET_WILDCARD_ID_VALUE: return value try: @@ -253,8 +253,10 @@ def data(self) -> dict[str, list[list[str, str]]]: geography_type_id = self.validated_data["geography_type_id"] # Handle wildcard - if geography_type_id == WILDCARD_ID_VALUE: - return {"choices": [[WILDCARD_ID_VALUE, "* (All geographies)"]]} + if geography_type_id == PERMISSION_SET_WILDCARD_ID_VALUE: + return { + "choices": [[PERMISSION_SET_WILDCARD_ID_VALUE, "* (All geographies)"]] + } parent_geography_type_id = int(geography_type_id) geographies = ( diff --git a/metrics/api/serializers/permission_sets.py b/metrics/api/serializers/permission_sets.py index 03c7e4413..f75cb6923 100644 --- a/metrics/api/serializers/permission_sets.py +++ b/metrics/api/serializers/permission_sets.py @@ -1,13 +1,13 @@ from django.db.models import QuerySet from rest_framework import serializers -from auth_content.constants import WILDCARD_ID_VALUE +from metrics.data.models.constants import PERMISSION_SET_WILDCARD_ID_VALUE from metrics.data.models.core_models.supporting import Metric, SubTheme, Topic def _validate_input_id(value, field_name): """Validate theme_id is either wildcard or a valid integer""" - if value == WILDCARD_ID_VALUE: + if value == PERMISSION_SET_WILDCARD_ID_VALUE: return value try: @@ -46,8 +46,10 @@ def data(self) -> dict: """ theme_id = self.validated_data["theme_id"] - if theme_id == WILDCARD_ID_VALUE: - return {"choices": [[WILDCARD_ID_VALUE, "* (All sub-themes)"]]} + if theme_id == PERMISSION_SET_WILDCARD_ID_VALUE: + return { + "choices": [[PERMISSION_SET_WILDCARD_ID_VALUE, "* (All sub-themes)"]] + } parent_theme_id = int(theme_id) sub_theme_tuples = _queryset_to_id_name_tuples( @@ -87,8 +89,8 @@ def data(self) -> dict: """ sub_theme_id = self.validated_data["sub_theme_id"] - if sub_theme_id == WILDCARD_ID_VALUE: - return {"choices": [[WILDCARD_ID_VALUE, "* (All topics)"]]} + if sub_theme_id == PERMISSION_SET_WILDCARD_ID_VALUE: + return {"choices": [[PERMISSION_SET_WILDCARD_ID_VALUE, "* (All topics)"]]} parent_sub_theme_id = int(sub_theme_id) topic_tuples = _queryset_to_id_name_tuples( @@ -129,8 +131,8 @@ def data(self) -> dict: """ topic_id = self.validated_data["topic_id"] - if topic_id == WILDCARD_ID_VALUE: - return {"choices": [[WILDCARD_ID_VALUE, "* (All metrics)"]]} + if topic_id == PERMISSION_SET_WILDCARD_ID_VALUE: + return {"choices": [[PERMISSION_SET_WILDCARD_ID_VALUE, "* (All metrics)"]]} parent_topic_id = int(topic_id) metric_tuples = _queryset_to_id_name_tuples( diff --git a/metrics/api/serializers/user.py b/metrics/api/serializers/user.py index 66263373a..24f0dbcff 100644 --- a/metrics/api/serializers/user.py +++ b/metrics/api/serializers/user.py @@ -3,7 +3,7 @@ from django.db.models import QuerySet from rest_framework import serializers -from auth_content.models.users import User +from cms.auth_content.models.users import User from metrics.utils.permission_grouping import group_by_geography_type, group_by_theme from metrics.utils.permission_hierarchy import ( build_permission_hierarchy, diff --git a/metrics/api/settings/default.py b/metrics/api/settings/default.py index 78af6c97d..32dbdd2b3 100644 --- a/metrics/api/settings/default.py +++ b/metrics/api/settings/default.py @@ -52,6 +52,7 @@ "metrics.api", "cms.acknowledgement", "cms.home", + "cms.auth_content", "cms.topic", "cms.topics_list", "cms.dashboard", @@ -78,7 +79,6 @@ "wagtail_trash", "modelcluster", "taggit", - "auth_content", ] MIDDLEWARE = [ diff --git a/metrics/api/settings/local.py b/metrics/api/settings/local.py index 8d0cf9636..2a3f10dfa 100644 --- a/metrics/api/settings/local.py +++ b/metrics/api/settings/local.py @@ -29,6 +29,7 @@ MIDDLEWARE += [ "debug_toolbar.middleware.DebugToolbarMiddleware", + "metrics.api.middleware.sql_debug.SQLDebugMiddleware", ] INTERNAL_IPS = ["127.0.0.1"] diff --git a/metrics/api/views/permission_sets.py b/metrics/api/views/permission_sets.py index 7a76eaf8f..91f94712e 100644 --- a/metrics/api/views/permission_sets.py +++ b/metrics/api/views/permission_sets.py @@ -42,7 +42,7 @@ class TopicsBySubThemeView(APIView): permission_classes = [] def get(self, request, sub_theme_id, *args, **kwargs): # noqa: PLR6301 - """API endpoint to fetch sub-themes based on selected theme.""" + """API endpoint to fetch topics based on selected sub-theme.""" serializer = TopicRequestSerializer(data={"sub_theme_id": sub_theme_id}) serializer.is_valid(raise_exception=True) return Response(serializer.data()) @@ -59,7 +59,7 @@ class MetricsByTopicView(APIView): permission_classes = [] def get(self, request, topic_id, *args, **kwargs): # noqa: PLR6301 - """API endpoint to fetch sub-themes based on selected theme.""" + """API endpoint to fetch metrics based on selected topic.""" serializer = MetricRequestSerializer(data={"topic_id": topic_id}) serializer.is_valid(raise_exception=True) return Response(serializer.data()) diff --git a/metrics/data/managers/core_models/geography.py b/metrics/data/managers/core_models/geography.py index 5ff2ce8cb..046721983 100644 --- a/metrics/data/managers/core_models/geography.py +++ b/metrics/data/managers/core_models/geography.py @@ -47,6 +47,19 @@ def get_name_by_code(self, geography_code: str) -> str | None: .first() ) + def get_id_by_name(self, geography_name: str) -> int: + """ + Gets the geography ID for a given geography name. + + Args: + geography_name: The name of the geography to look up + + Returns: + The geography ID if found, or -2 otherwise + """ + record = self.filter(name=geography_name).first() + return int(record.id) if record else -2 + def get_all_geography_codes_by_geography_type( self, geography_type_name: str ) -> Self: @@ -167,6 +180,23 @@ def get_name_by_code(self, geography_code: int) -> str | None: """ return self.get_queryset().get_name_by_code(geography_code) + def get_id_by_name(self, geography_name: str) -> int: + """Gets the geography ID which matches the given geography name. + + Args: + geography_name: The name of the geography to look up + + Returns: + The geography ID if found, -2 otherwise + + Examples: + >>> GeographyManager.get_id_by_name("England") + 6 + >>> GeographyManager.get_id_by_name("Unknown geography") + -2 + """ + return self.get_queryset().get_id_by_name(geography_name) + def get_all_names(self) -> GeographyQuerySet: """Gets all available deduplicated geography names as a flat list queryset. diff --git a/metrics/data/managers/core_models/geography_type.py b/metrics/data/managers/core_models/geography_type.py index ae45b36b7..df85ccd5b 100644 --- a/metrics/data/managers/core_models/geography_type.py +++ b/metrics/data/managers/core_models/geography_type.py @@ -41,6 +41,19 @@ def get_name_by_id(self, geography_type_id: int) -> str | None: """ return self.filter(id=geography_type_id).values_list("name", flat=True).first() + def get_id_by_name(self, geography_type_name: str) -> int: + """ + Gets the geography type ID for a given geography type name. + + Args: + geography_type_name: The name of the geography type to look up + + Returns: + The geography type ID if found, or -2 otherwise + """ + record = self.filter(name=geography_type_name).first() + return int(record.id) if record else -2 + def get_all_names_and_ids(self) -> models.QuerySet: """Gets all available geography_type names as a flat list queryset. @@ -77,6 +90,23 @@ def get_name_by_id(self, geography_type_id: int) -> str | None: """ return self.get_queryset().get_name_by_id(geography_type_id) + def get_id_by_name(self, geography_type_name: str) -> int: + """Gets the geography type ID which matches the given geography type name. + + Args: + geography_type_name: The name of the geography type to look up + + Returns: + The geography type ID if found, -2 otherwise + + Examples: + >>> GeographyTypeManager.get_id_by_name("Nation") + 5 + >>> GeographyTypeManager.get_id_by_name("Unknown type") + -2 + """ + return self.get_queryset().get_id_by_name(geography_type_name) + def get_all_names(self) -> GeographyTypeQuerySet: """Gets all available geography_type names as a flat list queryset. diff --git a/metrics/data/managers/core_models/metric.py b/metrics/data/managers/core_models/metric.py index e9fcb961d..66a3ec60c 100644 --- a/metrics/data/managers/core_models/metric.py +++ b/metrics/data/managers/core_models/metric.py @@ -29,6 +29,19 @@ def get_name_by_id(self, metric_id: int) -> str | None: """ return self.filter(id=metric_id).values_list("name", flat=True).first() + def get_id_by_name(self, metric_name: str) -> int: + """ + Gets the metric ID for a given metric name. + + Args: + metric_name: The name of the metric to look up + + Returns: + The metric ID if found, or -2 otherwise + """ + record = self.filter(name=metric_name).first() + return int(record.id) if record else -2 + def get_all_names(self) -> models.QuerySet: """Gets all available metric names as a flat list queryset. @@ -146,6 +159,23 @@ def get_name_by_id(self, metric_id: int) -> str | None: """ return self.get_queryset().get_name_by_id(metric_id) + def get_id_by_name(self, metric_name: str) -> int: + """Gets the metric ID which matches the given metric name. + + Args: + metric_name: The name of the metric to look up + + Returns: + The metric ID if found, -2 otherwise + + Examples: + >>> MetricManager.get_id_by_name("COVID-19_cases_countRollingMean") + 4 + >>> MetricManager.get_id_by_name("Unknown metric") + -2 + """ + return self.get_queryset().get_id_by_name(metric_name) + def get_all_names(self) -> MetricQuerySet: """Gets all available metric names as a flat list queryset. diff --git a/metrics/data/managers/core_models/time_series.py b/metrics/data/managers/core_models/time_series.py index aed8df7f6..5b63aaca2 100644 --- a/metrics/data/managers/core_models/time_series.py +++ b/metrics/data/managers/core_models/time_series.py @@ -6,6 +6,7 @@ """ import datetime +import logging from collections.abc import Iterable from typing import Self @@ -15,12 +16,13 @@ from metrics.api.permissions.fluent_permissions import ( is_public_data_only_enforced, - validate_permissions_for_non_public, ) from metrics.data.models import RBACPermission ALLOWABLE_METRIC_VALUE_RANGE_TYPE = tuple[str | float | int, str | float | int] +logger = logging.getLogger(__name__) + class CoreTimeSeriesQuerySet(models.QuerySet): """Custom queryset which can be used by the `CoreTimeSeriesManager`""" @@ -171,8 +173,10 @@ def query_for_data( stratum: str | None = None, sex: str | None = None, age: str | None = None, + theme: str, + sub_theme: str, metric_value_ranges: list[tuple[str | float | int]] | None = None, - restrict_to_public: bool = True, + permission_sets: dict, ) -> models.QuerySet: """Filters for a N-item list of dicts by the given params if `fields_to_export` is used. @@ -212,14 +216,18 @@ def query_for_data( Note that options are `M`, `F`, or `ALL`. age: The age range to apply additional filtering to. E.g. `0_4` would be used to capture the age of 0-4 years old + theme: The name of the theme being queried. + This is only used to determine permissions for + the non-public portion of the requested dataset. + sub_theme: The name of the sub theme being queried. + This is only used to determine permissions for + the non-public portion of the requested dataset. metric_value_ranges: List of tuples whereby each tuple represents a permissible metric value range. i.e. to filter for all record with values between 0 -> 80 AND 90 -> 100, this can be provided as `[(0, 80), (90, 100)]`. - restrict_to_public: Boolean switch to restrict the query - to only return public records. - If False, then non-public records will be included. + permission_sets: The JWT permissions extracted from the Cognito token. Returns: QuerySet: An ordered queryset from lowest -> highest @@ -231,6 +239,9 @@ def query_for_data( ]>` """ + + logger.info("Entered query_for_data()") + queryset = self.filter( metric__topic__name=topic, metric__name=metric, @@ -245,9 +256,30 @@ def query_for_data( sex=sex, age=age, ) + public_queryset = queryset.filter(is_public=True) - if restrict_to_public: - queryset = queryset.filter(is_public=True) + if permission_sets: + logger.info("Entered if permission_sets clause") + + # WORKAROUND: Cos circular import error when at the top of the file + from cms.auth_content.auth_utils import check_permissions_by_name + + if check_permissions_by_name( + permission_sets, + theme, + sub_theme, + topic, + metric, + geography_type, + geography, + ): + logger.info("Entered check_permissions_by_name() if clause") + + queryset = public_queryset + queryset.filter(is_public=False) + else: + logger.info("Entered else permission_sets clause") + + queryset = public_queryset queryset = self._exclude_data_under_embargo(queryset=queryset) queryset = self._filter_for_metric_value_ranges( @@ -533,6 +565,7 @@ def query_for_data( sub_theme: str = "", metric_value_ranges: list[str | float | int] | None = None, rbac_permissions: Iterable[RBACPermission] | None = None, + permission_sets: dict, ) -> CoreTimeSeriesQuerySet: """Filters for a 2-item object by the given params. Slices all values older than the `date_from`. @@ -582,6 +615,7 @@ def query_for_data( rbac_permissions: The RBAC permissions available to the given request. This dictates whether the given request is permitted access to non-public data or not. + permission_sets: The JWT permissions extracted from the Cognito token. Notes: If we have the following input `queryset`: @@ -611,20 +645,12 @@ def query_for_data( ]>` """ - rbac_permissions: Iterable[RBACPermission] = rbac_permissions or [] - has_access_to_non_public_data: bool = validate_permissions_for_non_public( - theme=theme, - sub_theme=sub_theme, - topic=topic, - metric=metric, - geography_type=geography_type, - geography=geography, - rbac_permissions=rbac_permissions, - ) return self.get_queryset().query_for_data( fields_to_export=fields_to_export, field_to_order_by=field_to_order_by, + theme=theme, + sub_theme=sub_theme, topic=topic, metric=metric, date_from=date_from, @@ -635,7 +661,7 @@ def query_for_data( sex=sex, age=age, metric_value_ranges=metric_value_ranges, - restrict_to_public=not has_access_to_non_public_data, + permission_sets=permission_sets, ) def query_for_superseded_data( diff --git a/metrics/data/managers/core_models/topic.py b/metrics/data/managers/core_models/topic.py index 00b6f0632..cd67e9f1a 100644 --- a/metrics/data/managers/core_models/topic.py +++ b/metrics/data/managers/core_models/topic.py @@ -40,6 +40,40 @@ def get_name_by_id(self, topic_id: int) -> str | None: """ return self.filter(id=topic_id).values_list("name", flat=True).first() + def get_id_by_name( + self, theme_name: str, sub_theme_name: str, topic_name: str + ) -> tuple[int, int, int]: + """ + Gets the theme, sub-theme and topic IDs matching the given names. + + Args: + theme_name: The name of the parent theme + sub_theme_name: The name of the parent sub-theme + topic_name: The name of the topic to look up + + Returns: + A tuple of (theme_id, sub_theme_id, topic_id) if found, + or the tuple (-2, -2, -2) otherwise + """ + record = self.filter( + sub_theme__theme__name=theme_name, + sub_theme__name=sub_theme_name, + name=topic_name, + ).first() + + if record: + return ( + int(record.sub_theme.theme_id), + int(record.sub_theme_id), + int(record.id), + ) + + return ( + -2, + -2, + -2, + ) + def get_all_unique_names(self) -> models.QuerySet: """Gets all available unique topic names as a flat list queryset. @@ -113,6 +147,30 @@ def get_name_by_id(self, topic_id: int) -> str | None: """ return self.get_queryset().get_name_by_id(topic_id) + def get_id_by_name( + self, theme_name: str, sub_theme_name: str, topic_name: str + ) -> tuple[int, int, int]: + """Gets the theme, sub-theme and topic IDs matching the given names. + + Args: + theme_name: The name of the parent theme + sub_theme_name: The name of the parent sub-theme + topic_name: The name of the topic to look up + + Returns: + A tuple of (theme_id, sub_theme_id, topic_id) if found, + or (-2, -2, -2) if not found. + + Examples: + >>> TopicManager.get_id_by_name("Infectious disease", "Respiratory", "COVID-19") + (1, 2, 3) + >>> TopicManager.get_id_by_name("Unknown", "Unknown", "Unknown") + (-2, -2, -2) + """ + return self.get_queryset().get_id_by_name( + theme_name, sub_theme_name, topic_name + ) + def get_all_names(self) -> TopicQuerySet: """Gets all available topic names as a flat list queryset. diff --git a/metrics/data/managers/rbac_models/user.py b/metrics/data/managers/rbac_models/user.py index f4a3194d1..b2967685f 100644 --- a/metrics/data/managers/rbac_models/user.py +++ b/metrics/data/managers/rbac_models/user.py @@ -9,7 +9,7 @@ from django.db import models -from auth_content.models.permission_sets import PermissionSet +from cms.auth_content.models.permission_sets import PermissionSet class UserQuerySet(models.QuerySet): diff --git a/metrics/data/models/constants.py b/metrics/data/models/constants.py index d5fcfba29..2388368e7 100644 --- a/metrics/data/models/constants.py +++ b/metrics/data/models/constants.py @@ -5,3 +5,4 @@ METRIC_FREQUENCY_MAX_CHAR_CONSTRAINT: int = 1 METRIC_VALUE_MAX_DIGITS: int = 11 METRIC_VALUE_DECIMAL_PLACES: int = 4 +PERMISSION_SET_WILDCARD_ID_VALUE: str = "-1" diff --git a/metrics/domain/models/charts/common.py b/metrics/domain/models/charts/common.py index 317301450..7608749d3 100644 --- a/metrics/domain/models/charts/common.py +++ b/metrics/domain/models/charts/common.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Iterable from decimal import Decimal from typing import Literal @@ -5,6 +6,8 @@ from pydantic.main import BaseModel from rest_framework.request import Request +logger = logging.getLogger(__name__) + class BaseChartRequestParams(BaseModel): file_format: Literal["png", "svg", "jpg", "jpeg", "json", "csv"] @@ -24,6 +27,16 @@ class BaseChartRequestParams(BaseModel): class Config: arbitrary_types_allowed = True + @property + def permission_sets(self) -> dict: + """Extract JWT permissions from the authenticated request""" + + logger.info("Entered BaseChartRequestParams.permission_sets") + + return getattr(self.request.user, "permission_sets", {}) + @property def rbac_permissions(self) -> Iterable["RBACPermission"]: + """TODO: RBAC-based permissions are legacy and will be removed in a future release""" + return getattr(self.request, "rbac_permissions", []) diff --git a/metrics/interfaces/plots/access.py b/metrics/interfaces/plots/access.py index 6e4eac34d..e8bcb0406 100644 --- a/metrics/interfaces/plots/access.py +++ b/metrics/interfaces/plots/access.py @@ -161,8 +161,12 @@ def get_queryset_from_core_model_manager( plot_params["fields_to_export"].append("upper_confidence") plot_params["fields_to_export"].append("lower_confidence") + logger.info("Entered access.py") + return self.core_model_manager.query_for_data( - **plot_params, rbac_permissions=self.chart_request_params.rbac_permissions + **plot_params, + rbac_permissions=self.chart_request_params.rbac_permissions, # old permissions (remove) + permission_sets=self.chart_request_params.permission_sets, # new permissions ) def build_plot_data_from_parameters_with_complete_queryset( diff --git a/metrics/utils/permission_hierarchy.py b/metrics/utils/permission_hierarchy.py index 53c9861a3..c5a182204 100644 --- a/metrics/utils/permission_hierarchy.py +++ b/metrics/utils/permission_hierarchy.py @@ -10,7 +10,7 @@ from django.db.models import QuerySet -from auth_content.models.permission_sets import PermissionSet +from cms.auth_content.models.permission_sets import PermissionSet from metrics.data.models.core_models.supporting import ( Geography, GeographyType, diff --git a/pyproject.toml b/pyproject.toml index 170a4e864..2ce7e729a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -225,6 +225,10 @@ ignore_imports = [ "metrics.api.urls_construction -> cms.snippets.views", "metrics.api.urls_construction -> feedback.api.urls", "cms.dynamic_content.elements -> validation.data_transfer_models.base", + "cms.auth_content.models.users -> metrics.data.managers.rbac_models.user", # Allow auth_content to be moved into CMS, consider refactor + "metrics.api.serializers.user -> cms.auth_content.models.users", # Allow auth_content to be moved into CMS, consider refactor + "metrics.data.managers.rbac_models.user -> cms.auth_content.models.permission_sets", # Allow auth_content to be moved into CMS, consider refactor + "metrics.utils.permission_hierarchy -> cms.auth_content.models.permission_sets", # Allow auth_content to be moved into CMS, consider refactor ] [[tool.importlinter.contracts]] @@ -237,7 +241,14 @@ layers = [ ] ignore_imports = [ "metrics.data.managers.core_models.time_series -> metrics.api.permissions.fluent_permissions", - "metrics.data.managers.core_models.headline -> metrics.api.permissions.fluent_permissions" + "metrics.data.managers.core_models.headline -> metrics.api.permissions.fluent_permissions", + "metrics.data.managers.rbac_models.user -> cms.auth_content.models.permission_sets", # Allow auth_content to be moved into CMS, consider refactor + "cms.auth_content.models.permission_sets -> cms.metrics_interface.field_choices_callables", # Allow auth_content to be moved into CMS, consider refactor + "cms.metrics_interface.field_choices_callables -> cms.metrics_interface", # Allow auth_content to be moved into CMS, consider refactor + "cms.metrics_interface -> cms.metrics_interface.interface", # Allow auth_content to be moved into CMS, consider refactor + "cms.metrics_interface.interface -> metrics.domain.charts.colour_scheme", # Allow auth_content to be moved into CMS, consider refactor + "cms.metrics_interface.interface -> metrics.domain.charts.common_charts.plots.line_multi_coloured.properties", # Allow auth_content to be moved into CMS, consider refactor + "cms.metrics_interface.interface -> metrics.domain.common.utils", # Allow auth_content to be moved into CMS, consider refactor ] [[tool.importlinter.contracts]] diff --git a/tests/factories/auth_content/models/permission_sets.py b/tests/factories/auth_content/models/permission_sets.py index 7f73e093f..633765392 100644 --- a/tests/factories/auth_content/models/permission_sets.py +++ b/tests/factories/auth_content/models/permission_sets.py @@ -1,6 +1,6 @@ import factory -from auth_content.models.permission_sets import PermissionSet +from cms.auth_content.models.permission_sets import PermissionSet class PermissionSetFactory(factory.django.DjangoModelFactory): diff --git a/tests/factories/auth_content/models/users.py b/tests/factories/auth_content/models/users.py index df74effbd..ff76242f1 100644 --- a/tests/factories/auth_content/models/users.py +++ b/tests/factories/auth_content/models/users.py @@ -1,7 +1,7 @@ import factory -from auth_content.models.permission_sets import PermissionSet -from auth_content.models.users import User +from cms.auth_content.models.permission_sets import PermissionSet +from cms.auth_content.models.users import User class UserFactory(factory.django.DjangoModelFactory): diff --git a/tests/integration/cms/dashboard/__init__.py b/tests/integration/cms/dashboard/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/cms/dashboard/test_viewsets.py b/tests/integration/cms/dashboard/test_viewsets.py new file mode 100644 index 000000000..5a9ef4e6a --- /dev/null +++ b/tests/integration/cms/dashboard/test_viewsets.py @@ -0,0 +1,219 @@ +import pytest +from unittest.mock import MagicMock +from django.test import RequestFactory +from rest_framework.request import Request +from wagtail.models import Page + +from cms.common.models import CommonPage +from cms.dashboard.viewsets import CMSPagesAPIViewSet +from cms.metrics_documentation.models.child import MetricsDocumentationChildEntry +from cms.topic.models import TopicPage +from metrics.data.models.core_models import Metric, Topic + + +class MockPermissionSets(list): + def __init__(self, permissions, has_global_access=False): + super().__init__(permissions) + self._summary = {"has_global_access": has_global_access} + + def __getitem__(self, key): + if key == "permission_sets": + return list(self) + if key == "summary": + return self._summary + return super().__getitem__(key) + + +@pytest.mark.django_db +class TestCMSPagesAPIViewSetPermissions: + + @pytest.fixture + def setup_pages(self): + influenza_topic = Topic.objects.create(name="Influenza") + metric = Metric.objects.create( + name="influenza_headline_positivityLatest", topic=influenza_topic + ) + + covid_topic = Topic.objects.create(name="COVID-19") + private_metric = Metric.objects.create( + name="COVID-19_headline_cases_7DayTotals", topic=covid_topic + ) + private_metric_two = Metric.objects.create( + name="COVID-19_headline_7DayAdmissionsChange", topic=covid_topic + ) + + home = Page.objects.get(id=2) + + public_topic = TopicPage( + title="Public Topic", + page_description="test", + slug="public-topic", + is_public=True, + theme="1", + seo_title="public-topic", + ) + home.add_child(instance=public_topic) + + private_topic = TopicPage( + title="Private Topic", + page_description="test", + slug="private-topic", + is_public=False, + theme="1", + sub_theme="test", + topic="test", + page_classification="official_sensitive", + seo_title="private-topic", + ) + home.add_child(instance=private_topic) + + public_metrics = MetricsDocumentationChildEntry( + title="Public Metric", + page_description="test", + slug="public-metric", + metric=metric.pk, + is_public=True, + seo_title="public-metrics", + ) + home.add_child(instance=public_metrics) + + private_metrics = MetricsDocumentationChildEntry( + title="Private Metric", + page_description="test", + slug="private-metric", + theme="2", + sub_theme="test", + topic="test", + metric=private_metric.pk, + is_public=False, + seo_title="private-metrics", + ) + home.add_child(instance=private_metrics) + + private_metrics_two = MetricsDocumentationChildEntry( + title="Private Metric 2", + page_description="test", + slug="private-metric-two", + theme="1", + sub_theme="test", + topic="test", + metric=private_metric_two.pk, + is_public=False, + seo_title="private-metrics-two", + ) + home.add_child(instance=private_metrics_two) + + standard_page = CommonPage( + title="Standard", body="test", slug="standard", seo_title="standard-page" + ) + home.add_child(instance=standard_page) + + return { + "public_topic": public_topic, + "private_topic": private_topic, + "public_metrics": public_metrics, + "private_metrics": private_metrics, + "standard_page": standard_page, + } + + def test_anonymous_user_access(self, setup_pages): + """ + Given a request is made by an unauthenticated user + When the queryset is retrieved + Then only public pages are returned + """ + # Given + rf = RequestFactory() + url = "/api/v2/pages/" + django_request = rf.get(url) + + request = Request(django_request) + + mock_user = MagicMock() + request.user = mock_user + + view = CMSPagesAPIViewSet() + view.request = request + + # When + result = view.get_queryset() + + # Then + titles = [p.title for p in result] + assert "Public Topic" in titles + assert "Public Metric" in titles + assert "Standard" in titles + assert "Private Topic" not in titles + assert "Private Metric" not in titles + assert "Private Metric 2" not in titles + + def test_global_access_user(self, setup_pages): + """ + Given a request is made by an authenticated user with global access + When the queryset is retrieved + Then all pages are returned + """ + # Given + rf = RequestFactory() + url = "/api/v2/pages/" + django_request = rf.get(url) + + request = Request(django_request) + + mock_user = MagicMock() + mock_user.permission_sets = MockPermissionSets( + [], + has_global_access=True, + ) + + request.user = mock_user + request.auth = "token" + + view = CMSPagesAPIViewSet() + view.request = request + + # When + result = view.get_queryset() + + # Then + titles = [p.title for p in result] + assert "Public Topic" in titles + assert "Public Metric" in titles + assert "Standard" in titles + assert "Private Topic" in titles + assert "Private Metric" in titles + assert "Private Metric 2" in titles + + def test_restricted_user_with_permission(self, setup_pages): + """ + Given a request is made by an authenticated user with access to some private pages + When the queryset is retrieved + Then only the pages the user has access to are returned + """ + # Given + rf = RequestFactory() + url = "/api/v2/pages/" + django_request = rf.get(url) + + request = Request(django_request) + + mock_user = MagicMock() + mock_user.permission_sets = MockPermissionSets( + [{"theme": {"id": "1"}, "sub_theme": {"id": "-1"}}], + has_global_access=False, + ) + + request.user = mock_user + request.auth = "token" + + view = CMSPagesAPIViewSet() + view.request = request + + # When + result = view.get_queryset() + + # Then + titles = [p.title for p in result] + assert "Private Topic" in titles + assert "Private Metric 2" in titles + assert "Private Metric" not in titles diff --git a/tests/integration/cms/dynamic_content/test_page_link_chooser.py b/tests/integration/cms/dynamic_content/test_page_link_chooser.py index 3853a0067..12fb0681d 100644 --- a/tests/integration/cms/dynamic_content/test_page_link_chooser.py +++ b/tests/integration/cms/dynamic_content/test_page_link_chooser.py @@ -28,6 +28,9 @@ def test_page_chooser_returns_full_url( path="abc", depth=1, title="abc", + theme="test", + sub_theme="test", + topic=1, live=True, seo_title="ABC", ) diff --git a/tests/integration/cms/metrics_documentation/data_migration/test_operations.py b/tests/integration/cms/metrics_documentation/data_migration/test_operations.py index 4e678a351..b8f464414 100644 --- a/tests/integration/cms/metrics_documentation/data_migration/test_operations.py +++ b/tests/integration/cms/metrics_documentation/data_migration/test_operations.py @@ -67,9 +67,12 @@ def test_removes_all_child_entries(self): path="abc", depth=1, title="Test", + theme="test", + sub_theme="test", slug="test", page_description="xyz", - metric=metric.name, + metric=metric.pk, + topic=metric.topic, seo_title="Test", ) assert MetricsDocumentationChildEntry.objects.exists() @@ -146,45 +149,40 @@ def test_creates_correct_child_entries(self, dashboard_root_page: UKHSARootPage) Then the correct child entries are created for the corresponding `Metric` records """ + # Given - _seed_truncated_test_data_with_split_auth() - healthcare_admission_metric = Metric.objects.get( - name="RSV_healthcare_admissionRateByWeek" + + entries = get_metrics_definitions() + assert entries, "No metric definitions found" + + test_entry = entries[0] + + topic = Topic.objects.create(name=test_entry["topic"]) + + metric = Metric.objects.create( + id=test_entry["metric"], + name=f"metric-{test_entry['metric']}", + topic=topic, ) # When create_metrics_documentation_parent_page_and_child_entries() # Then - healthcare_admission_rate_child_entry = ( - MetricsDocumentationChildEntry.objects.get( - metric=healthcare_admission_metric.name - ) - ) - assert healthcare_admission_rate_child_entry.metric_group == "healthcare" - expected_title = "RSV healthcare admission rate by week" - assert ( - healthcare_admission_rate_child_entry.slug - == expected_title.lower().replace(" ", "-") - ) - assert healthcare_admission_rate_child_entry.topic == "RSV" - assert healthcare_admission_rate_child_entry.title == expected_title - assert ( - healthcare_admission_rate_child_entry.seo_title - == f"{expected_title} | UKHSA data dashboard" - ) + child_entry = MetricsDocumentationChildEntry.objects.get(metric=metric.pk) + + expected_title = test_entry["title"] + + assert child_entry.title == expected_title + assert child_entry.slug == expected_title.lower().replace(" ", "-") + assert child_entry.topic == test_entry["topic"] + assert child_entry.seo_title == test_entry["seo_title"] - expected_page_description = ( - "This metric shows the rate per 100,000 people of the total number of people " - "with confirmed RSV admitted to hospital " - "(general admissions plus admissions to ICU and HDU) " - "in the 7 days up to and including the date shown." - ) assert ( - healthcare_admission_rate_child_entry.search_description - == healthcare_admission_rate_child_entry.page_description - == expected_page_description + child_entry.search_description + == child_entry.page_description + == test_entry["page_description"] ) @pytest.mark.django_db diff --git a/tests/integration/cms/metrics_documentation/models/test_child.py b/tests/integration/cms/metrics_documentation/models/test_child.py index 024dc6b53..da59bd2a9 100644 --- a/tests/integration/cms/metrics_documentation/models/test_child.py +++ b/tests/integration/cms/metrics_documentation/models/test_child.py @@ -17,25 +17,31 @@ def test_metric_is_unique(self): """ # Given metric_name = "influenza_headline_positivityLatest" - Metric.objects.create(name=metric_name) + created_metric = Metric.objects.create(name=metric_name) Topic.objects.create(name=metric_name.split("_")[0].title()) - _create_metrics_documentation_child_entry(metric_name=metric_name, path="doc_1") + _create_metrics_documentation_child_entry( + metric_name=metric_name, metric_id=created_metric.pk, path="doc_1" + ) # When / Then with pytest.raises(ValidationError): _create_metrics_documentation_child_entry( - metric_name=metric_name, path="doc_2" + metric_name=metric_name, metric_id=created_metric.pk, path="doc_2" ) def _create_metrics_documentation_child_entry( metric_name: str, + metric_id: int, path: str, ) -> MetricsDocumentationChildEntry: MetricsDocumentationChildEntry.objects.create( - metric=metric_name, + metric=metric_id, title=metric_name, + theme="test", + sub_theme="test", + topic=1, path=path, depth=1, slug=metric_name, diff --git a/tests/integration/cms/topic/test_managers.py b/tests/integration/cms/topic/test_managers.py index 0776a59a2..af36e4753 100644 --- a/tests/integration/cms/topic/test_managers.py +++ b/tests/integration/cms/topic/test_managers.py @@ -19,6 +19,9 @@ def test_get_live_pages(self): path="abc", depth=1, title="abc", + theme="test", + topic=1, + sub_theme="test", live=True, seo_title="ABC", ) @@ -26,6 +29,9 @@ def test_get_live_pages(self): path="def", depth=1, title="def", + theme="test2", + topic=2, + sub_theme="test2", live=False, seo_title="DEF", ) diff --git a/tests/integration/metrics/api/views/test_geographies.py b/tests/integration/metrics/api/views/test_geographies.py index 31b15c627..685595f63 100644 --- a/tests/integration/metrics/api/views/test_geographies.py +++ b/tests/integration/metrics/api/views/test_geographies.py @@ -4,7 +4,7 @@ from rest_framework.response import Response from rest_framework.test import APIClient -from auth_content.constants import WILDCARD_ID_VALUE +from cms.auth_content.constants import WILDCARD_ID_VALUE from tests.factories.metrics.geography import GeographyFactory from tests.factories.metrics.time_series import CoreTimeSeriesFactory from validation.geography_code import UNITED_KINGDOM_GEOGRAPHY_CODE diff --git a/tests/integration/metrics/api/views/test_permission_sets.py b/tests/integration/metrics/api/views/test_permission_sets.py index 7acc8d51a..1bb4a9096 100644 --- a/tests/integration/metrics/api/views/test_permission_sets.py +++ b/tests/integration/metrics/api/views/test_permission_sets.py @@ -4,7 +4,7 @@ from rest_framework.response import Response from rest_framework.test import APIClient -from auth_content.constants import WILDCARD_ID_VALUE +from cms.auth_content.constants import WILDCARD_ID_VALUE from tests.factories.metrics.metric import MetricFactory from tests.factories.metrics.sub_theme import SubThemeFactory from tests.factories.metrics.topic import TopicFactory diff --git a/tests/integration/metrics/api/views/test_user.py b/tests/integration/metrics/api/views/test_user.py index 0742215d3..1264f058e 100644 --- a/tests/integration/metrics/api/views/test_user.py +++ b/tests/integration/metrics/api/views/test_user.py @@ -5,7 +5,7 @@ from rest_framework.response import Response from rest_framework.test import APIClient -from auth_content.constants import WILDCARD_ID_VALUE +from cms.auth_content.constants import WILDCARD_ID_VALUE from tests.factories.auth_content.models.permission_sets import PermissionSetFactory from tests.factories.auth_content.models.users import UserFactory from tests.factories.metrics.metric import MetricFactory diff --git a/tests/integration/metrics/data/managers/rbac_models/test_user.py b/tests/integration/metrics/data/managers/rbac_models/test_user.py index 8e085950f..7c75ec8ac 100644 --- a/tests/integration/metrics/data/managers/rbac_models/test_user.py +++ b/tests/integration/metrics/data/managers/rbac_models/test_user.py @@ -1,6 +1,6 @@ import pytest -from auth_content.models.users import User +from cms.auth_content.models.users import User from tests.factories.auth_content.models.permission_sets import PermissionSetFactory from tests.factories.auth_content.models.users import UserFactory diff --git a/tests/integration/metrics/utils/test_permission_hierarchy.py b/tests/integration/metrics/utils/test_permission_hierarchy.py index 62ee32dba..851860656 100644 --- a/tests/integration/metrics/utils/test_permission_hierarchy.py +++ b/tests/integration/metrics/utils/test_permission_hierarchy.py @@ -491,7 +491,7 @@ def test_single_permission_no_deduplication(self): Then no deduplication occurs """ # Given - from auth_content.models.users import User + from cms.auth_content.models.users import User user = User.objects.create(user_id=uuid4()) perm = PermissionSetFactory.create_permission_set( @@ -521,7 +521,7 @@ def test_removes_fully_subsumed_permission(self): Then only permission A is returned """ # Given - from auth_content.models.users import User + from cms.auth_content.models.users import User user = User.objects.create(user_id=uuid4()) @@ -564,7 +564,7 @@ def test_keeps_independent_permissions(self): Then both are kept """ # Given - from auth_content.models.users import User + from cms.auth_content.models.users import User user = User.objects.create(user_id=uuid4()) @@ -603,7 +603,7 @@ def test_complex_multi_level_deduplication(self): Then correct deduplication occurs """ # Given - from auth_content.models.users import User + from cms.auth_content.models.users import User user = User.objects.create(user_id=uuid4()) @@ -658,7 +658,7 @@ def test_summary_contains_correct_statistics(self): Then summary contains accurate statistics """ # Given - from auth_content.models.users import User + from cms.auth_content.models.users import User user = User.objects.create(user_id=uuid4()) @@ -696,7 +696,7 @@ def test_hierarchy_structure_is_correct(self): Then each permission has correct structure """ # Given - from auth_content.models.users import User + from cms.auth_content.models.users import User user = User.objects.create(user_id=uuid4()) perm = PermissionSetFactory.create_permission_set( @@ -743,7 +743,7 @@ def test_returns_normalized_permission_list(self): Then returns list of NormalizedPermission objects """ # Given - from auth_content.models.users import User + from cms.auth_content.models.users import User user = User.objects.create(user_id=uuid4()) perm = PermissionSetFactory.create_permission_set( @@ -772,7 +772,7 @@ def test_deduplicates_permissions(self): Then subsumed permissions are removed """ # Given - from auth_content.models.users import User + from cms.auth_content.models.users import User user = User.objects.create(user_id=uuid4()) @@ -810,7 +810,7 @@ def test_empty_queryset_returns_empty_hierarchy(self): Then returns empty hierarchy with zero counts """ # Given - from auth_content.models.permission_sets import PermissionSet + from cms.auth_content.models.permission_sets import PermissionSet # When result = build_permission_hierarchy(PermissionSet.objects.none()) @@ -829,7 +829,7 @@ def test_handles_null_fields_gracefully(self): Then handles gracefully without errors """ # Given - from auth_content.models.users import User + from cms.auth_content.models.users import User user = User.objects.create(user_id=uuid4()) perm = PermissionSetFactory.create( @@ -859,7 +859,7 @@ def test_all_wildcards_at_different_levels(self): Then correctly handles all wildcard combinations """ # Given - from auth_content.models.users import User + from cms.auth_content.models.users import User user = User.objects.create(user_id=uuid4()) perm1 = PermissionSetFactory.create_permission_set( diff --git a/tests/unit/cms/auth_content/__init__.py b/tests/unit/cms/auth_content/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/cms/auth_content/models/test_permission_sets.py b/tests/unit/cms/auth_content/models/test_permission_sets.py new file mode 100644 index 000000000..1ca0eb2d0 --- /dev/null +++ b/tests/unit/cms/auth_content/models/test_permission_sets.py @@ -0,0 +1,165 @@ +import pytest +from unittest.mock import MagicMock, patch + +from django.core.exceptions import ValidationError +from cms.auth_content.models.permission_sets import PermissionSet, PermissionSetForm + + +class TestPermissionSetForm: + MOCK_PERMISSION_SET_FIELDS = [ + {"field_name": "theme", "field_label": "Theme"}, + {"field_name": "sub_theme", "field_label": "Sub Theme"}, + {"field_name": "topic", "field_label": "Topic"}, + {"field_name": "metric", "field_label": "Metric"}, + {"field_name": "geography", "field_label": "Geography"}, + ] + + def _make_form(self, instance=None, queryset_exists=False): + """ + Instantiate PermissionSetForm with all Wagtail + internals patched. + """ + with ( + patch( + "wagtail.admin.panels.WagtailAdminPageForm.__init__", return_value=None + ), + patch( + "cms.auth_content.models.permission_sets.PERMISSION_SET_FIELDS", + self.MOCK_PERMISSION_SET_FIELDS, + ), + patch( + "cms.auth_content.models.permission_sets._create_form_field", + side_effect=lambda field, wildcard: MagicMock(name=field["field_name"]), + ), + ): + form = PermissionSetForm.__new__(PermissionSetForm) + form.fields = {} + form.instance = instance or MagicMock(pk=None) + default_data = { + "theme": 1, + "sub_theme": 2, + "topic": 3, + "metric": 4, + "geography_type": 5, + "geography": 6, + } + form.cleaned_data = default_data + form.__init__() + + mock_qs = MagicMock() + mock_qs.exists.return_value = queryset_exists + mock_qs.exclude.return_value = mock_qs + + form._mock_qs = mock_qs + return form + + def test_init_sets_up_fields(self): + """ + When a new form is instantiated + Then a form field is added to `fields` for each entry in `PERMISSION_SET_FIELDS` + """ + form = self._make_form() + + assert len(form.fields) == 5 + assert "theme" in form.fields + assert "sub_theme" in form.fields + assert "topic" in form.fields + assert "metric" in form.fields + assert "geography" in form.fields + + def test_initialize_dependent_fields(self): + """ + Given a new form + When an instance has a pk value set + Then `_initialize_dependent_fields` is called + """ + instance = MagicMock(pk=1, sub_theme=1, topic=2, metric=3, geography=4) + form = self._make_form(instance=instance) + + assert form.fields["sub_theme"].widget.choices == [ + ("", "Select theme first"), + (1, "Loading... (ID: 1)"), + ] + assert form.fields["topic"].widget.choices == [ + ("", "Select sub-theme first"), + (2, "Loading... (ID: 2)"), + ] + assert form.fields["metric"].widget.choices == [ + ("", "Select topic first"), + (3, "Loading... (ID: 3)"), + ] + assert form.fields["geography"].widget.choices == [ + ("", "Select geography type first"), + (4, "Loading... (ID: 4)"), + ] + + def test_get_field_choices(self): + """ + When the static function `_get_field_choices` is called without a wildcard + Then the result returns the placeholder value + """ + result = PermissionSetForm._get_field_choices("test", "placeholder", None) + assert result == [("", "placeholder"), ("test", "Loading... (ID: test)")] + + def test_get_field_choices_wildcard_match(self): + """ + When the static function `_get_field_choices` is called with a wildcard + Then the result returns the wildcard match + """ + result = PermissionSetForm._get_field_choices("-1", "placeholder", "wildcard") + assert result == [("-1", "wildcard")] + + @patch("cms.auth_content.models.permission_sets.PermissionSet.objects.filter") + def test_validation_error_raised_if_queryset_duplicated( + self, mock_query_filter: MagicMock + ): + """ + Given a form is created with an existing queryset match + When `clean` is called + Then a `ValidationError` is raised + """ + form = self._make_form(queryset_exists=True) + mock_query_filter.return_value = form._mock_qs + + with pytest.raises(ValidationError) as e: + form.clean() + + assert ( + "A permission set with this exact combination already exists. Please modify your selection to create a unique permission set." + in str(e.value) + ) + + @patch("cms.auth_content.models.permission_sets.PermissionSet.objects.filter") + def test_returns_cleaned_data_when_no_duplicate_exists( + self, mock_query_filter: MagicMock + ): + """ + Given a form is created without an existing queryset match + When `clean` is called + Then the cleaned data is returned + """ + instance = MagicMock(pk=1) + form = self._make_form(instance=instance, queryset_exists=False) + mock_query_filter.return_value = form._mock_qs + + result = form.clean() + + assert result == form.cleaned_data + + +class TestPermissionSet: + def test_get_choice_label(self): + """ + Given a blank `PermissionSet` + When `_get_choice_label` is called with an unknown field and value + Then the unknown value is returned + """ + test_permission_set = PermissionSet() + unknown_field = "unknown_field_type" + test_value = "12345" + + # When + result = test_permission_set._get_choice_label(unknown_field, test_value) + + # Then + assert result == test_value diff --git a/tests/unit/cms/auth_content/test_auth_utils.py b/tests/unit/cms/auth_content/test_auth_utils.py new file mode 100644 index 000000000..ba64ad56c --- /dev/null +++ b/tests/unit/cms/auth_content/test_auth_utils.py @@ -0,0 +1,103 @@ +import pytest +from unittest.mock import MagicMock +from django import forms + +from cms.auth_content.auth_utils import _create_form_field + + +class TestCreateFormField: + def test_create_form_field_basic(self): + """ + Given no wildcard or callables in data + When `_create_form_field` is called + Then only the default choices are returned + """ + field_data = { + "field_choice_default": "Select an option", + "field_choice_wildcard": None, + "field_choice_callable": None, + "field_label": "My Label", + } + + result = _create_form_field(field_data) + + assert isinstance(result, forms.CharField) + assert result.label == "My Label" + expected_choices = [("", "Select an option")] + assert result.widget.choices == expected_choices + + def test_create_form_field_with_wildcard(self): + """ + Given a wildcard in the data + When `_create_form_field` is called + Then the wildcard choice is added + """ + field_data = { + "field_choice_default": "Default", + "field_choice_wildcard": "All Items", + "field_choice_callable": None, + "field_label": "Label", + } + wildcard_val = "-1" + + result = _create_form_field(field_data, wildcard_id_value=wildcard_val) + + expected_choices = [("", "Default"), ("-1", "All Items")] + assert result.widget.choices == expected_choices + + def test_create_form_field_with_callable(self): + """ + Given a callable in the data + When `_create_form_field` is called + Then the callable choice is added + """ + mock_callable = MagicMock(return_value=[("1", "One"), ("2", "Two")]) + + field_data = { + "field_choice_default": "Default", + "field_choice_wildcard": None, + "field_choice_callable": mock_callable, + "field_label": "Label", + } + + result = _create_form_field(field_data) + + expected_choices = [("", "Default"), ("1", "One"), ("2", "Two")] + assert result.widget.choices == expected_choices + mock_callable.assert_called_once() + + def test_create_form_field_all_features(self): + """ + Given both a wildcard and callable are in the data + When `_create_form_field` is called + Then both the wildcard and callable choices are added + """ + """Test combined default, wildcard, and callable choices""" + mock_callable = MagicMock(return_value=[("dynamic", "Dynamic")]) + field_data = { + "field_choice_default": "Default", + "field_choice_wildcard": "Wildcard", + "field_choice_callable": mock_callable, + "field_label": "Label", + } + + result = _create_form_field(field_data, wildcard_id_value="999") + + expected_choices = [ + ("", "Default"), + ("999", "Wildcard"), + ("dynamic", "Dynamic"), + ] + assert result.widget.choices == expected_choices + + def test_create_form_field_missing_key_error(self): + """ + Given data with missing keys + When `_create_form_field` is called + Then a key error exception is raised + """ + """Test behavior if a required key is missing from the dict""" + field_data = {"field_choice_default": "Missing other keys"} + + with pytest.raises(KeyError): + _create_form_field(field_data) diff --git a/tests/unit/auth_content/test_wagtail_hooks.py b/tests/unit/cms/auth_content/test_wagtail_hooks.py similarity index 58% rename from tests/unit/auth_content/test_wagtail_hooks.py rename to tests/unit/cms/auth_content/test_wagtail_hooks.py index bf1c76cdb..0a2b19f04 100644 --- a/tests/unit/auth_content/test_wagtail_hooks.py +++ b/tests/unit/cms/auth_content/test_wagtail_hooks.py @@ -1,9 +1,23 @@ -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch from django.test import TestCase from django.utils.safestring import SafeData -from auth_content.models.permission_sets import PermissionSet -from auth_content.wagtail_hooks import NoEditPermissionPolicy, PermissionSetViewSet +from cms.auth_content.models.permission_sets import PermissionSet +from cms.auth_content.wagtail_hooks import ( + NoEditPermissionPolicy, + PermissionSetViewSet, + AuthGroup, + register_auth_viewset, +) + + +class TestWagtailHooks(TestCase): + def test_register_auth_viewset(self): + result = register_auth_viewset() + assert result.menu_label == AuthGroup.menu_label + assert result.menu_icon == AuthGroup.menu_icon + assert result.menu_order == AuthGroup.menu_order + assert len(result.items) == 2 class TestPermissionSetDetailsProperty(TestCase): @@ -41,6 +55,14 @@ def setUp(self): def test_change_permission_denied(self): self.assertFalse(self.policy.user_has_permission(self.user, "change")) + @patch("wagtail.permission_policies.ModelPermissionPolicy.user_has_permission") + def test_user_has_permission_calls_super(self, spy_user_has_permissions: MagicMock): + spy_user_has_permissions.return_value = "parent_response" + result = self.policy.user_has_permission(self.user, "view") + + spy_user_has_permissions.assert_called_once_with(self.user, "view") + assert result == "parent_response" + def test_change_permission_denied_for_instance(self): self.assertFalse( self.policy.user_has_permission_for_instance( @@ -48,6 +70,22 @@ def test_change_permission_denied_for_instance(self): ) ) + @patch( + "wagtail.permission_policies.ModelPermissionPolicy.user_has_permission_for_instance" + ) + def test_user_has_permission_for_instance_calls_super( + self, spy_user_has_permissions_for_instance: MagicMock + ): + spy_user_has_permissions_for_instance.return_value = "parent_response" + result = self.policy.user_has_permission_for_instance( + self.user, "view", self.instance + ) + + spy_user_has_permissions_for_instance.assert_called_once_with( + self.user, "view", self.instance + ) + assert result == "parent_response" + class TestPermissionSetViewSet(TestCase): diff --git a/tests/unit/cms/dashboard/test_viewsets.py b/tests/unit/cms/dashboard/test_viewsets.py index 0620091aa..165ccb266 100644 --- a/tests/unit/cms/dashboard/test_viewsets.py +++ b/tests/unit/cms/dashboard/test_viewsets.py @@ -1,5 +1,155 @@ +import pytest + from cms.dashboard.serializers import CMSDraftPagesSerializer, ListablePageSerializer -from cms.dashboard.viewsets import CMSDraftPagesViewSet, CMSPagesAPIViewSet +from cms.dashboard.viewsets import ( + CMSDraftPagesViewSet, + CMSPagesAPIViewSet, + check_permissions, +) + + +class TestCheckPermissions: + @pytest.mark.parametrize( + "user_permissions, theme_id, sub_theme_id, topic_id", + [ + ([{"theme": {"id": "-1"}}], "10", "20", "30"), + ([{"theme": {"id": "10"}, "sub_theme": {"id": "-1"}}], "10", "20", "30"), + ( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "-1"}, + } + ], + "10", + "20", + "30", + ), + ( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "30"}, + } + ], + "10", + "20", + "30", + ), + ( + [ + {"theme": {"id": "5"}, "sub_theme": {"id": "-1"}}, + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "30"}, + }, + ], + "10", + "20", + "30", + ), + ], + ) + def test_check_permissions_valid_access( + self, user_permissions, theme_id, sub_theme_id, topic_id + ): + """ + Given a permission set that does grant access to the provided ids + When the `check_permissions` function is called + Then the function returns true + """ + assert ( + check_permissions(user_permissions, theme_id, sub_theme_id, topic_id) + == True + ) + + @pytest.mark.parametrize( + "user_permissions, theme_id, sub_theme_id, topic_id", + [ + ([{"theme": {"id": "99"}, "sub_theme": {"id": "-1"}}], "10", "20", "30"), + ( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "99"}, + "topic": {"id": "-1"}, + } + ], + "10", + "20", + "30", + ), + ( + [ + { + "theme": {"id": "10"}, + "sub_theme": {"id": "20"}, + "topic": {"id": "99"}, + } + ], + "10", + "20", + "30", + ), + ([], "10", "20", "30"), + ], + ) + def test_check_permissions_invalid_access( + self, user_permissions, theme_id, sub_theme_id, topic_id + ): + """ + Given a permission set that does not grant access to the provided ids + When the `check_permissions` function is called + Then the function returns false + """ + assert ( + check_permissions(user_permissions, theme_id, sub_theme_id, topic_id) + == False + ) + + @pytest.mark.parametrize( + "user_permissions, theme_id, sub_theme_id, topic_id", + [ + ([{}], "10", "20", "30"), + (None, "10", "20", "30"), + ([{"sub_theme": {"id": "-1"}, "topic": {"id": "-1"}}], "10", "20", "30"), + ( + [{"theme": {}, "sub_theme": {"id": "-1"}, "topic": {"id": "-1"}}], + "10", + "20", + "30", + ), + ([{"theme": {"id": "10"}, "topic": {"id": "-1"}}], "10", "20", "30"), + ( + [{"theme": {"id": "10"}, "sub_theme": {}, "topic": {"id": "-1"}}], + "10", + "20", + "30", + ), + ([{"theme": {"id": "10"}, "sub_theme": {"id": "20"}}], "10", "20", "30"), + ( + [{"theme": {"id": "10"}, "sub_theme": {"id": "20"}, "topic": {}}], + "10", + "20", + "30", + ), + ], + ) + def test_check_permissions_with_missing_values( + self, user_permissions, theme_id, sub_theme_id, topic_id + ): + """ + Given a permission set that is missing values + When the `check_permissions` function is called + Then the function returns false + """ + assert ( + check_permissions(user_permissions, theme_id, sub_theme_id, topic_id) + == False + ) class TestCMSDraftPagesViewSet: diff --git a/tests/unit/cms/metrics_documentation/data_migration/test_child_entries.py b/tests/unit/cms/metrics_documentation/data_migration/test_child_entries.py index 3d813e26b..39df8df1d 100644 --- a/tests/unit/cms/metrics_documentation/data_migration/test_child_entries.py +++ b/tests/unit/cms/metrics_documentation/data_migration/test_child_entries.py @@ -96,14 +96,18 @@ def test_returns_correct_dictionary( "Fake page description", "Fake methodology content", "Fake caveats content", + 1, ) spy_build_sections.return_value = [] expected_response = { "title": "Fake title", + "topic": "Fake", + "theme": "test", + "sub_theme": "test", "seo_title": "Fake title | UKHSA data dashboard", "search_description": "Fake page description", "page_description": "Fake page description", - "metric": "Fake_metric_name", + "metric": 1, "body": [], } @@ -132,6 +136,7 @@ def build_worksheet() -> Worksheet: work_sheet["E2"] = "Fake page description" work_sheet["F2"] = "Fake methodology content" work_sheet["G2"] = "Fake caveats content" + work_sheet["H2"] = 1 return work_sheet @@ -150,10 +155,13 @@ def test_delegates_calls_correctly( expected_response = [ { "title": "Fake title", + "topic": "Fake", + "theme": "test", + "sub_theme": "test", "seo_title": "Fake title | UKHSA data dashboard", "search_description": "Fake page description", "page_description": "Fake page description", - "metric": "Fake_metric_name", + "metric": 1, "body": [ { "type": "section", diff --git a/tests/unit/cms/metrics_documentation/data_migration/test_operations.py b/tests/unit/cms/metrics_documentation/data_migration/test_operations.py index b4741209a..274b78e5c 100644 --- a/tests/unit/cms/metrics_documentation/data_migration/test_operations.py +++ b/tests/unit/cms/metrics_documentation/data_migration/test_operations.py @@ -128,9 +128,11 @@ def test_log_recorded_when_metric_not_available_for_child_page( create_metrics_documentation_parent_page_and_child_entries() # Then - expected_log = ( - f"Metrics Documentation Child Entry for {fake_metric} was not created. " + expected_log_part_one = "Metrics Documentation Child Entry for " + expected_log_part_two = ( + " was not created. " "Because the corresponding `Metric` was not created beforehand" ) - assert expected_log in caplog.text + assert expected_log_part_one in caplog.text + assert expected_log_part_two in caplog.text diff --git a/tests/unit/cms/metrics_documentation/models/test_child.py b/tests/unit/cms/metrics_documentation/models/test_child.py index 08cc67a9b..8f2b60a90 100644 --- a/tests/unit/cms/metrics_documentation/models/test_child.py +++ b/tests/unit/cms/metrics_documentation/models/test_child.py @@ -1,12 +1,14 @@ from unittest import mock +from unittest.mock import MagicMock, patch from django.core.exceptions import ValidationError import pytest from wagtail.admin.panels import FieldPanel from wagtail.api.conf import APIField -from cms.metrics_documentation.models import child +from cms.dashboard.constants import THEME_FIELDS from cms.metrics_documentation.models.child import ( + MetricsDocumentationChildEntryAdminForm, InvalidTopicForChosenMetricForChildEntryError, ) from tests.fakes.factories.cms.metrics_documentation_child_entry_factory import ( @@ -16,6 +18,167 @@ MODULE_PATH = "cms.metrics_documentation.models.child" +class TestInvalidTopicForChosenMetricForChildEntryError: + def test_exception_has_expected_message(self): + actual = InvalidTopicForChosenMetricForChildEntryError( + "test_topic", "test_metric" + ) + expected = "InvalidTopicForChosenMetricForChildEntryError('The `test_topic` is not available for selected metric of `test_metric`')" + + assert expected == repr(actual) + + +class TestMetricsDocumentationChildEntryAdminForm: + MOCK_THEME_FIELDS = [ + {"field_name": "theme", "label": "Theme", "required": True}, + {"field_name": "sub_theme", "label": "Sub Theme", "required": False}, + {"field_name": "topic", "label": "Topic", "required": False}, + ] + + def _make_form(self, instance=None): + """ + Instantiate MetricsDocumentationChildEntryAdminForm with all Wagtail + internals patched. + """ + with ( + patch( + "wagtail.admin.panels.WagtailAdminPageForm.__init__", return_value=None + ), + patch( + "cms.metrics_documentation.models.child.THEME_FIELDS", + self.MOCK_THEME_FIELDS, + ), + patch( + "cms.metrics_documentation.models.child._create_form_field", + side_effect=lambda field: MagicMock(name=field["field_name"]), + ), + ): + form = MetricsDocumentationChildEntryAdminForm.__new__( + MetricsDocumentationChildEntryAdminForm + ) + form.fields = {} + form.instance = instance or MagicMock(pk=None) + form.__init__() + return form + + def _make_form_with_instance(self, sub_theme=None, topic=None): + """ + Instantiate MetricsDocumentationChildEntryAdminForm with all Wagtail + internals patched, and a mocked instance. + """ + instance = MagicMock(pk=1) + instance.sub_theme = sub_theme + instance.topic = topic + + form = self._make_form(instance=instance) + + for field_name in ("sub_theme", "topic"): + mock_widget = MagicMock() + mock_widget.choices = [] + form.fields[field_name] = MagicMock(widget=mock_widget) + + return form + + def test_creates_field_for_every_theme_field(self): + """ + When a new form is instantiated + Then a form field is added to `fields` for each entry in `THEME_FIELDS`. + """ + form = self._make_form() + + assert len(form.fields) == 3 + assert "theme" in form.fields + assert "sub_theme" in form.fields + assert "topic" in form.fields + + @mock.patch("cms.metrics_documentation.models.child._create_form_field") + @mock.patch("wagtail.admin.panels.WagtailAdminPageForm.__init__") + def test_field_creation_uses_create_form_field_helper( + self, spy_init_admin_form: mock.MagicMock, spy_create_form_field: mock.MagicMock + ): + """ + Given a new form is created + When init is called on the form + Then `_create_form_field` is called once per `THEME_FIELDS` entry. + """ + form = MetricsDocumentationChildEntryAdminForm.__new__( + MetricsDocumentationChildEntryAdminForm + ) + form.fields = {} + form.instance = MagicMock(pk=None) + form.__init__() + + assert spy_create_form_field.call_count == len(THEME_FIELDS) + spy_create_form_field.assert_any_call(THEME_FIELDS[0]) + spy_create_form_field.assert_any_call(THEME_FIELDS[1]) + spy_create_form_field.assert_any_call(THEME_FIELDS[2]) + spy_create_form_field.assert_any_call(THEME_FIELDS[3]) + + def test_initialize_dependent_fields_called_when_instance_has_pk(self): + """ + Given a new form + When an instance has a pk value set + Then `_initialize_dependent_fields` is called + """ + instance = MagicMock(pk=42) + + with patch.object( + MetricsDocumentationChildEntryAdminForm, + "_initialize_dependent_fields", + ) as mock_init_deps: + self._make_form(instance=instance) + + mock_init_deps.assert_called_once() + + def test_initialize_dependent_fields_not_called_when_instance_has_no_pk(self): + """ + Given a new form + When an instance does not have a pk value set + Then `_initialize_dependent_fields` is not called + """ + instance = MagicMock(pk=None) + + with patch.object( + MetricsDocumentationChildEntryAdminForm, + "_initialize_dependent_fields", + ) as mock_init_deps: + self._make_form(instance=instance) + + mock_init_deps.assert_not_called() + + def test_both_fields_updated_when_both_have_values(self): + """ + Given a new form with a sub_theme and topic + When `_initialize_dependent_fields` is called + Then both sub_theme and topic choices are set + """ + form = self._make_form_with_instance(sub_theme=3, topic=7) + + form._initialize_dependent_fields() + + assert form.fields["sub_theme"].widget.choices == [ + ("", "Select theme first"), + (3, "Loading... (ID: 3)"), + ] + assert form.fields["topic"].widget.choices == [ + ("", "Select sub-theme first"), + (7, "Loading... (ID: 7)"), + ] + + def test_skips_field_when_value_is_none(self): + """ + Given a new form with no sub_theme or topic + When `_initialize_dependent_fields` is called + Then widget choices are left untouched. + """ + form = self._make_form_with_instance(sub_theme=None, topic=None) + original_choices = form.fields["sub_theme"].widget.choices + + form._initialize_dependent_fields() + + assert form.fields["sub_theme"].widget.choices == original_choices + + class TestMetricsDocumentationChildEntry: @pytest.mark.parametrize( "expected_api_field", @@ -32,10 +195,10 @@ class TestMetricsDocumentationChildEntry: "page_classification", ], ) - @mock.patch(f"{MODULE_PATH}.get_all_unique_metric_names") + @mock.patch(f"{MODULE_PATH}.get_all_metric_names_and_ids") def test_has_correct_api_fields( self, - mock_get_all_unique_metric_names: mock.MagicMock(), + mock_get_all_metric_names_and_ids: mock.MagicMock(), expected_api_field: str, ): """ @@ -63,10 +226,10 @@ def test_has_correct_api_fields( "body", ], ) - @mock.patch(f"{MODULE_PATH}.get_all_unique_metric_names") + @mock.patch(f"{MODULE_PATH}.get_all_metric_names_and_ids") def test_has_the_correct_content_panels( self, - mock_get_all_unique_metric_names: mock.MagicMock(), + mock_get_all_metric_names_and_ids: mock.MagicMock(), expected_content_panel_name: str, ): """ @@ -89,154 +252,207 @@ def test_has_the_correct_content_panels( fake_metrics_documentation_child_entry_page, expected_content_panel_name ) - @mock.patch(f"{MODULE_PATH}.get_a_list_of_all_topic_names") - @mock.patch.object(child.MetricsDocumentationChildEntry, "find_topic") - @mock.patch(f"{MODULE_PATH}.get_all_unique_metric_names") - def test_get_topic_delegates_calls_correctly( + @mock.patch(f"{MODULE_PATH}.get_all_metric_names_and_ids") + @pytest.mark.parametrize( + "metric_id, metric_group", + [ + (1, "cases"), + (2, "headline"), + (3, "vaccinations"), + (4, "deaths"), + ], + ) + def test_metric_group_returns_expected_string( self, - mock_get_all_unique_metric_names: mock.MagicMock, - spy_find_topic: mock.MagicMock, - spy_get_a_list_of_all_topic_names: mock.MagicMock, + get_all_metric_names_and_ids: mock.MagicMock, + metric_id: int, + metric_group: str, ): """ Given a blank `MetricsDocumentationChildEntryPage` model. - When `get_topic()` is called. - Then the `get_a_list_of_all_topic_names()` method and `find_topic()` - methods are called. + When a metric id is supplied to the `metric` property. + Then the metric_group will be correctly extracted from the string. """ # Given - fake_topics = ["COVID-19", "Influenza"] + get_all_metric_names_and_ids.return_value = [ + (1, "COVID-19_cases_rateRollingMean"), + (2, "COVID-19_headline_vaccines_autumn23Total"), + (3, "COVID-19_vaccinations_autumn22_uptakeByDay"), + (4, "COVID-19_deaths_ONSByWeek"), + ] fake_metrics_documentation_child_entry_page = ( FakeMetricsDocumentationChildEntryFactory.build_page_from_template() ) - fake_metrics_documentation_child_entry_page.metric = ( - "COVID-19_cases_rateRollingMean" - ) # When - spy_get_a_list_of_all_topic_names.return_value = fake_topics - fake_metrics_documentation_child_entry_page.get_topic() + fake_metrics_documentation_child_entry_page.metric = metric_id # Then - spy_get_a_list_of_all_topic_names.assert_called_once() - spy_find_topic.assert_called_once() + assert fake_metrics_documentation_child_entry_page.metric_group == metric_group - @mock.patch(f"{MODULE_PATH}.get_a_list_of_all_topic_names") - @mock.patch(f"{MODULE_PATH}.get_all_unique_metric_names") + @mock.patch(f"{MODULE_PATH}.get_all_metric_names_and_ids") @pytest.mark.parametrize( - "metric_name, metric_group", - [ - ("COVID-19_cases_rateRollingMean", "cases"), - ("COVID-19_headline_vaccines_autumn23Total", "headline"), - ("COVID-19_vaccinations_autumn22_uptakeByDay", "vaccinations"), - ("COVID-19_deaths_ONSByWeek", "deaths"), - ], + "metric_id", + [1, 2, 3, 4, 5], ) - def test_metric_group_returns_expected_string( - self, - mock_get_all_unique_metric_names: mock.MagicMock, - mock_get_all_topic_names: mock.MagicMock, - metric_name: str, - metric_group: str, + def test_metric_group_returns_emptry_string_with_missing_values( + self, get_all_metric_names_and_ids: mock.MagicMock, metric_id: int ): """ Given a blank `MetricsDocumentationChildEntryPage` model. - When a metric name is supplied to the `metric` property. - Then the metric_group will be correctly extracted from the string. + When a metric id is supplied to the `metric` property with invalid choices returned. + Then the metric_group will return an empty string. """ # Given + get_all_metric_names_and_ids.return_value = [ + (1, "COVID-19casesrateRollingMean"), + (2, "COVID-19_"), + (3, ""), + (4, None), + ] fake_metrics_documentation_child_entry_page = ( FakeMetricsDocumentationChildEntryFactory.build_page_from_template() ) # When - fake_metrics_documentation_child_entry_page.metric = metric_name + fake_metrics_documentation_child_entry_page.metric = metric_id # Then - assert fake_metrics_documentation_child_entry_page.metric_group == metric_group + assert fake_metrics_documentation_child_entry_page.metric_group == "" - @mock.patch(f"{MODULE_PATH}.get_all_unique_metric_names") - @mock.patch(f"{MODULE_PATH}.get_a_list_of_all_topic_names") - @pytest.mark.parametrize( - "selected_metric, extracted_topic", - [ - ("COVID-19_cases_rateRollingMean", "COVID-19"), - ("influenza_headline_ICUHDUadmissionRatePercentChange", "Influenza"), - ("hMPV_testing_positivityByWeek", "hMPV"), - ("parainfluenza_headline_positivityLatest", "Parainfluenza"), - ("rhinovirus_headline_positivityLatest", "Rhinovirus"), - ("RSV_headline_admissionRateLatest", "RSV"), - ("adenovirus_headline_positivityLatest", "Adenovirus"), - ], + @mock.patch(f"{MODULE_PATH}.get_all_metric_names_and_ids") + def test_metric_group_returns_emptry_string_with_empty_metrics( + self, get_all_metric_names_and_ids: mock.MagicMock + ): + """ + Given a blank `MetricsDocumentationChildEntryPage` model. + When a metric id is supplied to the `metric` property with no choices returned. + Then the metric_group will return an empty string. + """ + # Given + get_all_metric_names_and_ids.return_value = [] + fake_metrics_documentation_child_entry_page = ( + FakeMetricsDocumentationChildEntryFactory.build_page_from_template() + ) + + # When + fake_metrics_documentation_child_entry_page.metric = 1 + + # Then + assert fake_metrics_documentation_child_entry_page.metric_group == "" + + @mock.patch( + "cms.dashboard.models.UKHSAPage._raise_error_if_seo_title_tag_not_provided", + return_value=None, ) - def test_find_topic_returns_expected_topic_name( + @mock.patch( + "cms.dashboard.models.UKHSAPage._raise_error_if_slug_not_unique", + return_value=None, + ) + @mock.patch(f"{MODULE_PATH}.get_all_metric_names_and_ids") + def test_public_error_raised_if_invalid_classification( self, - spy_get_a_list_of_all_topic_names: mock.MagicMock(), - mock_get_all_unique_metric_names: mock.MagicMock(), - selected_metric: str, - extracted_topic: str, + mock_get_all_metric_names_and_ids: mock.MagicMock(), + mock_slug_raise_error, + mock_seo_title_raise_error, ): """ - Given a blank `MetricsDocumentationChildEntryPage` model - a list of topics and a metric name. - When the `find_topic()` method is called. - Then the expected topic name will be matched from the list - using the metric name. + Given is_public is False (i.e the page is a non public page). + When no page classification is given. + Then a `ValidationError` is raised. """ # Given fake_metrics_documentation_child_entry_page = ( FakeMetricsDocumentationChildEntryFactory.build_page_from_template() ) - fake_topics = [ - "COVID-19", - "Influenza", - "RSV", - "hMPV", - "Parainfluenza", - "Rhinovirus", - "Adenovirus", - ] - fake_metrics_documentation_child_entry_page.metric = selected_metric - # When - return_topic = fake_metrics_documentation_child_entry_page.find_topic( - topics=fake_topics + fake_metrics_documentation_child_entry_page.is_public = False + fake_metrics_documentation_child_entry_page.page_classification = None + fake_metrics_documentation_child_entry_page.theme = "test" + fake_metrics_documentation_child_entry_page.sub_theme = "test" + fake_metrics_documentation_child_entry_page.topic = "test" + + # When/Then + with pytest.raises(ValidationError) as e: + fake_metrics_documentation_child_entry_page.clean() + + assert "Please select a classification level for this non-public page" in str( + e.value ) - # Then - assert return_topic == extracted_topic + @mock.patch( + "cms.dashboard.models.UKHSAPage._raise_error_if_seo_title_tag_not_provided", + return_value=None, + ) + @mock.patch( + "cms.dashboard.models.UKHSAPage._raise_error_if_slug_not_unique", + return_value=None, + ) + @mock.patch(f"{MODULE_PATH}.get_all_metric_names_and_ids") + def test_public_error_raised_if_invalid_theme( + self, + mock_get_all_metric_names_and_ids: mock.MagicMock(), + mock_slug_raise_error, + mock_seo_title_raise_error, + ): + """ + Given is_public is False (i.e the page is a non public page). + When no theme is given. + Then a `ValidationError` is raised. + """ + # Given + fake_metrics_documentation_child_entry_page = ( + FakeMetricsDocumentationChildEntryFactory.build_page_from_template() + ) - @mock.patch(f"{MODULE_PATH}.get_all_unique_metric_names") - @mock.patch(f"{MODULE_PATH}.get_a_list_of_all_topic_names") - def test_find_topic_raises_error( + fake_metrics_documentation_child_entry_page.is_public = False + fake_metrics_documentation_child_entry_page.page_classification = "test" + fake_metrics_documentation_child_entry_page.theme = None + fake_metrics_documentation_child_entry_page.sub_theme = "test" + fake_metrics_documentation_child_entry_page.topic = "test" + + # When/Then + with pytest.raises(ValidationError) as e: + fake_metrics_documentation_child_entry_page.clean() + + assert "Please select a theme for this non-public page" in str(e.value) + + @mock.patch( + "cms.dashboard.models.UKHSAPage._raise_error_if_seo_title_tag_not_provided", + return_value=None, + ) + @mock.patch( + "cms.dashboard.models.UKHSAPage._raise_error_if_slug_not_unique", + return_value=None, + ) + @mock.patch(f"{MODULE_PATH}.get_all_metric_names_and_ids") + def test_public_error_raised_if_invalid_sub_theme( self, - mock_get_a_list_of_all_topic_names: mock.MagicMock(), - mock_get_all_unique_metric_names: mock.MagicMock(), + mock_get_all_metric_names_and_ids: mock.MagicMock(), + mock_slug_raise_error, + mock_seo_title_raise_error, ): """ - Given a metric name that does not include a valid topic. - When the `find_topic()` method is called with a list of topics. - Then an `InvalidTopicForChosenMetricForChildEntryError` is raised. + Given is_public is False (i.e the page is a non public page). + When no sub theme is given. + Then a `ValidationError` is raised. """ # Given - fake_invalid_metric = "invalid_metric_contains_no_topic" - fake_topics = [ - "COVID-19", - "Influenza", - "RSV", - "hMPV", - "Parainfluenza", - "Rhinovirus", - "Adenovirus", - ] fake_metrics_documentation_child_entry_page = ( FakeMetricsDocumentationChildEntryFactory.build_page_from_template() ) - fake_metrics_documentation_child_entry_page.metric = fake_invalid_metric - # When / Then - with pytest.raises(InvalidTopicForChosenMetricForChildEntryError): - fake_metrics_documentation_child_entry_page.find_topic(topics=fake_topics) + fake_metrics_documentation_child_entry_page.is_public = False + fake_metrics_documentation_child_entry_page.page_classification = "test" + fake_metrics_documentation_child_entry_page.theme = "None" + fake_metrics_documentation_child_entry_page.sub_theme = None + fake_metrics_documentation_child_entry_page.topic = "test" + + # When/Then + with pytest.raises(ValidationError) as e: + fake_metrics_documentation_child_entry_page.clean() + + assert "Please select a subtheme for this non-public page" in str(e.value) @mock.patch( "cms.dashboard.models.UKHSAPage._raise_error_if_seo_title_tag_not_provided", @@ -246,18 +462,16 @@ def test_find_topic_raises_error( "cms.dashboard.models.UKHSAPage._raise_error_if_slug_not_unique", return_value=None, ) - @mock.patch(f"{MODULE_PATH}.get_all_unique_metric_names") - @mock.patch(f"{MODULE_PATH}.get_a_list_of_all_topic_names") - def test_public_error_raised_if_invalid_classification( + @mock.patch(f"{MODULE_PATH}.get_all_metric_names_and_ids") + def test_public_error_raised_if_invalid_topic( self, - mock_get_a_list_of_all_topic_names: mock.MagicMock(), - mock_get_all_unique_metric_names: mock.MagicMock(), + mock_get_all_metric_names_and_ids: mock.MagicMock(), mock_slug_raise_error, mock_seo_title_raise_error, ): """ Given is_public is False (i.e the page is a non public page). - When no page classification is given. + When no topic is given. Then a `ValidationError` is raised. """ # Given @@ -266,12 +480,17 @@ def test_public_error_raised_if_invalid_classification( ) fake_metrics_documentation_child_entry_page.is_public = False - fake_metrics_documentation_child_entry_page.page_classification = None + fake_metrics_documentation_child_entry_page.page_classification = "test" + fake_metrics_documentation_child_entry_page.theme = "test" + fake_metrics_documentation_child_entry_page.sub_theme = "test" + fake_metrics_documentation_child_entry_page.topic = None # When/Then - with pytest.raises(ValidationError): + with pytest.raises(ValidationError) as e: fake_metrics_documentation_child_entry_page.clean() + assert "Please select a topic for this non-public page" in str(e.value) + @mock.patch( "cms.dashboard.models.UKHSAPage._raise_error_if_seo_title_tag_not_provided", return_value=None, @@ -280,12 +499,10 @@ def test_public_error_raised_if_invalid_classification( "cms.dashboard.models.UKHSAPage._raise_error_if_slug_not_unique", return_value=None, ) - @mock.patch(f"{MODULE_PATH}.get_all_unique_metric_names") - @mock.patch(f"{MODULE_PATH}.get_a_list_of_all_topic_names") + @mock.patch(f"{MODULE_PATH}.get_all_metric_names_and_ids") def test_public_page_clears_page_classification( self, - mock_get_a_list_of_all_topic_names: mock.MagicMock(), - mock_get_all_unique_metric_names: mock.MagicMock(), + mock_get_all_metric_names_and_ids: mock.MagicMock(), mock_slug_raise_error, mock_seo_title_raise_error, ): @@ -316,12 +533,10 @@ def test_public_page_clears_page_classification( "cms.dashboard.models.UKHSAPage._raise_error_if_slug_not_unique", return_value=None, ) - @mock.patch(f"{MODULE_PATH}.get_all_unique_metric_names") - @mock.patch(f"{MODULE_PATH}.get_a_list_of_all_topic_names") + @mock.patch(f"{MODULE_PATH}.get_all_metric_names_and_ids") def test_non_public_page_doesnt_clean_page_classification( self, - mock_get_a_list_of_all_topic_names: mock.MagicMock(), - mock_get_all_unique_metric_names: mock.MagicMock(), + mock_get_all_metric_names_and_ids: mock.MagicMock(), mock_slug_raise_error, mock_seo_title_raise_error, ): @@ -337,6 +552,9 @@ def test_non_public_page_doesnt_clean_page_classification( fake_metrics_documentation_child_entry_page.is_public = False fake_metrics_documentation_child_entry_page.page_classification = "official" + fake_metrics_documentation_child_entry_page.theme = "infectious_disease" + fake_metrics_documentation_child_entry_page.sub_theme = "respiratory" + fake_metrics_documentation_child_entry_page.topic = "COVID-19" # When fake_metrics_documentation_child_entry_page.clean() diff --git a/tests/unit/cms/topic/test_models.py b/tests/unit/cms/topic/test_models.py index 9b0c1f156..d6458a5ae 100644 --- a/tests/unit/cms/topic/test_models.py +++ b/tests/unit/cms/topic/test_models.py @@ -5,7 +5,7 @@ from django.core.exceptions import ValidationError -from cms.topic.models import TopicPage +from cms.topic.models import TopicPage, TopicPageAdminForm from metrics.domain.charts.colour_scheme import RGBAChartLineColours from metrics.domain.charts.common_charts.plots.line_multi_coloured.properties import ( @@ -16,6 +16,108 @@ from wagtail.search.index import SearchField +class TestTopicPageAdminForm: + MOCK_THEME_FIELDS = [ + {"field_name": "theme", "label": "Theme", "required": True}, + {"field_name": "sub_theme", "label": "Sub Theme", "required": False}, + ] + + def _make_form(self, instance=None): + """ + Instantiate TopicPageAdminForm with all Wagtail + internals patched. + """ + with ( + mock.patch( + "wagtail.admin.panels.WagtailAdminPageForm.__init__", return_value=None + ), + mock.patch("cms.topic.models.THEME_FIELDS", self.MOCK_THEME_FIELDS), + mock.patch( + "cms.topic.models._create_form_field", + side_effect=lambda field: mock.MagicMock(name=field["field_name"]), + ), + ): + form = TopicPageAdminForm.__new__(TopicPageAdminForm) + form.fields = {} + form.instance = instance or mock.MagicMock(pk=None) + form.__init__() + return form + + def test_theme_fields_are_added_on_init(self): + """ + When a new form is instantieated + Then a form field is added to `fields` for each entry in `THEME_FIELDS`. + """ + form = self._make_form() + + assert len(form.fields) == 2 + assert "theme" in form.fields + assert "sub_theme" in form.fields + + def test_dependent_fields_initialised_for_saved_instance(self): + """ + Given a new form + When an instance has a pk value set + Then `_initialize_dependent_fields` is called + """ + with mock.patch.object( + TopicPageAdminForm, "_initialize_dependent_fields" + ) as init_fields_mock: + self._make_form(instance=mock.MagicMock(pk=1)) + init_fields_mock.assert_called_once() + + def test_dependent_fields_not_initialised_for_new_instance(self): + """ + Given a new form + When an instance does not have a pk value set + Then `_initialize_dependent_fields` is not called + """ + with mock.patch.object( + TopicPageAdminForm, "_initialize_dependent_fields" + ) as init_fields_mock: + self._make_form(instance=mock.MagicMock(pk=None)) + init_fields_mock.assert_not_called() + + def test_widget_choices_set_when_sub_theme_has_value(self): + """ + Given a new form with a sub_theme + When `_initialize_dependent_fields` is called + Then the sub_theme choices are set + """ + instance = mock.MagicMock(pk=1, sub_theme=5, topic=None) + form = self._make_form(instance=instance) + mock_widget = mock.MagicMock() + form.fields["sub_theme"] = mock.MagicMock(widget=mock_widget) + form.fields["topic"] = mock.MagicMock(widget=mock.MagicMock()) + + form._initialize_dependent_fields() + + assert mock_widget.choices == [ + ("", "Select theme first"), + (5, "Loading... (ID: 5)"), + ] + + def test_widget_choices_not_set_when_value_is_none(self): + """ + Given a new form with no sub_theme or topic + When `_initialize_dependent_fields` is called + Then widget choices are left untouched. + """ + instance = mock.MagicMock(pk=1, sub_theme=None, topic=None) + form = self._make_form(instance=instance) + mock_widget = mock.MagicMock(choices=[]) + form.fields["sub_theme"] = mock.MagicMock(widget=mock_widget) + form.fields["topic"] = mock.MagicMock(widget=mock.MagicMock(choices=[])) + + form._initialize_dependent_fields() + + assert mock_widget.choices == [] + + def test_get_field_choices_returns_correct_structure(self): + result = TopicPageAdminForm._get_field_choices(42, "Select theme first") + assert result == [("", "Select theme first"), (42, "Loading... (ID: 42)")] + + class TestTopicPage: @pytest.mark.parametrize( "expected_search_field", @@ -714,11 +816,117 @@ def test_public_error_raised_if_invalid_classification( fake_covid_topic_page.is_public = False fake_covid_topic_page.page_classification = None + fake_covid_topic_page.theme = "test" + fake_covid_topic_page.sub_theme = "test" + fake_covid_topic_page.topic = "test" + + # When/Then + with pytest.raises(ValidationError) as e: + fake_covid_topic_page.clean() + + assert "Please select a classification level for this non-public page" in str( + e.value + ) + + @mock.patch( + "cms.dashboard.models.UKHSAPage._raise_error_if_seo_title_tag_not_provided", + return_value=None, + ) + @mock.patch( + "cms.dashboard.models.UKHSAPage._raise_error_if_slug_not_unique", + return_value=None, + ) + def test_public_error_raised_if_invalid_theme( + self, + mock_slug_raise_error, + mock_seo_title_raise_error, + ): + """ + Given is_public is False (i.e the page is a non public page). + When no page theme is given. + Then a `ValidationError` is raised. + """ + # Given + fake_covid_topic_page = FakeTopicPageFactory.build_covid_19_page_from_template() + + fake_covid_topic_page.is_public = False + fake_covid_topic_page.page_classification = "test" + fake_covid_topic_page.theme = None + fake_covid_topic_page.sub_theme = "test" + fake_covid_topic_page.topic = "test" # When/Then - with pytest.raises(ValidationError): + with pytest.raises(ValidationError) as e: fake_covid_topic_page.clean() + assert "Please select a theme for this non-public page" in str(e.value) + + @mock.patch( + "cms.dashboard.models.UKHSAPage._raise_error_if_seo_title_tag_not_provided", + return_value=None, + ) + @mock.patch( + "cms.dashboard.models.UKHSAPage._raise_error_if_slug_not_unique", + return_value=None, + ) + def test_public_error_raised_if_invalid_sub_theme( + self, + mock_slug_raise_error, + mock_seo_title_raise_error, + ): + """ + Given is_public is False (i.e the page is a non public page). + When no page sub theme is given. + Then a `ValidationError` is raised. + """ + # Given + fake_covid_topic_page = FakeTopicPageFactory.build_covid_19_page_from_template() + + fake_covid_topic_page.is_public = False + fake_covid_topic_page.page_classification = "test" + fake_covid_topic_page.theme = "test" + fake_covid_topic_page.sub_theme = None + fake_covid_topic_page.topic = "test" + + # When/Then + with pytest.raises(ValidationError) as e: + fake_covid_topic_page.clean() + + assert "Please select a sub theme for this non-public page" in str(e.value) + + @mock.patch( + "cms.dashboard.models.UKHSAPage._raise_error_if_seo_title_tag_not_provided", + return_value=None, + ) + @mock.patch( + "cms.dashboard.models.UKHSAPage._raise_error_if_slug_not_unique", + return_value=None, + ) + def test_public_error_raised_if_invalid_topic( + self, + mock_slug_raise_error, + mock_seo_title_raise_error, + ): + """ + Given is_public is False (i.e the page is a non public page). + When no page topic is given. + Then a `ValidationError` is raised. + """ + # Given + fake_covid_topic_page = FakeTopicPageFactory.build_covid_19_page_from_template() + + fake_covid_topic_page.is_public = False + fake_covid_topic_page.page_classification = "test" + fake_covid_topic_page.theme = "test" + fake_covid_topic_page.sub_theme = "test" + fake_covid_topic_page.topic = None + + # When/Then + with pytest.raises(ValidationError) as e: + fake_covid_topic_page.clean() + + assert "Please select a topic for this non-public page" in str(e.value) + @mock.patch( "cms.dashboard.models.UKHSAPage._raise_error_if_seo_title_tag_not_provided", return_value=None, @@ -772,6 +980,9 @@ def test_non_public_page_doesnt_clean_page_classification( fake_covid_topic_page.is_public = False fake_covid_topic_page.page_classification = "official" + fake_covid_topic_page.theme = "infectious_disease" + fake_covid_topic_page.sub_theme = "respiratory" + fake_covid_topic_page.topic = "COVID-19" # When fake_covid_topic_page.clean() diff --git a/tests/unit/metrics/api/serializers/test_geographies.py b/tests/unit/metrics/api/serializers/test_geographies.py index fa9d1583d..108a7567e 100644 --- a/tests/unit/metrics/api/serializers/test_geographies.py +++ b/tests/unit/metrics/api/serializers/test_geographies.py @@ -4,7 +4,7 @@ from rest_framework.exceptions import ValidationError -from auth_content.constants import WILDCARD_ID_VALUE +from cms.auth_content.constants import WILDCARD_ID_VALUE from metrics.data.models.core_models.supporting import Geography from validation.geography_code import UNITED_KINGDOM_GEOGRAPHY_CODE from metrics.api.serializers.geographies import ( diff --git a/tests/unit/metrics/api/serializers/test_permission_sets.py b/tests/unit/metrics/api/serializers/test_permission_sets.py index 62ebd2045..7d6e95627 100644 --- a/tests/unit/metrics/api/serializers/test_permission_sets.py +++ b/tests/unit/metrics/api/serializers/test_permission_sets.py @@ -3,7 +3,7 @@ import pytest from rest_framework import serializers as drf_serializers -from auth_content.constants import WILDCARD_ID_VALUE +from cms.auth_content.constants import WILDCARD_ID_VALUE from metrics.api.serializers.permission_sets import ( MetricRequestSerializer, PermissionSetResponseSerializer,