diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f229b320..efc5f048 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,7 +20,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] + python-version: ['3.10', '3.11', '3.12', '3.13'] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} diff --git a/agave/chalice/rest_api.py b/agave/chalice/rest_api.py index 0465a80b..8c9edbcb 100644 --- a/agave/chalice/rest_api.py +++ b/agave/chalice/rest_api.py @@ -15,6 +15,11 @@ from pydantic import BaseModel, ValidationError from ..core.blueprints.decorators import copy_attributes +from ..core.query_params import ( + EmptyQueryMapping, + query_params_for_url, + validate_query_params, +) class RestApiBlueprint(Blueprint): @@ -238,9 +243,13 @@ def query(): next_page = } """ - params = self.current_request.query_params or dict() + query_mapping = ( + self.current_request.query_params or EmptyQueryMapping() + ) try: - query_params = cls.query_validator(**params) + query_params = validate_query_params( + query_mapping, cls.query_validator + ) except ValidationError as e: return Response(e.json(), status_code=400) @@ -296,11 +305,11 @@ def _all(query: QueryParams, filters: Q): if wants_more and has_more: query.created_before = item_dicts[-1]['created_at'] path = self.current_request.context['resourcePath'] - params = query.model_dump() + params = query_params_for_url(query) if self.user_id_filter_required(): - params.pop('user_id') + params.pop('user_id', None) if self.platform_id_filter_required(): - params.pop('platform_id') + params.pop('platform_id', None) next_page_uri = f'{path}?{urlencode(params)}' return dict(items=item_dicts, next_page_uri=next_page_uri) diff --git a/agave/core/filters.py b/agave/core/filters.py index f2d10669..ea419fe9 100644 --- a/agave/core/filters.py +++ b/agave/core/filters.py @@ -2,12 +2,22 @@ from mongoengine import Q +def _ids_filter_value(ids: str | list[str]) -> list[str]: + if isinstance(ids, str): + return [part.strip() for part in ids.split(',') if part.strip()] + return list(ids) + + def generic_query(query: QueryParams, excluded: list[str] = []) -> Q: filters = Q() if query.created_before: filters &= Q(created_at__lt=query.created_before) if query.created_after: filters &= Q(created_at__gt=query.created_after) + ids = getattr(query, 'ids', None) + if ids is not None: + id_list = _ids_filter_value(ids) + filters &= Q(id__in=id_list) if id_list else Q(id__in=[]) exclude_fields = { 'created_before', 'created_after', @@ -15,6 +25,7 @@ def generic_query(query: QueryParams, excluded: list[str] = []) -> Q: 'limit', 'page_size', 'key', + 'ids', *excluded, } fields = query.model_dump(exclude=exclude_fields) diff --git a/agave/core/query_params.py b/agave/core/query_params.py new file mode 100644 index 00000000..2b581498 --- /dev/null +++ b/agave/core/query_params.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from types import UnionType +from typing import Any, Iterator, TypeVar, Union, get_args, get_origin + +from pydantic import BaseModel + +ModelT = TypeVar('ModelT', bound=BaseModel) + + +def comma_separated_list(value: str | None) -> list[str]: + if not value: + return [] + return [part.strip() for part in value.split(',') if part.strip()] + + +def _is_list_annotation(annotation: Any) -> bool: + origin = get_origin(annotation) + if origin is list: + return True + if origin in (Union, UnionType): + return any( + _is_list_annotation(arg) + for arg in get_args(annotation) + if arg is not type(None) + ) + return False + + +def build_query_dict( + query_mapping: Any, model_cls: type[BaseModel] +) -> dict[str, Any]: + params: dict[str, Any] = {} + for name in query_mapping: + raw = query_mapping.get(name) + if name in model_cls.model_fields: + field = model_cls.model_fields[name] + if _is_list_annotation(field.annotation): + if raw is None: + continue + if isinstance(raw, str): + params[name] = comma_separated_list(raw) + else: + params[name] = list(raw) + else: + params[name] = raw + else: + params[name] = raw + return params + + +def validate_query_params( + query_mapping: Any, model_cls: type[ModelT] +) -> ModelT: + return model_cls(**build_query_dict(query_mapping, model_cls)) + + +def query_params_for_url(query: BaseModel) -> dict[str, Any]: + params = query.model_dump() + for name, field in type(query).model_fields.items(): + value = params.get(name) + if _is_list_annotation(field.annotation) and isinstance(value, list): + params[name] = ','.join(value) + return params + + +class EmptyQueryMapping: + def __contains__(self, key: str) -> bool: + return False + + def __iter__(self) -> Iterator[str]: + return iter(()) + + def get(self, key: str, default: Any = None) -> Any: + return default diff --git a/agave/fastapi/rest_api.py b/agave/fastapi/rest_api.py index 89b573a7..213fc833 100644 --- a/agave/fastapi/rest_api.py +++ b/agave/fastapi/rest_api.py @@ -23,6 +23,7 @@ from ..core.blueprints.decorators import copy_attributes from ..core.exc import NotFoundError, UnprocessableEntity +from ..core.query_params import query_params_for_url, validate_query_params SAMPLE_404 = { "summary": "Not found item", @@ -358,7 +359,9 @@ class QueryResponse(BaseModel): def validate_params(request: Request): try: - return cls.query_validator(**request.query_params) + return validate_query_params( + request.query_params, cls.query_validator + ) except ValidationError as e: raise UnprocessableEntity(e.json()) @@ -430,11 +433,11 @@ async def _all(query: QueryParams, filters: Q, resource_path: str): next_page_uri: Optional[str] = None if wants_more and has_more: query.created_before = item_dicts[-1]['created_at'] - params = query.model_dump() + params = query_params_for_url(query) if self.user_id_filter_required(): - params.pop('user_id') + params.pop('user_id', None) if self.platform_id_filter_required(): - params.pop('platform_id') + params.pop('platform_id', None) next_page_uri = f'{resource_path}?{urlencode(params)}' return dict(items=item_dicts, next_page_uri=next_page_uri) diff --git a/agave/version.py b/agave/version.py index c24ed73b..a54e32fc 100644 --- a/agave/version.py +++ b/agave/version.py @@ -1 +1 @@ -__version__ = '1.5.4' +__version__ = '1.5.4.dev3' diff --git a/requirements.txt b/requirements.txt index a32f0386..aa2ed858 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ boto3==1.35.74 types-boto3[sqs]==1.35.74 -cuenca-validations==2.1.3 +cuenca-validations==2.1.35.dev2 chalice==1.31.3 mongoengine==0.29.1 fastapi==0.115.11 diff --git a/setup.py b/setup.py index bd8803d0..2bc77786 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ packages=find_packages(), include_package_data=True, package_data=dict(agave=['py.typed']), - python_requires='>=3.9', + python_requires='>=3.10', install_requires=[ 'cuenca-validations>=2.1.0,<3.0.0', 'mongoengine>=0.29.0,<0.30.0', @@ -54,7 +54,6 @@ ], }, classifiers=[ - 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3.12', diff --git a/tests/blueprint/test_blueprint.py b/tests/blueprint/test_blueprint.py index 71a06b5d..6ecbf6ab 100644 --- a/tests/blueprint/test_blueprint.py +++ b/tests/blueprint/test_blueprint.py @@ -282,6 +282,43 @@ def test_query_count_resource( assert json_body['count'] == 1 +@pytest.mark.parametrize( + "client_fixture", ["fastapi_client", "chalice_client"] +) +@pytest.mark.usefixtures('accounts') +def test_query_with_comma_separated_ids_param( + client_fixture: str, + request: pytest.FixtureRequest, +) -> None: + client = request.getfixturevalue(client_fixture) + resp = client.get('/accounts?ids=US1,US2') + assert resp.status_code == 200 + assert 'items' in resp.json() + + +@pytest.mark.parametrize( + "client_fixture", ["fastapi_client", "chalice_client"] +) +@pytest.mark.usefixtures('accounts') +def test_query_pagination_preserves_comma_separated_ids( + client_fixture: str, + request: pytest.FixtureRequest, + accounts: list[Account], +) -> None: + client = request.getfixturevalue(client_fixture) + account_ids = [accounts[0].id, accounts[1].id] + ids_param = ','.join(account_ids) + resp = client.get(f'/accounts?ids={ids_param}&page_size=1&limit=10') + assert resp.status_code == 200 + json_body = resp.json() + next_page_uri = json_body['next_page_uri'] + assert next_page_uri is not None + assert f'ids={ids_param}' in next_page_uri.replace('%2C', ',') + + resp = client.get(next_page_uri) + assert resp.status_code == 200 + + @pytest.mark.parametrize( "client_fixture", ["fastapi_client", "chalice_client"] ) diff --git a/tests/core/test_query_params.py b/tests/core/test_query_params.py new file mode 100644 index 00000000..982b8fea --- /dev/null +++ b/tests/core/test_query_params.py @@ -0,0 +1,67 @@ +from typing import Optional + +import pytest +from pydantic import BaseModel, ConfigDict, ValidationError +from starlette.datastructures import QueryParams + +from agave.core.query_params import ( + EmptyQueryMapping, + comma_separated_list, + query_params_for_url, + validate_query_params, +) + + +class SampleQuery(BaseModel): + model_config = ConfigDict(extra='forbid') + + ids: Optional[list[str]] = None + name: Optional[str] = None + active: Optional[bool] = None + + +def test_comma_separated_list() -> None: + assert comma_separated_list('a,b') == ['a', 'b'] + assert comma_separated_list('a, b ,c') == ['a', 'b', 'c'] + assert comma_separated_list('US1') == ['US1'] + assert comma_separated_list('') == [] + assert comma_separated_list(None) == [] + + +def test_validate_query_params_comma_separated_ids() -> None: + query = QueryParams('ids=a,b&name=Frida') + validated = validate_query_params(query, SampleQuery) + assert validated.ids == ['a', 'b'] + assert validated.name == 'Frida' + + +def test_validate_query_params_scalar_field() -> None: + query = QueryParams('name=Frida') + validated = validate_query_params(query, SampleQuery) + assert validated.name == 'Frida' + assert validated.ids is None + + +def test_validate_query_params_rejects_unknown_fields() -> None: + query = QueryParams('wrong_param=value') + with pytest.raises(ValidationError): + validate_query_params(query, SampleQuery) + + +def test_empty_query_mapping() -> None: + mapping = EmptyQueryMapping() + assert 'x' not in mapping + assert mapping.get('x') is None + assert mapping.get('x', 'default') == 'default' + assert list(mapping) == [] + + +def test_validate_query_params_empty_mapping() -> None: + validated = validate_query_params(EmptyQueryMapping(), SampleQuery) + assert validated.ids is None + assert validated.name is None + + +def test_query_params_for_url_serializes_list_as_comma_separated() -> None: + query = SampleQuery(ids=['a', 'b'], name='Frida') + assert query_params_for_url(query)['ids'] == 'a,b'