Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/eligibility_signposting_api/config/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,11 @@
CONSUMER_ID = "NHSE-Product-ID"
ALLOWED_CONDITIONS = Literal["COVID", "FLU", "MMR", "RSV"]
CONSUMER_MAPPING_FILE_NAME = "consumer_mapping_config.json"
RESERVED_TEST_CONSUMER_IDS = {"test-consumer-1", "test-consumer-2", "test-consumer-3"}

TTL = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could just be env vars

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g.

import os
from typing import Literal

URL_PREFIX = "patient-check"
RULE_STOP_DEFAULT = False
NHS_NUMBER_HEADER = "nhs-login-nhs-number"
CONSUMER_ID = "NHSE-Product-ID"
ALLOWED_CONDITIONS = Literal["COVID", "FLU", "MMR", "RSV"]
CONSUMER_MAPPING_FILE_NAME = "consumer_mapping_config.json"
RESERVED_TEST_CONSUMER_IDS = {"test-consumer-1", "test-consumer-2", "test-consumer-3"}

CACHE_TTL_SECONDS = int(os.getenv("CONFIG_CACHE_TTL_SECONDS", "1800"))

"test": 300,
"dev": 300,
"preprod": 300,
"prod": 300,
}
55 changes: 51 additions & 4 deletions src/eligibility_signposting_api/repos/campaign_repo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import json
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using TTLCache would simplify this a bit e.g.

import json
import logging
from collections.abc import Generator
from typing import Annotated, NewType

from aws_xray_sdk.core import xray_recorder
from botocore.client import BaseClient
from cachetools import TTLCache
from wireup import Inject, service

from eligibility_signposting_api.config.constants import CACHE_TTL_SECONDS, RESERVED_TEST_CONSUMER_IDS
from eligibility_signposting_api.model.campaign_config import CampaignConfig, Rules

BucketName = NewType("BucketName", str)

logger = logging.getLogger(name)

campaign_config_cache: TTLCache[str, list[CampaignConfig]] = TTLCache(maxsize=1, ttl=CACHE_TTL_SECONDS)

@service
class CampaignRepo:
"""Repository class for Campaign Rules, which we can use to calculate a person's eligibility for vaccination.

These rules are stored as JSON files in AWS S3."""

def __init__(
    self,
    s3_client: Annotated[BaseClient, Inject(qualifier="s3")],
    bucket_name: Annotated[BucketName, Inject(param="rules_bucket_name")],
) -> None:
    super().__init__()
    self.s3_client = s3_client
    self.bucket_name = bucket_name

def get_campaign_configs(self, consumer_id: str) -> Generator[CampaignConfig, None, None]:
    bypass = consumer_id in RESERVED_TEST_CONSUMER_IDS
    cache_key = "all_campaigns"
    cached = None if bypass else campaign_config_cache.get(cache_key)

    with xray_recorder.in_subsegment("CampaignRepo.get_campaign_configs"):
        if cached is not None:
            logger.info("Using cached campaign configs")
            yield from cached
            return

        logger.info(
            "Refreshing campaign configs from S3 (consumer_id=%s, ttl_seconds=%s)",
            consumer_id,
            CACHE_TTL_SECONDS,
        )
        configs = self._load_campaign_configs_from_s3()

        if not bypass:
            campaign_config_cache[cache_key] = configs

        yield from configs

def _load_campaign_configs_from_s3(self) -> list[CampaignConfig]:
    campaign_configs: list[CampaignConfig] = []

    with xray_recorder.in_subsegment("CampaignRepo.load_campaign_configs_from_s3"):
        with xray_recorder.in_subsegment("list_objects"):
            campaign_objects = self.s3_client.list_objects(Bucket=self.bucket_name)

        with xray_recorder.in_subsegment("get_objects"):
            for campaign_object in campaign_objects.get("Contents", []):
                response = self.s3_client.get_object(
                    Bucket=self.bucket_name,
                    Key=f"{campaign_object['Key']}",
                )
                body = response["Body"].read()
                campaign_configs.append(
                    Rules.model_validate(json.loads(body)).campaign_config
                )

    return campaign_configs

import logging
import os
import time
from collections.abc import Generator
from typing import Annotated, NewType

Expand All @@ -7,9 +10,11 @@
from wireup import Inject, service

from eligibility_signposting_api.model.campaign_config import CampaignConfig, Rules
from eligibility_signposting_api.config.constants import TTL, RESERVED_TEST_CONSUMER_IDS

BucketName = NewType("BucketName", str)

logger = logging.getLogger(__name__)

@service
class CampaignRepo:
Expand All @@ -25,13 +30,55 @@ def __init__(
super().__init__()
self.s3_client = s3_client
self.bucket_name = bucket_name
self._campaign_configs_cache: list[CampaignConfig] | None = None
self._cache_expiry_epoch: float = 0.0
self._cache_ttl_seconds: int = int(TTL.get(os.getenv("ENVIRONMENT"), 0))

def get_campaign_configs(self, consumer_id: str) -> Generator[CampaignConfig, None, None]:
now = time.time()
cache_enabled = self._cache_ttl_seconds > 0
cache_valid = (
cache_enabled
and consumer_id not in RESERVED_TEST_CONSUMER_IDS
and self._campaign_configs_cache is not None
and now < self._cache_expiry_epoch
)

def get_campaign_configs(self) -> Generator[CampaignConfig]:
with xray_recorder.in_subsegment("CampaignRepo.get_campaign_configs"):
if cache_valid:
logger.info("Using cached campaign configs")
yield from self._campaign_configs_cache
return

logger.info(
"Refreshing campaign configs from S3 (consumer_id=%s, ttl_seconds=%s)",
consumer_id,
self._cache_ttl_seconds,
)
campaign_configs = self._load_campaign_configs_from_s3()

if cache_enabled and consumer_id not in RESERVED_TEST_CONSUMER_IDS:
self._campaign_configs_cache = campaign_configs
self._cache_expiry_epoch = now + self._cache_ttl_seconds

yield from campaign_configs

def _load_campaign_configs_from_s3(self) -> list[CampaignConfig]:
campaign_configs: list[CampaignConfig] = []

with xray_recorder.in_subsegment("CampaignRepo.load_campaign_configs_from_s3"):
with xray_recorder.in_subsegment("list_objects"):
campaign_objects = self.s3_client.list_objects(Bucket=self.bucket_name)

with xray_recorder.in_subsegment("get_objects"):
for campaign_object in campaign_objects["Contents"]:
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=f"{campaign_object['Key']}")
for campaign_object in campaign_objects.get("Contents", []):
response = self.s3_client.get_object(
Bucket=self.bucket_name,
Key=f"{campaign_object['Key']}",
)
body = response["Body"].read()
yield Rules.model_validate(json.loads(body)).campaign_config
campaign_configs.append(
Rules.model_validate(json.loads(body)).campaign_config
)

return campaign_configs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def get_eligibility_status(
except NotFoundError as e:
raise UnknownPersonError from e
else:
campaign_configs: list[CampaignConfig] = list(self.campaign_repo.get_campaign_configs())
campaign_configs: list[CampaignConfig] = list(
self.campaign_repo.get_campaign_configs(consumer_id)
)
permitted_campaign_configs = self.__collect_permitted_campaign_configs(
campaign_configs, ConsumerId(consumer_id)
)
Expand Down
1 change: 0 additions & 1 deletion src/eligibility_signposting_api/views/eligibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
Status.not_actionable: eligibility_response.Status.not_actionable,
Status.not_eligible: eligibility_response.Status.not_eligible,
}

logger = logging.getLogger(__name__)

eligibility_blueprint = Blueprint("eligibility", __name__)
Expand Down
1 change: 1 addition & 0 deletions tests/docker-compose.mock_aws.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ services:
- SECRET_MANAGER_ENDPOINT=http://moto-server:5000
- FIREHOSE_ENDPOINT=http://moto-server:5000
- LOG_LEVEL=INFO
#- ENVIRONMENT=dev
entrypoint: /bin/sh
command:
- "-c"
Expand Down
69 changes: 69 additions & 0 deletions tests/integration/lambda/test_app_running_as_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,3 +726,72 @@ def test_status_end_point(invoke_with_mock_apigw_request):
)
),
)

# Runnable if we go to tests/docker-compose.mock_aws.yml and add ENVIRONMENT=dev
# shows that in a non-local environment we can bypass cache

# def test_cache_bypass( # noqa: PLR0913
# lambda_client: BaseClient, # noqa:ARG001
# persisted_person: NHSNumber,
# rsv_campaign_config: CampaignConfig,
# consumer_to_active_rsv_campaign_mapping: ConsumerMapping, # noqa: ARG001
# consumer_id: ConsumerId,
# s3_client: BaseClient,
# audit_bucket: BucketName,
# invoke_with_mock_apigw_request,
# lambda_logs: Callable[[], list[str]],
# secretsmanager_client: BaseClient, # noqa:ARG001
# ):
# # Given
# invoke_path = f"/patient-check/{persisted_person}"
# headers = {
# "nhs-login-nhs-number": str(persisted_person),
# "x_request_id": "x_request_id",
# "x_correlation_id": "x_correlation_id",
# "nhsd_end_user_organisation_ods": "nhsd_end_user_organisation_ods",
# "nhsd-application-id": "nhsd-application-id",
# "NHSE-Product-ID": consumer_id,
# }
# params = {"includeActions": "Y"}
#
# objects = s3_client.list_objects_v2(Bucket="test-rules-bucket").get("Contents", [])
# assert_that(objects, is_not(equal_to([])))
# config_key = objects[0]["Key"]
# original = s3_client.get_object(Bucket="test-rules-bucket", Key=config_key)
# original_payload = json.loads(original["Body"].read())
# print(original_payload)
#
# # When
# response = invoke_with_mock_apigw_request(path=invoke_path, headers=headers, params=params)
#
# # Then
# assert_that(
# response,
# is_response().with_status_code(HTTPStatus.OK).and_body(is_json_that(has_key("processedSuggestions"))),
# )
#
# original_payload["CampaignConfig"]["Target"] = "RSV_CHANGED_FOR_BYPASS_TEST"
# s3_client.put_object(
# Bucket="test-rules-bucket",
# Key=config_key,
# Body=json.dumps(original_payload),
# ContentType="application/json",
# )
#
# # Second request without bypass header should still use cached config
# response_2 = invoke_with_mock_apigw_request(invoke_path, headers)
# assert_that(
# response_2,
# is_response().with_status_code(HTTPStatus.OK).and_body(is_json_that(has_key("processedSuggestions"))),
# )
#
# # Third request with bypass header should re-read S3 and reflect the change
# bypass_headers = {
# **headers,
# "X-Bypass-Campaign-Config-Cache": "true",
# }
# response_3 = invoke_with_mock_apigw_request(invoke_path, bypass_headers)
# assert_that(
# response_3,
# is_response().with_status_code(500),
# )
2 changes: 1 addition & 1 deletion tests/integration/repo/test_campaign_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_get_campaign_config(s3_client: BaseClient, rules_bucket: BucketName, ca
repo = CampaignRepo(s3_client, rules_bucket)

# When
actual = list(repo.get_campaign_configs())
actual = list(repo.get_campaign_configs("consumer_id"))

# Then
assert_that(
Expand Down
110 changes: 110 additions & 0 deletions tests/unit/repos/test_campaign_repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import io
import json
from unittest.mock import MagicMock

import pytest

from eligibility_signposting_api.repos.campaign_repo import CampaignRepo, BucketName
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'd then import campaign_config_cache here

from tests.fixtures.builders.model.rule import CampaignConfigFactory


def make_s3_body(payload: dict):
return {"Body": io.BytesIO(json.dumps(payload).encode("utf-8"))}


class TestCampaignRepo:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and can have a fixture to clear the cache e.g.

@pytest.fixture(autouse=True)
def clear_cache(self):
    campaign_config_cache.clear()

@pytest.fixture
def mock_s3_client(self):
return MagicMock()

@pytest.fixture
def repo(self, mock_s3_client):
return CampaignRepo(
s3_client=mock_s3_client,
bucket_name=BucketName("test-bucket"),
)

@pytest.fixture
def rules_payload(self):
campaign_config = CampaignConfigFactory.build()
return {
"campaign_config": campaign_config.model_dump(mode="json")
}

def test_get_campaign_configs_loads_from_s3(self, repo, mock_s3_client, rules_payload):
mock_s3_client.list_objects.return_value = {
"Contents": [{"Key": "rsv.json"}]
}
mock_s3_client.get_object.return_value = make_s3_body(rules_payload)

result = list(repo.get_campaign_configs("consumer_id"))

assert len(result) == 1
assert result[0].id == rules_payload["campaign_config"]["id"]

mock_s3_client.list_objects.assert_called_once_with(Bucket="test-bucket")
mock_s3_client.get_object.assert_called_once_with(
Bucket="test-bucket",
Key="rsv.json",
)

def test_get_campaign_configs_uses_cache_within_ttl(
self,
repo,
mock_s3_client,
monkeypatch,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... wouldn't need monkeypatch as could then call the cache.clear to force the refresh (later on)

):
repo._cache_ttl_seconds = 60

first_config = CampaignConfigFactory.build(version=1)

mock_s3_client.list_objects.return_value = {
"Contents": [{"Key": "rsv.json"}]
}
mock_s3_client.get_object.return_value = make_s3_body(
{"campaign_config": first_config.model_dump(mode="json")}
)

monkeypatch.setattr("time.time", lambda: 1000.0)

first = list(repo.get_campaign_configs("consumer_id"))
second = list(repo.get_campaign_configs("consumer_id"))

assert first[0].version == 1
assert second[0].version == 1
assert mock_s3_client.list_objects.call_count == 1
assert mock_s3_client.get_object.call_count == 1

def test_get_campaign_configs_refreshes_after_ttl_expiry(
self,
repo,
mock_s3_client,
monkeypatch,
):
repo._cache_ttl_seconds = 60

first_config = CampaignConfigFactory.build(version=1)
second_config = CampaignConfigFactory.build(version=2)

mock_s3_client.list_objects.return_value = {
"Contents": [{"Key": "rsv.json"}]
}
mock_s3_client.get_object.side_effect = [
make_s3_body({"campaign_config": first_config.model_dump(mode="json")}),
make_s3_body({"campaign_config": second_config.model_dump(mode="json")}),
]

current_time = {"value": 1000.0}
monkeypatch.setattr("time.time", lambda: current_time["value"])

first = list(repo.get_campaign_configs("consumer_id"))
current_time["value"] = 1030.0
second = list(repo.get_campaign_configs("consumer_id"))
current_time["value"] = 1061.0
third = list(repo.get_campaign_configs("test-consumer-1"))

assert first[0].version == 1
assert second[0].version == 1
assert third[0].version == 2
assert mock_s3_client.list_objects.call_count == 2
assert mock_s3_client.get_object.call_count == 2
1 change: 1 addition & 0 deletions tests/unit/views/test_eligibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def get_eligibility_status(
_include_actions: str,
_conditions: list[str],
_category: str,
_consumer_id: str,
) -> EligibilityStatus:
raise ValueError

Expand Down
Loading