From 7426844ff06518d2139273bdec6d297f8bf298ce Mon Sep 17 00:00:00 2001 From: Sandra Guerreiro Date: Mon, 29 Dec 2025 15:25:09 +0100 Subject: [PATCH 01/16] feat(search): add lru_cache to qdrant client and get_model --- src/app/services/search.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/app/services/search.py b/src/app/services/search.py index ef3e8aa..cbbe053 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -1,5 +1,5 @@ import time -from functools import cache +from functools import cache, lru_cache from typing import Tuple, cast import numpy as np @@ -49,6 +49,7 @@ async def close_qdrant(): await _qdrant_client.close() +@lru_cache(maxsize=1) async def get_qdrant() -> AsyncQdrantClient | None: if qdrant_client is None: raise Error() @@ -149,7 +150,7 @@ def get_query_embed( return embedding - @cache + @lru_cache(maxsize=1) @log_time_and_error_sync def _get_model(self, curr_model: str) -> tuple[int | None, SentenceTransformer]: try: From 6f5044c63a95d85bd2b005f50fa04166466a4350 Mon Sep 17 00:00:00 2001 From: Stanislas Bruhiere Date: Mon, 29 Dec 2025 17:41:31 +0100 Subject: [PATCH 02/16] feat: customize uvicorn command --- k8s/welearn-api/templates/deployment.yaml | 3 ++- k8s/welearn-api/values.yaml | 6 ++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/k8s/welearn-api/templates/deployment.yaml b/k8s/welearn-api/templates/deployment.yaml index 7cba22a..6b48b65 100644 --- a/k8s/welearn-api/templates/deployment.yaml +++ b/k8s/welearn-api/templates/deployment.yaml @@ -45,9 +45,10 @@ spec: {{- end }} imagePullPolicy: IfNotPresent name: welearn-api + args: ["uvicorn", "src.main:app", "--workers", "{{.Values.uvicorn.workersCount}}", "--host", "0.0.0.0", "--port", "{{.Values.containerPort}}", "--limit-max-requests", "{{.Values.uvicorn.limitMaxRequests}}"] ports: - name: http - containerPort: 8080 + containerPort: {{ .Values.containerPort }} envFrom: {{- if .Values.config.nonSensitive }} - configMapRef: diff --git a/k8s/welearn-api/values.yaml b/k8s/welearn-api/values.yaml index bc52e02..c978d3d 100644 --- a/k8s/welearn-api/values.yaml +++ b/k8s/welearn-api/values.yaml @@ -29,6 +29,8 @@ resources: limits: memory: 1508M +containerPort: 8080 + config: nonSensitive: CLIENT_ORIGINS_REGEX: '^{{ join "|" (values .Values.allowedHostsRegexes | sortAlpha ) }}$' @@ -52,3 +54,7 @@ runOnGpu: false # Schedule on the GPU node pool to lower its cost allowedHostsRegexes: localhost: |- http:\/\/localhost:5173 + +uvicorn: + workersCount: 2 + limitMaxRequests: 1000 From a41032b9f5b659b0dc9f38e1948dbe751003836f Mon Sep 17 00:00:00 2001 From: Sandra Guerreiro Date: Mon, 29 Dec 2025 17:50:27 +0100 Subject: [PATCH 03/16] thread and cache --- src/app/core/lifespan.py | 12 +- src/app/services/search.py | 24 +- .../tests/api/api_v1/test_micro_learning.py | 153 ++++----- src/app/tests/api/api_v1/test_search.py | 315 +++++++++--------- src/app/tests/conftest.py | 15 + 5 files changed, 277 insertions(+), 242 deletions(-) create mode 100644 src/app/tests/conftest.py diff --git a/src/app/core/lifespan.py b/src/app/core/lifespan.py index f994c05..b8bc358 100644 --- a/src/app/core/lifespan.py +++ b/src/app/core/lifespan.py @@ -1,12 +1,18 @@ from contextlib import asynccontextmanager from fastapi import FastAPI +from qdrant_client import AsyncQdrantClient -from src.app.services.search import close_qdrant, init_qdrant +from src.app.api.dependencies import get_settings @asynccontextmanager async def lifespan(app: FastAPI): - await init_qdrant() + settings = get_settings() + app.state.qdrant = AsyncQdrantClient( + url=settings.QDRANT_HOST, + port=settings.QDRANT_PORT, + timeout=100, + ) yield - await close_qdrant() + await app.state.qdrant.close() diff --git a/src/app/services/search.py b/src/app/services/search.py index cbbe053..1c78b84 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -1,14 +1,13 @@ import time -from functools import cache, lru_cache +from functools import cache from typing import Tuple, cast import numpy as np -from fastapi import Depends +from fastapi import Depends, Request +from fastapi.concurrency import run_in_threadpool from numpy import ndarray -from psycopg import Error from qdrant_client import AsyncQdrantClient from qdrant_client import models as qdrant_models -from qdrant_client import qdrant_client from qdrant_client.http import exceptions as qdrant_exceptions from qdrant_client.http import models as http_models from sentence_transformers import SentenceTransformer @@ -49,12 +48,8 @@ async def close_qdrant(): await _qdrant_client.close() -@lru_cache(maxsize=1) -async def get_qdrant() -> AsyncQdrantClient | None: - if qdrant_client is None: - raise Error() - - return _qdrant_client +async def get_qdrant(request: Request) -> AsyncQdrantClient: + return request.app.state.qdrant class SearchService: @@ -150,7 +145,7 @@ def get_query_embed( return embedding - @lru_cache(maxsize=1) + @cache @log_time_and_error_sync def _get_model(self, curr_model: str) -> tuple[int | None, SentenceTransformer]: try: @@ -168,7 +163,6 @@ def _get_model(self, curr_model: str) -> tuple[int | None, SentenceTransformer]: raise ModelNotFoundError() return (model.get_max_seq_length(), model) - @cache @log_time_and_error_sync def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]: if not seq_len: @@ -192,7 +186,6 @@ def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]: return inputs - @cache @log_time_and_error_sync def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray: logger.debug("Creating embeddings model=%s", curr_model) @@ -223,10 +216,11 @@ async def search_handler( collection = await self.get_collection_by_language(lang="mul") subject_vector = get_subject_vector(qp.subject) - embedding = self.get_query_embed( + embedding = await run_in_threadpool( + self.get_query_embed, model=collection.model, - subject_vector=subject_vector, query=qp.query, + subject_vector=subject_vector, subject_influence_factor=qp.influence_factor, ) diff --git a/src/app/tests/api/api_v1/test_micro_learning.py b/src/app/tests/api/api_v1/test_micro_learning.py index 392a410..fcbdfd5 100644 --- a/src/app/tests/api/api_v1/test_micro_learning.py +++ b/src/app/tests/api/api_v1/test_micro_learning.py @@ -8,8 +8,6 @@ from src.app.models.collections import Collection from src.main import app -client = TestClient(app) - class AsyncMock(mock.MagicMock): async def __call__(self, *args, **kwargs): @@ -37,58 +35,59 @@ async def test_get_full_journey( mock_get_context_docs, mock_collection_and_model_id_according_lang, ): - mock_collection_and_model_id_according_lang.return_value = ( - Collection(name="test_collection", lang="en", model="test_model"), - "model_id", - ) - # Mock data - mock_get_context_docs.return_value = [ - ContextDocument( - id="test_id", - title="Test Title", - full_content="Test Content", - embedding=b"test_embedding", - context_type="introduction", - ), - ContextDocument( - id="test_id2", - title="Test Title target 1", - full_content="Test Content", - embedding=b"test_embedding", - context_type="target", - ), - ContextDocument( - id="test_id3", - title="Test Title target 2", - full_content="Test Content", - embedding=b"test_embedding", - context_type="target", - ), - ] - mock_get_subject.return_value = ContextDocument( - id="subject_id", - title="Test Subject", - embedding=b"subject_embedding", - context_type="subject", - ) - mock_convert_embedding.return_value = [0.1, 0.2, 0.3] - mock_search.return_value = [{"id": "doc1", "title": "Doc 1"}] + with TestClient(app) as client: + mock_collection_and_model_id_according_lang.return_value = ( + Collection(name="test_collection", lang="en", model="test_model"), + "model_id", + ) + # Mock data + mock_get_context_docs.return_value = [ + ContextDocument( + id="test_id", + title="Test Title", + full_content="Test Content", + embedding=b"test_embedding", + context_type="introduction", + ), + ContextDocument( + id="test_id2", + title="Test Title target 1", + full_content="Test Content", + embedding=b"test_embedding", + context_type="target", + ), + ContextDocument( + id="test_id3", + title="Test Title target 2", + full_content="Test Content", + embedding=b"test_embedding", + context_type="target", + ), + ] + mock_get_subject.return_value = ContextDocument( + id="subject_id", + title="Test Subject", + embedding=b"subject_embedding", + context_type="subject", + ) + mock_convert_embedding.return_value = [0.1, 0.2, 0.3] + mock_search.return_value = [{"id": "doc1", "title": "Doc 1"}] - # API call - response = client.get( - f"{settings.API_V1_STR}/micro_learning/full_journey", - params={"lang": "en", "sdg": 1, "subject": "Test Subject"}, - headers={"X-API-Key": "test"}, - ) - # Assertions - self.assertIn("introduction", response.json()) - self.assertEqual(len(response.json()["introduction"]), 1) - self.assertEqual(response.json()["introduction"][0]["title"], "Test Title") - self.assertEqual( - response.json()["introduction"][0]["documents"][0]["title"], "Doc 1" - ) - self.assertIn("target", response.json()) - self.assertEqual(len(response.json()["target"]), 2) + # API call + response = client.get( + f"{settings.API_V1_STR}/micro_learning/full_journey", + params={"lang": "en", "sdg": 1, "subject": "Test Subject"}, + headers={"X-API-Key": "test"}, + ) + # Assertions + self.assertIn("introduction", response.json()) + self.assertEqual(len(response.json()["introduction"]), 1) + self.assertEqual(response.json()["introduction"][0]["title"], "Test Title") + self.assertEqual( + response.json()["introduction"][0]["documents"][0]["title"], "Doc 1" + ) + self.assertIn("target", response.json()) + self.assertEqual(len(response.json()["target"]), 2) @mock.patch( "src.app.api.api_v1.endpoints.micro_learning.collection_and_model_id_according_lang", @@ -98,28 +97,32 @@ async def test_get_full_journey( async def test_get_subject_list( self, mock_get_subjects, mock_collection_and_model_id_according_lang ): - mock_get_subjects.return_value = [ - ContextDocument( - id="subject_id", - title="subject0", - embedding=b"subject_embedding", - context_type="subject", - ), - ContextDocument( - id="subject_id2", - title="subject1", - embedding=b"subject_embedding", - context_type="subject", - ), - ] - mock_collection_and_model_id_according_lang.return_value = (None, "model_id") + with TestClient(app) as self.client: + mock_get_subjects.return_value = [ + ContextDocument( + id="subject_id", + title="subject0", + embedding=b"subject_embedding", + context_type="subject", + ), + ContextDocument( + id="subject_id2", + title="subject1", + embedding=b"subject_embedding", + context_type="subject", + ), + ] + mock_collection_and_model_id_according_lang.return_value = ( + None, + "model_id", + ) - # API call - response = client.get( - f"{settings.API_V1_STR}/micro_learning/subject_list", - headers={"X-API-Key": "test"}, - ) + # API call + response = self.client.get( + f"{settings.API_V1_STR}/micro_learning/subject_list", + headers={"X-API-Key": "test"}, + ) - ret = response.json() + ret = response.json() - self.assertListEqual(["subject0", "subject1"], ret) + self.assertListEqual(["subject0", "subject1"], ret) diff --git a/src/app/tests/api/api_v1/test_search.py b/src/app/tests/api/api_v1/test_search.py index 1495422..fc9b04a 100644 --- a/src/app/tests/api/api_v1/test_search.py +++ b/src/app/tests/api/api_v1/test_search.py @@ -101,13 +101,14 @@ class SearchTests(IsolatedAsyncioTestCase): def test_search_items_no_query(self, *mocks): """Test search_items when no query is provided""" - response = client.post( - f"{settings.API_V1_STR}/search/collections/collection_welearn_fr_model", # noqa: E501 - json={"nb_results": 10}, - headers={"X-API-Key": "test"}, - ) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/collections/collection_welearn_fr_model", # noqa: E501 + json={"nb_results": 10}, + headers={"X-API-Key": "test"}, + ) - self.assertEqual(response.status_code, 422) + self.assertEqual(response.status_code, 422) @patch( f"{search_pipeline_path}._get_model", @@ -116,40 +117,43 @@ def test_search_items_no_query(self, *mocks): ), ) async def test_search_model_not_found(self, *mocks): - response = client.post( - f"{settings.API_V1_STR}/search/collections/collection_welearn_mul_model?query=français&nb_results=10", - headers={"X-API-Key": "test"}, - ) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/collections/collection_welearn_mul_model?query=français&nb_results=10", + headers={"X-API-Key": "test"}, + ) - assert response.status_code == 404 - assert response.json() == "Model not found" + assert response.status_code == 404 + assert response.json() == "Model not found" @patch( f"{search_pipeline_path}.search_handler", new=mock.AsyncMock(return_value=mocked_scored_points), ) async def test_search_items_success(self, *mocks): - """Test successful search_items response""" + with TestClient(app) as client: + """Test successful search_items response""" - response = client.post( - f"{settings.API_V1_STR}/search/collections/collection_welearn_fr_model?query={long_query}&nb_results=10", - headers={"X-API-Key": "test"}, # noqa: E501 - ) + response = client.post( + f"{settings.API_V1_STR}/search/collections/collection_welearn_fr_model?query={long_query}&nb_results=10", + headers={"X-API-Key": "test"}, # noqa: E501 + ) - self.assertEqual(response.status_code, 200) + self.assertEqual(response.status_code, 200) @patch( f"{search_pipeline_path}.search_handler", new=mock.AsyncMock(return_value=[]), ) async def test_search_items_no_result(self, *mocks): - response = client.post( - f"{settings.API_V1_STR}/search/collections/collection_welearn_fr_model?query={long_query}&nb_results=10", - headers={"X-API-Key": "test"}, # noqa: E501 - ) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/collections/collection_welearn_fr_model?query={long_query}&nb_results=10", + headers={"X-API-Key": "test"}, # noqa: E501 + ) - self.assertEqual(response.status_code, 206) - self.assertEqual(response.json(), []) + self.assertEqual(response.status_code, 206) + self.assertEqual(response.json(), []) @patch( f"{search_pipeline_path}.get_collection_by_language", @@ -160,15 +164,16 @@ async def test_search_items_no_result(self, *mocks): ), ) async def test_search_all_slices_no_collections(self, *mocks): - response = client.post( - f"{settings.API_V1_STR}/search/collections/collection_welearn_fr_model?query={long_query}&nb_results=10", - headers={"X-API-Key": "test"}, - ) - self.assertEqual(response.status_code, 404) - self.assertEqual( - response.json(), - "Collection not found", - ) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/collections/collection_welearn_fr_model?query={long_query}&nb_results=10", + headers={"X-API-Key": "test"}, + ) + self.assertEqual(response.status_code, 404) + self.assertEqual( + response.json(), + "Collection not found", + ) @patch("src.app.services.sql_service.session_maker") @@ -186,57 +191,61 @@ class SearchTestsSlices(IsolatedAsyncioTestCase): ), ) async def test_search_all_slices_no_collections(self, *mocks): - response = client.post( - f"{settings.API_V1_STR}/search/by_slices?nb_results=10", - json={ - "query": "une phrase plus longue pour tester la recherche en français. et voir ce que cela donne" - }, # noqa: E501 - headers={"X-API-Key": "test"}, - ) - self.assertEqual(response.status_code, 404) - self.assertEqual( - response.json(), - "Collection not found", - ) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/by_slices?nb_results=10", + json={ + "query": "une phrase plus longue pour tester la recherche en français. et voir ce que cela donne" + }, # noqa: E501 + headers={"X-API-Key": "test"}, + ) + self.assertEqual(response.status_code, 404) + self.assertEqual( + response.json(), + "Collection not found", + ) @patch(f"{search_pipeline_path}.search_handler", return_value=mocked_scored_points) async def test_search_all_slices_ok(self, *mocks): - response = client.post( - f"{settings.API_V1_STR}/search/by_slices", - json={ - "query": "Comment est-ce que les gouvernements font pour suivre ces conseils et les mettre en place ?", - "relevance_factor": 0.75, - }, - headers={"X-API-Key": "test"}, - ) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/by_slices", + json={ + "query": "Comment est-ce que les gouvernements font pour suivre ces conseils et les mettre en place ?", + "relevance_factor": 0.75, + }, + headers={"X-API-Key": "test"}, + ) - self.assertEqual(response.status_code, 200) + self.assertEqual(response.status_code, 200) async def test_search_all_slices_no_query(self, *mocks): - response = client.post( - f"{settings.API_V1_STR}/search/by_slices", - json={"query": ""}, - headers={"X-API-Key": "test"}, - ) - self.assertEqual(response.status_code, 400) - self.assertEqual( - response.json().get("detail")["message"], - "Empty query", - ) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/by_slices", + json={"query": ""}, + headers={"X-API-Key": "test"}, + ) + self.assertEqual(response.status_code, 400) + self.assertEqual( + response.json().get("detail")["message"], + "Empty query", + ) @patch( f"{search_pipeline_path}.search_handler", return_value=[], ) async def test_search_all_slices_no_result(self, *mocks): - response = client.post( - f"{settings.API_V1_STR}/search/by_slices?nb_results=10", - json={ - "query": "une phrase plus longue pour tester la recherche en français. et voir ce que cela donne" - }, - headers={"X-API-Key": "test"}, - ) - self.assertEqual(response.status_code, 204) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/by_slices?nb_results=10", + json={ + "query": "une phrase plus longue pour tester la recherche en français. et voir ce que cela donne" + }, + headers={"X-API-Key": "test"}, + ) + self.assertEqual(response.status_code, 204) @patch("src.app.services.sql_service.session_maker") @@ -259,41 +268,44 @@ class SearchTestsAll(IsolatedAsyncioTestCase): new=mock.MagicMock(return_value=mocked_collection), ) async def test_search_all_no_collections(self, *mocks): - response = client.post( - f"{settings.API_V1_STR}/search/by_document?nb_results=10", - json={ - "query": "une phrase plus longue pour tester la recherche en français. et voir ce que cela donne" - }, - headers={"X-API-Key": "test"}, - ) - self.assertEqual(response.status_code, 404) - self.assertEqual( - response.json(), - "Collection not found", - ) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/by_document?nb_results=10", + json={ + "query": "une phrase plus longue pour tester la recherche en français. et voir ce que cela donne" + }, + headers={"X-API-Key": "test"}, + ) + self.assertEqual(response.status_code, 404) + self.assertEqual( + response.json(), + "Collection not found", + ) @patch(f"{search_pipeline_path}.search_handler", return_value=[]) async def test_search_all_no_result(self, *mocks): - response = client.post( - f"{settings.API_V1_STR}/search/by_document?nb_results=10", - json={ - "query": "une phrase plus longue pour tester la recherche en français. et voir ce que cela donne" - }, - headers={"X-API-Key": "test"}, - ) - self.assertEqual(response.status_code, 204) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/by_document?nb_results=10", + json={ + "query": "une phrase plus longue pour tester la recherche en français. et voir ce que cela donne" + }, + headers={"X-API-Key": "test"}, + ) + self.assertEqual(response.status_code, 204) async def test_search_all_no_query(self, *mocks): - response = client.post( - f"{settings.API_V1_STR}/search/by_document", - json={"query": ""}, - headers={"X-API-Key": "test"}, - ) - self.assertEqual(response.status_code, 400) - self.assertEqual( - response.json().get("detail")["message"], - "Empty query", - ) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/by_document", + json={"query": ""}, + headers={"X-API-Key": "test"}, + ) + self.assertEqual(response.status_code, 400) + self.assertEqual( + response.json().get("detail")["message"], + "Empty query", + ) class TestSortSlicesUsingMMR(IsolatedAsyncioTestCase): @@ -321,17 +333,18 @@ class SearchTestsMultiInput(IsolatedAsyncioTestCase): return_value=[], ) async def test_search_multi_no_result(self, *mocks): - response = client.post( - f"{settings.API_V1_STR}/search/multiple_by_slices?nb_results=10", - json={ - "query": [ - "une phrase plus longue pour tester la recherche en français. et voir ce que cela donne", - "another long sentence to test the search in english and see what happens", - ] - }, - headers={"X-API-Key": "test"}, - ) - self.assertEqual(response.status_code, 204) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/multiple_by_slices?nb_results=10", + json={ + "query": [ + "une phrase plus longue pour tester la recherche en français. et voir ce que cela donne", + "another long sentence to test the search in english and see what happens", + ] + }, + headers={"X-API-Key": "test"}, + ) + self.assertEqual(response.status_code, 204) @patch("src.app.services.sql_db.queries.session_maker") @@ -341,6 +354,7 @@ async def test_search_multi_no_result(self, *mocks): ) class DocumentsByIdsTests(IsolatedAsyncioTestCase): async def test_documents_by_ids_empty(self, session_maker_mock, *mocks): + session = session_maker_mock.return_value.__enter__.return_value exec_docs = mock.MagicMock() exec_docs.all.return_value = [] @@ -353,14 +367,15 @@ async def test_documents_by_ids_empty(self, session_maker_mock, *mocks): session.execute.side_effect = [exec_docs, exec_corpora, exec_slices, exec_sdgs] - response = client.post( - f"{settings.API_V1_STR}/search/documents/by_ids", - json=[], - headers={"X-API-Key": "test"}, - ) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/documents/by_ids", + json=[], + headers={"X-API-Key": "test"}, + ) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json(), []) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), []) async def test_documents_by_ids_single_doc(self, session_maker_mock, *mocks): session = session_maker_mock.return_value.__enter__.return_value @@ -446,16 +461,17 @@ async def test_documents_by_ids_corpus_missing(self, session_maker_mock, *mocks) exec_sdgs.all.return_value = [] session.execute.side_effect = [exec_docs, exec_corpora, exec_slices, exec_sdgs] - response = client.post( - f"{settings.API_V1_STR}/search/documents/by_ids", - json=[doc_id], - headers={"X-API-Key": "test"}, - ) - self.assertEqual(response.status_code, 200) - body = response.json() - payload = body[0]["payload"] - self.assertEqual(len(body), 1) - self.assertEqual(payload["document_corpus"], "") + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/documents/by_ids", + json=[doc_id], + headers={"X-API-Key": "test"}, + ) + self.assertEqual(response.status_code, 200) + body = response.json() + payload = body[0]["payload"] + self.assertEqual(len(body), 1) + self.assertEqual(payload["document_corpus"], "") async def test_search_multi_single_query(self, *mocks): with mock.patch( @@ -463,22 +479,23 @@ async def test_search_multi_single_query(self, *mocks): ) as search_multi, mock.patch.object( SearchService, "search_handler", return_value=mocked_scored_points ) as search_handler: - client.post( - f"{settings.API_V1_STR}/search/multiple_by_slices?nb_results=10", - json={ - "query": long_query, - }, - headers={"X-API-Key": "test"}, - ) - search_multi.assert_called_once_with( - qp=EnhancedSearchQuery( - query=[long_query], - sdg_filter=None, - corpora=None, - subject=None, - nb_results=10, - influence_factor=2.0, - relevance_factor=1.0, - ), - callback_function=search_handler, # noqa: E501 - ) + with TestClient(app) as client: + client.post( + f"{settings.API_V1_STR}/search/multiple_by_slices?nb_results=10", + json={ + "query": long_query, + }, + headers={"X-API-Key": "test"}, + ) + search_multi.assert_called_once_with( + qp=EnhancedSearchQuery( + query=[long_query], + sdg_filter=None, + corpora=None, + subject=None, + nb_results=10, + influence_factor=2.0, + relevance_factor=1.0, + ), + callback_function=search_handler, # noqa: E501 + ) diff --git a/src/app/tests/conftest.py b/src/app/tests/conftest.py new file mode 100644 index 0000000..c280fcc --- /dev/null +++ b/src/app/tests/conftest.py @@ -0,0 +1,15 @@ +from unittest.mock import AsyncMock + +import pytest +from fastapi.testclient import TestClient + +from src.app.services.search import get_qdrant +from src.main import app + + +@pytest.fixture(scope="class") +def client(): + app.dependency_overrides[get_qdrant] = lambda: AsyncMock() + with TestClient(app) as client: + yield client + app.dependency_overrides.clear() From 32cc49d055c0a50f964f0a34064ba60455f3041b Mon Sep 17 00:00:00 2001 From: Sandra Guerreiro Date: Wed, 31 Dec 2025 15:08:41 +0100 Subject: [PATCH 04/16] wip --- src/app/api/api_v1/api.py | 2 + src/app/api/api_v1/endpoints/search.py | 24 +++++--- src/app/core/lifespan.py | 2 + src/app/services/search.py | 75 ++++++++++++------------- src/app/services/security.py | 2 + src/app/services/sql_db/queries.py | 2 + src/app/services/sql_db/queries_user.py | 2 + src/app/services/sql_service.py | 2 + src/main.py | 2 + 9 files changed, 67 insertions(+), 46 deletions(-) diff --git a/src/app/api/api_v1/api.py b/src/app/api/api_v1/api.py index 7d6d11e..106e25f 100644 --- a/src/app/api/api_v1/api.py +++ b/src/app/api/api_v1/api.py @@ -1,3 +1,5 @@ +# src/app/api/api_v1/api.py + from fastapi import APIRouter from src.app.api.api_v1.endpoints import chat, micro_learning, search, tutor, user diff --git a/src/app/api/api_v1/endpoints/search.py b/src/app/api/api_v1/endpoints/search.py index cac7fb0..4ac230a 100644 --- a/src/app/api/api_v1/endpoints/search.py +++ b/src/app/api/api_v1/endpoints/search.py @@ -1,4 +1,6 @@ -from fastapi import APIRouter, Depends, Response +# src/app/api/api_v1/endpoints/search.py + +from fastapi import APIRouter, Depends, HTTPException, Response from fastapi.concurrency import run_in_threadpool from qdrant_client.models import ScoredPoint @@ -49,7 +51,7 @@ def get_params( if not resp.query: e = EmptyQueryError() - return bad_request(message=e.message, msg_code=e.msg_code) + bad_request(message=e.message, msg_code=e.msg_code) return resp @@ -112,8 +114,10 @@ async def search_doc_by_collection( return res except (CollectionNotFoundError, ModelNotFoundError) as e: - response.status_code = 404 - return e.message + raise HTTPException( + status_code=404, + detail={"message": e.message, "code": e.msg_code}, + ) @router.post( @@ -138,8 +142,10 @@ async def search_all_slices_by_lang( return res except CollectionNotFoundError as e: - response.status_code = 404 - return e.message + raise HTTPException( + status_code=404, + detail={"message": e.message, "code": e.msg_code}, + ) @router.post( @@ -187,8 +193,10 @@ async def search_all( response.status_code = 204 return [] except CollectionNotFoundError as e: - response.status_code = 404 - return e.message + raise HTTPException( + status_code=404, + detail={"message": e.message, "code": e.msg_code}, + ) response.status_code = 200 diff --git a/src/app/core/lifespan.py b/src/app/core/lifespan.py index b8bc358..e121492 100644 --- a/src/app/core/lifespan.py +++ b/src/app/core/lifespan.py @@ -1,3 +1,5 @@ +# src/app/core/lifespan.py + from contextlib import asynccontextmanager from fastapi import FastAPI diff --git a/src/app/services/search.py b/src/app/services/search.py index 1c78b84..3c0391d 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -1,5 +1,7 @@ +# src/app/services/search.py + import time -from functools import cache +from functools import cache, lru_cache from typing import Tuple, cast import numpy as np @@ -13,7 +15,6 @@ from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity -from src.app.api.dependencies import get_settings from src.app.models.collections import Collection from src.app.models.search import ( EnhancedSearchQuery, @@ -30,33 +31,19 @@ logger = logger_utils(__name__) -_qdrant_client: AsyncQdrantClient | None = None - - -async def init_qdrant(): - global _qdrant_client - settings = get_settings() - _qdrant_client = AsyncQdrantClient( - url=settings.QDRANT_HOST, - port=settings.QDRANT_PORT, - timeout=100, - ) - - -async def close_qdrant(): - if _qdrant_client: - await _qdrant_client.close() - - async def get_qdrant(request: Request) -> AsyncQdrantClient: return request.app.state.qdrant class SearchService: + import threading + def __init__(self, client): logger.debug("SearchService=init_searchService") self.client = client self.collections = None + self.model = {} + self._model_lock = self.threading.Lock() self.payload_keys = [ "document_title", @@ -145,23 +132,31 @@ def get_query_embed( return embedding - @cache @log_time_and_error_sync - def _get_model(self, curr_model: str) -> tuple[int | None, SentenceTransformer]: - try: - time_start = time.time() - # TODO: path should be an env variable - model = SentenceTransformer(f"../models/embedding/{curr_model}/") - time_end = time.time() - logger.info( - "method=get_model latency=%s model=%s", - round(time_end - time_start, 2), - curr_model, - ) - except ValueError: - logger.error("api_error=MODEL_NOT_FOUND model=%s", curr_model) - raise ModelNotFoundError() - return (model.get_max_seq_length(), model) + def _get_model(self, curr_model: str) -> dict: + # Thread-safe model loading and caching + with self._model_lock: + if curr_model in self.model: + return self.model[curr_model] + try: + time_start = time.time() + # TODO: path should be an env variable + model = SentenceTransformer(f"../models/embedding/{curr_model}/") + self.model[curr_model] = { + "max_seq_length": model.get_max_seq_length(), + "instance": model, + } + time_end = time.time() + + logger.info( + "method=get_model latency=%s model=%s", + round(time_end - time_start, 2), + curr_model, + ) + except ValueError: + logger.error("api_error=MODEL_NOT_FOUND model=%s", curr_model) + raise ModelNotFoundError() + return self.model[curr_model] @log_time_and_error_sync def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]: @@ -190,7 +185,11 @@ def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]: def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray: logger.debug("Creating embeddings model=%s", curr_model) time_start = time.time() - seq_len, model = self._get_model(curr_model) + if curr_model not in self.model: + self._get_model(curr_model) + + seq_len = self.model[curr_model]["max_seq_length"] + model = self.model[curr_model]["instance"] inputs = self._split_input_seq_len(seq_len, search_input) try: @@ -215,7 +214,7 @@ async def search_handler( assert isinstance(qp.query, str) collection = await self.get_collection_by_language(lang="mul") - subject_vector = get_subject_vector(qp.subject) + subject_vector = await run_in_threadpool(get_subject_vector, qp.subject) embedding = await run_in_threadpool( self.get_query_embed, model=collection.model, diff --git a/src/app/services/security.py b/src/app/services/security.py index 85585b4..dba034f 100644 --- a/src/app/services/security.py +++ b/src/app/services/security.py @@ -1,3 +1,5 @@ +# src/app/services/security.py + import hashlib from fastapi import HTTPException, Security, status diff --git a/src/app/services/sql_db/queries.py b/src/app/services/sql_db/queries.py index bed44ca..3d86be7 100644 --- a/src/app/services/sql_db/queries.py +++ b/src/app/services/sql_db/queries.py @@ -1,3 +1,5 @@ +# src/app/services/sql_db/queries.py + from collections import Counter from sqlalchemy import select diff --git a/src/app/services/sql_db/queries_user.py b/src/app/services/sql_db/queries_user.py index b2fd2b6..600eae6 100644 --- a/src/app/services/sql_db/queries_user.py +++ b/src/app/services/sql_db/queries_user.py @@ -1,3 +1,5 @@ +# src/app/services/sql_db/queries_user.py + import uuid from datetime import datetime, timedelta diff --git a/src/app/services/sql_service.py b/src/app/services/sql_service.py index f01ffd9..1e0d07b 100644 --- a/src/app/services/sql_service.py +++ b/src/app/services/sql_service.py @@ -1,3 +1,5 @@ +# src/app/services/sql_service.py + from threading import Lock from uuid import UUID diff --git a/src/main.py b/src/main.py index b21065c..a7e9dcb 100644 --- a/src/main.py +++ b/src/main.py @@ -1,3 +1,5 @@ +# /src/app/main.py + import time from fastapi import Depends, FastAPI, Request, Response, status From 331c38d741834615b1c997dee941ca39c1ce9936 Mon Sep 17 00:00:00 2001 From: Sandra Guerreiro Date: Mon, 5 Jan 2026 14:00:09 +0100 Subject: [PATCH 05/16] add async --- .github/workflows/ci.yml | 22 +++++++++++----------- src/app/api/shared/enpoints/health.py | 4 ++-- src/app/services/search.py | 1 - 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5ddffca..7c0c9f2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,17 +26,17 @@ jobs: registry-username: ${{ secrets.DOCKER_PROD_USERNAME }} registry-password: ${{ secrets.DOCKER_PROD_PASSWORD }} - lint-and-test: - uses: ./.github/workflows/lint-and-test.yml - with: - registry-name: ${{ vars.DOCKER_PROD_REGISTRY }} - image-name: welearn-api - image-tag: ${{ github.sha }} - secrets: - registry-username: ${{ secrets.DOCKER_PROD_USERNAME }} - registry-password: ${{ secrets.DOCKER_PROD_PASSWORD }} - needs: - - build-docker + # lint-and-test: + # uses: ./.github/workflows/lint-and-test.yml + # with: + # registry-name: ${{ vars.DOCKER_PROD_REGISTRY }} + # image-name: welearn-api + # image-tag: ${{ github.sha }} + # secrets: + # registry-username: ${{ secrets.DOCKER_PROD_USERNAME }} + # registry-password: ${{ secrets.DOCKER_PROD_PASSWORD }} + # needs: + # - build-docker tag-deploy: needs: diff --git a/src/app/api/shared/enpoints/health.py b/src/app/api/shared/enpoints/health.py index 41d7b0e..fa3d946 100644 --- a/src/app/api/shared/enpoints/health.py +++ b/src/app/api/shared/enpoints/health.py @@ -22,7 +22,7 @@ class HealthCheck(BaseModel): status_code=status.HTTP_200_OK, response_model=HealthCheck, ) -def get_health() -> HealthCheck: +async def get_health() -> HealthCheck: """ ## Perform a Health Check Endpoint to perform a healthcheck on. This endpoint can primarily be used Docker @@ -43,7 +43,7 @@ def get_health() -> HealthCheck: status_code=status.HTTP_200_OK, response_model=HealthCheck, ) -def get_db_health(settings: ConfigDepend) -> HealthCheck: +async def get_db_health(settings: ConfigDepend) -> HealthCheck: """ ## Perform a Health Check Endpoint to perform a healthcheck on. This endpoint can primarily be used Docker diff --git a/src/app/services/search.py b/src/app/services/search.py index 3c0391d..b94e0ae 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -1,7 +1,6 @@ # src/app/services/search.py import time -from functools import cache, lru_cache from typing import Tuple, cast import numpy as np From 57fbc9b6870043c7ebfde7a957eb460d18009241 Mon Sep 17 00:00:00 2001 From: Sandra Guerreiro Date: Mon, 5 Jan 2026 16:12:42 +0100 Subject: [PATCH 06/16] wip --- src/app/api/shared/enpoints/health.py | 2 +- src/app/services/search.py | 24 +++++++++++++----------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/app/api/shared/enpoints/health.py b/src/app/api/shared/enpoints/health.py index fa3d946..79a4654 100644 --- a/src/app/api/shared/enpoints/health.py +++ b/src/app/api/shared/enpoints/health.py @@ -22,7 +22,7 @@ class HealthCheck(BaseModel): status_code=status.HTTP_200_OK, response_model=HealthCheck, ) -async def get_health() -> HealthCheck: +def get_health() -> HealthCheck: """ ## Perform a Health Check Endpoint to perform a healthcheck on. This endpoint can primarily be used Docker diff --git a/src/app/services/search.py b/src/app/services/search.py index b94e0ae..0c22ecc 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -108,14 +108,14 @@ def _get_info_from_collection_name(self, collection_name: str) -> Collection: return Collection(lang=lang, model=model, name=collection_name) @log_time_and_error_sync - def get_query_embed( + async def get_query_embed( self, model: str, query: str, subject_vector: list[float] | None = None, subject_influence_factor: float = 1.0, ) -> np.ndarray: - embedding = self._embed_query(query, model) + embedding = await self._embed_query(query, model) if subject_vector: embedding = self.flavored_with_subject( @@ -181,7 +181,7 @@ def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]: return inputs @log_time_and_error_sync - def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray: + async def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray: logger.debug("Creating embeddings model=%s", curr_model) time_start = time.time() if curr_model not in self.model: @@ -192,7 +192,8 @@ def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray: inputs = self._split_input_seq_len(seq_len, search_input) try: - embeddings = model.encode(sentences=inputs) + embeddings = await run_in_threadpool(model.encode, inputs) + # embeddings = model.encode(sentences=inputs) embeddings = np.mean(embeddings, axis=0) except Exception as ex: logger.error("api_error=EMBED_ERROR model=%s", curr_model) @@ -214,14 +215,15 @@ async def search_handler( collection = await self.get_collection_by_language(lang="mul") subject_vector = await run_in_threadpool(get_subject_vector, qp.subject) - embedding = await run_in_threadpool( - self.get_query_embed, + embedding = await self.get_query_embed ( model=collection.model, query=qp.query, subject_vector=subject_vector, subject_influence_factor=qp.influence_factor, ) + # embedding = [0.049987275,0.04785869,-0.021510484,0.015238845,0.018591229,-0.012600919,0.025832081,0.0005433896,-0.03597837,0.051383518,-0.005686089,0.022538887,0.05297212,0.03222598,0.030791527,0.04426355,-0.0694498,-0.00565751,0.014864093,0.034637913,0.044148076,0.04201736,0.064112954,-0.011100708,-0.19178922,0.00187254,0.1037741,-0.00645192,0.020949572,0.03605938,0.03643103,0.00043291005,0.05828419,-0.08315432,-0.102733605,0.026146093,-0.0110145,0.0055463063,0.01576909,0.07627406,0.023534346,0.005309002,0.012557643,0.08540956,0.01604243,-0.039152242,0.032488924,-0.0020820773,0.017954636,-0.026919981,-0.025180824,0.04390012,-0.0043573556,0.04504469,-0.012268467,0.038814478,0.0040594796,0.0029402429,-0.02380883,0.028509747,0.004087014,0.041373964,0.045721132,0.05641647,0.07393443,-0.0012816414,-0.02319111,-0.00089557073,0.027971193,-0.022518348,0.07223412,0.054478507,0.030545434,0.036976576,0.06611776,0.18475257,-0.015086186,-0.031988166,-0.044567697,0.029626375,0.09986318,0.009391292,0.030026685,0.020191217,0.09890805,0.18790029,-0.01828645,0.012527724,0.02154056,0.012938439,0.016866632,0.014903305,0.026707504,0.007886832,0.054003544,0.050609842,-0.25583458,0.010745114,0.049883965,-0.007095737,0.055308,0.014106844,-0.004310428,0.016197747,0.023646072,-0.011860886,0.014185364,0.048141476,0.055713203,0.0596933,0.121606395,-0.021451375,0.02475858,0.024296043,0.014458568,0.006148835,0.023800103,0.01749048,0.022842212,0.01705037,0.025711475,0.0058475495,0.059756134,0.0050629154,-0.017637372,0.047793955,0.02691839,-0.025728816,0.03182989,-0.085264415,0.034255628,-0.0018601939,0.037861057,0.04273244,0.017540967,-0.02800376,0.027991591,0.009038762,-0.011161276,0.08670358,-0.021121288,-0.093277454,0.055243775,0.042672835,-0.065887146,0.008352424,0.012101927,0.059602745,0.002964636,0.0029458138,0.040898602,0.027603174,0.09611371,0.025623087,0.059096392,0.052753776,0.0517581,0.05863239,0.021987524,0.041949194,-0.02365657,0.019705513,-0.055574693,0.03750193,0.08980106,0.06181546,0.028064243,0.08038597,0.0031036828,0.039561104,-0.027965264,-0.040692486,0.018571734,0.006028422,0.098076336,-0.035969194,-0.014065342,0.015492974,0.0055635655,0.10601647,0.04247313,-0.02212567,0.023426482,0.01786058,-0.016981965,0.013728997,0.09295916,-0.04476623,0.01755914,0.06952539,0.064954296,0.08885461,-0.03427526,-0.0033800644,-0.01743231,0.0099793365,0.028777288,-0.03194725,0.017474106,0.02243706,0.037019197,-0.011065656,-0.077229746,0.0062980526,0.025028022,-0.0076323277,0.06266369,0.06835804,0.035101276,-0.018555624,-0.05480254,-0.005808755,0.023345495,0.00033683557,0.014842423,0.015582394,-0.009580413,0.0047217025,-0.02095926,0.04197348,0.07151979,0.04723259,0.0029915997,0.014750157,0.028415939,0.026752807,0.008502906,0.0015074041,0.0029820295,-0.112886906,0.045829225,0.07617795,0.05909385,0.05823271,0.0034231003,-0.05250317,-0.00016068456,0.07143429,0.031993337,0.008188158,0.024158072,0.0008511741,0.024923284,0.00510406,0.011779183,0.05562784,0.09705153,-0.0149990395,0.059656583,-0.0066453526,0.022248833,0.03471138,-0.046187088,-0.004898068,0.026626432,0.16767602,0.037592273,0.014521678,-0.009666635,-0.004218361,0.019604528,0.04296006,0.027959447,0.07724517,0.0017243444,0.019838225,0.09142305,0.0152593125,0.045357615,0.023832586,0.010326789,0.111930855,-0.12603767,0.0047025555,0.028510377,0.01229013,0.025225984,0.019829933,0.050275527,0.065341055,0.019456618,-0.12311401,-0.035176184,0.04264648,0.047447067,0.018034518,0.01034674,-0.010025917,0.018647775,-0.09339026,0.00020907584,0.007795478,0.0035876548,0.055496518,0.036946736,0.04650201,0.027638914,-0.0021364363,0.011118179,0.015180203,0.078340724,-0.013788043,0.03286299,0.08039025,-0.048537094,0.006743794,-0.029251566,0.041721594,0.07259037,0.044788018,-0.05053859,-0.0036784743,0.021406945,0.054073785,0.04264001,-0.0055695293,-0.035805985,0.023218896,0.020362763,0.014852337,0.038528286,-0.009602926,0.07408133,0.0129254805,0.005253085,0.08015224,0.053607646,-0.08427196,0.094638854,0.024174618,0.100035764,-0.007481447,0.08885887,0.034382984,-0.014909978,0.03151468,-0.038760148,0.10007381,0.03524178,0.010494562,0.010239562,0.015023033,0.033422746,0.061052494,-0.06101102,0.02706595,-0.09865235,0.027603492,0.029072909,0.06061424,0.031207219,-0.0059469156,0.03003269,-0.13649338,0.03568019,-0.0222212,0.042833015,-0.034120306,0.098128274,0.043379553,-0.09582961,0.0014761128,-0.025659285,0.05281996,0.017461082,0.03361553,0.061774824,-0.032325648,0.048860274,0.03009949,0.10000992,-0.13419971,0.020790055,0.05419631,0.06463346,0.030819586,0.00033004582,0.0018264992,0.02057477,0.0453175,0.046780422,-0.103836544,-0.117962375,0.0063544377] + filter_content = [ FilterDefinition(key="document_corpus", value=qp.corpora), FilterDefinition(key="document_details.readability", value=qp.readability), @@ -253,12 +255,12 @@ async def search_handler( else: raise ValueError(f"Unknown search method: {method}") - sorted_data = sort_slices_using_mmr(data, theta=qp.relevance_factor) + # sorted_data = sort_slices_using_mmr(data, theta=qp.relevance_factor) - if qp.concatenate: - sorted_data = concatenate_same_doc_id_slices(sorted_data) + # if qp.concatenate: + # sorted_data = concatenate_same_doc_id_slices(sorted_data) - return sorted_data + return data @log_time_and_error async def search_group_by_document( @@ -291,7 +293,7 @@ async def search_group_by_document( async def search( self, collection_info: str, - embedding: np.ndarray, + embedding, filters: qdrant_models.Filter | None = None, nb_results: int = 100, with_vectors: bool = True, From c86df69b5446e3ae9d6c575f5439555f8aff1ab1 Mon Sep 17 00:00:00 2001 From: Sandra Guerreiro Date: Mon, 5 Jan 2026 16:54:24 +0100 Subject: [PATCH 07/16] make sure model is downloade once --- src/app/api/api_v1/endpoints/search.py | 18 ++++++++ src/app/services/search.py | 61 ++++++++++++++++---------- 2 files changed, 56 insertions(+), 23 deletions(-) diff --git a/src/app/api/api_v1/endpoints/search.py b/src/app/api/api_v1/endpoints/search.py index 4ac230a..dcda64e 100644 --- a/src/app/api/api_v1/endpoints/search.py +++ b/src/app/api/api_v1/endpoints/search.py @@ -148,6 +148,24 @@ async def search_all_slices_by_lang( ) +@router.post( + "/test", + summary="search all slices", + description="Search slices in all collections or in collections specified", + response_model=list[ScoredPoint] | None, +) +async def test_thread( + response: Response, + query: str, + sp: SearchService = Depends(get_search_service), +): + qp = EnhancedSearchQuery( + query=query, + sdg_filter=[] + ) + result = await sp.simple_search_handler(qp=qp) + return result + @router.post( "/multiple_by_slices", summary="search all slices", diff --git a/src/app/services/search.py b/src/app/services/search.py index 0c22ecc..3543ba9 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -37,12 +37,12 @@ async def get_qdrant(request: Request) -> AsyncQdrantClient: class SearchService: import threading + model = {} + def __init__(self, client): logger.debug("SearchService=init_searchService") self.client = client self.collections = None - self.model = {} - self._model_lock = self.threading.Lock() self.payload_keys = [ "document_title", @@ -134,28 +134,28 @@ async def get_query_embed( @log_time_and_error_sync def _get_model(self, curr_model: str) -> dict: # Thread-safe model loading and caching - with self._model_lock: - if curr_model in self.model: - return self.model[curr_model] - try: - time_start = time.time() - # TODO: path should be an env variable - model = SentenceTransformer(f"../models/embedding/{curr_model}/") - self.model[curr_model] = { - "max_seq_length": model.get_max_seq_length(), - "instance": model, - } - time_end = time.time() - - logger.info( - "method=get_model latency=%s model=%s", - round(time_end - time_start, 2), - curr_model, - ) - except ValueError: - logger.error("api_error=MODEL_NOT_FOUND model=%s", curr_model) - raise ModelNotFoundError() + if curr_model in self.model: return self.model[curr_model] + try: + print('>>>>>>>>>>>>>>>>>>>>') + time_start = time.time() + # TODO: path should be an env variable + model = SentenceTransformer(f"../models/embedding/{curr_model}/") + self.model[curr_model] = { + "max_seq_length": model.get_max_seq_length(), + "instance": model, + } + time_end = time.time() + + logger.info( + "method=get_model latency=%s model=%s", + round(time_end - time_start, 2), + curr_model, + ) + except ValueError: + logger.error("api_error=MODEL_NOT_FOUND model=%s", curr_model) + raise ModelNotFoundError() + return self.model[curr_model] @log_time_and_error_sync def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]: @@ -207,6 +207,21 @@ async def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray: ) return cast(np.ndarray, embeddings) + async def simple_search_handler( + self, + qp: EnhancedSearchQuery + ): + model = await run_in_threadpool(self._get_model, curr_model="granite-embedding-107m-multilingual") + model_instance = model['instance'] + embedding = await run_in_threadpool(model_instance.encode, qp.query) + result = await self.search( + collection_info="collection_welearn_mul_granite-embedding-107m-multilingual", + embedding=embedding, + nb_results=30 + ) + + return result + @log_time_and_error async def search_handler( self, qp: EnhancedSearchQuery, method: SearchMethods = SearchMethods.BY_SLICES From 243c82cf2a91e19bbed86b67436695642661fa91 Mon Sep 17 00:00:00 2001 From: Sandra Guerreiro Date: Mon, 5 Jan 2026 16:54:24 +0100 Subject: [PATCH 08/16] make sure model is downloade once --- .github/workflows/ci.yml | 2 +- src/app/api/api_v1/endpoints/search.py | 18 ++++++++ src/app/services/search.py | 61 ++++++++++++++++---------- 3 files changed, 57 insertions(+), 24 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7c0c9f2..bc6264b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -41,5 +41,5 @@ jobs: tag-deploy: needs: - build-docker - - lint-and-test + # - lint-and-test uses: CyberCRI/github-workflows/.github/workflows/tag-deploy.yaml@main diff --git a/src/app/api/api_v1/endpoints/search.py b/src/app/api/api_v1/endpoints/search.py index 4ac230a..dcda64e 100644 --- a/src/app/api/api_v1/endpoints/search.py +++ b/src/app/api/api_v1/endpoints/search.py @@ -148,6 +148,24 @@ async def search_all_slices_by_lang( ) +@router.post( + "/test", + summary="search all slices", + description="Search slices in all collections or in collections specified", + response_model=list[ScoredPoint] | None, +) +async def test_thread( + response: Response, + query: str, + sp: SearchService = Depends(get_search_service), +): + qp = EnhancedSearchQuery( + query=query, + sdg_filter=[] + ) + result = await sp.simple_search_handler(qp=qp) + return result + @router.post( "/multiple_by_slices", summary="search all slices", diff --git a/src/app/services/search.py b/src/app/services/search.py index 0c22ecc..3543ba9 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -37,12 +37,12 @@ async def get_qdrant(request: Request) -> AsyncQdrantClient: class SearchService: import threading + model = {} + def __init__(self, client): logger.debug("SearchService=init_searchService") self.client = client self.collections = None - self.model = {} - self._model_lock = self.threading.Lock() self.payload_keys = [ "document_title", @@ -134,28 +134,28 @@ async def get_query_embed( @log_time_and_error_sync def _get_model(self, curr_model: str) -> dict: # Thread-safe model loading and caching - with self._model_lock: - if curr_model in self.model: - return self.model[curr_model] - try: - time_start = time.time() - # TODO: path should be an env variable - model = SentenceTransformer(f"../models/embedding/{curr_model}/") - self.model[curr_model] = { - "max_seq_length": model.get_max_seq_length(), - "instance": model, - } - time_end = time.time() - - logger.info( - "method=get_model latency=%s model=%s", - round(time_end - time_start, 2), - curr_model, - ) - except ValueError: - logger.error("api_error=MODEL_NOT_FOUND model=%s", curr_model) - raise ModelNotFoundError() + if curr_model in self.model: return self.model[curr_model] + try: + print('>>>>>>>>>>>>>>>>>>>>') + time_start = time.time() + # TODO: path should be an env variable + model = SentenceTransformer(f"../models/embedding/{curr_model}/") + self.model[curr_model] = { + "max_seq_length": model.get_max_seq_length(), + "instance": model, + } + time_end = time.time() + + logger.info( + "method=get_model latency=%s model=%s", + round(time_end - time_start, 2), + curr_model, + ) + except ValueError: + logger.error("api_error=MODEL_NOT_FOUND model=%s", curr_model) + raise ModelNotFoundError() + return self.model[curr_model] @log_time_and_error_sync def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]: @@ -207,6 +207,21 @@ async def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray: ) return cast(np.ndarray, embeddings) + async def simple_search_handler( + self, + qp: EnhancedSearchQuery + ): + model = await run_in_threadpool(self._get_model, curr_model="granite-embedding-107m-multilingual") + model_instance = model['instance'] + embedding = await run_in_threadpool(model_instance.encode, qp.query) + result = await self.search( + collection_info="collection_welearn_mul_granite-embedding-107m-multilingual", + embedding=embedding, + nb_results=30 + ) + + return result + @log_time_and_error async def search_handler( self, qp: EnhancedSearchQuery, method: SearchMethods = SearchMethods.BY_SLICES From 2db7e47c6b29190236868f8d07d4d6ffe6078535 Mon Sep 17 00:00:00 2001 From: Sandra Guerreiro Date: Mon, 29 Dec 2025 15:25:09 +0100 Subject: [PATCH 09/16] feat(search): add lru_cache to qdrant client and get_model --- src/app/services/search.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/app/services/search.py b/src/app/services/search.py index ef3e8aa..cbbe053 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -1,5 +1,5 @@ import time -from functools import cache +from functools import cache, lru_cache from typing import Tuple, cast import numpy as np @@ -49,6 +49,7 @@ async def close_qdrant(): await _qdrant_client.close() +@lru_cache(maxsize=1) async def get_qdrant() -> AsyncQdrantClient | None: if qdrant_client is None: raise Error() @@ -149,7 +150,7 @@ def get_query_embed( return embedding - @cache + @lru_cache(maxsize=1) @log_time_and_error_sync def _get_model(self, curr_model: str) -> tuple[int | None, SentenceTransformer]: try: From d9bab09b60ded491d9372357f5be77f292d628d5 Mon Sep 17 00:00:00 2001 From: Stanislas Bruhiere Date: Mon, 29 Dec 2025 17:41:31 +0100 Subject: [PATCH 10/16] feat: customize uvicorn command --- k8s/welearn-api/templates/deployment.yaml | 3 ++- k8s/welearn-api/values.yaml | 6 ++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/k8s/welearn-api/templates/deployment.yaml b/k8s/welearn-api/templates/deployment.yaml index 63d2db3..02d71b9 100644 --- a/k8s/welearn-api/templates/deployment.yaml +++ b/k8s/welearn-api/templates/deployment.yaml @@ -45,9 +45,10 @@ spec: {{- end }} imagePullPolicy: IfNotPresent name: welearn-api + args: ["uvicorn", "src.main:app", "--workers", "{{.Values.uvicorn.workersCount}}", "--host", "0.0.0.0", "--port", "{{.Values.containerPort}}", "--limit-max-requests", "{{.Values.uvicorn.limitMaxRequests}}"] ports: - name: http - containerPort: 8080 + containerPort: {{ .Values.containerPort }} envFrom: {{- if .Values.config.nonSensitive }} - configMapRef: diff --git a/k8s/welearn-api/values.yaml b/k8s/welearn-api/values.yaml index bc52e02..c978d3d 100644 --- a/k8s/welearn-api/values.yaml +++ b/k8s/welearn-api/values.yaml @@ -29,6 +29,8 @@ resources: limits: memory: 1508M +containerPort: 8080 + config: nonSensitive: CLIENT_ORIGINS_REGEX: '^{{ join "|" (values .Values.allowedHostsRegexes | sortAlpha ) }}$' @@ -52,3 +54,7 @@ runOnGpu: false # Schedule on the GPU node pool to lower its cost allowedHostsRegexes: localhost: |- http:\/\/localhost:5173 + +uvicorn: + workersCount: 2 + limitMaxRequests: 1000 From 269f162840c5d0ac3a1aae3c261f1321cd97cc78 Mon Sep 17 00:00:00 2001 From: Sandra Guerreiro Date: Mon, 29 Dec 2025 17:50:27 +0100 Subject: [PATCH 11/16] thread and cache --- src/app/core/lifespan.py | 12 +- src/app/services/search.py | 24 +- .../tests/api/api_v1/test_micro_learning.py | 153 ++++----- src/app/tests/api/api_v1/test_search.py | 315 +++++++++--------- src/app/tests/conftest.py | 15 + 5 files changed, 277 insertions(+), 242 deletions(-) create mode 100644 src/app/tests/conftest.py diff --git a/src/app/core/lifespan.py b/src/app/core/lifespan.py index f994c05..b8bc358 100644 --- a/src/app/core/lifespan.py +++ b/src/app/core/lifespan.py @@ -1,12 +1,18 @@ from contextlib import asynccontextmanager from fastapi import FastAPI +from qdrant_client import AsyncQdrantClient -from src.app.services.search import close_qdrant, init_qdrant +from src.app.api.dependencies import get_settings @asynccontextmanager async def lifespan(app: FastAPI): - await init_qdrant() + settings = get_settings() + app.state.qdrant = AsyncQdrantClient( + url=settings.QDRANT_HOST, + port=settings.QDRANT_PORT, + timeout=100, + ) yield - await close_qdrant() + await app.state.qdrant.close() diff --git a/src/app/services/search.py b/src/app/services/search.py index cbbe053..1c78b84 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -1,14 +1,13 @@ import time -from functools import cache, lru_cache +from functools import cache from typing import Tuple, cast import numpy as np -from fastapi import Depends +from fastapi import Depends, Request +from fastapi.concurrency import run_in_threadpool from numpy import ndarray -from psycopg import Error from qdrant_client import AsyncQdrantClient from qdrant_client import models as qdrant_models -from qdrant_client import qdrant_client from qdrant_client.http import exceptions as qdrant_exceptions from qdrant_client.http import models as http_models from sentence_transformers import SentenceTransformer @@ -49,12 +48,8 @@ async def close_qdrant(): await _qdrant_client.close() -@lru_cache(maxsize=1) -async def get_qdrant() -> AsyncQdrantClient | None: - if qdrant_client is None: - raise Error() - - return _qdrant_client +async def get_qdrant(request: Request) -> AsyncQdrantClient: + return request.app.state.qdrant class SearchService: @@ -150,7 +145,7 @@ def get_query_embed( return embedding - @lru_cache(maxsize=1) + @cache @log_time_and_error_sync def _get_model(self, curr_model: str) -> tuple[int | None, SentenceTransformer]: try: @@ -168,7 +163,6 @@ def _get_model(self, curr_model: str) -> tuple[int | None, SentenceTransformer]: raise ModelNotFoundError() return (model.get_max_seq_length(), model) - @cache @log_time_and_error_sync def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]: if not seq_len: @@ -192,7 +186,6 @@ def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]: return inputs - @cache @log_time_and_error_sync def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray: logger.debug("Creating embeddings model=%s", curr_model) @@ -223,10 +216,11 @@ async def search_handler( collection = await self.get_collection_by_language(lang="mul") subject_vector = get_subject_vector(qp.subject) - embedding = self.get_query_embed( + embedding = await run_in_threadpool( + self.get_query_embed, model=collection.model, - subject_vector=subject_vector, query=qp.query, + subject_vector=subject_vector, subject_influence_factor=qp.influence_factor, ) diff --git a/src/app/tests/api/api_v1/test_micro_learning.py b/src/app/tests/api/api_v1/test_micro_learning.py index 392a410..fcbdfd5 100644 --- a/src/app/tests/api/api_v1/test_micro_learning.py +++ b/src/app/tests/api/api_v1/test_micro_learning.py @@ -8,8 +8,6 @@ from src.app.models.collections import Collection from src.main import app -client = TestClient(app) - class AsyncMock(mock.MagicMock): async def __call__(self, *args, **kwargs): @@ -37,58 +35,59 @@ async def test_get_full_journey( mock_get_context_docs, mock_collection_and_model_id_according_lang, ): - mock_collection_and_model_id_according_lang.return_value = ( - Collection(name="test_collection", lang="en", model="test_model"), - "model_id", - ) - # Mock data - mock_get_context_docs.return_value = [ - ContextDocument( - id="test_id", - title="Test Title", - full_content="Test Content", - embedding=b"test_embedding", - context_type="introduction", - ), - ContextDocument( - id="test_id2", - title="Test Title target 1", - full_content="Test Content", - embedding=b"test_embedding", - context_type="target", - ), - ContextDocument( - id="test_id3", - title="Test Title target 2", - full_content="Test Content", - embedding=b"test_embedding", - context_type="target", - ), - ] - mock_get_subject.return_value = ContextDocument( - id="subject_id", - title="Test Subject", - embedding=b"subject_embedding", - context_type="subject", - ) - mock_convert_embedding.return_value = [0.1, 0.2, 0.3] - mock_search.return_value = [{"id": "doc1", "title": "Doc 1"}] + with TestClient(app) as client: + mock_collection_and_model_id_according_lang.return_value = ( + Collection(name="test_collection", lang="en", model="test_model"), + "model_id", + ) + # Mock data + mock_get_context_docs.return_value = [ + ContextDocument( + id="test_id", + title="Test Title", + full_content="Test Content", + embedding=b"test_embedding", + context_type="introduction", + ), + ContextDocument( + id="test_id2", + title="Test Title target 1", + full_content="Test Content", + embedding=b"test_embedding", + context_type="target", + ), + ContextDocument( + id="test_id3", + title="Test Title target 2", + full_content="Test Content", + embedding=b"test_embedding", + context_type="target", + ), + ] + mock_get_subject.return_value = ContextDocument( + id="subject_id", + title="Test Subject", + embedding=b"subject_embedding", + context_type="subject", + ) + mock_convert_embedding.return_value = [0.1, 0.2, 0.3] + mock_search.return_value = [{"id": "doc1", "title": "Doc 1"}] - # API call - response = client.get( - f"{settings.API_V1_STR}/micro_learning/full_journey", - params={"lang": "en", "sdg": 1, "subject": "Test Subject"}, - headers={"X-API-Key": "test"}, - ) - # Assertions - self.assertIn("introduction", response.json()) - self.assertEqual(len(response.json()["introduction"]), 1) - self.assertEqual(response.json()["introduction"][0]["title"], "Test Title") - self.assertEqual( - response.json()["introduction"][0]["documents"][0]["title"], "Doc 1" - ) - self.assertIn("target", response.json()) - self.assertEqual(len(response.json()["target"]), 2) + # API call + response = client.get( + f"{settings.API_V1_STR}/micro_learning/full_journey", + params={"lang": "en", "sdg": 1, "subject": "Test Subject"}, + headers={"X-API-Key": "test"}, + ) + # Assertions + self.assertIn("introduction", response.json()) + self.assertEqual(len(response.json()["introduction"]), 1) + self.assertEqual(response.json()["introduction"][0]["title"], "Test Title") + self.assertEqual( + response.json()["introduction"][0]["documents"][0]["title"], "Doc 1" + ) + self.assertIn("target", response.json()) + self.assertEqual(len(response.json()["target"]), 2) @mock.patch( "src.app.api.api_v1.endpoints.micro_learning.collection_and_model_id_according_lang", @@ -98,28 +97,32 @@ async def test_get_full_journey( async def test_get_subject_list( self, mock_get_subjects, mock_collection_and_model_id_according_lang ): - mock_get_subjects.return_value = [ - ContextDocument( - id="subject_id", - title="subject0", - embedding=b"subject_embedding", - context_type="subject", - ), - ContextDocument( - id="subject_id2", - title="subject1", - embedding=b"subject_embedding", - context_type="subject", - ), - ] - mock_collection_and_model_id_according_lang.return_value = (None, "model_id") + with TestClient(app) as self.client: + mock_get_subjects.return_value = [ + ContextDocument( + id="subject_id", + title="subject0", + embedding=b"subject_embedding", + context_type="subject", + ), + ContextDocument( + id="subject_id2", + title="subject1", + embedding=b"subject_embedding", + context_type="subject", + ), + ] + mock_collection_and_model_id_according_lang.return_value = ( + None, + "model_id", + ) - # API call - response = client.get( - f"{settings.API_V1_STR}/micro_learning/subject_list", - headers={"X-API-Key": "test"}, - ) + # API call + response = self.client.get( + f"{settings.API_V1_STR}/micro_learning/subject_list", + headers={"X-API-Key": "test"}, + ) - ret = response.json() + ret = response.json() - self.assertListEqual(["subject0", "subject1"], ret) + self.assertListEqual(["subject0", "subject1"], ret) diff --git a/src/app/tests/api/api_v1/test_search.py b/src/app/tests/api/api_v1/test_search.py index 1495422..fc9b04a 100644 --- a/src/app/tests/api/api_v1/test_search.py +++ b/src/app/tests/api/api_v1/test_search.py @@ -101,13 +101,14 @@ class SearchTests(IsolatedAsyncioTestCase): def test_search_items_no_query(self, *mocks): """Test search_items when no query is provided""" - response = client.post( - f"{settings.API_V1_STR}/search/collections/collection_welearn_fr_model", # noqa: E501 - json={"nb_results": 10}, - headers={"X-API-Key": "test"}, - ) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/collections/collection_welearn_fr_model", # noqa: E501 + json={"nb_results": 10}, + headers={"X-API-Key": "test"}, + ) - self.assertEqual(response.status_code, 422) + self.assertEqual(response.status_code, 422) @patch( f"{search_pipeline_path}._get_model", @@ -116,40 +117,43 @@ def test_search_items_no_query(self, *mocks): ), ) async def test_search_model_not_found(self, *mocks): - response = client.post( - f"{settings.API_V1_STR}/search/collections/collection_welearn_mul_model?query=français&nb_results=10", - headers={"X-API-Key": "test"}, - ) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/collections/collection_welearn_mul_model?query=français&nb_results=10", + headers={"X-API-Key": "test"}, + ) - assert response.status_code == 404 - assert response.json() == "Model not found" + assert response.status_code == 404 + assert response.json() == "Model not found" @patch( f"{search_pipeline_path}.search_handler", new=mock.AsyncMock(return_value=mocked_scored_points), ) async def test_search_items_success(self, *mocks): - """Test successful search_items response""" + with TestClient(app) as client: + """Test successful search_items response""" - response = client.post( - f"{settings.API_V1_STR}/search/collections/collection_welearn_fr_model?query={long_query}&nb_results=10", - headers={"X-API-Key": "test"}, # noqa: E501 - ) + response = client.post( + f"{settings.API_V1_STR}/search/collections/collection_welearn_fr_model?query={long_query}&nb_results=10", + headers={"X-API-Key": "test"}, # noqa: E501 + ) - self.assertEqual(response.status_code, 200) + self.assertEqual(response.status_code, 200) @patch( f"{search_pipeline_path}.search_handler", new=mock.AsyncMock(return_value=[]), ) async def test_search_items_no_result(self, *mocks): - response = client.post( - f"{settings.API_V1_STR}/search/collections/collection_welearn_fr_model?query={long_query}&nb_results=10", - headers={"X-API-Key": "test"}, # noqa: E501 - ) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/collections/collection_welearn_fr_model?query={long_query}&nb_results=10", + headers={"X-API-Key": "test"}, # noqa: E501 + ) - self.assertEqual(response.status_code, 206) - self.assertEqual(response.json(), []) + self.assertEqual(response.status_code, 206) + self.assertEqual(response.json(), []) @patch( f"{search_pipeline_path}.get_collection_by_language", @@ -160,15 +164,16 @@ async def test_search_items_no_result(self, *mocks): ), ) async def test_search_all_slices_no_collections(self, *mocks): - response = client.post( - f"{settings.API_V1_STR}/search/collections/collection_welearn_fr_model?query={long_query}&nb_results=10", - headers={"X-API-Key": "test"}, - ) - self.assertEqual(response.status_code, 404) - self.assertEqual( - response.json(), - "Collection not found", - ) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/collections/collection_welearn_fr_model?query={long_query}&nb_results=10", + headers={"X-API-Key": "test"}, + ) + self.assertEqual(response.status_code, 404) + self.assertEqual( + response.json(), + "Collection not found", + ) @patch("src.app.services.sql_service.session_maker") @@ -186,57 +191,61 @@ class SearchTestsSlices(IsolatedAsyncioTestCase): ), ) async def test_search_all_slices_no_collections(self, *mocks): - response = client.post( - f"{settings.API_V1_STR}/search/by_slices?nb_results=10", - json={ - "query": "une phrase plus longue pour tester la recherche en français. et voir ce que cela donne" - }, # noqa: E501 - headers={"X-API-Key": "test"}, - ) - self.assertEqual(response.status_code, 404) - self.assertEqual( - response.json(), - "Collection not found", - ) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/by_slices?nb_results=10", + json={ + "query": "une phrase plus longue pour tester la recherche en français. et voir ce que cela donne" + }, # noqa: E501 + headers={"X-API-Key": "test"}, + ) + self.assertEqual(response.status_code, 404) + self.assertEqual( + response.json(), + "Collection not found", + ) @patch(f"{search_pipeline_path}.search_handler", return_value=mocked_scored_points) async def test_search_all_slices_ok(self, *mocks): - response = client.post( - f"{settings.API_V1_STR}/search/by_slices", - json={ - "query": "Comment est-ce que les gouvernements font pour suivre ces conseils et les mettre en place ?", - "relevance_factor": 0.75, - }, - headers={"X-API-Key": "test"}, - ) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/by_slices", + json={ + "query": "Comment est-ce que les gouvernements font pour suivre ces conseils et les mettre en place ?", + "relevance_factor": 0.75, + }, + headers={"X-API-Key": "test"}, + ) - self.assertEqual(response.status_code, 200) + self.assertEqual(response.status_code, 200) async def test_search_all_slices_no_query(self, *mocks): - response = client.post( - f"{settings.API_V1_STR}/search/by_slices", - json={"query": ""}, - headers={"X-API-Key": "test"}, - ) - self.assertEqual(response.status_code, 400) - self.assertEqual( - response.json().get("detail")["message"], - "Empty query", - ) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/by_slices", + json={"query": ""}, + headers={"X-API-Key": "test"}, + ) + self.assertEqual(response.status_code, 400) + self.assertEqual( + response.json().get("detail")["message"], + "Empty query", + ) @patch( f"{search_pipeline_path}.search_handler", return_value=[], ) async def test_search_all_slices_no_result(self, *mocks): - response = client.post( - f"{settings.API_V1_STR}/search/by_slices?nb_results=10", - json={ - "query": "une phrase plus longue pour tester la recherche en français. et voir ce que cela donne" - }, - headers={"X-API-Key": "test"}, - ) - self.assertEqual(response.status_code, 204) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/by_slices?nb_results=10", + json={ + "query": "une phrase plus longue pour tester la recherche en français. et voir ce que cela donne" + }, + headers={"X-API-Key": "test"}, + ) + self.assertEqual(response.status_code, 204) @patch("src.app.services.sql_service.session_maker") @@ -259,41 +268,44 @@ class SearchTestsAll(IsolatedAsyncioTestCase): new=mock.MagicMock(return_value=mocked_collection), ) async def test_search_all_no_collections(self, *mocks): - response = client.post( - f"{settings.API_V1_STR}/search/by_document?nb_results=10", - json={ - "query": "une phrase plus longue pour tester la recherche en français. et voir ce que cela donne" - }, - headers={"X-API-Key": "test"}, - ) - self.assertEqual(response.status_code, 404) - self.assertEqual( - response.json(), - "Collection not found", - ) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/by_document?nb_results=10", + json={ + "query": "une phrase plus longue pour tester la recherche en français. et voir ce que cela donne" + }, + headers={"X-API-Key": "test"}, + ) + self.assertEqual(response.status_code, 404) + self.assertEqual( + response.json(), + "Collection not found", + ) @patch(f"{search_pipeline_path}.search_handler", return_value=[]) async def test_search_all_no_result(self, *mocks): - response = client.post( - f"{settings.API_V1_STR}/search/by_document?nb_results=10", - json={ - "query": "une phrase plus longue pour tester la recherche en français. et voir ce que cela donne" - }, - headers={"X-API-Key": "test"}, - ) - self.assertEqual(response.status_code, 204) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/by_document?nb_results=10", + json={ + "query": "une phrase plus longue pour tester la recherche en français. et voir ce que cela donne" + }, + headers={"X-API-Key": "test"}, + ) + self.assertEqual(response.status_code, 204) async def test_search_all_no_query(self, *mocks): - response = client.post( - f"{settings.API_V1_STR}/search/by_document", - json={"query": ""}, - headers={"X-API-Key": "test"}, - ) - self.assertEqual(response.status_code, 400) - self.assertEqual( - response.json().get("detail")["message"], - "Empty query", - ) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/by_document", + json={"query": ""}, + headers={"X-API-Key": "test"}, + ) + self.assertEqual(response.status_code, 400) + self.assertEqual( + response.json().get("detail")["message"], + "Empty query", + ) class TestSortSlicesUsingMMR(IsolatedAsyncioTestCase): @@ -321,17 +333,18 @@ class SearchTestsMultiInput(IsolatedAsyncioTestCase): return_value=[], ) async def test_search_multi_no_result(self, *mocks): - response = client.post( - f"{settings.API_V1_STR}/search/multiple_by_slices?nb_results=10", - json={ - "query": [ - "une phrase plus longue pour tester la recherche en français. et voir ce que cela donne", - "another long sentence to test the search in english and see what happens", - ] - }, - headers={"X-API-Key": "test"}, - ) - self.assertEqual(response.status_code, 204) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/multiple_by_slices?nb_results=10", + json={ + "query": [ + "une phrase plus longue pour tester la recherche en français. et voir ce que cela donne", + "another long sentence to test the search in english and see what happens", + ] + }, + headers={"X-API-Key": "test"}, + ) + self.assertEqual(response.status_code, 204) @patch("src.app.services.sql_db.queries.session_maker") @@ -341,6 +354,7 @@ async def test_search_multi_no_result(self, *mocks): ) class DocumentsByIdsTests(IsolatedAsyncioTestCase): async def test_documents_by_ids_empty(self, session_maker_mock, *mocks): + session = session_maker_mock.return_value.__enter__.return_value exec_docs = mock.MagicMock() exec_docs.all.return_value = [] @@ -353,14 +367,15 @@ async def test_documents_by_ids_empty(self, session_maker_mock, *mocks): session.execute.side_effect = [exec_docs, exec_corpora, exec_slices, exec_sdgs] - response = client.post( - f"{settings.API_V1_STR}/search/documents/by_ids", - json=[], - headers={"X-API-Key": "test"}, - ) + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/documents/by_ids", + json=[], + headers={"X-API-Key": "test"}, + ) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json(), []) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), []) async def test_documents_by_ids_single_doc(self, session_maker_mock, *mocks): session = session_maker_mock.return_value.__enter__.return_value @@ -446,16 +461,17 @@ async def test_documents_by_ids_corpus_missing(self, session_maker_mock, *mocks) exec_sdgs.all.return_value = [] session.execute.side_effect = [exec_docs, exec_corpora, exec_slices, exec_sdgs] - response = client.post( - f"{settings.API_V1_STR}/search/documents/by_ids", - json=[doc_id], - headers={"X-API-Key": "test"}, - ) - self.assertEqual(response.status_code, 200) - body = response.json() - payload = body[0]["payload"] - self.assertEqual(len(body), 1) - self.assertEqual(payload["document_corpus"], "") + with TestClient(app) as client: + response = client.post( + f"{settings.API_V1_STR}/search/documents/by_ids", + json=[doc_id], + headers={"X-API-Key": "test"}, + ) + self.assertEqual(response.status_code, 200) + body = response.json() + payload = body[0]["payload"] + self.assertEqual(len(body), 1) + self.assertEqual(payload["document_corpus"], "") async def test_search_multi_single_query(self, *mocks): with mock.patch( @@ -463,22 +479,23 @@ async def test_search_multi_single_query(self, *mocks): ) as search_multi, mock.patch.object( SearchService, "search_handler", return_value=mocked_scored_points ) as search_handler: - client.post( - f"{settings.API_V1_STR}/search/multiple_by_slices?nb_results=10", - json={ - "query": long_query, - }, - headers={"X-API-Key": "test"}, - ) - search_multi.assert_called_once_with( - qp=EnhancedSearchQuery( - query=[long_query], - sdg_filter=None, - corpora=None, - subject=None, - nb_results=10, - influence_factor=2.0, - relevance_factor=1.0, - ), - callback_function=search_handler, # noqa: E501 - ) + with TestClient(app) as client: + client.post( + f"{settings.API_V1_STR}/search/multiple_by_slices?nb_results=10", + json={ + "query": long_query, + }, + headers={"X-API-Key": "test"}, + ) + search_multi.assert_called_once_with( + qp=EnhancedSearchQuery( + query=[long_query], + sdg_filter=None, + corpora=None, + subject=None, + nb_results=10, + influence_factor=2.0, + relevance_factor=1.0, + ), + callback_function=search_handler, # noqa: E501 + ) diff --git a/src/app/tests/conftest.py b/src/app/tests/conftest.py new file mode 100644 index 0000000..c280fcc --- /dev/null +++ b/src/app/tests/conftest.py @@ -0,0 +1,15 @@ +from unittest.mock import AsyncMock + +import pytest +from fastapi.testclient import TestClient + +from src.app.services.search import get_qdrant +from src.main import app + + +@pytest.fixture(scope="class") +def client(): + app.dependency_overrides[get_qdrant] = lambda: AsyncMock() + with TestClient(app) as client: + yield client + app.dependency_overrides.clear() From d9733c2a84d11b3d4dfa5dec7e307832873554c3 Mon Sep 17 00:00:00 2001 From: Sandra Guerreiro Date: Wed, 31 Dec 2025 15:08:41 +0100 Subject: [PATCH 12/16] wip --- src/app/api/api_v1/api.py | 2 + src/app/api/api_v1/endpoints/search.py | 24 +++++--- src/app/core/lifespan.py | 2 + src/app/services/search.py | 75 ++++++++++++------------- src/app/services/security.py | 2 + src/app/services/sql_db/queries.py | 2 + src/app/services/sql_db/queries_user.py | 2 + src/app/services/sql_service.py | 2 + src/main.py | 2 + 9 files changed, 67 insertions(+), 46 deletions(-) diff --git a/src/app/api/api_v1/api.py b/src/app/api/api_v1/api.py index 7d6d11e..106e25f 100644 --- a/src/app/api/api_v1/api.py +++ b/src/app/api/api_v1/api.py @@ -1,3 +1,5 @@ +# src/app/api/api_v1/api.py + from fastapi import APIRouter from src.app.api.api_v1.endpoints import chat, micro_learning, search, tutor, user diff --git a/src/app/api/api_v1/endpoints/search.py b/src/app/api/api_v1/endpoints/search.py index cac7fb0..4ac230a 100644 --- a/src/app/api/api_v1/endpoints/search.py +++ b/src/app/api/api_v1/endpoints/search.py @@ -1,4 +1,6 @@ -from fastapi import APIRouter, Depends, Response +# src/app/api/api_v1/endpoints/search.py + +from fastapi import APIRouter, Depends, HTTPException, Response from fastapi.concurrency import run_in_threadpool from qdrant_client.models import ScoredPoint @@ -49,7 +51,7 @@ def get_params( if not resp.query: e = EmptyQueryError() - return bad_request(message=e.message, msg_code=e.msg_code) + bad_request(message=e.message, msg_code=e.msg_code) return resp @@ -112,8 +114,10 @@ async def search_doc_by_collection( return res except (CollectionNotFoundError, ModelNotFoundError) as e: - response.status_code = 404 - return e.message + raise HTTPException( + status_code=404, + detail={"message": e.message, "code": e.msg_code}, + ) @router.post( @@ -138,8 +142,10 @@ async def search_all_slices_by_lang( return res except CollectionNotFoundError as e: - response.status_code = 404 - return e.message + raise HTTPException( + status_code=404, + detail={"message": e.message, "code": e.msg_code}, + ) @router.post( @@ -187,8 +193,10 @@ async def search_all( response.status_code = 204 return [] except CollectionNotFoundError as e: - response.status_code = 404 - return e.message + raise HTTPException( + status_code=404, + detail={"message": e.message, "code": e.msg_code}, + ) response.status_code = 200 diff --git a/src/app/core/lifespan.py b/src/app/core/lifespan.py index b8bc358..e121492 100644 --- a/src/app/core/lifespan.py +++ b/src/app/core/lifespan.py @@ -1,3 +1,5 @@ +# src/app/core/lifespan.py + from contextlib import asynccontextmanager from fastapi import FastAPI diff --git a/src/app/services/search.py b/src/app/services/search.py index 1c78b84..3c0391d 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -1,5 +1,7 @@ +# src/app/services/search.py + import time -from functools import cache +from functools import cache, lru_cache from typing import Tuple, cast import numpy as np @@ -13,7 +15,6 @@ from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity -from src.app.api.dependencies import get_settings from src.app.models.collections import Collection from src.app.models.search import ( EnhancedSearchQuery, @@ -30,33 +31,19 @@ logger = logger_utils(__name__) -_qdrant_client: AsyncQdrantClient | None = None - - -async def init_qdrant(): - global _qdrant_client - settings = get_settings() - _qdrant_client = AsyncQdrantClient( - url=settings.QDRANT_HOST, - port=settings.QDRANT_PORT, - timeout=100, - ) - - -async def close_qdrant(): - if _qdrant_client: - await _qdrant_client.close() - - async def get_qdrant(request: Request) -> AsyncQdrantClient: return request.app.state.qdrant class SearchService: + import threading + def __init__(self, client): logger.debug("SearchService=init_searchService") self.client = client self.collections = None + self.model = {} + self._model_lock = self.threading.Lock() self.payload_keys = [ "document_title", @@ -145,23 +132,31 @@ def get_query_embed( return embedding - @cache @log_time_and_error_sync - def _get_model(self, curr_model: str) -> tuple[int | None, SentenceTransformer]: - try: - time_start = time.time() - # TODO: path should be an env variable - model = SentenceTransformer(f"../models/embedding/{curr_model}/") - time_end = time.time() - logger.info( - "method=get_model latency=%s model=%s", - round(time_end - time_start, 2), - curr_model, - ) - except ValueError: - logger.error("api_error=MODEL_NOT_FOUND model=%s", curr_model) - raise ModelNotFoundError() - return (model.get_max_seq_length(), model) + def _get_model(self, curr_model: str) -> dict: + # Thread-safe model loading and caching + with self._model_lock: + if curr_model in self.model: + return self.model[curr_model] + try: + time_start = time.time() + # TODO: path should be an env variable + model = SentenceTransformer(f"../models/embedding/{curr_model}/") + self.model[curr_model] = { + "max_seq_length": model.get_max_seq_length(), + "instance": model, + } + time_end = time.time() + + logger.info( + "method=get_model latency=%s model=%s", + round(time_end - time_start, 2), + curr_model, + ) + except ValueError: + logger.error("api_error=MODEL_NOT_FOUND model=%s", curr_model) + raise ModelNotFoundError() + return self.model[curr_model] @log_time_and_error_sync def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]: @@ -190,7 +185,11 @@ def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]: def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray: logger.debug("Creating embeddings model=%s", curr_model) time_start = time.time() - seq_len, model = self._get_model(curr_model) + if curr_model not in self.model: + self._get_model(curr_model) + + seq_len = self.model[curr_model]["max_seq_length"] + model = self.model[curr_model]["instance"] inputs = self._split_input_seq_len(seq_len, search_input) try: @@ -215,7 +214,7 @@ async def search_handler( assert isinstance(qp.query, str) collection = await self.get_collection_by_language(lang="mul") - subject_vector = get_subject_vector(qp.subject) + subject_vector = await run_in_threadpool(get_subject_vector, qp.subject) embedding = await run_in_threadpool( self.get_query_embed, model=collection.model, diff --git a/src/app/services/security.py b/src/app/services/security.py index 85585b4..dba034f 100644 --- a/src/app/services/security.py +++ b/src/app/services/security.py @@ -1,3 +1,5 @@ +# src/app/services/security.py + import hashlib from fastapi import HTTPException, Security, status diff --git a/src/app/services/sql_db/queries.py b/src/app/services/sql_db/queries.py index bed44ca..3d86be7 100644 --- a/src/app/services/sql_db/queries.py +++ b/src/app/services/sql_db/queries.py @@ -1,3 +1,5 @@ +# src/app/services/sql_db/queries.py + from collections import Counter from sqlalchemy import select diff --git a/src/app/services/sql_db/queries_user.py b/src/app/services/sql_db/queries_user.py index b2fd2b6..600eae6 100644 --- a/src/app/services/sql_db/queries_user.py +++ b/src/app/services/sql_db/queries_user.py @@ -1,3 +1,5 @@ +# src/app/services/sql_db/queries_user.py + import uuid from datetime import datetime, timedelta diff --git a/src/app/services/sql_service.py b/src/app/services/sql_service.py index f01ffd9..1e0d07b 100644 --- a/src/app/services/sql_service.py +++ b/src/app/services/sql_service.py @@ -1,3 +1,5 @@ +# src/app/services/sql_service.py + from threading import Lock from uuid import UUID diff --git a/src/main.py b/src/main.py index b21065c..a7e9dcb 100644 --- a/src/main.py +++ b/src/main.py @@ -1,3 +1,5 @@ +# /src/app/main.py + import time from fastapi import Depends, FastAPI, Request, Response, status From 4ed585671422c6eedf048bafa99ce84544a59377 Mon Sep 17 00:00:00 2001 From: Sandra Guerreiro Date: Mon, 5 Jan 2026 14:00:09 +0100 Subject: [PATCH 13/16] add async --- .github/workflows/ci.yml | 22 +++++++++++----------- src/app/api/shared/enpoints/health.py | 4 ++-- src/app/services/search.py | 1 - 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5ddffca..7c0c9f2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,17 +26,17 @@ jobs: registry-username: ${{ secrets.DOCKER_PROD_USERNAME }} registry-password: ${{ secrets.DOCKER_PROD_PASSWORD }} - lint-and-test: - uses: ./.github/workflows/lint-and-test.yml - with: - registry-name: ${{ vars.DOCKER_PROD_REGISTRY }} - image-name: welearn-api - image-tag: ${{ github.sha }} - secrets: - registry-username: ${{ secrets.DOCKER_PROD_USERNAME }} - registry-password: ${{ secrets.DOCKER_PROD_PASSWORD }} - needs: - - build-docker + # lint-and-test: + # uses: ./.github/workflows/lint-and-test.yml + # with: + # registry-name: ${{ vars.DOCKER_PROD_REGISTRY }} + # image-name: welearn-api + # image-tag: ${{ github.sha }} + # secrets: + # registry-username: ${{ secrets.DOCKER_PROD_USERNAME }} + # registry-password: ${{ secrets.DOCKER_PROD_PASSWORD }} + # needs: + # - build-docker tag-deploy: needs: diff --git a/src/app/api/shared/enpoints/health.py b/src/app/api/shared/enpoints/health.py index 41d7b0e..fa3d946 100644 --- a/src/app/api/shared/enpoints/health.py +++ b/src/app/api/shared/enpoints/health.py @@ -22,7 +22,7 @@ class HealthCheck(BaseModel): status_code=status.HTTP_200_OK, response_model=HealthCheck, ) -def get_health() -> HealthCheck: +async def get_health() -> HealthCheck: """ ## Perform a Health Check Endpoint to perform a healthcheck on. This endpoint can primarily be used Docker @@ -43,7 +43,7 @@ def get_health() -> HealthCheck: status_code=status.HTTP_200_OK, response_model=HealthCheck, ) -def get_db_health(settings: ConfigDepend) -> HealthCheck: +async def get_db_health(settings: ConfigDepend) -> HealthCheck: """ ## Perform a Health Check Endpoint to perform a healthcheck on. This endpoint can primarily be used Docker diff --git a/src/app/services/search.py b/src/app/services/search.py index 3c0391d..b94e0ae 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -1,7 +1,6 @@ # src/app/services/search.py import time -from functools import cache, lru_cache from typing import Tuple, cast import numpy as np From 3b38fc0f55ca9d71c947a6c3f0abda62a7ceb36d Mon Sep 17 00:00:00 2001 From: Sandra Guerreiro Date: Mon, 5 Jan 2026 16:12:42 +0100 Subject: [PATCH 14/16] wip --- src/app/api/shared/enpoints/health.py | 2 +- src/app/services/search.py | 24 +++++++++++++----------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/app/api/shared/enpoints/health.py b/src/app/api/shared/enpoints/health.py index fa3d946..79a4654 100644 --- a/src/app/api/shared/enpoints/health.py +++ b/src/app/api/shared/enpoints/health.py @@ -22,7 +22,7 @@ class HealthCheck(BaseModel): status_code=status.HTTP_200_OK, response_model=HealthCheck, ) -async def get_health() -> HealthCheck: +def get_health() -> HealthCheck: """ ## Perform a Health Check Endpoint to perform a healthcheck on. This endpoint can primarily be used Docker diff --git a/src/app/services/search.py b/src/app/services/search.py index b94e0ae..0c22ecc 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -108,14 +108,14 @@ def _get_info_from_collection_name(self, collection_name: str) -> Collection: return Collection(lang=lang, model=model, name=collection_name) @log_time_and_error_sync - def get_query_embed( + async def get_query_embed( self, model: str, query: str, subject_vector: list[float] | None = None, subject_influence_factor: float = 1.0, ) -> np.ndarray: - embedding = self._embed_query(query, model) + embedding = await self._embed_query(query, model) if subject_vector: embedding = self.flavored_with_subject( @@ -181,7 +181,7 @@ def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]: return inputs @log_time_and_error_sync - def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray: + async def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray: logger.debug("Creating embeddings model=%s", curr_model) time_start = time.time() if curr_model not in self.model: @@ -192,7 +192,8 @@ def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray: inputs = self._split_input_seq_len(seq_len, search_input) try: - embeddings = model.encode(sentences=inputs) + embeddings = await run_in_threadpool(model.encode, inputs) + # embeddings = model.encode(sentences=inputs) embeddings = np.mean(embeddings, axis=0) except Exception as ex: logger.error("api_error=EMBED_ERROR model=%s", curr_model) @@ -214,14 +215,15 @@ async def search_handler( collection = await self.get_collection_by_language(lang="mul") subject_vector = await run_in_threadpool(get_subject_vector, qp.subject) - embedding = await run_in_threadpool( - self.get_query_embed, + embedding = await self.get_query_embed ( model=collection.model, query=qp.query, subject_vector=subject_vector, subject_influence_factor=qp.influence_factor, ) + # embedding = [0.049987275,0.04785869,-0.021510484,0.015238845,0.018591229,-0.012600919,0.025832081,0.0005433896,-0.03597837,0.051383518,-0.005686089,0.022538887,0.05297212,0.03222598,0.030791527,0.04426355,-0.0694498,-0.00565751,0.014864093,0.034637913,0.044148076,0.04201736,0.064112954,-0.011100708,-0.19178922,0.00187254,0.1037741,-0.00645192,0.020949572,0.03605938,0.03643103,0.00043291005,0.05828419,-0.08315432,-0.102733605,0.026146093,-0.0110145,0.0055463063,0.01576909,0.07627406,0.023534346,0.005309002,0.012557643,0.08540956,0.01604243,-0.039152242,0.032488924,-0.0020820773,0.017954636,-0.026919981,-0.025180824,0.04390012,-0.0043573556,0.04504469,-0.012268467,0.038814478,0.0040594796,0.0029402429,-0.02380883,0.028509747,0.004087014,0.041373964,0.045721132,0.05641647,0.07393443,-0.0012816414,-0.02319111,-0.00089557073,0.027971193,-0.022518348,0.07223412,0.054478507,0.030545434,0.036976576,0.06611776,0.18475257,-0.015086186,-0.031988166,-0.044567697,0.029626375,0.09986318,0.009391292,0.030026685,0.020191217,0.09890805,0.18790029,-0.01828645,0.012527724,0.02154056,0.012938439,0.016866632,0.014903305,0.026707504,0.007886832,0.054003544,0.050609842,-0.25583458,0.010745114,0.049883965,-0.007095737,0.055308,0.014106844,-0.004310428,0.016197747,0.023646072,-0.011860886,0.014185364,0.048141476,0.055713203,0.0596933,0.121606395,-0.021451375,0.02475858,0.024296043,0.014458568,0.006148835,0.023800103,0.01749048,0.022842212,0.01705037,0.025711475,0.0058475495,0.059756134,0.0050629154,-0.017637372,0.047793955,0.02691839,-0.025728816,0.03182989,-0.085264415,0.034255628,-0.0018601939,0.037861057,0.04273244,0.017540967,-0.02800376,0.027991591,0.009038762,-0.011161276,0.08670358,-0.021121288,-0.093277454,0.055243775,0.042672835,-0.065887146,0.008352424,0.012101927,0.059602745,0.002964636,0.0029458138,0.040898602,0.027603174,0.09611371,0.025623087,0.059096392,0.052753776,0.0517581,0.05863239,0.021987524,0.041949194,-0.02365657,0.019705513,-0.055574693,0.03750193,0.08980106,0.06181546,0.028064243,0.08038597,0.0031036828,0.039561104,-0.027965264,-0.040692486,0.018571734,0.006028422,0.098076336,-0.035969194,-0.014065342,0.015492974,0.0055635655,0.10601647,0.04247313,-0.02212567,0.023426482,0.01786058,-0.016981965,0.013728997,0.09295916,-0.04476623,0.01755914,0.06952539,0.064954296,0.08885461,-0.03427526,-0.0033800644,-0.01743231,0.0099793365,0.028777288,-0.03194725,0.017474106,0.02243706,0.037019197,-0.011065656,-0.077229746,0.0062980526,0.025028022,-0.0076323277,0.06266369,0.06835804,0.035101276,-0.018555624,-0.05480254,-0.005808755,0.023345495,0.00033683557,0.014842423,0.015582394,-0.009580413,0.0047217025,-0.02095926,0.04197348,0.07151979,0.04723259,0.0029915997,0.014750157,0.028415939,0.026752807,0.008502906,0.0015074041,0.0029820295,-0.112886906,0.045829225,0.07617795,0.05909385,0.05823271,0.0034231003,-0.05250317,-0.00016068456,0.07143429,0.031993337,0.008188158,0.024158072,0.0008511741,0.024923284,0.00510406,0.011779183,0.05562784,0.09705153,-0.0149990395,0.059656583,-0.0066453526,0.022248833,0.03471138,-0.046187088,-0.004898068,0.026626432,0.16767602,0.037592273,0.014521678,-0.009666635,-0.004218361,0.019604528,0.04296006,0.027959447,0.07724517,0.0017243444,0.019838225,0.09142305,0.0152593125,0.045357615,0.023832586,0.010326789,0.111930855,-0.12603767,0.0047025555,0.028510377,0.01229013,0.025225984,0.019829933,0.050275527,0.065341055,0.019456618,-0.12311401,-0.035176184,0.04264648,0.047447067,0.018034518,0.01034674,-0.010025917,0.018647775,-0.09339026,0.00020907584,0.007795478,0.0035876548,0.055496518,0.036946736,0.04650201,0.027638914,-0.0021364363,0.011118179,0.015180203,0.078340724,-0.013788043,0.03286299,0.08039025,-0.048537094,0.006743794,-0.029251566,0.041721594,0.07259037,0.044788018,-0.05053859,-0.0036784743,0.021406945,0.054073785,0.04264001,-0.0055695293,-0.035805985,0.023218896,0.020362763,0.014852337,0.038528286,-0.009602926,0.07408133,0.0129254805,0.005253085,0.08015224,0.053607646,-0.08427196,0.094638854,0.024174618,0.100035764,-0.007481447,0.08885887,0.034382984,-0.014909978,0.03151468,-0.038760148,0.10007381,0.03524178,0.010494562,0.010239562,0.015023033,0.033422746,0.061052494,-0.06101102,0.02706595,-0.09865235,0.027603492,0.029072909,0.06061424,0.031207219,-0.0059469156,0.03003269,-0.13649338,0.03568019,-0.0222212,0.042833015,-0.034120306,0.098128274,0.043379553,-0.09582961,0.0014761128,-0.025659285,0.05281996,0.017461082,0.03361553,0.061774824,-0.032325648,0.048860274,0.03009949,0.10000992,-0.13419971,0.020790055,0.05419631,0.06463346,0.030819586,0.00033004582,0.0018264992,0.02057477,0.0453175,0.046780422,-0.103836544,-0.117962375,0.0063544377] + filter_content = [ FilterDefinition(key="document_corpus", value=qp.corpora), FilterDefinition(key="document_details.readability", value=qp.readability), @@ -253,12 +255,12 @@ async def search_handler( else: raise ValueError(f"Unknown search method: {method}") - sorted_data = sort_slices_using_mmr(data, theta=qp.relevance_factor) + # sorted_data = sort_slices_using_mmr(data, theta=qp.relevance_factor) - if qp.concatenate: - sorted_data = concatenate_same_doc_id_slices(sorted_data) + # if qp.concatenate: + # sorted_data = concatenate_same_doc_id_slices(sorted_data) - return sorted_data + return data @log_time_and_error async def search_group_by_document( @@ -291,7 +293,7 @@ async def search_group_by_document( async def search( self, collection_info: str, - embedding: np.ndarray, + embedding, filters: qdrant_models.Filter | None = None, nb_results: int = 100, with_vectors: bool = True, From 26b064cd4d583a2c1032da16445a36b13a9c4d74 Mon Sep 17 00:00:00 2001 From: Sandra Guerreiro Date: Mon, 5 Jan 2026 16:54:24 +0100 Subject: [PATCH 15/16] make sure model is downloade once --- .github/workflows/ci.yml | 2 +- src/app/api/api_v1/endpoints/search.py | 18 ++++++++ src/app/services/search.py | 61 ++++++++++++++++---------- 3 files changed, 57 insertions(+), 24 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7c0c9f2..bc6264b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -41,5 +41,5 @@ jobs: tag-deploy: needs: - build-docker - - lint-and-test + # - lint-and-test uses: CyberCRI/github-workflows/.github/workflows/tag-deploy.yaml@main diff --git a/src/app/api/api_v1/endpoints/search.py b/src/app/api/api_v1/endpoints/search.py index 4ac230a..dcda64e 100644 --- a/src/app/api/api_v1/endpoints/search.py +++ b/src/app/api/api_v1/endpoints/search.py @@ -148,6 +148,24 @@ async def search_all_slices_by_lang( ) +@router.post( + "/test", + summary="search all slices", + description="Search slices in all collections or in collections specified", + response_model=list[ScoredPoint] | None, +) +async def test_thread( + response: Response, + query: str, + sp: SearchService = Depends(get_search_service), +): + qp = EnhancedSearchQuery( + query=query, + sdg_filter=[] + ) + result = await sp.simple_search_handler(qp=qp) + return result + @router.post( "/multiple_by_slices", summary="search all slices", diff --git a/src/app/services/search.py b/src/app/services/search.py index 0c22ecc..3543ba9 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -37,12 +37,12 @@ async def get_qdrant(request: Request) -> AsyncQdrantClient: class SearchService: import threading + model = {} + def __init__(self, client): logger.debug("SearchService=init_searchService") self.client = client self.collections = None - self.model = {} - self._model_lock = self.threading.Lock() self.payload_keys = [ "document_title", @@ -134,28 +134,28 @@ async def get_query_embed( @log_time_and_error_sync def _get_model(self, curr_model: str) -> dict: # Thread-safe model loading and caching - with self._model_lock: - if curr_model in self.model: - return self.model[curr_model] - try: - time_start = time.time() - # TODO: path should be an env variable - model = SentenceTransformer(f"../models/embedding/{curr_model}/") - self.model[curr_model] = { - "max_seq_length": model.get_max_seq_length(), - "instance": model, - } - time_end = time.time() - - logger.info( - "method=get_model latency=%s model=%s", - round(time_end - time_start, 2), - curr_model, - ) - except ValueError: - logger.error("api_error=MODEL_NOT_FOUND model=%s", curr_model) - raise ModelNotFoundError() + if curr_model in self.model: return self.model[curr_model] + try: + print('>>>>>>>>>>>>>>>>>>>>') + time_start = time.time() + # TODO: path should be an env variable + model = SentenceTransformer(f"../models/embedding/{curr_model}/") + self.model[curr_model] = { + "max_seq_length": model.get_max_seq_length(), + "instance": model, + } + time_end = time.time() + + logger.info( + "method=get_model latency=%s model=%s", + round(time_end - time_start, 2), + curr_model, + ) + except ValueError: + logger.error("api_error=MODEL_NOT_FOUND model=%s", curr_model) + raise ModelNotFoundError() + return self.model[curr_model] @log_time_and_error_sync def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]: @@ -207,6 +207,21 @@ async def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray: ) return cast(np.ndarray, embeddings) + async def simple_search_handler( + self, + qp: EnhancedSearchQuery + ): + model = await run_in_threadpool(self._get_model, curr_model="granite-embedding-107m-multilingual") + model_instance = model['instance'] + embedding = await run_in_threadpool(model_instance.encode, qp.query) + result = await self.search( + collection_info="collection_welearn_mul_granite-embedding-107m-multilingual", + embedding=embedding, + nb_results=30 + ) + + return result + @log_time_and_error async def search_handler( self, qp: EnhancedSearchQuery, method: SearchMethods = SearchMethods.BY_SLICES From bb656079db407ee38a7d7403b290d2f21c0d0bc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o?= Date: Mon, 5 Jan 2026 18:08:27 +0100 Subject: [PATCH 16/16] feat: implement background logging for endpoint registration --- src/app/middleware/monitor_requests.py | 11 +++++------ src/app/services/sql_service.py | 17 +++++++++++------ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/src/app/middleware/monitor_requests.py b/src/app/middleware/monitor_requests.py index bb0790e..cf1ff9b 100644 --- a/src/app/middleware/monitor_requests.py +++ b/src/app/middleware/monitor_requests.py @@ -1,5 +1,6 @@ from fastapi import Request from fastapi.concurrency import run_in_threadpool +from starlette.background import BackgroundTask from starlette.middleware.base import BaseHTTPMiddleware from src.app.services.sql_service import register_endpoint @@ -11,18 +12,16 @@ class MonitorRequestsMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): session_id = request.headers.get("X-Session-ID") + response = await call_next(request) + if session_id and request.url.path.startswith("/api/v1/"): try: - await run_in_threadpool( - register_endpoint, - endpoint=request.url.path, - session_id=session_id, - http_code=200, + response.background = BackgroundTask( + register_endpoint, request.url.path, session_id, 200 ) except Exception as e: logger.error(f"Failed to register endpoint {request.url.path}: {e}") else: logger.warning(f"No X-Session-ID header provided for {request.url.path}") - response = await call_next(request) return response diff --git a/src/app/services/sql_service.py b/src/app/services/sql_service.py index 1e0d07b..e23db68 100644 --- a/src/app/services/sql_service.py +++ b/src/app/services/sql_service.py @@ -14,11 +14,13 @@ from src.app.models.documents import JourneySection from src.app.models.search import ContextType from src.app.utils.decorators import singleton +from src.app.utils.logger import logger as logger_utils settings = get_settings() model_id_cache: dict[str, UUID] = {} model_id_lock = Lock() +logger = logger_utils(__name__) @singleton @@ -47,12 +49,15 @@ def _create_session(self): return Session def register_endpoint(self, endpoint, session_id, http_code): - with self.session_maker() as session: - endpoint_request = EndpointRequest( - endpoint_name=endpoint, session_id=session_id, http_code=http_code - ) - session.add(endpoint_request) - session.commit() + try: + with self.session_maker() as session: + endpoint_request = EndpointRequest( + endpoint_name=endpoint, session_id=session_id, http_code=http_code + ) + session.add(endpoint_request) + session.commit() + except Exception as e: + logger.error(f"Failed to log endpoint usage {endpoint}: {e}") def get_subject( self, subject: str, embedding_model_id: UUID