From 1f62a1d3b3758815daa69c7de183dd618652ae5c Mon Sep 17 00:00:00 2001 From: SWhyteAnswer Date: Wed, 25 Mar 2026 15:21:09 +0000 Subject: [PATCH] [PRMP-1465] Implement user restriction creation functionality and associated tests --- .../base-lambdas-reusable-deploy-all.yml | 14 + lambdas/enums/lambda_error.py | 53 +- lambdas/enums/logging_app_interaction.py | 1 + .../create_user_restriction_handler.py | 125 +++++ lambdas/models/user_restrictions/__init__.py | 0 lambdas/services/authoriser_service.py | 6 + .../create_user_restriction_service.py | 74 +++ .../search_user_restriction_service.py | 2 +- .../user_restriction_dynamo_service.py | 64 ++- lambdas/tests/unit/conftest.py | 8 +- lambdas/tests/unit/handlers/conftest.py | 75 +++ .../test_create_user_restriction_handler.py | 495 ++++++++++++++++++ .../test_user_restrictions.py | 68 +-- .../test_create_user_restriction_service.py | 161 ++++++ .../test_search_user_restriction_service.py | 4 +- .../test_user_restriction_dynamo_service.py | 80 ++- lambdas/tests/unit/utils/test_ods_utils.py | 49 +- lambdas/utils/exceptions.py | 12 +- lambdas/utils/lambda_exceptions.py | 4 + 19 files changed, 1208 insertions(+), 87 deletions(-) create mode 100644 lambdas/handlers/user_restrictions/create_user_restriction_handler.py create mode 100644 lambdas/models/user_restrictions/__init__.py create mode 100644 lambdas/services/user_restrictions/create_user_restriction_service.py create mode 100644 lambdas/tests/unit/handlers/test_create_user_restriction_handler.py create mode 100644 lambdas/tests/unit/services/test_create_user_restriction_service.py diff --git a/.github/workflows/base-lambdas-reusable-deploy-all.yml b/.github/workflows/base-lambdas-reusable-deploy-all.yml index dc55354310..4d0c2c81d1 100644 --- a/.github/workflows/base-lambdas-reusable-deploy-all.yml +++ b/.github/workflows/base-lambdas-reusable-deploy-all.yml @@ -940,3 +940,17 @@ jobs: secrets: AWS_ASSUME_ROLE: ${{ secrets.AWS_ASSUME_ROLE }} + deploy_create_user_restriction_lambda: + name: Deploy Create User Restriction Lambda + uses: ./.github/workflows/base-lambdas-reusable-deploy.yml + with: + environment: ${{ inputs.environment }} + python_version: ${{ inputs.python_version }} + build_branch: ${{ inputs.build_branch }} + sandbox: ${{ inputs.sandbox }} + lambda_handler_path: user_restrictions + lambda_handler_name: create_user_restriction_handler + lambda_aws_name: CreateUserRestriction + lambda_layer_names: "core_lambda_layer" + secrets: + AWS_ASSUME_ROLE: ${{ secrets.AWS_ASSUME_ROLE }} \ No newline at end of file diff --git a/lambdas/enums/lambda_error.py b/lambdas/enums/lambda_error.py index 5a6e4f2633..33d311e54b 100644 --- a/lambdas/enums/lambda_error.py +++ b/lambdas/enums/lambda_error.py @@ -14,6 +14,7 @@ class ErrorMessage(StrEnum): FAILED_TO_VALIDATE = "Failed to validate data" FAILED_TO_UPDATE_DYNAMO = "Failed to update DynamoDB" FAILED_TO_CREATE_TRANSACTION = "Failed to create transaction" + INVALID_REQUEST_BODY = "Invalid request body" class LambdaError(Enum): @@ -503,7 +504,7 @@ def create_error_body( } UpdateUploadStateInvalidBody = { "err_code": "US_4005", - "message": "Invalid request body", + "message": ErrorMessage.INVALID_REQUEST_BODY, } UpdateUploadStateFieldType = { "err_code": "US_4006", @@ -565,7 +566,7 @@ def create_error_body( DocumentReviewInvalidBody = { "err_code": "DRV_4004", - "message": "Invalid request body", + "message": ErrorMessage.INVALID_REQUEST_BODY, } DocumentReviewInvalidNhsNumber = { @@ -755,6 +756,54 @@ def create_error_body( "message": "Failed to parse SQS event", } + """ + Errors for UserRestriction lambda + """ + CreateRestrictionMissingBody = { + "err_code": "UR_4001", + "message": "Missing request body", + } + CreateRestrictionMissingFields = { + "err_code": "UR_4002", + "message": "Missing required fields", + } + CreateRestrictionPatientIdMismatch = { + "err_code": "UR_4003", + "message": "patientId does not match nhs_number", + } + CreateRestrictionMissingContext = { + "err_code": "UR_4004", + "message": "Missing user context", + } + CreateRestrictionInvalidWorker = { + "err_code": "UR_4005", + "message": "Invalid Worker", + } + CreateRestrictionPractitionerModelError = { + "err_code": "UR_4006", + "message": "Unable to process restricted user information", + } + CreateRestrictionSelfRestriction = { + "err_code": "UR_4007", + "message": "You cannot create a restriction for yourself", + } + CreateRestrictionAlreadyExists = { + "err_code": "UR_4009", + "message": "A restriction already exists for this user and patient", + } + CreateRestrictionInvalidBody = { + "err_code": "UR_4008", + "message": ErrorMessage.INVALID_REQUEST_BODY, + } + CreateRestrictionPatientOdsMismatch = { + "err_code": "UR_4010", + "message": "Patient's general practice ODS does not match request context ODS", + } + CreateRestrictionPatientNotFound = { + "err_code": "UR_4011", + "message": "Patient not found in PDS", + } + MockError = { "message": "Client error", "err_code": "AB_XXXX", diff --git a/lambdas/enums/logging_app_interaction.py b/lambdas/enums/logging_app_interaction.py index cbba2f313b..f23387eb97 100644 --- a/lambdas/enums/logging_app_interaction.py +++ b/lambdas/enums/logging_app_interaction.py @@ -35,3 +35,4 @@ class LoggingAppInteraction(Enum): MANIFEST_JOB = "Manifest job" RESTRICTION_SOFT_DELETE = "User restrictions - soft-delete" SEARCH_HISTORY = "Search document reference history" + USER_RESTRICTION = "User restriction" diff --git a/lambdas/handlers/user_restrictions/create_user_restriction_handler.py b/lambdas/handlers/user_restrictions/create_user_restriction_handler.py new file mode 100644 index 0000000000..844df18ae5 --- /dev/null +++ b/lambdas/handlers/user_restrictions/create_user_restriction_handler.py @@ -0,0 +1,125 @@ +import json + +from enums.feature_flags import FeatureFlags +from enums.lambda_error import LambdaError +from enums.logging_app_interaction import LoggingAppInteraction +from services.feature_flags_service import FeatureFlagService +from services.user_restrictions.create_user_restriction_service import ( + CreateUserRestrictionService, +) +from utils.audit_logging_setup import LoggingService +from utils.decorators.ensure_env_var import ensure_environment_variables +from utils.decorators.handle_lambda_exceptions import handle_lambda_exceptions +from utils.decorators.override_error_check import override_error_check +from utils.decorators.set_audit_arg import set_request_context_for_logging +from utils.decorators.validate_patient_id import ( + extract_nhs_number_from_event, + validate_patient_id, +) +from utils.exceptions import ( + HealthcareWorkerAPIException, + HealthcareWorkerPractitionerModelException, + OdsErrorException, + UserRestrictionAlreadyExistsException, +) +from utils.lambda_exceptions import LambdaException +from utils.lambda_response import ApiGatewayResponse +from utils.ods_utils import extract_creator_and_ods_code_from_request_context +from utils.request_context import request_context + +logger = LoggingService(__name__) + + +def parse_body(body: str | None) -> tuple[str, str]: + if not body: + logger.error("Missing request body") + raise LambdaException( + 400, + LambdaError.CreateRestrictionMissingBody, + ) + + payload = json.loads(body) + + restricted_smartcard_id = payload.get("smartcardId") + nhs_number = payload.get("nhsNumber") + if not restricted_smartcard_id or not nhs_number: + logger.error("Missing required fields") + raise LambdaException( + 400, + LambdaError.CreateRestrictionMissingFields, + ) + + return restricted_smartcard_id, nhs_number + + +@set_request_context_for_logging +@override_error_check +@ensure_environment_variables( + names=[ + "RESTRICTIONS_TABLE_NAME", + "HEALTHCARE_WORKER_API_URL", + ], +) +@handle_lambda_exceptions +@validate_patient_id +def lambda_handler(event, context): + request_context.app_interaction = LoggingAppInteraction.USER_RESTRICTION.value + + feature_flag_service = FeatureFlagService() + feature_flag_service.validate_feature_flag( + FeatureFlags.USER_RESTRICTION_ENABLED, + ) + logger.info("Starting create user restriction process") + + restricted_smartcard_id, nhs_number = parse_body(event.get("body")) + request_context.patient_nhs_no = nhs_number + + patient_id = extract_nhs_number_from_event(event) + if patient_id != nhs_number: + logger.error("patientId query param does not match nhs_number in request body") + raise LambdaException( + 400, + LambdaError.PatientIdMismatch, + ) + + try: + creator, ods_code = extract_creator_and_ods_code_from_request_context() + except OdsErrorException: + logger.error("Missing user context") + raise LambdaException( + 400, + LambdaError.CreateRestrictionMissingContext, + ) + + service = CreateUserRestrictionService() + try: + restriction_id = service.create_restriction( + restricted_smartcard_id=restricted_smartcard_id, + nhs_number=nhs_number, + custodian=ods_code, + creator=creator, + ) + except UserRestrictionAlreadyExistsException as exc: + logger.error(exc) + raise LambdaException( + 409, + LambdaError.CreateRestrictionAlreadyExists, + ) + except HealthcareWorkerAPIException as exc: + logger.error(exc) + raise LambdaException( + 400, + LambdaError.CreateRestrictionInvalidWorker, + ) + except HealthcareWorkerPractitionerModelException as exc: + logger.error(exc) + raise LambdaException( + 400, + LambdaError.CreateRestrictionPractitionerModelError, + ) + + return ApiGatewayResponse( + 201, + json.dumps({"id": restriction_id}), + "POST", + ).create_api_gateway_response() diff --git a/lambdas/models/user_restrictions/__init__.py b/lambdas/models/user_restrictions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lambdas/services/authoriser_service.py b/lambdas/services/authoriser_service.py index 4e2c731325..e21e36dfa9 100644 --- a/lambdas/services/authoriser_service.py +++ b/lambdas/services/authoriser_service.py @@ -148,6 +148,12 @@ def deny_access_policy(self, path, http_verb, user_role, nhs_number: str = None) case "/DocumentReview": deny_resource = False + case "/UserRestriction": + if http_verb == HttpVerb.POST: + deny_resource = not patient_access_is_allowed + else: + deny_resource = False + case "/UploadState": deny_resource = ( not patient_access_is_allowed or is_user_gp_clinical or is_user_pcse diff --git a/lambdas/services/user_restrictions/create_user_restriction_service.py b/lambdas/services/user_restrictions/create_user_restriction_service.py new file mode 100644 index 0000000000..e1f75b8f32 --- /dev/null +++ b/lambdas/services/user_restrictions/create_user_restriction_service.py @@ -0,0 +1,74 @@ +from enums.lambda_error import LambdaError +from models.user_restrictions.user_restrictions import UserRestriction +from services.user_restrictions.user_restriction_dynamo_service import ( + UserRestrictionDynamoService, +) +from services.user_restrictions.utilities import get_healthcare_worker_api_service +from utils.audit_logging_setup import LoggingService +from utils.exceptions import ( + UserRestrictionAlreadyExistsException, +) +from utils.lambda_exceptions import LambdaException +from utils.utilities import get_pds_service + +logger = LoggingService(__name__) + + +class CreateUserRestrictionService: + def __init__(self): + self.dynamo_service = UserRestrictionDynamoService() + self.healthcare_service = get_healthcare_worker_api_service() + self.pds_service = get_pds_service() + + def create_restriction( + self, + restricted_smartcard_id: str, + nhs_number: str, + custodian: str, + creator: str, + ) -> str: + if restricted_smartcard_id == creator: + logger.error("You cannot create a restriction for yourself") + raise LambdaException( + 400, + LambdaError.CreateRestrictionSelfRestriction, + ) + + patient = self.pds_service.fetch_patient_details(nhs_number) + if not patient: + logger.error("Patient not found in PDS") + raise LambdaException( + 404, + LambdaError.SearchPatientNoPDS, + ) + if patient.general_practice_ods != custodian: + logger.error( + "Patient's general practice ODS does not match request context ODS", + ) + raise LambdaException( + 403, + LambdaError.SearchPatientNoAuth, + ) + + existing = self.dynamo_service.get_active_restriction( + nhs_number=nhs_number, + restricted_user=restricted_smartcard_id, + ) + if existing: + raise UserRestrictionAlreadyExistsException( + "A restriction already exists for this user and patient", + ) + + self.healthcare_service.get_practitioner(restricted_smartcard_id) + + restriction = UserRestriction( + restricted_user=restricted_smartcard_id, + nhs_number=nhs_number, + custodian=custodian, + creator=creator, + ) + + self.dynamo_service.create_restriction_item(restriction) + + logger.info("Created user restriction") + return restriction.id diff --git a/lambdas/services/user_restrictions/search_user_restriction_service.py b/lambdas/services/user_restrictions/search_user_restriction_service.py index ba53f7dd2c..a9f718171d 100644 --- a/lambdas/services/user_restrictions/search_user_restriction_service.py +++ b/lambdas/services/user_restrictions/search_user_restriction_service.py @@ -43,7 +43,7 @@ def process_request( logger.info(f"Querying user restrictions for ODS code {ods_code}") restrictions, next_token = self.dynamo_service.query_restrictions( ods_code=ods_code, - smart_card_id=smartcard_id, + smartcard_id=smartcard_id, nhs_number=nhs_number, limit=limit, start_key=next_page_token, diff --git a/lambdas/services/user_restrictions/user_restriction_dynamo_service.py b/lambdas/services/user_restrictions/user_restriction_dynamo_service.py index 47ebf7a8ba..ed49971dc0 100644 --- a/lambdas/services/user_restrictions/user_restriction_dynamo_service.py +++ b/lambdas/services/user_restrictions/user_restriction_dynamo_service.py @@ -2,6 +2,7 @@ from datetime import datetime, timezone from enum import StrEnum +from boto3.dynamodb.conditions import Attr from botocore.exceptions import ClientError from pydantic import ValidationError @@ -36,10 +37,35 @@ def __init__(self): self.dynamo_service = DynamoDBService() self.table_name = os.environ["RESTRICTIONS_TABLE_NAME"] + def create_restriction_item(self, restriction: UserRestriction) -> None: + self.dynamo_service.create_item( + table_name=self.table_name, + item=restriction.model_dump(by_alias=True, exclude_none=True), + key_name=UserRestrictionsFields.ID.value, + ) + + def get_active_restriction( + self, + nhs_number: str, + restricted_user: str, + ) -> dict | None: + query_filter = Attr(UserRestrictionsFields.RESTRICTED_USER).eq( + restricted_user, + ) & Attr(UserRestrictionsFields.IS_ACTIVE).eq(True) + + results = self.dynamo_service.query_table( + table_name=self.table_name, + index_name=UserRestrictionIndexes.NHS_NUMBER_INDEX, + search_key=UserRestrictionsFields.NHS_NUMBER, + search_condition=nhs_number, + query_filter=query_filter, + ) + return results[0] if results else None + def query_restrictions( self, ods_code: str, - smart_card_id: str | None = None, + smartcard_id: str | None = None, nhs_number: str | None = None, limit: int = DEFAULT_LIMIT, start_key: str | None = None, @@ -48,22 +74,28 @@ def query_restrictions( filter_expression, expression_attribute_names, expression_attribute_values = ( self._build_query_filter( - smart_card_id=smart_card_id, + smartcard_id=smartcard_id, nhs_number=nhs_number, ) ) - response = self.dynamo_service.query_table_with_paginator( - table_name=self.table_name, - index_name=UserRestrictionIndexes.CUSTODIAN_INDEX, - key=UserRestrictionsFields.CUSTODIAN, - condition=ods_code, - filter_expression=filter_expression, - expression_attribute_names=expression_attribute_names, - expression_attribute_values=expression_attribute_values, - limit=limit, - start_key=start_key, - ) + try: + response = self.dynamo_service.query_table_with_paginator( + table_name=self.table_name, + index_name=UserRestrictionIndexes.CUSTODIAN_INDEX, + key=UserRestrictionsFields.CUSTODIAN, + condition=ods_code, + filter_expression=filter_expression, + expression_attribute_names=expression_attribute_names, + expression_attribute_values=expression_attribute_values, + limit=limit, + start_key=start_key, + ) + except ClientError as e: + logger.error(f"DynamoDB ClientError when querying restrictions: {e}") + raise UserRestrictionValidationException( + f"Failed to query user restrictions from DynamoDB: {e}", + ) from e items = response.get("Items", []) restrictions = self._validate_restrictions(items) @@ -115,7 +147,7 @@ def update_restriction_inactive( @staticmethod def _build_query_filter( - smart_card_id: str | None, + smartcard_id: str | None, nhs_number: str | None, ) -> tuple[str, dict, dict]: conditions = [ @@ -125,12 +157,12 @@ def _build_query_filter( "value": True, }, ] - if smart_card_id: + if smartcard_id: conditions.append( { "field": UserRestrictionsFields.RESTRICTED_USER, "operator": ConditionOperator.EQUAL.value, - "value": smart_card_id, + "value": smartcard_id, }, ) if nhs_number: diff --git a/lambdas/tests/unit/conftest.py b/lambdas/tests/unit/conftest.py index eff76cdd91..83b3922e68 100644 --- a/lambdas/tests/unit/conftest.py +++ b/lambdas/tests/unit/conftest.py @@ -84,15 +84,17 @@ MOCK_STATISTICS_REPORT_BUCKET_NAME = "test_statistics_report_bucket" REVIEW_SQS_QUEUE_URL = "test_review_queue" TEST_NHS_NUMBER = "9000000009" -TEST_UUID = "550e8400-e29b-41d4-a716-446655440000" -TEST_SMART_CARD_ID = "123456789120" -TEST_NEXT_PAGE_TOKEN = "some-next-token" +TEST_UUID = "12345678-1234-5678-1234-567812345678" +TEST_SMART_CARD_ID = "test-smartcard-id-9012" +TEST_NEXT_PAGE_TOKEN = "test-next-page-token" TEST_FILE_KEY = "test_file_key" TEST_FILE_NAME = "test.pdf" TEST_FILE_SIZE = 24000 TEST_VIRUS_SCANNER_RESULT = "not_scanned" TEST_DOCUMENT_LOCATION = f"s3://{MOCK_BUCKET}/{TEST_FILE_KEY}" TEST_CURRENT_GP_ODS = "Y12345" +MOCK_SMART_CARD_ID = "smartcard-uuid-1234" +MOCK_CREATOR_ID = "creator-uuid-5678" AUTH_STATE_TABLE_NAME = "test_state_table" AUTH_SESSION_TABLE_NAME = "test_session_table" diff --git a/lambdas/tests/unit/handlers/conftest.py b/lambdas/tests/unit/handlers/conftest.py index 196584c56e..97eea9e107 100755 --- a/lambdas/tests/unit/handlers/conftest.py +++ b/lambdas/tests/unit/handlers/conftest.py @@ -1,8 +1,18 @@ +import json + import pytest + from enums.feature_flags import FeatureFlags from enums.report_distribution_action import ReportDistributionAction +from models.pds_models import PatientDetails from repositories.reporting.reporting_dynamo_repository import ReportingDynamoRepository from services.feature_flags_service import FeatureFlagService +from tests.unit.conftest import ( + MOCK_CREATOR_ID, + MOCK_SMART_CARD_ID, + TEST_CURRENT_GP_ODS, + TEST_NHS_NUMBER, +) @pytest.fixture @@ -321,3 +331,68 @@ def mock_report_orchestration_wiring(mocker): "mock_window": mock_window, "mock_report_date": mock_report_date, } + + +@pytest.fixture +def mock_user_restriction_enabled(mocker): + mock_function = mocker.patch.object(FeatureFlagService, "get_feature_flags_by_flag") + mock_feature_flag = mock_function.return_value = { + FeatureFlags.USER_RESTRICTION_ENABLED: True, + FeatureFlags.USE_SMARTCARD_AUTH: False, + } + yield mock_feature_flag + + +@pytest.fixture +def mock_user_restriction_disabled(mocker): + mock_function = mocker.patch.object(FeatureFlagService, "get_feature_flags_by_flag") + mock_feature_flag = mock_function.return_value = { + FeatureFlags.USER_RESTRICTION_ENABLED: False, + } + yield mock_feature_flag + + +@pytest.fixture +def valid_create_restriction_event(): + yield { + "httpMethod": "POST", + "headers": {"Authorization": "test_token"}, + "queryStringParameters": {"patientId": TEST_NHS_NUMBER}, + "body": json.dumps( + {"smartcardId": MOCK_SMART_CARD_ID, "nhsNumber": TEST_NHS_NUMBER}, + ), + } + + +@pytest.fixture +def mock_request_context(mocker): + mock_context = mocker.patch("utils.ods_utils.request_context") + mock_context.authorization = { + "nhs_user_id": MOCK_CREATOR_ID, + "selected_organisation": {"org_ods_code": TEST_CURRENT_GP_ODS}, + } + yield mock_context + + +@pytest.fixture +def mock_pds_service_with_matching_ods(mocker): + """ + Mock PDS service to return patient with ODS code matching TEST_CURRENT_GP_ODS. + Use this fixture in tests that need ODS validation to pass. + """ + mock_patient_details = PatientDetails( + nhsNumber=TEST_NHS_NUMBER, + givenName=["Jane"], + familyName="Smith", + birthDate="2010-10-22", + postalCode="LS1 6AE", + superseded=False, + restricted=False, + generalPracticeOds=TEST_CURRENT_GP_ODS, # Y12345 - matches request context + active=True, + ) + mock_service = mocker.patch( + "services.user_restrictions.create_user_restriction_service.get_pds_service", + ) + mock_service.return_value.fetch_patient_details.return_value = mock_patient_details + yield mock_service diff --git a/lambdas/tests/unit/handlers/test_create_user_restriction_handler.py b/lambdas/tests/unit/handlers/test_create_user_restriction_handler.py new file mode 100644 index 0000000000..c31f8b57ed --- /dev/null +++ b/lambdas/tests/unit/handlers/test_create_user_restriction_handler.py @@ -0,0 +1,495 @@ +import json + +import pytest + +from enums.lambda_error import LambdaError +from lambdas.handlers.user_restrictions.create_user_restriction_handler import ( + lambda_handler, + parse_body, +) +from tests.unit.conftest import ( + MOCK_CREATOR_ID, + MOCK_INTERACTION_ID, + MOCK_SMART_CARD_ID, + TEST_CURRENT_GP_ODS, + TEST_NHS_NUMBER, + TEST_UUID, +) +from utils.exceptions import ( + HealthcareWorkerAPIException, + HealthcareWorkerPractitionerModelException, + UserRestrictionAlreadyExistsException, +) +from utils.lambda_exceptions import LambdaException +from utils.lambda_response import ApiGatewayResponse + + +@pytest.fixture +def mock_service(set_env, mocker): + mock = mocker.patch( + "lambdas.handlers.user_restrictions.create_user_restriction_handler.CreateUserRestrictionService", + ) + yield mock.return_value + + +def test_lambda_handler_returns_201_on_success( + valid_create_restriction_event, + context, + mock_service, + mock_request_context, + mock_pds_service_with_matching_ods, + mock_user_restriction_enabled, +): + mock_service.create_restriction.return_value = TEST_UUID + + expected = ApiGatewayResponse( + 201, + json.dumps({"id": TEST_UUID}), + "POST", + ).create_api_gateway_response() + + actual = lambda_handler(valid_create_restriction_event, context) + + assert actual == expected + + +def test_lambda_handler_calls_service_with_correct_args( + valid_create_restriction_event, + context, + mock_service, + mock_request_context, + mock_pds_service_with_matching_ods, + mock_user_restriction_enabled, +): + mock_service.create_restriction.return_value = TEST_UUID + + lambda_handler(valid_create_restriction_event, context) + + mock_service.create_restriction.assert_called_once_with( + restricted_smartcard_id=MOCK_SMART_CARD_ID, + nhs_number=TEST_NHS_NUMBER, + custodian=TEST_CURRENT_GP_ODS, + creator=MOCK_CREATOR_ID, + ) + + +def test_lambda_handler_returns_400_when_body_missing( + context, + set_env, + mock_request_context, + mock_user_restriction_enabled, +): + event = { + "httpMethod": "POST", + "headers": {}, + "queryStringParameters": {"patientId": TEST_NHS_NUMBER}, + } + + body = { + "message": LambdaError.CreateRestrictionMissingBody.value["message"], + "err_code": LambdaError.CreateRestrictionMissingBody.value["err_code"], + "interaction_id": MOCK_INTERACTION_ID, + } + expected = ApiGatewayResponse( + status_code=400, + body=json.dumps(body), + methods="POST", + ).create_api_gateway_response() + + actual = lambda_handler(event, context) + + assert actual == expected + + +def test_lambda_handler_returns_400_when_smart_card_id_missing( + context, + set_env, + mock_request_context, + mock_user_restriction_enabled, +): + event = { + "httpMethod": "POST", + "headers": {}, + "queryStringParameters": {"patientId": TEST_NHS_NUMBER}, + "body": json.dumps({"nhsNumber": TEST_NHS_NUMBER}), + } + + body = { + "message": LambdaError.CreateRestrictionMissingFields.value["message"], + "err_code": LambdaError.CreateRestrictionMissingFields.value["err_code"], + "interaction_id": MOCK_INTERACTION_ID, + } + expected = ApiGatewayResponse( + status_code=400, + body=json.dumps(body), + methods="POST", + ).create_api_gateway_response() + + actual = lambda_handler(event, context) + + assert actual == expected + + +def test_lambda_handler_returns_400_when_nhs_number_missing( + context, + set_env, + mock_request_context, + mock_user_restriction_enabled, +): + event = { + "httpMethod": "POST", + "headers": {}, + "queryStringParameters": {"patientId": TEST_NHS_NUMBER}, + "body": json.dumps({"smartcardId": MOCK_SMART_CARD_ID}), + } + + body = { + "message": LambdaError.CreateRestrictionMissingFields.value["message"], + "err_code": LambdaError.CreateRestrictionMissingFields.value["err_code"], + "interaction_id": MOCK_INTERACTION_ID, + } + expected = ApiGatewayResponse( + status_code=400, + body=json.dumps(body), + methods="POST", + ).create_api_gateway_response() + + actual = lambda_handler(event, context) + + assert actual == expected + + +def test_lambda_handler_returns_400_when_creator_missing( + valid_create_restriction_event, + context, + set_env, + mocker, + mock_user_restriction_enabled, +): + mock_ctx = mocker.patch("utils.ods_utils.request_context") + mock_ctx.authorization = { + "selected_organisation": {"org_ods_code": TEST_CURRENT_GP_ODS}, + } + + body = { + "message": LambdaError.CreateRestrictionMissingContext.value["message"], + "err_code": LambdaError.CreateRestrictionMissingContext.value["err_code"], + "interaction_id": MOCK_INTERACTION_ID, + } + expected = ApiGatewayResponse( + status_code=400, + body=json.dumps(body), + methods="POST", + ).create_api_gateway_response() + + actual = lambda_handler(valid_create_restriction_event, context) + + assert actual == expected + + +def test_lambda_handler_returns_400_when_ods_code_missing( + valid_create_restriction_event, + context, + set_env, + mocker, + mock_user_restriction_enabled, +): + mock_ctx = mocker.patch("utils.ods_utils.request_context") + mock_ctx.authorization = {"nhs_user_id": MOCK_CREATOR_ID} + + body = { + "message": LambdaError.CreateRestrictionMissingContext.value["message"], + "err_code": LambdaError.CreateRestrictionMissingContext.value["err_code"], + "interaction_id": MOCK_INTERACTION_ID, + } + expected = ApiGatewayResponse( + status_code=400, + body=json.dumps(body), + methods="POST", + ).create_api_gateway_response() + + actual = lambda_handler(valid_create_restriction_event, context) + + assert actual == expected + + +def test_lambda_handler_returns_409_when_restriction_already_exists( + valid_create_restriction_event, + context, + mock_service, + mock_request_context, + mock_pds_service_with_matching_ods, + mock_user_restriction_enabled, +): + mock_service.create_restriction.side_effect = UserRestrictionAlreadyExistsException( + "A restriction already exists for this user and patient", + ) + + body = { + "message": LambdaError.CreateRestrictionAlreadyExists.value["message"], + "err_code": LambdaError.CreateRestrictionAlreadyExists.value["err_code"], + "interaction_id": MOCK_INTERACTION_ID, + } + expected = ApiGatewayResponse( + status_code=409, + body=json.dumps(body), + methods="POST", + ).create_api_gateway_response() + + actual = lambda_handler(valid_create_restriction_event, context) + + assert actual == expected + + +def test_lambda_handler_returns_400_on_value_error( + valid_create_restriction_event, + context, + set_env, + mocker, + mock_user_restriction_enabled, +): + # Mock request context where creator's ID equals the restricted smartcard ID + mock_ctx = mocker.patch("utils.ods_utils.request_context") + mock_ctx.authorization = { + "nhs_user_id": MOCK_SMART_CARD_ID, # Same as restricted_smartcard_id + "selected_organisation": {"org_ods_code": TEST_CURRENT_GP_ODS}, + } + + body = { + "message": LambdaError.CreateRestrictionSelfRestriction.value["message"], + "err_code": LambdaError.CreateRestrictionSelfRestriction.value["err_code"], + "interaction_id": MOCK_INTERACTION_ID, + } + expected = ApiGatewayResponse( + status_code=400, + body=json.dumps(body), + methods="POST", + ).create_api_gateway_response() + + actual = lambda_handler(valid_create_restriction_event, context) + + assert actual == expected + + +def test_lambda_handler_returns_correct_status_on_healthcare_worker_api_exception( + valid_create_restriction_event, + context, + mock_service, + mock_request_context, + mock_pds_service_with_matching_ods, + mock_user_restriction_enabled, +): + mock_service.create_restriction.side_effect = HealthcareWorkerAPIException( + status_code=404, + ) + + body = { + "message": LambdaError.CreateRestrictionInvalidWorker.value["message"], + "err_code": LambdaError.CreateRestrictionInvalidWorker.value["err_code"], + "interaction_id": MOCK_INTERACTION_ID, + } + expected = ApiGatewayResponse( + status_code=400, + body=json.dumps(body), + methods="POST", + ).create_api_gateway_response() + + actual = lambda_handler(valid_create_restriction_event, context) + + assert actual == expected + + +def test_lambda_handler_returns_400_when_patient_id_does_not_match_nhs_number( + context, + set_env, + mock_request_context, + mock_user_restriction_enabled, +): + event = { + "httpMethod": "POST", + "headers": {"Authorization": "test_token"}, + "queryStringParameters": {"patientId": "9000000017"}, + "body": json.dumps( + {"smartcardId": MOCK_SMART_CARD_ID, "nhsNumber": TEST_NHS_NUMBER}, + ), + } + + body = { + "message": LambdaError.PatientIdMismatch.value["message"], + "err_code": LambdaError.PatientIdMismatch.value["err_code"], + "interaction_id": MOCK_INTERACTION_ID, + } + expected = ApiGatewayResponse( + status_code=400, + body=json.dumps(body), + methods="POST", + ).create_api_gateway_response() + + actual = lambda_handler(event, context) + + assert actual == expected + + +def test_lambda_handler_returns_400_on_practitioner_model_exception( + valid_create_restriction_event, + context, + mock_service, + mock_request_context, + mock_pds_service_with_matching_ods, + mock_user_restriction_enabled, +): + mock_service.create_restriction.side_effect = ( + HealthcareWorkerPractitionerModelException() + ) + + body = { + "message": LambdaError.CreateRestrictionPractitionerModelError.value["message"], + "err_code": LambdaError.CreateRestrictionPractitionerModelError.value[ + "err_code" + ], + "interaction_id": MOCK_INTERACTION_ID, + } + expected = ApiGatewayResponse( + status_code=400, + body=json.dumps(body), + methods="POST", + ).create_api_gateway_response() + + actual = lambda_handler(valid_create_restriction_event, context) + + assert actual == expected + + +def test_lambda_handler_returns_404_feature_flag_disabled( + valid_create_restriction_event, + context, + mock_user_restriction_disabled, + set_env, +): + body = { + "message": LambdaError.FeatureFlagDisabled.value["message"], + "err_code": LambdaError.FeatureFlagDisabled.value["err_code"], + "interaction_id": MOCK_INTERACTION_ID, + } + + expected = ApiGatewayResponse( + status_code=404, + body=json.dumps(body), + methods="POST", + ).create_api_gateway_response() + + actual = lambda_handler(valid_create_restriction_event, context) + assert actual == expected + + +def test_lambda_handler_returns_404_when_patient_not_found_in_pds( + valid_create_restriction_event, + context, + mock_request_context, + mock_user_restriction_enabled, + mocker, + set_env, +): + mock_pds = mocker.patch( + "services.user_restrictions.create_user_restriction_service.get_pds_service", + ) + mock_pds.return_value.fetch_patient_details.return_value = None + + body = { + "message": LambdaError.SearchPatientNoPDS.value["message"], + "err_code": LambdaError.SearchPatientNoPDS.value["err_code"], + "interaction_id": MOCK_INTERACTION_ID, + } + + expected = ApiGatewayResponse( + status_code=404, + body=json.dumps(body), + methods="POST", + ).create_api_gateway_response() + + actual = lambda_handler(valid_create_restriction_event, context) + assert actual == expected + + +def test_lambda_handler_returns_403_when_patient_ods_does_not_match_requester_ods( + valid_create_restriction_event, + context, + mock_request_context, + mock_user_restriction_enabled, + mocker, + set_env, +): + from models.pds_models import PatientDetails + + # Mock patient with different ODS code + mismatched_patient = PatientDetails( + nhsNumber=TEST_NHS_NUMBER, + givenName=["Jane"], + familyName="Smith", + birthDate="2010-10-22", + postalCode="LS1 6AE", + superseded=False, + restricted=False, + generalPracticeOds="X9999", # Different ODS than TEST_CURRENT_GP_ODS (Y12345) + active=True, + ) + + mock_pds = mocker.patch( + "services.user_restrictions.create_user_restriction_service.get_pds_service", + ) + mock_pds.return_value.fetch_patient_details.return_value = mismatched_patient + + body = { + "message": LambdaError.SearchPatientNoAuth.value["message"], + "err_code": LambdaError.SearchPatientNoAuth.value["err_code"], + "interaction_id": MOCK_INTERACTION_ID, + } + + expected = ApiGatewayResponse( + status_code=403, + body=json.dumps(body), + methods="POST", + ).create_api_gateway_response() + + actual = lambda_handler(valid_create_restriction_event, context) + assert actual == expected + + +# --- parse_body unit tests --- + + +def test_parse_body_returns_fields_on_valid_input(): + body = json.dumps( + {"smartcardId": MOCK_SMART_CARD_ID, "nhsNumber": TEST_NHS_NUMBER}, + ) + + result = parse_body(body) + + assert result == (MOCK_SMART_CARD_ID, TEST_NHS_NUMBER) + + +def test_parse_body_raises_when_body_is_none(): + with pytest.raises(LambdaException) as exc_info: + parse_body(None) + assert ( + exc_info.value.err_code + == LambdaError.CreateRestrictionMissingBody.value["err_code"] + ) + + +def test_parse_body_raises_when_smart_card_id_missing(): + with pytest.raises(LambdaException) as exc_info: + parse_body(json.dumps({"nhsNumber": TEST_NHS_NUMBER})) + assert ( + exc_info.value.err_code + == LambdaError.CreateRestrictionMissingFields.value["err_code"] + ) + + +def test_parse_body_raises_when_nhs_number_missing(): + with pytest.raises(LambdaException) as exc_info: + parse_body(json.dumps({"smartcardId": MOCK_SMART_CARD_ID})) + assert ( + exc_info.value.err_code + == LambdaError.CreateRestrictionMissingFields.value["err_code"] + ) diff --git a/lambdas/tests/unit/models/user_restrictions/test_user_restrictions.py b/lambdas/tests/unit/models/user_restrictions/test_user_restrictions.py index e9da414d84..291c4fed60 100644 --- a/lambdas/tests/unit/models/user_restrictions/test_user_restrictions.py +++ b/lambdas/tests/unit/models/user_restrictions/test_user_restrictions.py @@ -1,34 +1,34 @@ -from datetime import datetime, timezone - -import freezegun - -from models.user_restrictions.user_restrictions import UserRestriction -from tests.e2e.mns.mns_helper import TEST_NHS_NUMBER -from tests.unit.conftest import TEST_CURRENT_GP_ODS, TEST_UUID - - -@freezegun.freeze_time("2024-01-01T12:00:00Z") -def test_model_dump_camel_case(mock_uuid): - restriction = UserRestriction( - restricted_user="123456789012", - nhs_number=TEST_NHS_NUMBER, - custodian=TEST_CURRENT_GP_ODS, - creator="223456789022", - ) - - created_timestamp = int(datetime.now(timezone.utc).timestamp()) - - expected = { - "id": TEST_UUID, - "nhsNumber": TEST_NHS_NUMBER, - "custodian": TEST_CURRENT_GP_ODS, - "creator": "223456789022", - "restrictedUser": "123456789012", - "created": created_timestamp, - "isActive": True, - "lastUpdated": created_timestamp, - "removedBy": None, - } - - actual = restriction.model_dump_camel_case() - assert actual == expected +from datetime import datetime, timezone + +import freezegun + +from models.user_restrictions.user_restrictions import UserRestriction +from tests.e2e.mns.mns_helper import TEST_NHS_NUMBER +from tests.unit.conftest import TEST_CURRENT_GP_ODS, TEST_UUID + + +@freezegun.freeze_time("2024-01-01T12:00:00Z") +def test_model_dump_camel_case(mock_uuid): + restriction = UserRestriction( + restricted_user="123456789012", + nhs_number=TEST_NHS_NUMBER, + custodian=TEST_CURRENT_GP_ODS, + creator="223456789022", + ) + + created_timestamp = int(datetime.now(timezone.utc).timestamp()) + + expected = { + "id": TEST_UUID, + "nhsNumber": TEST_NHS_NUMBER, + "custodian": TEST_CURRENT_GP_ODS, + "creator": "223456789022", + "restrictedUser": "123456789012", + "created": created_timestamp, + "isActive": True, + "lastUpdated": created_timestamp, + "removedBy": None, + } + + actual = restriction.model_dump_camel_case() + assert actual == expected diff --git a/lambdas/tests/unit/services/test_create_user_restriction_service.py b/lambdas/tests/unit/services/test_create_user_restriction_service.py new file mode 100644 index 0000000000..9ff32f5620 --- /dev/null +++ b/lambdas/tests/unit/services/test_create_user_restriction_service.py @@ -0,0 +1,161 @@ +from unittest.mock import MagicMock + +import pytest +from freezegun import freeze_time + +from models.pds_models import PatientDetails +from models.user_restrictions.practitioner import Practitioner +from services.user_restrictions.create_user_restriction_service import ( + CreateUserRestrictionService, +) +from tests.unit.conftest import ( + MOCK_CREATOR_ID, + MOCK_SMART_CARD_ID, + TEST_CURRENT_GP_ODS, + TEST_NHS_NUMBER, +) +from utils.exceptions import ( + HealthcareWorkerAPIException, + HealthcareWorkerPractitionerModelException, + UserRestrictionAlreadyExistsException, +) + +MOCK_PRACTITIONER = Practitioner( + smartcard_id=MOCK_SMART_CARD_ID, + first_name="Jane", + last_name="Doe", +) + +MOCK_PATIENT = PatientDetails( + nhsNumber=TEST_NHS_NUMBER, + givenName=["John"], + familyName="Doe", + birthDate="1990-01-01", + postalCode="LS1 6AE", + superseded=False, + restricted=False, + generalPracticeOds=TEST_CURRENT_GP_ODS, + active=True, +) + + +@pytest.fixture +def mock_service(set_env, mocker): + mocker.patch( + "services.user_restrictions.create_user_restriction_service.UserRestrictionDynamoService", + ) + mocker.patch( + "services.user_restrictions.create_user_restriction_service.get_healthcare_worker_api_service", + ) + mocker.patch( + "services.user_restrictions.create_user_restriction_service.get_pds_service", + ) + + service = CreateUserRestrictionService() + service.dynamo_service = MagicMock() + service.dynamo_service.get_active_restriction.return_value = None + service.healthcare_service = MagicMock() + service.pds_service = MagicMock() + service.pds_service.fetch_patient_details.return_value = MOCK_PATIENT + + yield service + + +@freeze_time("2024-01-01 12:00:00") +def test_create_restriction_happy_path(mock_service): + mock_service.healthcare_service.get_practitioner.return_value = MOCK_PRACTITIONER + + result = mock_service.create_restriction( + restricted_smartcard_id=MOCK_SMART_CARD_ID, + nhs_number=TEST_NHS_NUMBER, + custodian=TEST_CURRENT_GP_ODS, + creator=MOCK_CREATOR_ID, + ) + + assert isinstance(result, str) + assert len(result) > 0 + + +@freeze_time("2024-01-01 12:00:00") +def test_create_restriction_calls_get_practitioner(mock_service): + mock_service.healthcare_service.get_practitioner.return_value = MOCK_PRACTITIONER + + mock_service.create_restriction( + restricted_smartcard_id=MOCK_SMART_CARD_ID, + nhs_number=TEST_NHS_NUMBER, + custodian=TEST_CURRENT_GP_ODS, + creator=MOCK_CREATOR_ID, + ) + + mock_service.healthcare_service.get_practitioner.assert_called_once_with( + MOCK_SMART_CARD_ID, + ) + + +@freeze_time("2024-01-01 12:00:00") +def test_create_restriction_writes_to_dynamo(mock_service): + mock_service.healthcare_service.get_practitioner.return_value = MOCK_PRACTITIONER + + mock_service.create_restriction( + restricted_smartcard_id=MOCK_SMART_CARD_ID, + nhs_number=TEST_NHS_NUMBER, + custodian=TEST_CURRENT_GP_ODS, + creator=MOCK_CREATOR_ID, + ) + + mock_service.dynamo_service.create_restriction_item.assert_called_once() + restriction = mock_service.dynamo_service.create_restriction_item.call_args.args[0] + assert restriction.restricted_user == MOCK_SMART_CARD_ID + assert restriction.nhs_number == TEST_NHS_NUMBER + + +def test_create_restriction_raises_when_restriction_already_exists(mock_service): + mock_service.dynamo_service.get_active_restriction.return_value = { + "ID": "existing-id", + } + + with pytest.raises( + UserRestrictionAlreadyExistsException, + match="A restriction already exists for this user and patient", + ): + mock_service.create_restriction( + restricted_smartcard_id=MOCK_SMART_CARD_ID, + nhs_number=TEST_NHS_NUMBER, + custodian=TEST_CURRENT_GP_ODS, + creator=MOCK_CREATOR_ID, + ) + + mock_service.healthcare_service.get_practitioner.assert_not_called() + mock_service.dynamo_service.create_restriction_item.assert_not_called() + + +def test_create_restriction_propagates_healthcare_worker_api_exception(mock_service): + mock_service.healthcare_service.get_practitioner.side_effect = ( + HealthcareWorkerAPIException(status_code=404) + ) + + with pytest.raises(HealthcareWorkerAPIException): + mock_service.create_restriction( + restricted_smartcard_id=MOCK_SMART_CARD_ID, + nhs_number=TEST_NHS_NUMBER, + custodian=TEST_CURRENT_GP_ODS, + creator=MOCK_CREATOR_ID, + ) + + mock_service.dynamo_service.create_restriction_item.assert_not_called() + + +def test_create_restriction_propagates_practitioner_model_exception(mock_service): + mock_service.healthcare_service.get_practitioner.side_effect = ( + HealthcareWorkerPractitionerModelException() + ) + + with pytest.raises(HealthcareWorkerPractitionerModelException): + mock_service.create_restriction( + restricted_smartcard_id=MOCK_SMART_CARD_ID, + nhs_number=TEST_NHS_NUMBER, + custodian=TEST_CURRENT_GP_ODS, + creator=MOCK_CREATOR_ID, + ) + + mock_service.dynamo_service.create_restriction_item.assert_not_called() diff --git a/lambdas/tests/unit/services/user_restriction/test_search_user_restriction_service.py b/lambdas/tests/unit/services/user_restriction/test_search_user_restriction_service.py index 8d94e35eda..8a72909258 100644 --- a/lambdas/tests/unit/services/user_restriction/test_search_user_restriction_service.py +++ b/lambdas/tests/unit/services/user_restriction/test_search_user_restriction_service.py @@ -73,7 +73,7 @@ def test_process_request_calls_query_restrictions_and_enriches(mock_service, moc mock_query.assert_called_once_with( ods_code=TEST_CURRENT_GP_ODS, - smart_card_id=None, + smartcard_id=None, nhs_number=None, limit=DEFAULT_LIMIT, start_key=None, @@ -97,7 +97,7 @@ def test_process_request_passes_next_page_token_as_start_key(mock_service, mocke mock_query.assert_called_once_with( ods_code=TEST_CURRENT_GP_ODS, - smart_card_id=None, + smartcard_id=None, nhs_number=None, limit=DEFAULT_LIMIT, start_key=TEST_NEXT_PAGE_TOKEN, diff --git a/lambdas/tests/unit/services/user_restriction/test_user_restriction_dynamo_service.py b/lambdas/tests/unit/services/user_restriction/test_user_restriction_dynamo_service.py index 8fb8499887..4d69a1a867 100644 --- a/lambdas/tests/unit/services/user_restriction/test_user_restriction_dynamo_service.py +++ b/lambdas/tests/unit/services/user_restriction/test_user_restriction_dynamo_service.py @@ -11,17 +11,20 @@ from services.user_restrictions.user_restriction_dynamo_service import ( UserRestrictionDynamoService, ) -from tests.unit.conftest import TEST_CURRENT_GP_ODS, TEST_NHS_NUMBER, TEST_UUID +from tests.unit.conftest import ( + TEST_CURRENT_GP_ODS, + TEST_NEXT_PAGE_TOKEN, + TEST_NHS_NUMBER, + TEST_SMART_CARD_ID, + TEST_UUID, +) from tests.unit.services.user_restriction.conftest import MOCK_IDENTIFIER from utils.exceptions import ( UserRestrictionConditionCheckFailedException, UserRestrictionValidationException, ) -TEST_ODS_CODE = "Y12345" -TEST_SMART_CARD_ID = "SC001" MOCK_USER_RESTRICTION_TABLE = "test_user_restriction_table" -TEST_NEXT_TOKEN = "some-opaque-next-token" MOCK_TIME_STAMP = 1704110400 @@ -36,6 +39,18 @@ "LastUpdated": 1704067200, } +MOCK_RESTRICTION = { + "ID": TEST_UUID, + "RestrictedSmartcard": TEST_SMART_CARD_ID, + "NhsNumber": TEST_NHS_NUMBER, + "Custodian": TEST_CURRENT_GP_ODS, + "Created": 1700000000, + "CreatorSmartcard": "SC002", + "RemoverSmartCard": None, + "IsActive": True, + "LastUpdated": 1700000001, +} + MOCK_DYNAMO_RESPONSE_WITH_ITEM = {"Items": [MOCK_RESTRICTION_ITEM]} MOCK_DYNAMO_RESPONSE_EMPTY = {"Items": []} @@ -57,89 +72,96 @@ def mock_dynamo_service(mock_service): def test_query_restrictions_calls_paginator_with_correct_key_and_index(mock_service): - mock_service.query_restrictions(ods_code=TEST_ODS_CODE) + mock_service.query_restrictions(ods_code=TEST_CURRENT_GP_ODS) call_kwargs = ( mock_service.dynamo_service.query_table_with_paginator.call_args.kwargs ) assert call_kwargs["key"] == UserRestrictionsFields.CUSTODIAN - assert call_kwargs["condition"] == TEST_ODS_CODE + assert call_kwargs["condition"] == TEST_CURRENT_GP_ODS assert call_kwargs["index_name"] == UserRestrictionIndexes.CUSTODIAN_INDEX def test_query_restrictions_by_ods_code_uses_active_filter(mock_service): - mock_service.query_restrictions(ods_code=TEST_ODS_CODE) + mock_service.query_restrictions(ods_code=TEST_CURRENT_GP_ODS) call_kwargs = ( mock_service.dynamo_service.query_table_with_paginator.call_args.kwargs ) - assert "IsActive" in call_kwargs["filter_expression"] + assert UserRestrictionsFields.IS_ACTIVE in call_kwargs["filter_expression"] assert ( UserRestrictionsFields.RESTRICTED_USER not in call_kwargs["filter_expression"] ) assert UserRestrictionsFields.NHS_NUMBER not in call_kwargs["filter_expression"] -def test_query_restrictions_by_smart_card_id_applies_smartcard_filter(mock_service): +def test_query_restrictions_by_smartcard_id_applies_smartcard_filter(mock_service): mock_service.query_restrictions( - ods_code=TEST_ODS_CODE, - smart_card_id=TEST_SMART_CARD_ID, + ods_code=TEST_CURRENT_GP_ODS, + smartcard_id=TEST_SMART_CARD_ID, ) call_kwargs = ( mock_service.dynamo_service.query_table_with_paginator.call_args.kwargs ) - assert "IsActive" in call_kwargs["filter_expression"] + assert UserRestrictionsFields.IS_ACTIVE in call_kwargs["filter_expression"] assert UserRestrictionsFields.RESTRICTED_USER in call_kwargs["filter_expression"] assert ( - call_kwargs["expression_attribute_values"][":RestrictedSmartcard_condition_val"] + call_kwargs["expression_attribute_values"][ + f":{UserRestrictionsFields.RESTRICTED_USER}_condition_val" + ] == TEST_SMART_CARD_ID ) def test_query_restrictions_by_nhs_number_applies_nhs_number_filter(mock_service): - mock_service.query_restrictions(ods_code=TEST_ODS_CODE, nhs_number=TEST_NHS_NUMBER) + mock_service.query_restrictions( + ods_code=TEST_CURRENT_GP_ODS, + nhs_number=TEST_NHS_NUMBER, + ) call_kwargs = ( mock_service.dynamo_service.query_table_with_paginator.call_args.kwargs ) - assert "IsActive" in call_kwargs["filter_expression"] + assert UserRestrictionsFields.IS_ACTIVE in call_kwargs["filter_expression"] assert UserRestrictionsFields.NHS_NUMBER in call_kwargs["filter_expression"] assert ( - call_kwargs["expression_attribute_values"][":NhsNumber_condition_val"] + call_kwargs["expression_attribute_values"][ + f":{UserRestrictionsFields.NHS_NUMBER}_condition_val" + ] == TEST_NHS_NUMBER ) def test_query_restrictions_passes_limit_and_start_key(mock_service): mock_service.query_restrictions( - ods_code=TEST_ODS_CODE, + ods_code=TEST_CURRENT_GP_ODS, limit=5, - start_key=TEST_NEXT_TOKEN, + start_key=TEST_NEXT_PAGE_TOKEN, ) call_kwargs = ( mock_service.dynamo_service.query_table_with_paginator.call_args.kwargs ) assert call_kwargs["limit"] == 5 - assert call_kwargs["start_key"] == TEST_NEXT_TOKEN + assert call_kwargs["start_key"] == TEST_NEXT_PAGE_TOKEN def test_query_restrictions_returns_next_token(mock_service): mock_service.dynamo_service.query_table_with_paginator.return_value = { - "Items": [MOCK_RESTRICTION_ITEM], - "NextToken": TEST_NEXT_TOKEN, + "Items": [MOCK_RESTRICTION], + "NextToken": TEST_NEXT_PAGE_TOKEN, } - _, next_token = mock_service.query_restrictions(ods_code=TEST_ODS_CODE) + _, next_token = mock_service.query_restrictions(ods_code=TEST_CURRENT_GP_ODS) - assert next_token == TEST_NEXT_TOKEN + assert next_token == TEST_NEXT_PAGE_TOKEN def test_query_restrictions_returns_empty_list_when_no_items(mock_service): mock_service.dynamo_service.query_table_with_paginator.return_value = {"Items": []} - results, next_token = mock_service.query_restrictions(ods_code=TEST_ODS_CODE) + results, next_token = mock_service.query_restrictions(ods_code=TEST_CURRENT_GP_ODS) assert results == [] assert next_token is None @@ -150,6 +172,16 @@ def test_validate_restrictions_raises_for_invalid_items(): UserRestrictionDynamoService._validate_restrictions([{"invalid": "data"}]) +def test_query_restrictions_raises_validation_exception_on_client_error(mock_service): + mock_service.dynamo_service.query_table_with_paginator.side_effect = ClientError( + {"Error": {"Code": "500", "Message": "DynamoDB error"}}, + "query", + ) + + with pytest.raises(UserRestrictionValidationException): + mock_service.query_restrictions(ods_code=TEST_CURRENT_GP_ODS) + + @freeze_time("2024-01-01 12:00:00") def test_soft_delete_user_restriction(mock_service): mock_service.update_restriction_inactive( diff --git a/lambdas/tests/unit/utils/test_ods_utils.py b/lambdas/tests/unit/utils/test_ods_utils.py index cc55e2b1db..2469aae769 100644 --- a/lambdas/tests/unit/utils/test_ods_utils.py +++ b/lambdas/tests/unit/utils/test_ods_utils.py @@ -1,9 +1,10 @@ import pytest from enums.patient_ods_inactive_status import PatientOdsInactiveStatus -from tests.unit.conftest import TEST_CURRENT_GP_ODS +from tests.unit.conftest import MOCK_CREATOR_ID, TEST_CURRENT_GP_ODS from utils.exceptions import OdsErrorException from utils.ods_utils import ( + extract_creator_and_ods_code_from_request_context, extract_ods_code_from_request_context, extract_ods_role_code_with_r_prefix_from_role_codes_string, is_ods_code_active, @@ -20,6 +21,16 @@ def mocked_request_context_with_ods(mocker): yield mocker.patch("utils.ods_utils.request_context", mocked_context) +@pytest.fixture() +def mocked_request_context_with_creator_and_ods(mocker): + mocked_context = mocker.MagicMock() + mocked_context.authorization = { + "nhs_user_id": MOCK_CREATOR_ID, + "selected_organisation": {"org_ods_code": TEST_CURRENT_GP_ODS}, + } + yield mocker.patch("utils.ods_utils.request_context", mocked_context) + + @pytest.mark.parametrize( "ods_code,expected", [ @@ -81,3 +92,39 @@ def test_is_valid_ods_code(value, expected): actual = is_valid_ods_code(value) assert actual == expected + + +def test_extract_creator_and_ods_code_returns_both( + mocked_request_context_with_creator_and_ods, +): + creator, ods_code = extract_creator_and_ods_code_from_request_context() + + assert creator == MOCK_CREATOR_ID + assert ods_code == TEST_CURRENT_GP_ODS + + +def test_extract_creator_and_ods_code_raises_when_creator_missing(mocker): + mocked_context = mocker.MagicMock() + mocked_context.authorization = { + "selected_organisation": {"org_ods_code": TEST_CURRENT_GP_ODS}, + } + mocker.patch("utils.ods_utils.request_context", mocked_context) + + with pytest.raises(OdsErrorException): + extract_creator_and_ods_code_from_request_context() + + +def test_extract_creator_and_ods_code_raises_when_ods_code_missing(mocker): + mocked_context = mocker.MagicMock() + mocked_context.authorization = {"nhs_user_id": MOCK_CREATOR_ID} + mocker.patch("utils.ods_utils.request_context", mocked_context) + + with pytest.raises(OdsErrorException): + extract_creator_and_ods_code_from_request_context() + + +def test_extract_creator_and_ods_code_raises_when_no_auth(mocker): + mocker.patch("utils.ods_utils.request_context", {}) + + with pytest.raises(OdsErrorException): + extract_creator_and_ods_code_from_request_context() diff --git a/lambdas/utils/exceptions.py b/lambdas/utils/exceptions.py index 160228ef8f..5792c46bfd 100644 --- a/lambdas/utils/exceptions.py +++ b/lambdas/utils/exceptions.py @@ -58,10 +58,6 @@ class DocumentServiceException(Exception): pass -class UserRestrictionValidationException(Exception): - pass - - class DocumentReviewException(Exception): pass @@ -231,6 +227,14 @@ class HealthcareWorkerPractitionerModelException(Exception): pass +class UserRestrictionAlreadyExistsException(Exception): + pass + + +class UserRestrictionValidationException(Exception): + pass + + class MigrationUnrecoverableException(Exception): def __init__(self, message: str, item_id: str): super().__init__(message) diff --git a/lambdas/utils/lambda_exceptions.py b/lambdas/utils/lambda_exceptions.py index 23fd8b6eee..25908468da 100644 --- a/lambdas/utils/lambda_exceptions.py +++ b/lambdas/utils/lambda_exceptions.py @@ -144,3 +144,7 @@ class UserRestrictionsException(LambdaException): class SearchUserRestrictionException(LambdaException): pass + + +class CreateUserRestrictionLambdaException(LambdaException): + pass