Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backend/src/cms_backend/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from cms_backend.api.routes.http_errors import BadRequestError
from cms_backend.api.routes.staging import router as staging_router
from cms_backend.api.routes.titles import router as titles_router
from cms_backend.api.routes.warehouse import router as warehouse_router
from cms_backend.api.routes.zimfarm_notifications import (
router as zimfarm_notification_router,
)
Expand Down Expand Up @@ -74,6 +75,7 @@ def create_app(*, debug: bool = True):
main_router.include_router(router=auth_router)
main_router.include_router(router=account_router)
main_router.include_router(router=staging_router)
main_router.include_router(router=warehouse_router)

app.include_router(router=main_router)

Expand Down
94 changes: 86 additions & 8 deletions backend/src/cms_backend/api/routes/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,51 @@
import xxhash
from fastapi import APIRouter, Depends, Path, Query
from fastapi.responses import Response
from pydantic import AnyUrl, Field
from sqlalchemy.orm import Session as OrmSession

from cms_backend.api.routes.fields import LimitFieldMax200, SkipField
from cms_backend.api.routes.dependencies import require_permission
from cms_backend.api.routes.fields import LimitFieldMax200, NotEmptyString, SkipField
from cms_backend.api.routes.models import ListResponse, calculate_pagination_metadata
from cms_backend.api.routes.utils import build_library_xml
from cms_backend.db import gen_dbsession
from cms_backend.db.collection import create_collection as db_create_collection
from cms_backend.db.collection import (
get_collection,
create_collection_full_schema,
get_collection_by_name_or_none,
get_latest_books_for_collection,
)
from cms_backend.db.collection import get_collection as db_get_collection
from cms_backend.db.collection import (
get_collection_by_name as db_get_collection_by_name,
)
from cms_backend.db.collection import get_collections as db_get_collections
from cms_backend.db.collection import update_collection as db_update_collection
from cms_backend.db.exceptions import RecordDoesNotExistError
from cms_backend.schemas import BaseModel
from cms_backend.schemas.orms import CollectionLightSchema
from cms_backend.schemas.models import CollectionUpdateSchema
from cms_backend.schemas.orms import CollectionFullSchema, CollectionLightSchema
from cms_backend.utils import is_valid_uuid

router = APIRouter(prefix="/collections", tags=["collections"])


class CollectionsGetSchema(BaseModel):
skip: SkipField = 0
limit: LimitFieldMax200 = 20
name: NotEmptyString | None = None


@router.get("")
async def get_collections(
def get_collections(
params: Annotated[CollectionsGetSchema, Query()],
session: Annotated[OrmSession, Depends(gen_dbsession)],
) -> ListResponse[CollectionLightSchema]:
"""Get a list of collections"""

results = db_get_collections(session, skip=params.skip, limit=params.limit)
results = db_get_collections(
session, skip=params.skip, limit=params.limit, name=params.name
)

return ListResponse[CollectionLightSchema](
meta=calculate_pagination_metadata(
Expand All @@ -49,6 +62,71 @@ async def get_collections(
)


class CollectionCreateSchema(BaseModel):
name: NotEmptyString = Field(min_length=3)
warehouse_name: NotEmptyString = Field(min_length=3)
download_base_url: AnyUrl | None = None
view_base_url: AnyUrl | None = None


@router.post(
"",
dependencies=[Depends(require_permission(namespace="collection", name="create"))],
)
def create_collection(
session: Annotated[OrmSession, Depends(gen_dbsession)],
request: CollectionCreateSchema,
):
"""Create a collection"""
return create_collection_full_schema(
db_create_collection(
session,
name=request.name,
warehouse_name=request.warehouse_name,
download_base_url=(
str(request.download_base_url)
if request.download_base_url is not None
else None
),
view_base_url=(
str(request.view_base_url)
if request.view_base_url is not None
else None
),
)
)


@router.get("/{collection_id_or_name}")
def get_collection(
collection_id_or_name: Annotated[str, Path()],
session: Annotated[OrmSession, Depends(gen_dbsession)],
):
"""Get collection by collection ID (UUID) or name."""
if is_valid_uuid(collection_id_or_name):
collection = db_get_collection(session, UUID(collection_id_or_name))
else:
collection = db_get_collection_by_name(session, collection_id_or_name)
return create_collection_full_schema(collection)


@router.patch(
"/{collection_id_or_name}",
dependencies=[Depends(require_permission(namespace="collection", name="update"))],
)
def update_collection(
collection_id_or_name: Annotated[str, Path()],
collection_data: CollectionUpdateSchema,
session: OrmSession = Depends(gen_dbsession),
) -> CollectionFullSchema:
"""Update a collectio's data"""
return create_collection_full_schema(
db_update_collection(
session, collection_id=collection_id_or_name, request=collection_data
)
)


def _get_catalog_xml_content(
collection_id_or_name: str, session: OrmSession
) -> tuple[str, int]:
Expand All @@ -57,7 +135,7 @@ def _get_catalog_xml_content(
try:
collection_id = UUID(collection_id_or_name)
try:
collection = get_collection(session, collection_id)
collection = db_get_collection(session, collection_id)
except RecordDoesNotExistError:
pass
except ValueError:
Expand All @@ -78,7 +156,7 @@ def _get_catalog_xml_content(


@router.get("/{collection_id_or_name}/catalog.xml")
async def get_library_catalog_xml(
def get_library_catalog_xml(
collection_id_or_name: Annotated[str, Path()],
session: Annotated[OrmSession, Depends(gen_dbsession)],
):
Expand All @@ -95,7 +173,7 @@ async def get_library_catalog_xml(


@router.head("/{collection_id_or_name}/catalog.xml")
async def head_library_catalog_xml(
def head_library_catalog_xml(
collection_id_or_name: Annotated[str, Path()],
session: Annotated[OrmSession, Depends(gen_dbsession)],
):
Expand Down
15 changes: 1 addition & 14 deletions backend/src/cms_backend/api/routes/fields.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Annotated, Any

import regex
from pydantic import (
AfterValidator,
Field,
Expand All @@ -10,16 +9,6 @@
)


class GraphemeStr(str):
def __len__(self) -> int:
# Count the number of grapheme clusters
return len(regex.findall(r"\X", self))


def to_grapheme_str(value: str):
return GraphemeStr(value)


def no_null_char(value: str) -> str:
"""Validate that string value does not contains Unicode null character"""
if "\u0000" in value:
Expand All @@ -46,9 +35,7 @@ def not_empty(value: str) -> str:
return value.strip()


NoNullCharString = Annotated[
str, AfterValidator(no_null_char), AfterValidator(to_grapheme_str)
]
NoNullCharString = Annotated[str, AfterValidator(no_null_char)]

NotEmptyString = Annotated[NoNullCharString, AfterValidator(not_empty)]

Expand Down
2 changes: 2 additions & 0 deletions backend/src/cms_backend/api/routes/titles.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class TitlesGetSchema(BaseModel):
skip: SkipField = 0
limit: LimitFieldMax200 = 20
name: NotEmptyString | None = None
collection_name: NotEmptyString | None = None


class BaseTitleCreateUpdateSchema(BaseModel):
Expand Down Expand Up @@ -70,6 +71,7 @@ def get_titles(
skip=params.skip,
limit=params.limit,
name=params.name,
collection_name=params.collection_name,
)
return ListResponse[TitleLightSchema](
meta=calculate_pagination_metadata(
Expand Down
30 changes: 30 additions & 0 deletions backend/src/cms_backend/api/routes/warehouse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Annotated

from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session as OrmSession

from cms_backend.api.routes.dependencies import gen_dbsession
from cms_backend.api.routes.fields import LimitFieldMax200, SkipField
from cms_backend.api.routes.models import ListResponse, calculate_pagination_metadata
from cms_backend.db.warehouse import get_warehouses as db_get_warehouses

router = APIRouter(prefix="/warehouses", tags=["warehouses"])


@router.get("")
def get_warehouses(
session: OrmSession = Depends(gen_dbsession),
skip: Annotated[SkipField, Query()] = 0,
limit: Annotated[LimitFieldMax200, Query()] = 200,
):
"""Get a list of warehouses"""
result = db_get_warehouses(session, skip=skip, limit=limit)
return ListResponse(
items=result.records,
meta=calculate_pagination_metadata(
nb_records=result.nb_records,
skip=skip,
limit=limit,
page_size=len(result.records),
),
)
73 changes: 70 additions & 3 deletions backend/src/cms_backend/db/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import NamedTuple, cast
from uuid import UUID

from sqlalchemy import and_, func, select
from sqlalchemy import and_, func, select, update
from sqlalchemy.orm import Session as OrmSession

from cms_backend.db import count_from_stmt
Expand All @@ -14,7 +14,14 @@
CollectionTitle,
Title,
)
from cms_backend.schemas.orms import CollectionLightSchema, ListResult
from cms_backend.db.warehouse import get_warehouse
from cms_backend.schemas.models import CollectionUpdateSchema
from cms_backend.schemas.orms import (
CollectionFullSchema,
CollectionLightSchema,
ListResult,
)
from cms_backend.utils import is_valid_uuid


def get_collection_or_none(session: OrmSession, library_id: UUID) -> Collection | None:
Expand Down Expand Up @@ -125,7 +132,7 @@ def get_latest_books_for_collection(


def get_collections(
session: OrmSession, *, skip: int, limit: int
session: OrmSession, *, name: str | None = None, skip: int, limit: int
) -> ListResult[CollectionLightSchema]:
"""Get the list of collections."""
stmt = (
Expand All @@ -139,8 +146,19 @@ def get_collections(
[],
).label("paths"),
)
.where(
# If a client provides an argument i.e it is not None,
# we compare the corresponding model field against the argument,
# otherwise, we compare the argument to its default which translates
# to a SQL true i.e we don't filter based on this argument (a no-op).
(
Collection.name.ilike(f"%{name if name is not None else ''}%")
| (name is None)
),
)
.outerjoin(CollectionTitle)
.group_by(Collection.id)
.order_by(Collection.name.desc())
)

return ListResult[CollectionLightSchema](
Expand All @@ -156,3 +174,52 @@ def get_collections(
).all()
],
)


def create_collection_full_schema(collection: Collection) -> CollectionFullSchema:
return CollectionFullSchema(
id=collection.id,
warehouse=collection.warehouse.name,
name=collection.name,
download_base_url=collection.download_base_url,
view_base_url=collection.view_base_url,
)


def create_collection(
session: OrmSession,
*,
name: str,
warehouse_name: str,
download_base_url: str | None = None,
view_base_url: str | None,
) -> Collection:
warehouse = get_warehouse(session, warehouse_name)
collection = Collection(
name=name,
download_base_url=download_base_url,
view_base_url=view_base_url,
warehouse_id=warehouse.id,
)
session.add(collection)
session.flush()
return collection


def update_collection(
session: OrmSession, *, collection_id: str, request: CollectionUpdateSchema
) -> Collection:
"""Update a collection"""
if is_valid_uuid(collection_id):
collection = get_collection(session, UUID(collection_id))
else:
collection = get_collection_by_name(session, collection_id)
values = request.model_dump(exclude_unset=True)
if not values:
return collection

session.execute(
update(Collection).values(**values).where(Collection.id == collection.id)
)
session.refresh(collection)
return collection
2 changes: 1 addition & 1 deletion backend/src/cms_backend/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ class Warehouse(Base):
id: Mapped[UUID] = mapped_column(
init=False, primary_key=True, server_default=text("uuid_generate_v4()")
)
name: Mapped[str]
name: Mapped[str] = mapped_column(unique=True, index=True)

collections: Mapped[list["Collection"]] = relationship(
back_populates="warehouse",
Expand Down
Loading
Loading