diff --git a/metrics/api/settings/__init__.py b/metrics/api/settings/__init__.py index bfb41ba1f..12f4f7d5f 100644 --- a/metrics/api/settings/__init__.py +++ b/metrics/api/settings/__init__.py @@ -3,6 +3,8 @@ match config.APIENV: case "LOCAL": from .local import * + case "TEST": + from .test import * case "STANDALONE": from .standalone import * diff --git a/metrics/api/settings/test.py b/metrics/api/settings/test.py new file mode 100644 index 000000000..38f1455d7 --- /dev/null +++ b/metrics/api/settings/test.py @@ -0,0 +1,22 @@ +import os + +from metrics.api.settings import ROOT_LEVEL_BASE_DIR + +DATA_UPLOAD_MAX_NUMBER_FIELDS = None + +DEBUG = True + +DATABASES = { + "default": { + "TIME_ZONE": "Europe/London", + "ENGINE": "django.db.backends.sqlite3", + "NAME": os.path.join(ROOT_LEVEL_BASE_DIR, "test.sqlite3"), + }, + "test": { + "TIME_ZONE": "Europe/London", + "ENGINE": "django.db.backends.sqlite3", + "NAME": os.path.join(ROOT_LEVEL_BASE_DIR, "test.sqlite3"), + }, +} + +INTERNAL_IPS = ["127.0.0.1"] diff --git a/metrics/api/urls_construction.py b/metrics/api/urls_construction.py index 57ac14819..6dd3cd6a9 100644 --- a/metrics/api/urls_construction.py +++ b/metrics/api/urls_construction.py @@ -98,7 +98,7 @@ def construct_cms_admin_urlpatterns( ] -DEFAULT_PUBLIC_API_PREFIX = "api/public/timeseries/" +DEFAULT_PUBLIC_API_PREFIX = "api/public/" def construct_public_api_urlpatterns( diff --git a/public_api/urls.py b/public_api/urls.py index ca1b3870f..a6eba4212 100644 --- a/public_api/urls.py +++ b/public_api/urls.py @@ -9,6 +9,7 @@ GeographyTypeListViewV2, MetricListViewV2, PublicAPIRootViewV2, + SearchView, SubThemeDetailViewV2, SubThemeListViewV2, ThemeDetailViewV2, @@ -33,6 +34,9 @@ ) from public_api.views.timeseries_viewset import APITimeSeriesViewSet +TIMESERIES_PREFIX = "timeseries/" +SEARCH_PREFIX = "search/" + def construct_url_patterns_for_public_api( *, @@ -48,8 +52,12 @@ def construct_url_patterns_for_public_api( set of versioned URLS. """ urls = [] - urls.extend(_construct_version_one_urls(prefix=prefix)) - urls.extend(_construct_version_two_urls(prefix=prefix)) + # Timeseries API + urls.extend(_construct_version_one_urls(prefix=prefix + TIMESERIES_PREFIX)) + urls.extend(_construct_version_two_urls(prefix=prefix + TIMESERIES_PREFIX)) + + # Search API + urls.extend(_construct_search_urls(prefix=prefix + SEARCH_PREFIX)) if MetricsPublicAPIInterface.is_auth_enabled(): urls.append( @@ -222,3 +230,25 @@ def _construct_version_two_urls( name="timeseries-list-v2", ), ] + + +def _construct_search_urls( + *, + prefix: str, +) -> list[resolvers.URLResolver]: + """Returns a list of URLResolvers for the public search API + + Args: + prefix: The prefix to add to the start of the url paths + + Returns: + List of `URLResolver` objects each representing a + set of versioned URLS. + """ + return [ + path( + f"{prefix}v1", + SearchView.as_view(), + name="search", + ), + ] diff --git a/public_api/version_02/views/__init__.py b/public_api/version_02/views/__init__.py index aa39e1c87..913114386 100644 --- a/public_api/version_02/views/__init__.py +++ b/public_api/version_02/views/__init__.py @@ -13,3 +13,5 @@ ) from .root_view import PublicAPIRootViewV2 + +from .search import SearchView diff --git a/public_api/version_02/views/search.py b/public_api/version_02/views/search.py new file mode 100644 index 000000000..bf980a465 --- /dev/null +++ b/public_api/version_02/views/search.py @@ -0,0 +1,33 @@ +from drf_spectacular.utils import extend_schema +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.views import APIView +from wagtail.api.v2.serializers import get_serializer_class +from wagtail.models import Page + +from cms.topic.models import TopicPage +from public_api.version_02.views.base import PUBLIC_API_TAG + + +@extend_schema(tags=[PUBLIC_API_TAG]) +class SearchView(APIView): + """This endpoint provides search results and could in the future provide + autocomplete suggestions etc. + + """ + + def get(self, request: Request): + search = request.GET.get("search") + limit = int(request.GET.get("limit", 0)) + fields = request.GET.get("fields", ["title", "slug"]) + meta = request.GET.get("meta", ["id"]) + topic_results = TopicPage.objects.all().search(search) + if not limit or topic_results.count() < limit: + # TODO: go get more + # results = queryset + results = topic_results + else: + results = topic_results[0:limit] + print(f"AIDAN: returning {results}") + serialized = get_serializer_class(Page, fields, meta)(results, many=True) + return Response(serialized.data) diff --git a/scripts/_cache.sh b/scripts/_cache.sh index e8350dd7b..9896e16b7 100644 --- a/scripts/_cache.sh +++ b/scripts/_cache.sh @@ -9,6 +9,7 @@ function _cache_help() { echo echo " flush-redis - flush and re-fill the redis (private api) cache" echo " flush-redis-reserved-namespace - blue-green update the reserved namespace in the redis (private api) cache" + echo " flush-search - flush and re-build the wagtail search index" return 0 } @@ -20,6 +21,7 @@ function _cache() { case $verb in "flush-redis") _cache_flush_redis $args ;; "flush-redis-reserved-namespace") _cache_flush_redis_reserved_namespace $args ;; + "flush-search") _flush_search $args ;; *) _cache_help ;; esac @@ -34,3 +36,8 @@ function _cache_flush_redis_reserved_namespace() { uhd venv activate python manage.py hydrate_private_api_cache_reserved_namespace } + +function _flush_search() { + uhd venv activate + python manage.py wagtail_update_index +} diff --git a/scripts/_tests.sh b/scripts/_tests.sh index 3fac12953..8fbd606e2 100644 --- a/scripts/_tests.sh +++ b/scripts/_tests.sh @@ -41,7 +41,10 @@ function _tests_unit() { function _tests_integration() { uhd venv activate - python -m pytest tests/integration "$@" + rm -f test.sqlite3 + uhd django migrate + uhd bootstrap all + pytest tests/integration "$@" } function _tests_system() { diff --git a/tests/factories/cms/page.py b/tests/factories/cms/page.py new file mode 100644 index 000000000..083f96e6c --- /dev/null +++ b/tests/factories/cms/page.py @@ -0,0 +1,12 @@ +import factory + +from cms.common.models import CommonPage + + +class CommonPageFactory(factory.django.DjangoModelFactory): + """ + Factory for creating `CommonPage` instances for tests + """ + + class Meta: + model = CommonPage diff --git a/tests/factories/cms/topic.py b/tests/factories/cms/topic.py new file mode 100644 index 000000000..e54a7b4d5 --- /dev/null +++ b/tests/factories/cms/topic.py @@ -0,0 +1,12 @@ +import factory + +from cms.topic.models import TopicPage + + +class TopicPageFactory(factory.django.DjangoModelFactory): + """ + Factory for creating `Topic` instances for tests + """ + + class Meta: + model = TopicPage diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 86c26cc68..2efef4442 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -2,6 +2,12 @@ import pytest from django.utils import timezone +from cms.common.models import UKHSAPage +from cms.topic.models import TopicPage +from tests.factories.cms.topic import TopicPageFactory +from tests.factories.cms.page import CommonPageFactory + + from metrics.data.models.core_models import ( Age, CoreHeadline, @@ -105,6 +111,44 @@ def core_timeseries_example() -> list[CoreTimeSeries]: ] +@pytest.fixture +def topics() -> list[TopicPage]: + topics = [] + for i in range(5): + title = f"title topic {i}" + if i % 2 == 0: + title += " rare" + topics = topics + [ + TopicPageFactory.create( + path=f"topic-{i}", + depth=1, + title=title, + slug=f"slug-{i}", + seo_title=f"seo_title {i}", + body=f"body text {i}", + ) + ] + return topics + + +@pytest.fixture +def other_pages() -> list[UKHSAPage]: + pages = [] + for i in range(5): + pages = pages + [ + CommonPageFactory.create( + path=f"page-{i}", + depth=1, + title=f"title other {i}", + slug=f"slug-page{i}", + seo_title=f"seo_title {i}", + body=f"body text {i}", + ) + ] + + return pages + + @pytest.fixture def patch_auth_enabled(monkeypatch): monkeypatch.setenv("AUTH_ENABLED", "1") diff --git a/tests/integration/public_api/v2/views/test_search.py b/tests/integration/public_api/v2/views/test_search.py new file mode 100644 index 000000000..90193cd2b --- /dev/null +++ b/tests/integration/public_api/v2/views/test_search.py @@ -0,0 +1,92 @@ +from http import HTTPStatus +import os +import pytest + +from rest_framework.test import RequestsClient + + +@pytest.mark.django_db +class TestSearchAPIView: + + @property + def path(self) -> str: + return "/api/public/search/v1" + + @property + def target_domain(self) -> str: + return os.environ.get("PUBLIC_API_TEST_DOMAIN", "http://testserver") + + def test_search_finds_topics_only(self, topics, other_pages): + """ + Given a string that matches all topic pages + When the API is called with that query and a limit of 5 + Then the results will include only topic pages + """ + + # Given + limit = len(topics) + search = "topic" # All topics have a title with topic in them + query = f"search={search}&limit={limit}" + + # When + client = RequestsClient() + url = f"{self.target_domain}{self.path}?{query}" + response: Response = client.get(url) + + # Then + assert response.status_code == HTTPStatus.OK + response_data: list[dict] = response.json() + assert limit == len(response_data) + for page in topics: + target = {"title": page.title, "slug": page.slug} + assert response_data.index(target) > -1 + + def test_search_finds_topics_and_pages(self, topics, other_pages): + """ + Given a string that matches 2 topic page + When the API is called with that query and no limit + Then the results will include the matching topic pages and others with the topics first + """ + + # Given + search = "rare" # All topics have a title, only even ones have rare in it + query = f"search={search}" + expected_results = [t for t in topics if "rare" in t.title] + + # When + client = RequestsClient() + url = f"{self.target_domain}{self.path}?{query}" + response: Response = client.get(url) + + # Then + assert response.status_code == HTTPStatus.OK + response_data: list[dict] = response.json() + assert len(expected_results) == len(response_data) + for page in expected_results: + target = {"title": page.title, "slug": page.slug} + assert response_data.index(target) > -1 + + def test_search_doesnt_find_unpublish_pages(self, topics, other_pages): + """ + Given a string that matches 2 topic page + When the API is called with that query and no limit + Then the results will include the matching topic pages and others with the topics first + """ + + # Given + search = "rare" # All topics have a title, only even ones have rare in it + query = f"search={search}" + expected_results = [t for t in topics if "rare" in t.title] + + # When + client = RequestsClient() + url = f"{self.target_domain}{self.path}?{query}" + response: Response = client.get(url) + + # Then + assert response.status_code == HTTPStatus.OK + response_data: list[dict] = response.json() + assert len(expected_results) == len(response_data) + for page in expected_results: + target = {"title": page.title, "slug": page.slug} + assert response_data.index(target) > -1