diff --git a/cms/auth_content/auth_utils.py b/cms/auth_content/auth_utils.py index 56735c2da..cb54ffd23 100644 --- a/cms/auth_content/auth_utils.py +++ b/cms/auth_content/auth_utils.py @@ -1,8 +1,280 @@ +import logging from collections.abc import Callable from django import forms +from cms.auth_content.constants import WILDCARD_ID_VALUE from cms.dynamic_content import help_texts +from metrics.data.models.core_models.supporting import ( + Geography, + GeographyType, + Metric, + Topic, +) + +logger = logging.getLogger(__name__) + + +def check_permissions_by_name( + permission_sets, + theme_name, + sub_theme_name, + topic_name, + metric_name, + geography_type, + geography_name, +) -> bool: + """ + This is a wrapper that converts permission resource names + into ids. It is only used to check CHART permissions. + """ + + logger.info("Entered check_permissions_by_name()") + + theme_id, sub_theme_id, topic_id = Topic.objects.get_id_by_name( + theme_name, sub_theme_name, topic_name + ) + metric_id = Metric.objects.get_id_by_name(metric_name) + geography_type_id = GeographyType.objects.get_id_by_name(geography_type) + geography_id = Geography.objects.get_id_by_name(geography_name) + + # Be safe, just in case a NAME doesn't have an ID + if any( + value == -2 + for value in ( + theme_id, + sub_theme_id, + topic_id, + metric_id, + geography_type_id, + geography_id, + ) + ): + return False + + return check_permission_set( + permission_sets, + theme_id, + sub_theme_id, + topic_id, + metric_id, + geography_type_id, + geography_id, + ) + + +def check_permission_set( + permission_sets, + theme_id, + sub_theme_id, + topic_id, + metric_id, + geography_type, + geography_id, +) -> bool: + """ + This is a wrapper that only checks for global permissions, and + delegates further checks to our core permission checking function. + It is only used to check CHART permissions. + + @param {dict} permission_sets which contains a permission_sets list, eg: + { + "permission_sets": [ + { + "theme": {"id": "100", "name": "immunisation"}, + "sub_theme": {"id": "200", "name": "childhood-vaccines"}, + "topic": {"id": "-1", "name": "* (All)"}, + "metric": {"id": "-1", "name": "* (All)"}, + "geography_type": {"id": "300", "name": "Nation"}, + "geography": {"id": "-1", "name": "* (All)"}, + } + ], + "summary": {"has_global_access": False}, + } + """ + + logger.info("Entered check_permission_set()") + + if not isinstance(permission_sets, dict): + return False + if not isinstance(permission_sets.get("permission_sets"), list): + return False + if not isinstance(permission_sets.get("summary"), dict): + return False + if not isinstance(permission_sets.get("summary").get("has_global_access"), bool): + return False + + if permission_sets.get("summary").get("has_global_access"): + return True + + return check_permissions( + permission_sets.get("permission_sets"), + theme_id, + sub_theme_id, + topic_id, + metric_id, + geography_type, + geography_id, + ) + + +def check_permissions( + permission_sets, + theme_id, + sub_theme_id, + topic_id, + metric_id=None, + geography_type=None, + geography_id=None, +) -> bool: + """ + This is our core permission-checking function It is + used to check both PAGE & CHART permissions. + + Metric- and geography-related permissions must be + evaluated separately (spec says). + + @param {list} permission_sets, eg: + [ + { + "theme": {"id": "100", "name": "immunisation"}, + "sub_theme": {"id": "200", "name": "childhood-vaccines"}, + "topic": {"id": "-1", "name": "* (All)"}, + "metric": {"id": "-1", "name": "* (All)"}, + "geography_type": {"id": "300", "name": "Nation"}, + "geography": {"id": "-1", "name": "* (All)"}, + } + ] + """ + + logger.info("Entered check_permissions()") + + if not isinstance(permission_sets, list): + return False + + for permission_set in permission_sets: + if geography_type and geography_id: + # CHART permissions + if check_metric_related_permissions( + permission_set, theme_id, sub_theme_id, topic_id, metric_id + ) and check_geography_permissions( + permission_set, geography_type, geography_id + ): + return True + else: + # PAGE permissions + if check_metric_related_permissions( + permission_set, theme_id, sub_theme_id, topic_id, metric_id + ): + return True + + return False + + +def check_metric_related_permissions( + permission_set, + theme_id, + sub_theme_id, + topic_id, + metric_id=None, +) -> bool: + """ + Make sure that every theme, sub_theme, topic and metric + match or have a wildcard at the end (only look at the + first 4 attributes of permission_set). + + @param {dict} permission_set, eg: + { + "theme": {"id": "100", "name": "immunisation"}, + "sub_theme": {"id": "200", "name": "childhood-vaccines"}, + "topic": {"id": "-1", "name": "* (All)"}, + "metric": {"id": "-1", "name": "* (All)"}, + "geography_type": {"id": "300", "name": "Nation"}, + "geography": {"id": "-1", "name": "* (All)"}, + } + """ + + logger.info("Entered check_metric_related_permissions()") + + if not isinstance(permission_set, dict): + return False + + theme_id = str(theme_id) + sub_theme_id = str(sub_theme_id) + topic_id = str(topic_id) + metric_id = str(metric_id) + + permission_theme_id = str(permission_set.get("theme", {}).get("id")) + permission_sub_theme_id = str(permission_set.get("sub_theme", {}).get("id")) + permission_topic_id = str(permission_set.get("topic", {}).get("id")) + permission_metric_id = str(permission_set.get("metric", {}).get("id")) + + if permission_theme_id == WILDCARD_ID_VALUE: + return True + + if permission_theme_id == theme_id and permission_sub_theme_id == WILDCARD_ID_VALUE: + return True + + if ( + permission_theme_id == theme_id + and permission_sub_theme_id == sub_theme_id + and permission_topic_id in {WILDCARD_ID_VALUE, topic_id} + ): + return True + + if ( + permission_theme_id == theme_id + and permission_sub_theme_id == sub_theme_id + and permission_topic_id == topic_id + and permission_metric_id in {WILDCARD_ID_VALUE, metric_id} + ): + return True + + return False + + +def check_geography_permissions( + permission_set, + geography_type=None, + geography_id=None, +) -> bool: + """ + Make sure that both geography_type and geography + match or have a wildcard at the end (only look at the + first 2 attributes of permission_set). + + @param {dict} permission_set, eg: + { + "theme": {"id": "100", "name": "immunisation"}, + "sub_theme": {"id": "200", "name": "childhood-vaccines"}, + "topic": {"id": "-1", "name": "* (All)"}, + "metric": {"id": "-1", "name": "* (All)"}, + "geography_type": {"id": "300", "name": "Nation"}, + "geography": {"id": "-1", "name": "* (All)"}, + } + """ + + logger.info("Entered check_geography_permissions()") + + if not isinstance(permission_set, dict): + return False + + geography_type = str(geography_type) + geography_id = str(geography_id) + + permission_geography_type = str(permission_set.get("geography_type", {}).get("id")) + permission_geography_id = str(permission_set.get("geography", {}).get("id")) + + if permission_geography_type == WILDCARD_ID_VALUE: + return True + + if permission_geography_type == geography_type and permission_geography_id in { + WILDCARD_ID_VALUE, + geography_id, + }: + return True + + return False def _create_form_field( diff --git a/cms/dashboard/viewsets.py b/cms/dashboard/viewsets.py index 89c41446f..89b2c1ddc 100644 --- a/cms/dashboard/viewsets.py +++ b/cms/dashboard/viewsets.py @@ -9,40 +9,12 @@ from wagtail.api.v2.views import PagesAPIViewSet from caching.private_api.decorators import cache_response -from cms.auth_content.constants import WILDCARD_ID_VALUE +from cms.auth_content.auth_utils import check_permissions from cms.dashboard.serializers import CMSDraftPagesSerializer, ListablePageSerializer from cms.metrics_documentation.models.child import MetricsDocumentationChildEntry from cms.topic.models import TopicPage -def check_permissions(user_permissions, theme_id, sub_theme_id, topic_id) -> bool: - if not isinstance(user_permissions, list): - return False - - for permission in user_permissions: - permission_theme_id = permission.get("theme", {}).get("id") - permission_sub_theme_id = permission.get("sub_theme", {}).get("id") - permission_topic_id = permission.get("topic", {}).get("id") - - if permission_theme_id == WILDCARD_ID_VALUE: - return True - - if ( - permission_theme_id == theme_id - and permission_sub_theme_id == WILDCARD_ID_VALUE - ): - return True - - if ( - permission_theme_id == theme_id - and permission_sub_theme_id == sub_theme_id - and (permission_topic_id in {WILDCARD_ID_VALUE, topic_id}) - ): - return True - - return False - - @extend_schema(tags=["cms"]) class CMSPagesAPIViewSet(PagesAPIViewSet): # This is the /pages (or proxy/pages env dependent endpoint) diff --git a/common/auth/cognito_jwt/user_manager.py b/common/auth/cognito_jwt/user_manager.py index 3a69470a4..a101168e0 100644 --- a/common/auth/cognito_jwt/user_manager.py +++ b/common/auth/cognito_jwt/user_manager.py @@ -17,6 +17,23 @@ def get_or_create_for_cognito(jwt_payload): try: username = jwt_payload["entraObjectId"] permission_sets = jwt_payload["permissionSets"] + + # DEBUGGING: Manual testing (just for now) + # username = "{YOUR_ENTRA_OBJECT_ID}" + # permission_sets = { + # "permission_sets": [ + # { + # "theme": {"id": "100", "name": "immunisation"}, + # "sub_theme": {"id": "200", "name": "childhood-vaccines"}, + # "topic": {"id": "-1", "name": "* (All)"}, + # "metric": {"id": "-1", "name": "* (All)"}, + # "geography_type": {"id": "300", "name": "Nation"}, + # "geography": {"id": "400", "name": "England"}, + # } + # ], + # "summary": {"has_global_access": False}, + # } + if not permission_sets: logger.debug( "Empty permissionSets in token for user: '%s'", diff --git a/metrics/api/middleware/__init__.py b/metrics/api/middleware/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/metrics/api/middleware/sql_debug.py b/metrics/api/middleware/sql_debug.py new file mode 100644 index 000000000..dc00d0762 --- /dev/null +++ b/metrics/api/middleware/sql_debug.py @@ -0,0 +1,24 @@ +from django.db import connection + + +def _print_sql(execute, sql, params, many, context): + print(f"\n[SQL] {sql}") + if params: + print(f"[PARAMS] {params}") + return execute(sql, params, many, context) + + +class SQLDebugMiddleware: + """ + Middleware that prints the raw SQL and params for every DB query made + during a request/response cycle. + + Only intended for local development — add to MIDDLEWARE in local.py. + """ + + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + with connection.execute_wrapper(_print_sql): + return self.get_response(request) diff --git a/metrics/api/settings/local.py b/metrics/api/settings/local.py index 8d0cf9636..2a3f10dfa 100644 --- a/metrics/api/settings/local.py +++ b/metrics/api/settings/local.py @@ -29,6 +29,7 @@ MIDDLEWARE += [ "debug_toolbar.middleware.DebugToolbarMiddleware", + "metrics.api.middleware.sql_debug.SQLDebugMiddleware", ] INTERNAL_IPS = ["127.0.0.1"] diff --git a/metrics/data/managers/core_models/geography.py b/metrics/data/managers/core_models/geography.py index 5ff2ce8cb..046721983 100644 --- a/metrics/data/managers/core_models/geography.py +++ b/metrics/data/managers/core_models/geography.py @@ -47,6 +47,19 @@ def get_name_by_code(self, geography_code: str) -> str | None: .first() ) + def get_id_by_name(self, geography_name: str) -> int: + """ + Gets the geography ID for a given geography name. + + Args: + geography_name: The name of the geography to look up + + Returns: + The geography ID if found, or -2 otherwise + """ + record = self.filter(name=geography_name).first() + return int(record.id) if record else -2 + def get_all_geography_codes_by_geography_type( self, geography_type_name: str ) -> Self: @@ -167,6 +180,23 @@ def get_name_by_code(self, geography_code: int) -> str | None: """ return self.get_queryset().get_name_by_code(geography_code) + def get_id_by_name(self, geography_name: str) -> int: + """Gets the geography ID which matches the given geography name. + + Args: + geography_name: The name of the geography to look up + + Returns: + The geography ID if found, -2 otherwise + + Examples: + >>> GeographyManager.get_id_by_name("England") + 6 + >>> GeographyManager.get_id_by_name("Unknown geography") + -2 + """ + return self.get_queryset().get_id_by_name(geography_name) + def get_all_names(self) -> GeographyQuerySet: """Gets all available deduplicated geography names as a flat list queryset. diff --git a/metrics/data/managers/core_models/geography_type.py b/metrics/data/managers/core_models/geography_type.py index ae45b36b7..df85ccd5b 100644 --- a/metrics/data/managers/core_models/geography_type.py +++ b/metrics/data/managers/core_models/geography_type.py @@ -41,6 +41,19 @@ def get_name_by_id(self, geography_type_id: int) -> str | None: """ return self.filter(id=geography_type_id).values_list("name", flat=True).first() + def get_id_by_name(self, geography_type_name: str) -> int: + """ + Gets the geography type ID for a given geography type name. + + Args: + geography_type_name: The name of the geography type to look up + + Returns: + The geography type ID if found, or -2 otherwise + """ + record = self.filter(name=geography_type_name).first() + return int(record.id) if record else -2 + def get_all_names_and_ids(self) -> models.QuerySet: """Gets all available geography_type names as a flat list queryset. @@ -77,6 +90,23 @@ def get_name_by_id(self, geography_type_id: int) -> str | None: """ return self.get_queryset().get_name_by_id(geography_type_id) + def get_id_by_name(self, geography_type_name: str) -> int: + """Gets the geography type ID which matches the given geography type name. + + Args: + geography_type_name: The name of the geography type to look up + + Returns: + The geography type ID if found, -2 otherwise + + Examples: + >>> GeographyTypeManager.get_id_by_name("Nation") + 5 + >>> GeographyTypeManager.get_id_by_name("Unknown type") + -2 + """ + return self.get_queryset().get_id_by_name(geography_type_name) + def get_all_names(self) -> GeographyTypeQuerySet: """Gets all available geography_type names as a flat list queryset. diff --git a/metrics/data/managers/core_models/metric.py b/metrics/data/managers/core_models/metric.py index e9fcb961d..66a3ec60c 100644 --- a/metrics/data/managers/core_models/metric.py +++ b/metrics/data/managers/core_models/metric.py @@ -29,6 +29,19 @@ def get_name_by_id(self, metric_id: int) -> str | None: """ return self.filter(id=metric_id).values_list("name", flat=True).first() + def get_id_by_name(self, metric_name: str) -> int: + """ + Gets the metric ID for a given metric name. + + Args: + metric_name: The name of the metric to look up + + Returns: + The metric ID if found, or -2 otherwise + """ + record = self.filter(name=metric_name).first() + return int(record.id) if record else -2 + def get_all_names(self) -> models.QuerySet: """Gets all available metric names as a flat list queryset. @@ -146,6 +159,23 @@ def get_name_by_id(self, metric_id: int) -> str | None: """ return self.get_queryset().get_name_by_id(metric_id) + def get_id_by_name(self, metric_name: str) -> int: + """Gets the metric ID which matches the given metric name. + + Args: + metric_name: The name of the metric to look up + + Returns: + The metric ID if found, -2 otherwise + + Examples: + >>> MetricManager.get_id_by_name("COVID-19_cases_countRollingMean") + 4 + >>> MetricManager.get_id_by_name("Unknown metric") + -2 + """ + return self.get_queryset().get_id_by_name(metric_name) + def get_all_names(self) -> MetricQuerySet: """Gets all available metric names as a flat list queryset. diff --git a/metrics/data/managers/core_models/time_series.py b/metrics/data/managers/core_models/time_series.py index aed8df7f6..5b63aaca2 100644 --- a/metrics/data/managers/core_models/time_series.py +++ b/metrics/data/managers/core_models/time_series.py @@ -6,6 +6,7 @@ """ import datetime +import logging from collections.abc import Iterable from typing import Self @@ -15,12 +16,13 @@ from metrics.api.permissions.fluent_permissions import ( is_public_data_only_enforced, - validate_permissions_for_non_public, ) from metrics.data.models import RBACPermission ALLOWABLE_METRIC_VALUE_RANGE_TYPE = tuple[str | float | int, str | float | int] +logger = logging.getLogger(__name__) + class CoreTimeSeriesQuerySet(models.QuerySet): """Custom queryset which can be used by the `CoreTimeSeriesManager`""" @@ -171,8 +173,10 @@ def query_for_data( stratum: str | None = None, sex: str | None = None, age: str | None = None, + theme: str, + sub_theme: str, metric_value_ranges: list[tuple[str | float | int]] | None = None, - restrict_to_public: bool = True, + permission_sets: dict, ) -> models.QuerySet: """Filters for a N-item list of dicts by the given params if `fields_to_export` is used. @@ -212,14 +216,18 @@ def query_for_data( Note that options are `M`, `F`, or `ALL`. age: The age range to apply additional filtering to. E.g. `0_4` would be used to capture the age of 0-4 years old + theme: The name of the theme being queried. + This is only used to determine permissions for + the non-public portion of the requested dataset. + sub_theme: The name of the sub theme being queried. + This is only used to determine permissions for + the non-public portion of the requested dataset. metric_value_ranges: List of tuples whereby each tuple represents a permissible metric value range. i.e. to filter for all record with values between 0 -> 80 AND 90 -> 100, this can be provided as `[(0, 80), (90, 100)]`. - restrict_to_public: Boolean switch to restrict the query - to only return public records. - If False, then non-public records will be included. + permission_sets: The JWT permissions extracted from the Cognito token. Returns: QuerySet: An ordered queryset from lowest -> highest @@ -231,6 +239,9 @@ def query_for_data( ]>` """ + + logger.info("Entered query_for_data()") + queryset = self.filter( metric__topic__name=topic, metric__name=metric, @@ -245,9 +256,30 @@ def query_for_data( sex=sex, age=age, ) + public_queryset = queryset.filter(is_public=True) - if restrict_to_public: - queryset = queryset.filter(is_public=True) + if permission_sets: + logger.info("Entered if permission_sets clause") + + # WORKAROUND: Cos circular import error when at the top of the file + from cms.auth_content.auth_utils import check_permissions_by_name + + if check_permissions_by_name( + permission_sets, + theme, + sub_theme, + topic, + metric, + geography_type, + geography, + ): + logger.info("Entered check_permissions_by_name() if clause") + + queryset = public_queryset + queryset.filter(is_public=False) + else: + logger.info("Entered else permission_sets clause") + + queryset = public_queryset queryset = self._exclude_data_under_embargo(queryset=queryset) queryset = self._filter_for_metric_value_ranges( @@ -533,6 +565,7 @@ def query_for_data( sub_theme: str = "", metric_value_ranges: list[str | float | int] | None = None, rbac_permissions: Iterable[RBACPermission] | None = None, + permission_sets: dict, ) -> CoreTimeSeriesQuerySet: """Filters for a 2-item object by the given params. Slices all values older than the `date_from`. @@ -582,6 +615,7 @@ def query_for_data( rbac_permissions: The RBAC permissions available to the given request. This dictates whether the given request is permitted access to non-public data or not. + permission_sets: The JWT permissions extracted from the Cognito token. Notes: If we have the following input `queryset`: @@ -611,20 +645,12 @@ def query_for_data( ]>` """ - rbac_permissions: Iterable[RBACPermission] = rbac_permissions or [] - has_access_to_non_public_data: bool = validate_permissions_for_non_public( - theme=theme, - sub_theme=sub_theme, - topic=topic, - metric=metric, - geography_type=geography_type, - geography=geography, - rbac_permissions=rbac_permissions, - ) return self.get_queryset().query_for_data( fields_to_export=fields_to_export, field_to_order_by=field_to_order_by, + theme=theme, + sub_theme=sub_theme, topic=topic, metric=metric, date_from=date_from, @@ -635,7 +661,7 @@ def query_for_data( sex=sex, age=age, metric_value_ranges=metric_value_ranges, - restrict_to_public=not has_access_to_non_public_data, + permission_sets=permission_sets, ) def query_for_superseded_data( diff --git a/metrics/data/managers/core_models/topic.py b/metrics/data/managers/core_models/topic.py index 00b6f0632..cd67e9f1a 100644 --- a/metrics/data/managers/core_models/topic.py +++ b/metrics/data/managers/core_models/topic.py @@ -40,6 +40,40 @@ def get_name_by_id(self, topic_id: int) -> str | None: """ return self.filter(id=topic_id).values_list("name", flat=True).first() + def get_id_by_name( + self, theme_name: str, sub_theme_name: str, topic_name: str + ) -> tuple[int, int, int]: + """ + Gets the theme, sub-theme and topic IDs matching the given names. + + Args: + theme_name: The name of the parent theme + sub_theme_name: The name of the parent sub-theme + topic_name: The name of the topic to look up + + Returns: + A tuple of (theme_id, sub_theme_id, topic_id) if found, + or the tuple (-2, -2, -2) otherwise + """ + record = self.filter( + sub_theme__theme__name=theme_name, + sub_theme__name=sub_theme_name, + name=topic_name, + ).first() + + if record: + return ( + int(record.sub_theme.theme_id), + int(record.sub_theme_id), + int(record.id), + ) + + return ( + -2, + -2, + -2, + ) + def get_all_unique_names(self) -> models.QuerySet: """Gets all available unique topic names as a flat list queryset. @@ -113,6 +147,30 @@ def get_name_by_id(self, topic_id: int) -> str | None: """ return self.get_queryset().get_name_by_id(topic_id) + def get_id_by_name( + self, theme_name: str, sub_theme_name: str, topic_name: str + ) -> tuple[int, int, int]: + """Gets the theme, sub-theme and topic IDs matching the given names. + + Args: + theme_name: The name of the parent theme + sub_theme_name: The name of the parent sub-theme + topic_name: The name of the topic to look up + + Returns: + A tuple of (theme_id, sub_theme_id, topic_id) if found, + or (-2, -2, -2) if not found. + + Examples: + >>> TopicManager.get_id_by_name("Infectious disease", "Respiratory", "COVID-19") + (1, 2, 3) + >>> TopicManager.get_id_by_name("Unknown", "Unknown", "Unknown") + (-2, -2, -2) + """ + return self.get_queryset().get_id_by_name( + theme_name, sub_theme_name, topic_name + ) + def get_all_names(self) -> TopicQuerySet: """Gets all available topic names as a flat list queryset. diff --git a/metrics/domain/models/charts/common.py b/metrics/domain/models/charts/common.py index 317301450..7608749d3 100644 --- a/metrics/domain/models/charts/common.py +++ b/metrics/domain/models/charts/common.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Iterable from decimal import Decimal from typing import Literal @@ -5,6 +6,8 @@ from pydantic.main import BaseModel from rest_framework.request import Request +logger = logging.getLogger(__name__) + class BaseChartRequestParams(BaseModel): file_format: Literal["png", "svg", "jpg", "jpeg", "json", "csv"] @@ -24,6 +27,16 @@ class BaseChartRequestParams(BaseModel): class Config: arbitrary_types_allowed = True + @property + def permission_sets(self) -> dict: + """Extract JWT permissions from the authenticated request""" + + logger.info("Entered BaseChartRequestParams.permission_sets") + + return getattr(self.request.user, "permission_sets", {}) + @property def rbac_permissions(self) -> Iterable["RBACPermission"]: + """TODO: RBAC-based permissions are legacy and will be removed in a future release""" + return getattr(self.request, "rbac_permissions", []) diff --git a/metrics/interfaces/plots/access.py b/metrics/interfaces/plots/access.py index 6e4eac34d..e8bcb0406 100644 --- a/metrics/interfaces/plots/access.py +++ b/metrics/interfaces/plots/access.py @@ -161,8 +161,12 @@ def get_queryset_from_core_model_manager( plot_params["fields_to_export"].append("upper_confidence") plot_params["fields_to_export"].append("lower_confidence") + logger.info("Entered access.py") + return self.core_model_manager.query_for_data( - **plot_params, rbac_permissions=self.chart_request_params.rbac_permissions + **plot_params, + rbac_permissions=self.chart_request_params.rbac_permissions, # old permissions (remove) + permission_sets=self.chart_request_params.permission_sets, # new permissions ) def build_plot_data_from_parameters_with_complete_queryset(