diff --git a/common/auth/cognito_jwt/backend.py b/common/auth/cognito_jwt/backend.py index 1fed42fec..b27462767 100644 --- a/common/auth/cognito_jwt/backend.py +++ b/common/auth/cognito_jwt/backend.py @@ -37,7 +37,11 @@ class JSONWebTokenAuthentication(BaseAuthentication): """ def authenticate(self, request): - """Entrypoint for Django Rest Framework""" + """ + The JWT token has arrived at an API endpoint and the journey starts here. + Entrypoint for the Django Rest Framework. + """ + jwt_token = self.get_jwt_token(request) if jwt_token is None: return None @@ -46,7 +50,8 @@ def authenticate(self, request): try: token_validator = self.get_token_validator(request) jwt_payload = token_validator.validate(jwt_token) - except TokenError: + except TokenError as error: + logger.warning("JWT validation failed: %s", error) raise exceptions.AuthenticationFailed from None custom_user_manager = self.get_custom_user_manager() diff --git a/common/auth/cognito_jwt/user_manager.py b/common/auth/cognito_jwt/user_manager.py index 9d8c522c8..0f2c0bb43 100644 --- a/common/auth/cognito_jwt/user_manager.py +++ b/common/auth/cognito_jwt/user_manager.py @@ -1,8 +1,14 @@ import logging +from typing import TYPE_CHECKING from django.contrib.auth import get_user_model from django.contrib.auth.models import BaseUserManager +from metrics.utils.permission_hierarchy import convert_permission_set_into_hierarchy + +if TYPE_CHECKING: # just for IDE checks + from rest_framework.request import Request + logger = logging.getLogger(__name__) @@ -14,9 +20,23 @@ def get_or_create_for_cognito(jwt_payload): We don't need to store or retrieve any info, we use what's in the JWT, so this speeds up the request by removing the need for any DB access """ + try: username = jwt_payload["entraObjectId"] - permission_sets = jwt_payload["permissionSets"] + raw_permission_sets = jwt_payload["permissionSets"] + + # Manual testing (just for now) + # username = "678a605b-16f3-4342-9f02-db74613701ac" + # raw_permission_sets = { + # "permission_sets": [ + # { + # "theme": {"id": "100", "name": "immunisation"}, + # "sub_theme": {"id": "200", "name": "childhood-vaccines"}, + # "topic": {"id": "-1", "name": "* (All)"}, + # } + # ], + # "summary": {"has_global_access": False}, + # } except KeyError: logger.debug( "Error getting entraObjectId and/or permissionSets field(s)" @@ -25,7 +45,50 @@ def get_or_create_for_cognito(jwt_payload): ) return None + permission_sets = convert_permission_set_into_hierarchy(raw_permission_sets) + permission_count = len(permission_sets.get("permission_set_hierarchy", [])) + has_global_access = bool(permission_sets.get("has_global_access", False)) + + logger.info( + "JWT token for user '%s' with permissions: permission_count=%d, has_global_access=%s", + username, permission_count, has_global_access, + ) + user_class = get_user_model() user = user_class(username=username) user.permission_sets = permission_sets + return user + + +def extract_jwt_permissions(*, request: "Request | None") -> dict: + """ + Extract the normalized JWT permissions dict from an authenticated request. + + Reads `request.user.permission_sets`, which is set by CognitoManager + during JWT authentication. Lives here because it is the counterpart to + the code above that "writes" user.permission_sets. + + @param {Request | None} request, eg: + + + @return {dict}, eg: + { + "permission_set_hierarchy": [ + {"theme": {"id": "100", "name": "immunisation"}, ...} + ], + "has_global_access": False + } + """ + if request is None: + return {} + + user = getattr(request, "user", None) + if user is None: + return {} + + permission_sets = getattr(user, "permission_sets", {}) + if not permission_sets: + return {} + + return permission_sets diff --git a/metrics/api/views/charts/single_category_charts.py b/metrics/api/views/charts/single_category_charts.py index 2767c3982..43fbe8623 100644 --- a/metrics/api/views/charts/single_category_charts.py +++ b/metrics/api/views/charts/single_category_charts.py @@ -4,11 +4,13 @@ from django.http import FileResponse from drf_spectacular.utils import OpenApiExample, extend_schema from rest_framework import permissions +from rest_framework.authentication import SessionAuthentication from rest_framework.response import Response from rest_framework.views import APIView import config from caching.private_api.decorators import cache_response +from common.auth.cognito_jwt import JSONWebTokenAuthentication from metrics.api.decorators.auth import require_authorisation from metrics.api.enums import AppMode from metrics.api.serializers import ChartsSerializer @@ -218,6 +220,7 @@ def post(cls, request, *args, **kwargs): class EncodedChartsView(APIView): + authentication_classes = [SessionAuthentication, JSONWebTokenAuthentication] permission_classes = [] @classmethod diff --git a/metrics/data/managers/core_models/headline.py b/metrics/data/managers/core_models/headline.py index f33bf913c..14ecd01c5 100644 --- a/metrics/data/managers/core_models/headline.py +++ b/metrics/data/managers/core_models/headline.py @@ -15,6 +15,9 @@ from metrics.api.permissions.fluent_permissions import ( validate_permissions_for_non_public, ) +from metrics.utils.permissions import ( + check_any_permissions_allow_access, +) class CoreHeadlineQuerySet(models.QuerySet): @@ -334,6 +337,7 @@ def query_for_data( theme: str = "", sub_theme: str = "", rbac_permissions: Iterable["RBACPermission"] | None = None, + jwt_permissions: dict | None = None, **kwargs, ): """Filters for a N-item list of dicts by the given params if `fields_to_export` is used. @@ -373,6 +377,10 @@ def query_for_data( rbac_permissions: The RBAC permissions available to the given request. This dictates whether the given request is permitted access to non-public data or not. + The new JWT-based authorization below takes precedence + over RBAC permissions, which is not in use anymore. + jwt_permissions: The JWT permissions extracted from the Cognito token. + Contains 'permission_set_hierarchy' (list) and 'has_global_access' (bool). Returns: Queryset of (x_axis, y_axis) where x_axis represents the variable on the x_axis @@ -382,16 +390,38 @@ def query_for_data( Examples: """ + rbac_permissions = rbac_permissions or [] - has_access_to_non_public_data: bool = validate_permissions_for_non_public( - theme=theme, - sub_theme=sub_theme, - topic=topic, - metric=metric, - geography=geography, - geography_type=geography_type, - rbac_permissions=rbac_permissions, - ) + + has_access_to_non_public_data: bool + + if jwt_permissions: + # Check JWT permissions first (new authorization takes precedence) + has_global_access = jwt_permissions.get("has_global_access", False) + + if has_global_access: + has_access_to_non_public_data = True + else: + has_access_to_non_public_data = check_any_permissions_allow_access( + jwt_permissions=jwt_permissions, + theme=theme, + sub_theme=sub_theme, + topic=topic, + metric=metric, + geography_type=geography_type, + geography=geography, + ) + else: + # Legacy RBAC permissions (not in use) (to be removed in a future release) + has_access_to_non_public_data = validate_permissions_for_non_public( + theme=theme, + sub_theme=sub_theme, + topic=topic, + metric=metric, + geography=geography, + geography_type=geography_type, + rbac_permissions=rbac_permissions, + ) if has_access_to_non_public_data: queryset = self.get_queryset().get_all_headlines_released_from_embargo( diff --git a/metrics/data/managers/core_models/time_series.py b/metrics/data/managers/core_models/time_series.py index aed8df7f6..2578429f6 100644 --- a/metrics/data/managers/core_models/time_series.py +++ b/metrics/data/managers/core_models/time_series.py @@ -18,6 +18,9 @@ validate_permissions_for_non_public, ) from metrics.data.models import RBACPermission +from metrics.utils.permissions import ( + check_any_permissions_allow_access, +) ALLOWABLE_METRIC_VALUE_RANGE_TYPE = tuple[str | float | int, str | float | int] @@ -533,6 +536,7 @@ def query_for_data( sub_theme: str = "", metric_value_ranges: list[str | float | int] | None = None, rbac_permissions: Iterable[RBACPermission] | None = None, + jwt_permissions: dict | None = None, ) -> CoreTimeSeriesQuerySet: """Filters for a 2-item object by the given params. Slices all values older than the `date_from`. @@ -579,11 +583,14 @@ def query_for_data( i.e. to filter for all record with values between 0 -> 80 AND 90 -> 100, this can be provided as `[(0, 80), (90, 100)]`. - rbac_permissions: The RBAC permissions available - to the given request. This dictates whether the given - request is permitted access to non-public data or not. - - Notes: + rbac_permissions: The RBAC permissions available + to the given request. This dictates whether the given + request is permitted access to non-public data or not. + jwt_permissions: JWT permissions dict extracted from Cognito token. + Contains 'has_global_access' (bool) and 'permission_set_hierarchy' (list). + Used for new JWT-based authorization (takes precedence over RBAC permissions). + + Notes: If we have the following input `queryset`: ---------------------------------------- | 2023-01-01 | 2023-01-02 | 2023-01-03 | @@ -611,16 +618,38 @@ def query_for_data( ]>` """ + rbac_permissions: Iterable[RBACPermission] = rbac_permissions or [] - has_access_to_non_public_data: bool = validate_permissions_for_non_public( - theme=theme, - sub_theme=sub_theme, - topic=topic, - metric=metric, - geography_type=geography_type, - geography=geography, - rbac_permissions=rbac_permissions, - ) + + has_access_to_non_public_data: bool + + if jwt_permissions: + # Check JWT permissions first (new authorization takes precedence) + has_global_access = jwt_permissions.get("has_global_access", False) + + if has_global_access: + has_access_to_non_public_data = True + else: + has_access_to_non_public_data = check_any_permissions_allow_access( + jwt_permissions=jwt_permissions, + theme=theme, + sub_theme=sub_theme, + topic=topic, + metric=metric, + geography_type=geography_type, + geography=geography, + ) + else: + # Legacy RBAC permissions (not in use) (to be removed in a future release) + has_access_to_non_public_data = validate_permissions_for_non_public( + theme=theme, + sub_theme=sub_theme, + topic=topic, + metric=metric, + geography_type=geography_type, + geography=geography, + rbac_permissions=rbac_permissions, + ) return self.get_queryset().query_for_data( fields_to_export=fields_to_export, diff --git a/metrics/interfaces/plots/access.py b/metrics/interfaces/plots/access.py index 6e4eac34d..35ddfaf93 100644 --- a/metrics/interfaces/plots/access.py +++ b/metrics/interfaces/plots/access.py @@ -7,6 +7,7 @@ from django.db.models import Manager, QuerySet from pydantic import BaseModel +from common.auth.cognito_jwt.user_manager import extract_jwt_permissions from metrics.api.settings import auth from metrics.data.models.core_models import CoreTimeSeries, Topic from metrics.domain.common.utils import ChartAxisFields @@ -162,7 +163,9 @@ def get_queryset_from_core_model_manager( plot_params["fields_to_export"].append("lower_confidence") return self.core_model_manager.query_for_data( - **plot_params, rbac_permissions=self.chart_request_params.rbac_permissions + **plot_params, + rbac_permissions=self.chart_request_params.rbac_permissions, # legacy permissions to be removed + jwt_permissions=extract_jwt_permissions(request=self.chart_request_params.request), # new permissions ) def build_plot_data_from_parameters_with_complete_queryset( diff --git a/metrics/utils/permission_hierarchy.py b/metrics/utils/permission_hierarchy.py index c5a182204..cefed1e78 100644 --- a/metrics/utils/permission_hierarchy.py +++ b/metrics/utils/permission_hierarchy.py @@ -21,6 +21,54 @@ ) +def convert_permission_set_into_hierarchy(raw_permission_sets: dict) -> dict: + """ + Convert a "permission_set" back into a "permission_set_hierarchy" again + (the NormalizedPermission class does it the other way round) + + @param {dict} raw_permission_sets, eg: + { + "permission_sets": [ + { + "theme": {"id": "100", "name": "immunisation"}, + "sub_theme": {"id": "200", "name": "childhood-vaccines"}, + "topic": {"id": "215", "name": "MMR1"}, + } + ], + "summary": { + "has_global_access": False + }, + } + + @return {dict}, eg: + { + "permission_set_hierarchy": [ + { + "theme": {"id": "100", "name": "immunisation"}, + "sub_theme": {"id": "200", "name": "childhood-vaccines"}, + "topic": {"id": "215", "name": "MMR1"}, + } + ], + "has_global_access": False, + } + """ + + permission_set_hierarchy = raw_permission_sets.get("permission_set_hierarchy") + if permission_set_hierarchy is None: + permission_set_hierarchy = raw_permission_sets.get("permission_sets", []) + + has_global_access = raw_permission_sets.get("has_global_access") + if has_global_access is None: + has_global_access = raw_permission_sets.get("summary", {}).get( + "has_global_access", False + ) + + return { + "permission_set_hierarchy": permission_set_hierarchy, + "has_global_access": bool(has_global_access), + } + + @dataclass class NormalizedPermission: """ diff --git a/metrics/utils/permissions.py b/metrics/utils/permissions.py new file mode 100644 index 000000000..6c5329dd0 --- /dev/null +++ b/metrics/utils/permissions.py @@ -0,0 +1,292 @@ +""" +Non-public permission validation, filtering and matching functions +""" + +from typing import Literal, NotRequired, TypedDict + +from metrics.data.models.core_models.supporting import ( + Geography, + GeographyType, + Metric, + Topic, +) + +# Our permission resources, e.g. anything that you can filter by permission +PermissionFilterResource = Literal[ + "theme", + "sub_theme", + "topic", + "metric", + "geography_type", + "geography", +] + + +class PermissionSetLevel(TypedDict): + """ + The "id" is the actual permission. + The "name" is just some blurb describing it. + """ + + id: str # can be "-1" and therefore a string + name: str + + +class PermissionHierarchy(TypedDict): + """ + Our permission resources, e.g. anything that you can filter by permission. + Each one of them is optional, but at least one of them has to be provided. + """ + + theme: NotRequired[PermissionSetLevel] + sub_theme: NotRequired[PermissionSetLevel] + topic: NotRequired[PermissionSetLevel] + metric: NotRequired[PermissionSetLevel] + geography_type: NotRequired[PermissionSetLevel] + geography: NotRequired[PermissionSetLevel] + + +def check_any_permissions_allow_access( + *, + jwt_permissions: dict, + theme: str = "", + sub_theme: str = "", + topic: str, + metric: str, + geography_type: str, + geography: str, +) -> bool: + """ + This is our CORE PERMISSION-CHECKING function. + + Resolve request names to IDs and return whether any + of the passed API filters are being satisfied through + the user's permissions defined in the jwt_permissions. + + Any of the 6 of the passed API filters are optional. If + none of them is passed, return false, cause at this point + we already know user has_global_access=False + + @param {dict} jwt_permissions, eg: + { + "permission_set_hierarchy": [ + { + "theme": {"id": "100", "name": "immunisation"}, + "sub_theme": {"id": "200", "name": "childhood-vaccines"}, + "topic": {"id": "-1", "name": "* (All)"}, + } + ] + } + + @param {str} theme, eg: + "immunisation" + + @param {str} sub_theme, eg: + "childhood-vaccines" + + @param {str} topic, eg: + "MMR1" + + @return {bool} + """ + + topic_record = None + if topic: + if theme and sub_theme: + topic_record = Topic.objects.filter( + name=topic, + sub_theme__name=sub_theme, + sub_theme__theme__name=theme, + ).first() + if topic_record is None: + topic_record = Topic.objects.filter(name=topic).first() + + metric_record = None + if metric: + metric_record = Metric.objects.filter(name=metric).first() + + geography_type_record = None + if geography_type: + geography_type_record = GeographyType.objects.filter( + name=geography_type + ).first() + + geography_record = None + if geography: + geography_queryset = Geography.objects.filter(name=geography) + if geography_type_record: + geography_queryset = geography_queryset.filter( + geography_type_id=geography_type_record.id + ) + geography_record = geography_queryset.first() + + requested_filters: dict[PermissionFilterResource, str | None] = { + "theme": str(topic_record.sub_theme.theme_id) if topic_record else None, + "sub_theme": str(topic_record.sub_theme_id) if topic_record else None, + "topic": str(topic_record.id) if topic_record else None, + "metric": str(metric_record.id) if metric_record else None, + "geography_type": ( + str(geography_type_record.id) if geography_type_record else None + ), + "geography": ( + str(geography_record.geography_code) + if geography_record and geography_record.geography_code + else None + ), + } + + # Don't pass on any rubbish + normalized_requested_filters: dict[PermissionFilterResource, str] = { + key: str(value) + for key, value in requested_filters.items() + if _is_permission_value_valid(value) + } + + matching_permissions = filter_permissions( + jwt_permissions=jwt_permissions, + requested_filters=normalized_requested_filters, + ) + + return bool(matching_permissions) + + +def filter_permissions( + *, + jwt_permissions: dict, + requested_filters: dict[PermissionFilterResource, str], +) -> list[dict[str, str]]: + """ + Filter permissions that match the requested IDs. + Returns a list of matching permissions containing IDs (without wildcards "-1"). + + @param {dict} jwt_permissions, eg: + { + "permission_set_hierarchy": [ + { + "theme": {"id": "100", "name": "immunisation"}, + "sub_theme": {"id": "200", "name": "childhood-vaccines"}, + "topic": {"id": "-1", "name": "* (All)"}, + } + ] + } + + @param {dict} requested_filters, eg: + {"theme": "100", "sub_theme": "200", "topic": "300"} + + @return {list}, eg: + [{"theme": "100", "sub_theme": "200"}] + """ + + # Gotta have both of them to match permissions, cause at + # this point we already know user has_global_access=False + permission_set_hierarchy = jwt_permissions.get( + "permission_set_hierarchy" + ) + if not isinstance(permission_set_hierarchy, list): + return [] + if not requested_filters: + return [] + + matching_permissions: list[dict[str, str]] = [] + + for permission in permission_set_hierarchy: + if not isinstance(permission, dict) or not permission: + continue + + concrete_filters: dict[str, str] = {} + + # All the requested filters must be present in the permission row or "-1" + for filter_key, requested_value in requested_filters.items(): + permission_value = _extract_permission_id( + permission=permission, + filter_key=filter_key, + ) + + if permission_value == requested_value: + # Exact permission match grants access + concrete_filters[filter_key] = requested_value + continue + elif permission_value == "-1": + # Wildcard also grants access + concrete_filters[filter_key] = requested_value + continue + + # This permission row did not satisfy all the requested filters + break + else: + # Only runs if loop didn't break -> full match + matching_permissions.append(concrete_filters) + + return matching_permissions + + +def _extract_permission_id( + *, + permission: dict[str, object], + filter_key: PermissionFilterResource +) -> str | None: + """ + Extract a permission id for one filter key from a permission entry. + + @param {dict} permission, eg: + { + "theme": {"id": "100", "name": "immunisation"}, + "sub_theme": {"id": "200", "name": "childhood-vaccines"}, + "topic": {"id": "-1", "name": "* (All)"}, + } + + @param {str} filter_key, eg: + "sub_theme" + + @return {str | None}, eg: + "200" + """ + + # Do we offer permissions on this resource? + if filter_key not in permission: + return None + + # Permission resource present? + permission_resource = permission.get(filter_key) + if not permission_resource: + return None + + # Permission resource malformed? + if not isinstance(permission_resource, dict): + return None + + # Permission value meaningful? + permission_id = permission_resource.get("id") + if not _is_permission_value_valid(permission_id): + return None + + # If all is fine up to here, return the ID + return str(permission_id) + + +def _is_permission_value_valid(value: object) -> bool: + """ + Is value a valid permission ID? + + A valid permission value is an integer-like value except 0. + Integers that exist as string types are fine, eg "5" instead of 5. + The wildcard "-1" remains also valid. + """ + + if value is None: + return False + + if isinstance(value, bool): + return False + + value_as_text = str(value).strip() + + if not value_as_text: + return False + + try: + numeric_value = int(value_as_text) + except (TypeError, ValueError): + return False + + return numeric_value != 0