diff --git a/src/eligibility_signposting_api/config/constants.py b/src/eligibility_signposting_api/config/constants.py index e9a1f93c..784031ab 100644 --- a/src/eligibility_signposting_api/config/constants.py +++ b/src/eligibility_signposting_api/config/constants.py @@ -6,3 +6,11 @@ CONSUMER_ID = "NHSE-Product-ID" ALLOWED_CONDITIONS = Literal["COVID", "FLU", "MMR", "RSV"] CONSUMER_MAPPING_FILE_NAME = "consumer_mapping_config.json" +RESERVED_TEST_CONSUMER_IDS = {"test-consumer-1", "test-consumer-2", "test-consumer-3"} + +TTL = { + "test": 300, + "dev": 300, + "preprod": 300, + "prod": 300, +} diff --git a/src/eligibility_signposting_api/repos/campaign_repo.py b/src/eligibility_signposting_api/repos/campaign_repo.py index 42fc2efd..dae56afa 100644 --- a/src/eligibility_signposting_api/repos/campaign_repo.py +++ b/src/eligibility_signposting_api/repos/campaign_repo.py @@ -1,4 +1,7 @@ import json +import logging +import os +import time from collections.abc import Generator from typing import Annotated, NewType @@ -7,9 +10,11 @@ from wireup import Inject, service from eligibility_signposting_api.model.campaign_config import CampaignConfig, Rules +from eligibility_signposting_api.config.constants import TTL, RESERVED_TEST_CONSUMER_IDS BucketName = NewType("BucketName", str) +logger = logging.getLogger(__name__) @service class CampaignRepo: @@ -25,13 +30,55 @@ def __init__( super().__init__() self.s3_client = s3_client self.bucket_name = bucket_name + self._campaign_configs_cache: list[CampaignConfig] | None = None + self._cache_expiry_epoch: float = 0.0 + self._cache_ttl_seconds: int = int(TTL.get(os.getenv("ENVIRONMENT"), 0)) + + def get_campaign_configs(self, consumer_id: str) -> Generator[CampaignConfig, None, None]: + now = time.time() + cache_enabled = self._cache_ttl_seconds > 0 + cache_valid = ( + cache_enabled + and consumer_id not in RESERVED_TEST_CONSUMER_IDS + and self._campaign_configs_cache is not None + and now < self._cache_expiry_epoch + ) - def get_campaign_configs(self) -> Generator[CampaignConfig]: with xray_recorder.in_subsegment("CampaignRepo.get_campaign_configs"): + if cache_valid: + logger.info("Using cached campaign configs") + yield from self._campaign_configs_cache + return + + logger.info( + "Refreshing campaign configs from S3 (consumer_id=%s, ttl_seconds=%s)", + consumer_id, + self._cache_ttl_seconds, + ) + campaign_configs = self._load_campaign_configs_from_s3() + + if cache_enabled and consumer_id not in RESERVED_TEST_CONSUMER_IDS: + self._campaign_configs_cache = campaign_configs + self._cache_expiry_epoch = now + self._cache_ttl_seconds + + yield from campaign_configs + + def _load_campaign_configs_from_s3(self) -> list[CampaignConfig]: + campaign_configs: list[CampaignConfig] = [] + + with xray_recorder.in_subsegment("CampaignRepo.load_campaign_configs_from_s3"): with xray_recorder.in_subsegment("list_objects"): campaign_objects = self.s3_client.list_objects(Bucket=self.bucket_name) + with xray_recorder.in_subsegment("get_objects"): - for campaign_object in campaign_objects["Contents"]: - response = self.s3_client.get_object(Bucket=self.bucket_name, Key=f"{campaign_object['Key']}") + for campaign_object in campaign_objects.get("Contents", []): + response = self.s3_client.get_object( + Bucket=self.bucket_name, + Key=f"{campaign_object['Key']}", + ) body = response["Body"].read() - yield Rules.model_validate(json.loads(body)).campaign_config + campaign_configs.append( + Rules.model_validate(json.loads(body)).campaign_config + ) + + return campaign_configs diff --git a/src/eligibility_signposting_api/services/eligibility_services.py b/src/eligibility_signposting_api/services/eligibility_services.py index 13b701d6..4642e74b 100644 --- a/src/eligibility_signposting_api/services/eligibility_services.py +++ b/src/eligibility_signposting_api/services/eligibility_services.py @@ -50,7 +50,9 @@ def get_eligibility_status( except NotFoundError as e: raise UnknownPersonError from e else: - campaign_configs: list[CampaignConfig] = list(self.campaign_repo.get_campaign_configs()) + campaign_configs: list[CampaignConfig] = list( + self.campaign_repo.get_campaign_configs(consumer_id) + ) permitted_campaign_configs = self.__collect_permitted_campaign_configs( campaign_configs, ConsumerId(consumer_id) ) diff --git a/src/eligibility_signposting_api/views/eligibility.py b/src/eligibility_signposting_api/views/eligibility.py index b935678f..6a49d508 100644 --- a/src/eligibility_signposting_api/views/eligibility.py +++ b/src/eligibility_signposting_api/views/eligibility.py @@ -27,7 +27,6 @@ Status.not_actionable: eligibility_response.Status.not_actionable, Status.not_eligible: eligibility_response.Status.not_eligible, } - logger = logging.getLogger(__name__) eligibility_blueprint = Blueprint("eligibility", __name__) diff --git a/tests/docker-compose.mock_aws.yml b/tests/docker-compose.mock_aws.yml index a4388136..8734c753 100644 --- a/tests/docker-compose.mock_aws.yml +++ b/tests/docker-compose.mock_aws.yml @@ -32,6 +32,7 @@ services: - SECRET_MANAGER_ENDPOINT=http://moto-server:5000 - FIREHOSE_ENDPOINT=http://moto-server:5000 - LOG_LEVEL=INFO + #- ENVIRONMENT=dev entrypoint: /bin/sh command: - "-c" diff --git a/tests/integration/lambda/test_app_running_as_lambda.py b/tests/integration/lambda/test_app_running_as_lambda.py index e823aed4..247bf473 100644 --- a/tests/integration/lambda/test_app_running_as_lambda.py +++ b/tests/integration/lambda/test_app_running_as_lambda.py @@ -726,3 +726,72 @@ def test_status_end_point(invoke_with_mock_apigw_request): ) ), ) + +# Runnable if we go to tests/docker-compose.mock_aws.yml and add ENVIRONMENT=dev +# shows that in a non-local environment we can bypass cache + +# def test_cache_bypass( # noqa: PLR0913 +# lambda_client: BaseClient, # noqa:ARG001 +# persisted_person: NHSNumber, +# rsv_campaign_config: CampaignConfig, +# consumer_to_active_rsv_campaign_mapping: ConsumerMapping, # noqa: ARG001 +# consumer_id: ConsumerId, +# s3_client: BaseClient, +# audit_bucket: BucketName, +# invoke_with_mock_apigw_request, +# lambda_logs: Callable[[], list[str]], +# secretsmanager_client: BaseClient, # noqa:ARG001 +# ): +# # Given +# invoke_path = f"/patient-check/{persisted_person}" +# headers = { +# "nhs-login-nhs-number": str(persisted_person), +# "x_request_id": "x_request_id", +# "x_correlation_id": "x_correlation_id", +# "nhsd_end_user_organisation_ods": "nhsd_end_user_organisation_ods", +# "nhsd-application-id": "nhsd-application-id", +# "NHSE-Product-ID": consumer_id, +# } +# params = {"includeActions": "Y"} +# +# objects = s3_client.list_objects_v2(Bucket="test-rules-bucket").get("Contents", []) +# assert_that(objects, is_not(equal_to([]))) +# config_key = objects[0]["Key"] +# original = s3_client.get_object(Bucket="test-rules-bucket", Key=config_key) +# original_payload = json.loads(original["Body"].read()) +# print(original_payload) +# +# # When +# response = invoke_with_mock_apigw_request(path=invoke_path, headers=headers, params=params) +# +# # Then +# assert_that( +# response, +# is_response().with_status_code(HTTPStatus.OK).and_body(is_json_that(has_key("processedSuggestions"))), +# ) +# +# original_payload["CampaignConfig"]["Target"] = "RSV_CHANGED_FOR_BYPASS_TEST" +# s3_client.put_object( +# Bucket="test-rules-bucket", +# Key=config_key, +# Body=json.dumps(original_payload), +# ContentType="application/json", +# ) +# +# # Second request without bypass header should still use cached config +# response_2 = invoke_with_mock_apigw_request(invoke_path, headers) +# assert_that( +# response_2, +# is_response().with_status_code(HTTPStatus.OK).and_body(is_json_that(has_key("processedSuggestions"))), +# ) +# +# # Third request with bypass header should re-read S3 and reflect the change +# bypass_headers = { +# **headers, +# "X-Bypass-Campaign-Config-Cache": "true", +# } +# response_3 = invoke_with_mock_apigw_request(invoke_path, bypass_headers) +# assert_that( +# response_3, +# is_response().with_status_code(500), +# ) diff --git a/tests/integration/repo/test_campaign_repo.py b/tests/integration/repo/test_campaign_repo.py index 96742d38..72d061b3 100644 --- a/tests/integration/repo/test_campaign_repo.py +++ b/tests/integration/repo/test_campaign_repo.py @@ -27,7 +27,7 @@ def test_get_campaign_config(s3_client: BaseClient, rules_bucket: BucketName, ca repo = CampaignRepo(s3_client, rules_bucket) # When - actual = list(repo.get_campaign_configs()) + actual = list(repo.get_campaign_configs("consumer_id")) # Then assert_that( diff --git a/tests/unit/repos/test_campaign_repo.py b/tests/unit/repos/test_campaign_repo.py new file mode 100644 index 00000000..53b97000 --- /dev/null +++ b/tests/unit/repos/test_campaign_repo.py @@ -0,0 +1,110 @@ +import io +import json +from unittest.mock import MagicMock + +import pytest + +from eligibility_signposting_api.repos.campaign_repo import CampaignRepo, BucketName +from tests.fixtures.builders.model.rule import CampaignConfigFactory + + +def make_s3_body(payload: dict): + return {"Body": io.BytesIO(json.dumps(payload).encode("utf-8"))} + + +class TestCampaignRepo: + @pytest.fixture + def mock_s3_client(self): + return MagicMock() + + @pytest.fixture + def repo(self, mock_s3_client): + return CampaignRepo( + s3_client=mock_s3_client, + bucket_name=BucketName("test-bucket"), + ) + + @pytest.fixture + def rules_payload(self): + campaign_config = CampaignConfigFactory.build() + return { + "campaign_config": campaign_config.model_dump(mode="json") + } + + def test_get_campaign_configs_loads_from_s3(self, repo, mock_s3_client, rules_payload): + mock_s3_client.list_objects.return_value = { + "Contents": [{"Key": "rsv.json"}] + } + mock_s3_client.get_object.return_value = make_s3_body(rules_payload) + + result = list(repo.get_campaign_configs("consumer_id")) + + assert len(result) == 1 + assert result[0].id == rules_payload["campaign_config"]["id"] + + mock_s3_client.list_objects.assert_called_once_with(Bucket="test-bucket") + mock_s3_client.get_object.assert_called_once_with( + Bucket="test-bucket", + Key="rsv.json", + ) + + def test_get_campaign_configs_uses_cache_within_ttl( + self, + repo, + mock_s3_client, + monkeypatch, + ): + repo._cache_ttl_seconds = 60 + + first_config = CampaignConfigFactory.build(version=1) + + mock_s3_client.list_objects.return_value = { + "Contents": [{"Key": "rsv.json"}] + } + mock_s3_client.get_object.return_value = make_s3_body( + {"campaign_config": first_config.model_dump(mode="json")} + ) + + monkeypatch.setattr("time.time", lambda: 1000.0) + + first = list(repo.get_campaign_configs("consumer_id")) + second = list(repo.get_campaign_configs("consumer_id")) + + assert first[0].version == 1 + assert second[0].version == 1 + assert mock_s3_client.list_objects.call_count == 1 + assert mock_s3_client.get_object.call_count == 1 + + def test_get_campaign_configs_refreshes_after_ttl_expiry( + self, + repo, + mock_s3_client, + monkeypatch, + ): + repo._cache_ttl_seconds = 60 + + first_config = CampaignConfigFactory.build(version=1) + second_config = CampaignConfigFactory.build(version=2) + + mock_s3_client.list_objects.return_value = { + "Contents": [{"Key": "rsv.json"}] + } + mock_s3_client.get_object.side_effect = [ + make_s3_body({"campaign_config": first_config.model_dump(mode="json")}), + make_s3_body({"campaign_config": second_config.model_dump(mode="json")}), + ] + + current_time = {"value": 1000.0} + monkeypatch.setattr("time.time", lambda: current_time["value"]) + + first = list(repo.get_campaign_configs("consumer_id")) + current_time["value"] = 1030.0 + second = list(repo.get_campaign_configs("consumer_id")) + current_time["value"] = 1061.0 + third = list(repo.get_campaign_configs("test-consumer-1")) + + assert first[0].version == 1 + assert second[0].version == 1 + assert third[0].version == 2 + assert mock_s3_client.list_objects.call_count == 2 + assert mock_s3_client.get_object.call_count == 2 diff --git a/tests/unit/views/test_eligibility.py b/tests/unit/views/test_eligibility.py index 3a0c2304..2b6b5c5f 100644 --- a/tests/unit/views/test_eligibility.py +++ b/tests/unit/views/test_eligibility.py @@ -91,6 +91,7 @@ def get_eligibility_status( _include_actions: str, _conditions: list[str], _category: str, + _consumer_id: str, ) -> EligibilityStatus: raise ValueError